{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"(Forecasting Hurricane Trajectories with State Space Models)=\n",
"# Forecasting Hurricane Trajectories with State Space Models\n",
"\n",
":::{post} June 15, 2025 \n",
":tags: state space model\n",
":category: intermediate, tutorial\n",
":author: Jonathan Dekermanjian\n",
":::"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Introduction\n",
"In this case study we are going to forecast the paths of hurricanes by applying several State Space Models (SSM). We will begin with a simple two-dimensional constant acceleration tracking model, where we only have one parameter to estimate. Subsequently, we will progressively add complexity and parameters as we develop up our model. \n",
"\n",
"As a brief introduction to SSMs, the general idea is that we define our system using two equations.
\n",
"The state equation and the observation equation. \n",
"\n",
"$$\n",
"x_{t+1} = T_{t}x_{t} + c_{t} + R_{t}\\epsilon_{t}\n",
"$$ \n",
"\n",
"$$\n",
"y_{t} = Z_{t}x_{t} + d_{t} + \\eta_{t}\n",
"$$\n",
"\n",
"The process/state covariance is given by $\\epsilon_{t} \\sim N(0, Q_{t})$ where $Q_{t}$ is the process/state innovations and the observation/measurement covariance is given by $\\eta_{t} \\sim N(0, H_{t})$ where $H_{t}$ describes the uncertainty in the measurement device or measurement procedure. \n",
"\n",
"We have the following matrices:\n",
"|State Equation variables|Definition|\n",
"| --- | --- |\n",
"| $T_{t}$ | The state transition matrix at time $t$ defines the kinematics of the process generating the series.\n",
"| $x_{t}$ | The state vector at time $t$ describes the current state of the system.\n",
"| $c_{t}$ | Intercept vector at time $t$ can include covariates/control/exogenous variables that are deterministically measured.\n",
"| $R_{t}$ | Selection matrix at time $t$ selects which process innovations are allowed to affect the next state.\n",
"| $\\epsilon_{t}$ | State/Process innovations at time $t$ defines the shocks influencing the changes in the state matrix.\n",
"\n",
"
\n",
"\n",
"|Observation Equation variables|Definition|\n",
"| --- | --- |\n",
"| $Z_{t}$ | The design matrix at time $t$ defines which states directly influence the observed variables.\n",
"| $x_{t}$ | The state vector at time $t$ describes the current state of the system.\n",
"| $d_{t}$ | Intercept vector at time $t$ can include covariates/control/exogenous variables that are deterministically measured.\n",
"| $\\eta_{t}$ | observation/measurement error at time $t$ defines the uncertainty in the observation.\n",
"\n",
"Estimation occurs in an iterative fashion (after an initialization step). In which the following steps are repeated:\n",
"1. Predict the next state vector $x_{t+1|t}$ and the next state/process covariance matrix $P_{t+1|t}$\n",
"2. Compute the Kalman gain\n",
"3. Estimate the current state vector and the current state/process covariance matrix\n",
"\n",
"Where $P_{t}$ is the uncertainty in the state predictions at time $t$.\n",
"\n",
"The general idea is that we make predictions based on our current state vector and state/process covariance (uncertainty) then we correct these predictions once we have our observations.\n",
"\n",
"The following equations define the process:\n",
"|Description|Equation|\n",
"| --- | --- |\n",
"|Predict the next state vector| $\\hat{x}_{t+1\\|t} = T_{t}\\hat{x}_{t\\|t}$ |\n",
"|Predict the next state/process covariance| $P_{t+1\\|t} = T_{t}P_{t+1\\|t}T_{t}^{T} + Q$ |\n",
"|Compute Kalman Gain | $K_{t} = P_{t\\|t-1}Z^{T}(ZP_{t\\|t-1}Z^{T} + H_{t})^{-1}$ |\n",
"|Estimate current state vector| $\\hat{x}_{t\\|t} = \\hat{x}_{t\\|t-1} + K_{t}(y_{t} - Z\\hat{x}_{t\\|t-1})$ |\n",
"|Estimate current state/process covariance| $P_{t\\|t} = (I - K_{t}Z_{t})P_{t\\|t-1}(I - K_{t}Z_{t})^{T} + K_{t}H_{t}K_{t}^{T}$ |\n",
"\n",
":::{note}\n",
"We wrote the equation for $P_{t\\|t}$ above using Joseph form, which is more numerically stable but also wordier. In different texts you may encounter this equation written in \"standard\" form.\n",
":::"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Imports"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"# Import libraries\n",
"import re\n",
"import warnings\n",
"\n",
"warnings.filterwarnings(\"ignore\", message=\"The RandomType SharedVariables\", category=UserWarning)\n",
"\n",
"import arviz as az\n",
"import arviz.labels as azl\n",
"import numpy as np\n",
"import pymc as pm\n",
"import pytensor.tensor as pt\n",
"import xarray as xr"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
":::{include} ../extra_installs.md\n",
":::"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"# Required Extra Dependencies\n",
"import plotly.graph_objects as go\n",
"import plotly.io as pio\n",
"import polars as pl\n",
"\n",
"from patsy import dmatrix\n",
"from pymc_extras.statespace.core.statespace import PyMCStateSpace\n",
"from pymc_extras.statespace.models.utilities import make_default_coords\n",
"from pymc_extras.statespace.utils.constants import (\n",
" ALL_STATE_AUX_DIM,\n",
" ALL_STATE_DIM,\n",
" TIME_DIM,\n",
")\n",
"\n",
"# make all plotly figures static\n",
"pio.renderers.default = \"svg\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Helper Functions"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def ellipse_covariance(covariance: np.ndarray) -> np.ndarray:\n",
" \"\"\"\n",
" Generates a 95% CI ellipse via a chi-square multivariate normal approximation.\n",
"\n",
" Parameters\n",
" ----------\n",
" covariance : ndarray\n",
" The estimated covariance matrix\n",
"\n",
" Returns\n",
" -------\n",
" ndarray\n",
" matrix of ellipse points\n",
" \"\"\"\n",
" evals, evects = np.linalg.eig(covariance)\n",
" largest_evect = evects[np.argmax(evals)]\n",
" largest_eval = np.max(evals)\n",
" smallest_eval = np.min(evals)\n",
" angle = np.arctan2(largest_evect[1], largest_evect[0])\n",
" if angle < 0:\n",
" angle = angle + 2 * np.pi\n",
" chisquare_val = 2.4477 # 95% CI MVN\n",
" theta_grid = np.linspace(0, 2 * np.pi)\n",
" phi = angle\n",
" a = chisquare_val * np.sqrt(largest_eval) # half-major axis scaled by k corresponding to 95% CI\n",
" b = chisquare_val * np.sqrt(smallest_eval) # half-minor axis scaled by k\n",
" ellipse_x_r = a * np.cos(theta_grid)\n",
" ellipse_y_r = b * np.sin(theta_grid)\n",
" R = np.array([[np.cos(phi), np.sin(phi)], [-np.sin(phi), np.cos(phi)]])\n",
" return np.array([ellipse_x_r, ellipse_y_r]).T @ R"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def plot_hurricane_path(\n",
" data: pl.DataFrame,\n",
" posterior_mean: np.ndarray,\n",
" predicted_covariance: np.ndarray,\n",
" uncertainty_index: int = 3,\n",
") -> go.Figure:\n",
" \"\"\"\n",
" Plots actual vs predicted Hurricane path.\n",
"\n",
" Parameters\n",
" ----------\n",
" data : DataFrame\n",
" dataframe containing the actual values\n",
" posterior_mean : ndarray\n",
" The posterior mean of the estimated distributions\n",
" predicted_covariance : ndarray\n",
" The predicted covariance matrices at each time point\n",
" uncertainty_index : int\n",
" When to start drawing the uncertainty on the map (due to huge uncertainty in the begining of the process)\n",
"\n",
" Returns\n",
" -------\n",
" Figure\n",
" Plotly Hurricane Figure\n",
" \"\"\"\n",
" fig = go.Figure()\n",
" for i in range(predicted_covariance.shape[0]):\n",
" if uncertainty_index and (\n",
" i < uncertainty_index\n",
" ): # The uncertainty can be quite large depending on how you initialze P0\n",
" continue\n",
" r_ellipse = ellipse_covariance(predicted_covariance[i, :2, :2])\n",
" means = posterior_mean[i]\n",
" fig.add_trace(\n",
" go.Scattermap(\n",
" lon=r_ellipse[:, 0].astype(float) + means[0].values,\n",
" lat=r_ellipse[:, 1].astype(float) + means[1].values,\n",
" mode=\"lines\",\n",
" fill=\"toself\",\n",
" showlegend=True if i == uncertainty_index else False,\n",
" legendgroup=\"HDI\",\n",
" hoverinfo=\"skip\",\n",
" marker_color=\"blue\",\n",
" name=\"95% CI\",\n",
" )\n",
" )\n",
" fig.add_traces(\n",
" [\n",
" go.Scattermap(\n",
" lon=posterior_mean[:, 0],\n",
" lat=posterior_mean[:, 1],\n",
" name=\"predictions\",\n",
" mode=\"lines+markers\",\n",
" line=dict(color=\"lightblue\"),\n",
" hovertemplate=[\n",
" f\"\"\"Period: {i+1}
Longitude: {posterior[0]:.1f}
Latitude: {posterior[1]:.1f}
Longitude: {row['longitude']:.1f}
Latitude: {row['latitude']:.1f}
Miles Away: %{y}\",\n",
" ),\n",
" go.Scatter(\n",
" x=np.arange(len(cumulative_errors)) + 1,\n",
" y=cumulative_errors,\n",
" name=\"Cumulative Average Error\",\n",
" hovertemplate=\"Period: %{x}
Miles Away: %{y}\",\n",
" ),\n",
" ]\n",
" )\n",
" fig.add_shape(\n",
" type=\"line\",\n",
" y0=mean_error,\n",
" y1=mean_error,\n",
" x0=0,\n",
" x1=errors.shape[0],\n",
" line_dash=\"dash\",\n",
" line_color=\"black\",\n",
" name=\"Overall Mean Error\",\n",
" label=dict(text=f\"Mean: {mean_error:.2f}\", textposition=\"start\"),\n",
" showlegend=True,\n",
" )\n",
" fig.update_layout(\n",
" title=f\"{main_title} Model Evaluation\",\n",
" xaxis=dict(title=\"Time Period\"),\n",
" yaxis=dict(title=\"Miles Away from Actual\"),\n",
" )\n",
" return fig"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load and Process the Dataset\n",
"The data comes from the National Oceanic and Atmospheric Administration (NOAA) and is stored in an odd format (likely to save space). We need to wrangle it before we can proceed."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"parsed_data = list()\n",
"try:\n",
" with open(\"../data/hurdat2-1851-2023-051124.txt\") as f:\n",
" lines = f.readlines()\n",
" for line in lines:\n",
" commas = re.findall(\",\", line)\n",
" if len(commas) < 4:\n",
" hsep = line.split(\",\")\n",
" storm_id = hsep[0]\n",
" storm_name = hsep[1].strip()\n",
" else:\n",
" dsep = line.split(\", \")\n",
" year = dsep[0][:4]\n",
" month = dsep[0][4:6]\n",
" day = dsep[0][6:8]\n",
" hours = dsep[1][:2]\n",
" minutes = dsep[1][-2:]\n",
" record_identifier = dsep[2]\n",
" latitude = dsep[4]\n",
" longitude = dsep[5]\n",
" max_wind = dsep[6]\n",
" min_pressure = dsep[7]\n",
" parsed_data.append(\n",
" [\n",
" storm_id,\n",
" storm_name,\n",
" year,\n",
" month,\n",
" day,\n",
" hours,\n",
" minutes,\n",
" record_identifier,\n",
" latitude,\n",
" longitude,\n",
" max_wind,\n",
" min_pressure,\n",
" ]\n",
" )\n",
"except FileNotFoundError:\n",
" stream = pm.get_data(\"hurdat2-1851-2023-051124.txt\")\n",
" lines = stream.readlines()\n",
" for line in lines:\n",
" commas = re.findall(\",\", line)\n",
" if len(commas) < 4:\n",
" hsep = line.split(\",\")\n",
" storm_id = hsep[0]\n",
" storm_name = hsep[1].strip()\n",
" else:\n",
" dsep = line.split(\", \")\n",
" year = dsep[0][:4]\n",
" month = dsep[0][4:6]\n",
" day = dsep[0][6:8]\n",
" hours = dsep[1][:2]\n",
" minutes = dsep[1][-2:]\n",
" record_identifier = dsep[2]\n",
" latitude = dsep[4]\n",
" longitude = dsep[5]\n",
" max_wind = dsep[6]\n",
" min_pressure = dsep[7]\n",
" parsed_data.append(\n",
" [\n",
" storm_id,\n",
" storm_name,\n",
" year,\n",
" month,\n",
" day,\n",
" hours,\n",
" minutes,\n",
" record_identifier,\n",
" latitude,\n",
" longitude,\n",
" max_wind,\n",
" min_pressure,\n",
" ]\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"df = pl.DataFrame(\n",
" parsed_data,\n",
" orient=\"row\",\n",
" schema={\n",
" \"storm_id\": pl.String,\n",
" \"storm_name\": pl.String,\n",
" \"year\": pl.String,\n",
" \"month\": pl.String,\n",
" \"day\": pl.String,\n",
" \"hour\": pl.String,\n",
" \"minute\": pl.String,\n",
" \"record_identifier\": pl.String,\n",
" \"latitude\": pl.String,\n",
" \"longitude\": pl.String,\n",
" \"max_wind\": pl.String,\n",
" \"min_pressure\": pl.String,\n",
" },\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"df_clean = (\n",
" df.with_columns(\n",
" pl.concat_str( # combine columns to generate a datetime string field\n",
" \"year\", \"month\", \"day\", \"hour\", \"minute\"\n",
" ).alias(\"datetime\")\n",
" )\n",
" .with_columns( # Cast fields to appropriate data types\n",
" pl.col(\"datetime\")\n",
" .str.strptime(dtype=pl.Datetime, format=\"%Y%m%d%H%M\")\n",
" .dt.replace_time_zone(\"UTC\")\n",
" .name.keep(),\n",
" pl.col(\"latitude\").str.extract(r\"(\\d\\d?.\\d)\").cast(pl.Float32).name.keep(),\n",
" (pl.col(\"longitude\").str.extract(r\"(\\d\\d?.\\d)\").cast(pl.Float32) * -1).name.keep(),\n",
" pl.col(\"max_wind\").str.strip_chars().cast(pl.Float32).name.keep(),\n",
" pl.col(\"min_pressure\").str.strip_chars().cast(pl.Float32).name.keep(),\n",
" )\n",
" .drop(\"year\", \"month\", \"day\", \"hour\", \"minute\") # Drop redundant fields\n",
" .filter(pl.col(\"storm_name\") != \"UNNAMED\") # remove unnamed hurricanes\n",
" .with_columns(\n",
" category=( # Create hurricane intensity category level\n",
" pl.when(pl.col(\"max_wind\") > 155)\n",
" .then(pl.lit(5.0))\n",
" .when(pl.col(\"max_wind\").is_between(131, 155))\n",
" .then(pl.lit(4.0))\n",
" .when(pl.col(\"max_wind\").is_between(111, 130))\n",
" .then(pl.lit(3.0))\n",
" .when(pl.col(\"max_wind\").is_between(96, 110))\n",
" .then(pl.lit(2.0))\n",
" .when(pl.col(\"max_wind\").is_between(74, 95))\n",
" .then(pl.lit(1.0))\n",
" .otherwise(pl.lit(0.0))\n",
" )\n",
" )\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"
storm_id | storm_name | record_identifier | latitude | longitude | max_wind | min_pressure | datetime | category |
---|---|---|---|---|---|---|---|---|
str | str | str | f32 | f32 | f32 | f32 | datetime[μs, UTC] | f64 |
"AL011950" | "ABLE" | " " | 17.1 | -55.5 | 35.0 | -999.0 | 1950-08-12 00:00:00 UTC | 0.0 |
"AL011950" | "ABLE" | " " | 17.700001 | -56.299999 | 40.0 | -999.0 | 1950-08-12 06:00:00 UTC | 0.0 |
"AL011950" | "ABLE" | " " | 18.200001 | -57.400002 | 45.0 | -999.0 | 1950-08-12 12:00:00 UTC | 0.0 |
"AL011950" | "ABLE" | " " | 19.0 | -58.599998 | 50.0 | -999.0 | 1950-08-12 18:00:00 UTC | 0.0 |
"AL011950" | "ABLE" | " " | 20.0 | -60.0 | 50.0 | -999.0 | 1950-08-13 00:00:00 UTC | 0.0 |
Model Requirements \n", " \n", " Variable Shape Constraints Dimensions \n", " ───────────────────────────────────────────────────────────────────────────────────── \n", " x0 (6,) None ('state',) \n", " P0 (6, 6) Positive Semi-definite ('state', 'state_aux') \n", " acceleration_innovations (1,) Positive (1,) \n", " \n", "These parameters should be assigned priors inside a PyMC model block before calling the\n", " build_statespace_graph method. \n", "\n" ], "text/plain": [ "\u001b[3m Model Requirements \u001b[0m\n", " \n", " \u001b[1m \u001b[0m\u001b[1mVariable \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1mShape \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1mConstraints \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1m Dimensions\u001b[0m\u001b[1m \u001b[0m \n", " ───────────────────────────────────────────────────────────────────────────────────── \n", " x0 \u001b[1m(\u001b[0m\u001b[1;36m6\u001b[0m,\u001b[1m)\u001b[0m \u001b[3;35mNone\u001b[0m \u001b[1m(\u001b[0m\u001b[32m'state'\u001b[0m,\u001b[1m)\u001b[0m \n", " P0 \u001b[1m(\u001b[0m\u001b[1;36m6\u001b[0m, \u001b[1;36m6\u001b[0m\u001b[1m)\u001b[0m Positive Semi-definite \u001b[1m(\u001b[0m\u001b[32m'state'\u001b[0m, \u001b[32m'state_aux'\u001b[0m\u001b[1m)\u001b[0m \n", " acceleration_innovations \u001b[1m(\u001b[0m\u001b[1;36m1\u001b[0m,\u001b[1m)\u001b[0m Positive \u001b[1m(\u001b[0m\u001b[1;36m1\u001b[0m,\u001b[1m)\u001b[0m \n", " \n", "\u001b[2;3mThese parameters should be assigned priors inside a PyMC model block before calling the\u001b[0m\n", "\u001b[2;3m build_statespace_graph method. \u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "n_ssm = NewtonianSSM(mode=\"JAX\")" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/miniconda3/envs/pymc_examples_dev/lib/python3.13/site-packages/pymc_extras/statespace/utils/data_tools.py:74: UserWarning:\n", "\n", "No time index found on the supplied data. A simple range index will be automatically generated.\n", "\n" ] }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "
Sampler Progress
\n", "Total Chains: 4
\n", "Active Chains: 0
\n", "\n", " Finished Chains:\n", " 4\n", "
\n", "Sampling for now
\n", "\n", " Estimated Time to Completion:\n", " now\n", "
\n", "\n", " \n", "Progress | \n", "Draws | \n", "Divergences | \n", "Step Size | \n", "Gradients/Draw | \n", "
---|---|---|---|---|
\n", " \n", " | \n", "2000 | \n", "0 | \n", "1.38 | \n", "3 | \n", "
\n", " \n", " | \n", "2000 | \n", "0 | \n", "1.35 | \n", "3 | \n", "
\n", " \n", " | \n", "2000 | \n", "0 | \n", "1.38 | \n", "3 | \n", "
\n", " \n", " | \n", "2000 | \n", "0 | \n", "1.38 | \n", "1 | \n", "
\n", " | mean | \n", "sd | \n", "hdi_3% | \n", "hdi_97% | \n", "
---|---|---|---|---|
acceleration_innovations[0] | \n", "0.004 | \n", "0.001 | \n", "0.002 | \n", "0.007 | \n", "
Model Requirements \n", " \n", " Variable Shape Constraints Dimensions \n", " ──────────────────────────────────────────────────────────────────────────────────────── \n", " x0 (6,) None (6,) \n", " P0 (12, 12) Positive Semi-definite ('state', 'state_aux') \n", " acceleration_innovations (1,) Positive (1,) \n", " beta_exog (6,) None ('exog_dims',) \n", " \n", " exogenous_data (None, 6) pm.Data ('time', 'exog_dims') \n", " \n", " These parameters should be assigned priors inside a PyMC model block before calling the \n", " build_statespace_graph method. \n", "\n" ], "text/plain": [ "\u001b[3m Model Requirements \u001b[0m\n", " \n", " \u001b[1m \u001b[0m\u001b[1mVariable \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1mShape \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1mConstraints \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1m Dimensions\u001b[0m\u001b[1m \u001b[0m \n", " ──────────────────────────────────────────────────────────────────────────────────────── \n", " x0 \u001b[1m(\u001b[0m\u001b[1;36m6\u001b[0m,\u001b[1m)\u001b[0m \u001b[3;35mNone\u001b[0m \u001b[1m(\u001b[0m\u001b[1;36m6\u001b[0m,\u001b[1m)\u001b[0m \n", " P0 \u001b[1m(\u001b[0m\u001b[1;36m12\u001b[0m, \u001b[1;36m12\u001b[0m\u001b[1m)\u001b[0m Positive Semi-definite \u001b[1m(\u001b[0m\u001b[32m'state'\u001b[0m, \u001b[32m'state_aux'\u001b[0m\u001b[1m)\u001b[0m \n", " acceleration_innovations \u001b[1m(\u001b[0m\u001b[1;36m1\u001b[0m,\u001b[1m)\u001b[0m Positive \u001b[1m(\u001b[0m\u001b[1;36m1\u001b[0m,\u001b[1m)\u001b[0m \n", " beta_exog \u001b[1m(\u001b[0m\u001b[1;36m6\u001b[0m,\u001b[1m)\u001b[0m \u001b[3;35mNone\u001b[0m \u001b[1m(\u001b[0m\u001b[32m'exog_dims'\u001b[0m,\u001b[1m)\u001b[0m \n", " \n", " exogenous_data \u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m, \u001b[1;36m6\u001b[0m\u001b[1m)\u001b[0m pm.Data \u001b[1m(\u001b[0m\u001b[32m'time'\u001b[0m, \u001b[32m'exog_dims'\u001b[0m\u001b[1m)\u001b[0m \n", " \n", "\u001b[2;3m These parameters should be assigned priors inside a PyMC model block before calling the \u001b[0m\n", "\u001b[2;3m build_statespace_graph method. \u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "exog_ssm = ExogenousSSM(k_exog=X_exog.shape[1], mode=\"JAX\")" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/miniconda3/envs/pymc_examples_dev/lib/python3.13/site-packages/pymc_extras/statespace/utils/data_tools.py:74: UserWarning:\n", "\n", "No time index found on the supplied data. A simple range index will be automatically generated.\n", "\n" ] }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "
Sampler Progress
\n", "Total Chains: 4
\n", "Active Chains: 0
\n", "\n", " Finished Chains:\n", " 4\n", "
\n", "Sampling for a minute
\n", "\n", " Estimated Time to Completion:\n", " now\n", "
\n", "\n", " \n", "Progress | \n", "Draws | \n", "Divergences | \n", "Step Size | \n", "Gradients/Draw | \n", "
---|---|---|---|---|
\n", " \n", " | \n", "2000 | \n", "0 | \n", "0.62 | \n", "7 | \n", "
\n", " \n", " | \n", "2000 | \n", "0 | \n", "0.61 | \n", "7 | \n", "
\n", " \n", " | \n", "2000 | \n", "0 | \n", "0.62 | \n", "7 | \n", "
\n", " \n", " | \n", "2000 | \n", "0 | \n", "0.59 | \n", "7 | \n", "
Model Requirements \n", " \n", " Variable Shape Constraints Dimensions \n", " ───────────────────────────────────────────────────────────────────────────────────────── \n", " x0 (6,) None (6,) \n", " P0 (40, 40) Positive Semi-definite ('state', 'state_aux') \n", " acceleration_innovations (1,) Positive (1,) \n", " beta_exog (34,) None ('exog_dims',) \n", " \n", " exogenous_data (None, 34) pm.Data ('time', 'exog_dims') \n", " \n", " These parameters should be assigned priors inside a PyMC model block before calling the \n", " build_statespace_graph method. \n", "\n" ], "text/plain": [ "\u001b[3m Model Requirements \u001b[0m\n", " \n", " \u001b[1m \u001b[0m\u001b[1mVariable \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1mShape \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1mConstraints \u001b[0m\u001b[1m \u001b[0m \u001b[1m \u001b[0m\u001b[1m Dimensions\u001b[0m\u001b[1m \u001b[0m \n", " ───────────────────────────────────────────────────────────────────────────────────────── \n", " x0 \u001b[1m(\u001b[0m\u001b[1;36m6\u001b[0m,\u001b[1m)\u001b[0m \u001b[3;35mNone\u001b[0m \u001b[1m(\u001b[0m\u001b[1;36m6\u001b[0m,\u001b[1m)\u001b[0m \n", " P0 \u001b[1m(\u001b[0m\u001b[1;36m40\u001b[0m, \u001b[1;36m40\u001b[0m\u001b[1m)\u001b[0m Positive Semi-definite \u001b[1m(\u001b[0m\u001b[32m'state'\u001b[0m, \u001b[32m'state_aux'\u001b[0m\u001b[1m)\u001b[0m \n", " acceleration_innovations \u001b[1m(\u001b[0m\u001b[1;36m1\u001b[0m,\u001b[1m)\u001b[0m Positive \u001b[1m(\u001b[0m\u001b[1;36m1\u001b[0m,\u001b[1m)\u001b[0m \n", " beta_exog \u001b[1m(\u001b[0m\u001b[1;36m34\u001b[0m,\u001b[1m)\u001b[0m \u001b[3;35mNone\u001b[0m \u001b[1m(\u001b[0m\u001b[32m'exog_dims'\u001b[0m,\u001b[1m)\u001b[0m \n", " \n", " exogenous_data \u001b[1m(\u001b[0m\u001b[3;35mNone\u001b[0m, \u001b[1;36m34\u001b[0m\u001b[1m)\u001b[0m pm.Data \u001b[1m(\u001b[0m\u001b[32m'time'\u001b[0m, \u001b[32m'exog_dims'\u001b[0m\u001b[1m)\u001b[0m \n", " \n", "\u001b[2;3m These parameters should be assigned priors inside a PyMC model block before calling the \u001b[0m\n", "\u001b[2;3m build_statespace_graph method. \u001b[0m\n" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "spline_ssm = SplineSSM(k_exog=exog_data.shape[1], mode=\"JAX\")" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/miniconda3/envs/pymc_examples_dev/lib/python3.13/site-packages/pymc_extras/statespace/utils/data_tools.py:74: UserWarning:\n", "\n", "No time index found on the supplied data. A simple range index will be automatically generated.\n", "\n" ] }, { "data": { "text/html": [ "\n", "\n" ], "text/plain": [ "
Sampler Progress
\n", "Total Chains: 4
\n", "Active Chains: 4
\n", "\n", " Finished Chains:\n", " 0\n", "
\n", "Sampling for a minute
\n", "\n", " Estimated Time to Completion:\n", " 2 minutes\n", "
\n", "\n", " \n", "Progress | \n", "Draws | \n", "Divergences | \n", "Step Size | \n", "Gradients/Draw | \n", "
---|---|---|---|---|
\n", " \n", " | \n", "437 | \n", "0 | \n", "0.62 | \n", "7 | \n", "
\n", " \n", " | \n", "425 | \n", "0 | \n", "0.53 | \n", "7 | \n", "
\n", " \n", " | \n", "435 | \n", "0 | \n", "0.54 | \n", "15 | \n", "
\n", " \n", " | \n", "434 | \n", "0 | \n", "0.58 | \n", "7 | \n", "