.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/4-Recurrent-SLDS.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_4-Recurrent-SLDS.py: Recurrent SLDS =============== .. GENERATED FROM PYTHON SOURCE LINES 19-387 .. rst-class:: sphx-glr-horizontal * .. image-sg:: /auto_examples/images/sphx_glr_4-Recurrent-SLDS_001.png :alt: True Trajectory, Observations for first 1000 time steps :srcset: /auto_examples/images/sphx_glr_4-Recurrent-SLDS_001.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/images/sphx_glr_4-Recurrent-SLDS_002.png :alt: 4 Recurrent SLDS :srcset: /auto_examples/images/sphx_glr_4-Recurrent-SLDS_002.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/images/sphx_glr_4-Recurrent-SLDS_003.png :alt: True, Inferred, BBVI, Inferred, Laplace-EM :srcset: /auto_examples/images/sphx_glr_4-Recurrent-SLDS_003.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/images/sphx_glr_4-Recurrent-SLDS_004.png :alt: True Dynamics :srcset: /auto_examples/images/sphx_glr_4-Recurrent-SLDS_004.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/images/sphx_glr_4-Recurrent-SLDS_005.png :alt: Inferred Dynamics, Laplace-EM :srcset: /auto_examples/images/sphx_glr_4-Recurrent-SLDS_005.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/images/sphx_glr_4-Recurrent-SLDS_006.png :alt: Inferred Dynamics, BBVI :srcset: /auto_examples/images/sphx_glr_4-Recurrent-SLDS_006.png :class: sphx-glr-multi-img .. rst-class:: sphx-glr-script-out .. code-block:: none 0%| | 0/100 [00:00 0: ax.quiver(xy[zk, 0], xy[zk, 1], dxydt_m[zk, 0], dxydt_m[zk, 1], color=colors[k % len(colors)], alpha=alpha) ax.set_xlabel('$x_1$') ax.set_ylabel('$x_2$') plt.tight_layout() return ax # - # ## 2. Simulating Data from an rSLDS # # Below, we create a simulated dataset from a non-linear system, which we'll call the "Nascar" dataset. The Nascar dataset is meant to emulate cars going around a track. There are 4 states total: 2 each for driving along each straightaway, and two for semicircular turns at each end of the track. # # You'll note that in creating the rSLDS we use **transitions="recurrent_only"**. This means that the transition probabilities are determined only by the previous state (and on the inputs, if present). There is no dependence on the prior $z_t$. Instead, each state simply has a constant bias $r_i$ which biases the transitions toward state $i$. This model is strictly less flexible that the full rSLDS formulation. By setting the weights on the current state to be very large, we make the discrete state transitions essentially deterministic. After creating the rSLDS and sampling a trajectory, we plot the true trajectory below. # # Global parameters T = 10000 K = 4 D_obs = 10 D_latent = 2 # + # Simulate the "nascar" data def make_nascar_model(): As = [random_rotation(D_latent, np.pi/24.), random_rotation(D_latent, np.pi/48.)] # Set the center points for each system centers = [np.array([+2.0, 0.]), np.array([-2.0, 0.])] bs = [-(A - np.eye(D_latent)).dot(center) for A, center in zip(As, centers)] # Add a "right" state As.append(np.eye(D_latent)) bs.append(np.array([+0.1, 0.])) # Add a "right" state As.append(np.eye(D_latent)) bs.append(np.array([-0.25, 0.])) # Construct multinomial regression to divvy up the space w1, b1 = np.array([+1.0, 0.0]), np.array([-2.0]) # x + b > 0 -> x > -b w2, b2 = np.array([-1.0, 0.0]), np.array([-2.0]) # -x + b > 0 -> x < b w3, b3 = np.array([0.0, +1.0]), np.array([0.0]) # y > 0 w4, b4 = np.array([0.0, -1.0]), np.array([0.0]) # y < 0 Rs = np.row_stack((100*w1, 100*w2, 10*w3,10*w4)) r = np.concatenate((100*b1, 100*b2, 10*b3, 10*b4)) true_rslds = ssm.SLDS(D_obs, K, D_latent, transitions="recurrent_only", dynamics="diagonal_gaussian", emissions="gaussian_orthog", single_subspace=True) true_rslds.dynamics.mu_init = np.tile(np.array([[0, 1]]), (K, 1)) true_rslds.dynamics.sigmasq_init = 1e-4 * np.ones((K, D_latent)) true_rslds.dynamics.As = np.array(As) true_rslds.dynamics.bs = np.array(bs) true_rslds.dynamics.sigmasq = 1e-4 * np.ones((K, D_latent)) true_rslds.transitions.Rs = Rs true_rslds.transitions.r = r true_rslds.emissions.inv_etas = np.log(1e-2) * np.ones((1, D_obs)) return true_rslds # Sample from the model true_rslds = make_nascar_model() z, x, y = true_rslds.sample(T=T) # - # **Visualizing Trajectories** # # We've defined some helper functions above to plot the latent state trajectories, and color code them according to the discrete state. In the left panel, we show the continuous state trajectories. In the right panel below, we show 3 of the observations over the first 1000 time steps (our observations are 10 dimensional, but we've plotted 3 observation traces to reduce clutter). # + fig = plt.figure(figsize=(15, 6)) gs = gridspec.GridSpec(1, 2, width_ratios=[2, 3]) ax0 = plt.subplot(gs[0]) plot_trajectory(z, x, ax=ax0) plt.title("True Trajectory") ax1 = plt.subplot(gs[1]) plot_observations(z[:1000], y[:1000,:3], ax=ax1) plt.title("Observations for first 1000 time steps") plt.tight_layout() # - # ## 3. Exercise # # ### 3.1 Linear vs. Non-Linear Systems # It's worth looking at the plot of the trajectories and considering the following: what behavior does the latent state show that could not be captured by a single linear dynamical system? # # ### 3.2 Understanding the discrete state transitions # Let's look again at the form of the discrete state transition probabilities: # # $$ # p(z_t = i \mid z_{t-1} = j, x_{t-1}) \propto # \exp{\left( \log (P_{j,i}) + w_i^T u_t + r_i ^T x_{t-1} \right)} # $$ # # In this case, we have used the **recurrent only** transitions class, which means we get rid of the transition matrix and replace it with a bias. Since we don't have any external inputs here, we can also leave out the input terms: # # # $$ # p(z_t = i \mid z_{t-1} = j, x_{t-1}) \propto # \exp{\left( r_i + R_i ^T x_{t-1} \right)} # $$ # # What happens as the magnitude of the entries in $R_i$ become very large (compared to the entries of $R_j$ for the other states? Do the transitions become more or less random? # # ## 4. Fitting an rSLDS # Below, we create a new rSLDS object and fit it to the data generated above (note that our new rSLDS will only have access to the observations $y$ and not the true states $z$ or $x$). # # ### 4.1 Fitting Methods # The fitting methods available for the rSLDS are the same as those available for the SLDS. We've reproduced the section on fitting methods from the SLDS notebook below. # # **Important Note:** # # Understanding the following section is not necessary to use SSM! _For practical purposes, it is almost always best to use the Laplace-EM method with the Structured Mean-Field Posterior, which is the default._ Running the below cells will be a bit slow on a typical laptop (around 5 minutes). We're working on speeding things up in future releases of SSM. # # # **Parameter Learning for rSLDS** # Parameter learning in an rSLDS requires approximate methods. SSM provides two approximate inference algorithms: Black Box Variational Inference (`"bbvi"`) and Laplace Variational EM (`"laplace_em"`). We don't have the space to describe these methods in detail here, but Black Box Variational Inference was described in ["Variational Inference: A Review for Statisticians"](https://arxiv.org/pdf/1601.00670.pdf) by Blei et al. The Laplace Approximation is described in several sources, but a good reference for the context of state-space models is ["Estimating State and Parameters in state-space models of Spike Trains,"](https://pdfs.semanticscholar.org/a71e/bf112cabd47cc67284dc8c12ab7644195d60.pdf) a book chapter by Macke et al. The specific method used in this notebook is described by ["Zoltowski et al (2020)"](https://arxiv.org/abs/2001.04571). # # # # **Approximate Posterior Distributions** # When using approximate methods, we must choose the form of the distribution we use to approximate the posterior. Here, SSM provides three options: # 1. `variational_posterior="meanfield"` # The mean-field approximation uses a factorized distribution as the approximating posterior. Compatible with the BBVI method. # # 2. `variational_posterior="tridiag"` # This approximates the posterior using a Gaussian with a block tridiagonal covariance matrix, which can be thought of as approximating the SLDS posterior with the posterior from an LDS. Compatible with the BBVI method. # # 3. `variational_posterior="structured_meanfield"` # This assumes a posterior where the join distribution over the continuous and discrete latent states factors as follows. If $q(z,x \mid y)$ is the joint posterior of the discrete and continuous states given the data, we use the approximation $q(z,x \mid y) \approx q(z \mid y)q(x \mid y)$, where $q(z \mid y)$ is the posterior for a Markov chain. Compatible with the Laplace-EM method. # # **Calling the Fit function in SSM** # All models in SSM share the same general syntax for fitting a model from data. Below, we call the fit function using three different methods and compare convergence. The syntax is as follows: # ```python # elbos, posterior = slds.fit(data, method= "...", # variational_posterior="...", # num_iters= ...) # ``` # In the the call to `fit`, method should be one of {`"bbvi"`, `"laplace_em"`}. # The `variational_posterior` argument should be one of {`"mf"`, `"structured_meanfield"`}. However, when using Laplace-EM _only_ structured mean field is supported. # Below, we fit using four methods, and compare convergence. # # # ### 4.2 Getting the Inferred States # For every LDS, SLDS, and rSLDS model in SSM, calling `fit` returns a tuple of `(elbos, posterior)`. `elbos` is a list containing a lower bound on the log-likelihood of the data at each iteration, used to check the convergence of the fitting algorithm. `posterior` is a posterior object (the exact type depends on which posterior is used). The posterior object is used to get an estimate of the latent variables (in this case $x$ and $z$) for each time step. # # Below, we use the line: # ```python # xhat_lem = q_lem.mean_continuous_states[0] # ``` # # to get an estimate $\hat x$ of the continuous state over time. The reason for the index `[0]` is that `posterior.mean_continuous_states` will return a list, where each entry is the posterior for a given trial. In this case, we only have a single trial, so we get the first (and only) element of the list. # # + # Fit an rSLDS with its default initialization, using Laplace-EM with a structured variational posterior rslds = ssm.SLDS(D_obs, K, D_latent, transitions="recurrent_only", dynamics="diagonal_gaussian", emissions="gaussian_orthog", single_subspace=True) rslds.initialize(y) q_elbos_lem, q_lem = rslds.fit(y, method="laplace_em", variational_posterior="structured_meanfield", initialize=False, num_iters=100, alpha=0.0) xhat_lem = q_lem.mean_continuous_states[0] rslds.permute(find_permutation(z, rslds.most_likely_states(xhat_lem, y))) zhat_lem = rslds.most_likely_states(xhat_lem, y) # store rslds rslds_lem = copy.deepcopy(rslds) # + # Fit an rSLDS with its default initialization, using BBVI with a structured variational posterior rslds = ssm.SLDS(D_obs, K, D_latent, transitions="recurrent_only", dynamics="diagonal_gaussian", emissions="gaussian_orthog", single_subspace=True) rslds.initialize(y) q_elbos_bbvi, q_bbvi = rslds.fit(y, method="bbvi", variational_posterior="meanfield", initialize=False, num_iters=1000) # + # Get the posterior mean of the continuous states xhat_bbvi = q_bbvi.mean[0] # Find the permutation that matches the true and inferred states rslds.permute(find_permutation(z, rslds.most_likely_states(xhat_bbvi, y))) zhat_bbvi = rslds.most_likely_states(xhat_bbvi, y) # - # ### 4.4 Checking Convergence # Below, we plot the ELBO obtained via both Laplace-EM and BBVI. We see that the Laplace-EM algorithm tends to converge must faster (and to a better lower bound). # Plot some results plt.figure() plt.plot(q_elbos_bbvi, label="BBVI") plt.plot(q_elbos_lem[1:], label="Laplace-EM") plt.legend() plt.xlabel("Iteration") plt.ylabel("ELBO") # ## 5. Visualizing True and Inferred States # We wrote some helper functions above that plot a state trajectory, with different colors corresponding to the discrete latent states. **Note**: we only can recover the true system up to an affine transformation. That's why, even though we have permuted the discrete states to match the true system, the colors don't always match up. # # In the cell immediately below, we see that the estimated latent trajectories found using Laplace-EM match the ground-truth more closely. In the cell below that, we extract the dynamics matrices the $A_k$s and use them to plot the system dynamics in each state. Note that the Laplace-EM algorithm does a better job at finding the positions in state-space which trigger discrete state transitions. plt.figure(figsize=[10,4]) ax1 = plt.subplot(131) plot_trajectory(z, x, ax=ax1) plt.title("True") ax2 = plt.subplot(132) plot_trajectory(zhat_bbvi, xhat_bbvi, ax=ax2) plt.title("Inferred, BBVI") ax3 = plt.subplot(133) plot_trajectory(zhat_lem, xhat_lem, ax=ax3) plt.title("Inferred, Laplace-EM") plt.tight_layout() # + plt.figure(figsize=(6,4)) ax = plt.subplot(111) lim = abs(x).max(axis=0) + 1 plot_most_likely_dynamics(true_rslds, xlim=(-lim[0], lim[0]), ylim=(-lim[1], lim[1]), ax=ax) plt.title("True Dynamics") plt.figure(figsize=(6,4)) ax = plt.subplot(111) lim = abs(xhat_lem).max(axis=0) + 1 plot_most_likely_dynamics(rslds_lem, xlim=(-lim[0], lim[0]), ylim=(-lim[1], lim[1]), ax=ax) plt.title("Inferred Dynamics, Laplace-EM") plt.figure(figsize=(6,4)) ax = plt.subplot(111) lim = abs(xhat_bbvi).max(axis=0) + 1 plot_most_likely_dynamics(rslds, xlim=(-lim[0], lim[0]), ylim=(-lim[1], lim[1]), ax=ax) plt.title("Inferred Dynamics, BBVI") .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 3 minutes 53.641 seconds) .. _sphx_glr_download_auto_examples_4-Recurrent-SLDS.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 4-Recurrent-SLDS.py <4-Recurrent-SLDS.py>` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: 4-Recurrent-SLDS.ipynb <4-Recurrent-SLDS.ipynb>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_