--- jupytext: text_representation: extension: .md format_name: myst format_version: 0.13 jupytext_version: 1.14.1 kernelspec: display_name: Python 3 language: python name: python3 --- # Multi-population recurrent switching linear dynamical systems overview +++ {"colab_type": "text", "id": "view-in-github"} Open In Colab +++ ## If you want to quickly see how to fit your own data, jump down to the "Fit model to data" section

This notebook goes through the simulation example shown in our manuscript (Figure 2A,B). Below, we briefly describe the model. We also recommend looking at the "Recurrent SLDS" notebook, which provides more details on the standard rSLDS.

**1. Data**. Let $y_t^{_{(j)}}$ denote a vector of activity measurements of the $N_j$ neurons in population $j$ in time bin $t$.
**2. Emissions**. let $x_t^{_{(j)}}$ denote a continuous latent state of population $j$ at time $t$. The population states may differ in dimensionality~$D_j$, since populations may differ in size and complexity. The observed activity of population $j$ is modeled with a generalized linear model, \begin{align} E[y_t^{(j)}] &= f(C_j x_t^{(j)} + d_j), \end{align} where each population has its own linear mapping parameterized by $\{C_j, d_j\}$. In this notebook, we use a Poisson GLM. Inputs can also be passed into this GLM, as described in the rSLDS notebook. There are multi-population emissions classes that will be loaded in the example below.
**3. Continuous State Update (Dynamics)**. The dynamics of a switching linear dynamical system are piecewise linear, with the linear dynamics at a given time determined by a discrete state, (more on discrete states below). \begin{align} x_t \sim A^{(z_t)} x_{t-1} + b^{(z_t)} \end{align} where $z_t$ is the discrete state, $A^{(z_t)}$ is the dynamics for that discrete state, and $x_t$ contains the latents from all populations, $[x_t^1, x_t^2, ..., x_t^J]$. We ignore the noise term here for simplicity. Having unique continuous latents for each population allows us to decompose the dynamics in an interpretable manner. We model the temporal dynamics of the continuous states as \begin{align} x_t^{(j)} \sim A_{(j \: to \: j)}^{(z_t)} x_{t-1}^{(j)} + \sum_{i \neq j} A_{(i \: to \: j)}^{(z_t)} x_{t-1}^{(i)} + b_j^{(z_t)}. \end{align} In the full dynamics matrices, $A^{(z_t)}$ we will show in the example below, the on-diagonal blocks represent the internal dynamics, $A_{(j \: to \: j)}^{(z_t)}$ and the off-diagonal blocks represent the external dynamics, $A_{(i \: to \: j)}^{(z_t)}$. **4. Discrete State Update (Transitions)**. Recurrent transitions are based on the continuous latent state. Our recurrent transitions have a sticky component, $S$ that determines the probabilities of staying in a state, and a switching component, $R$, that determines the probabilities of switching to states. In the model we use in this notebook: \begin{align} p(z_t = i \mid z_{t-1} = j, x_{t-1}) &= \mathrm{softmax}\bigg\{ \Big( \big(R x_{t-1}\big) + r\Big) \odot (1 - e_{z_{t-1}}) + \Big( \big(S x_{t-1} \big) + s \Big) \odot e_{z_{t-1}} \bigg\}, \end{align} where $e_{z_{t-1}} \in \{0,1\}^K$ is a one-hot encoding of $z_{t-1}$. To understand which populations are contributing to the transitions, we can decompose this equation: \begin{align} p(z_t = i \mid z_{t-1} = j, x_{t-1}) &= \mathrm{softmax}\bigg\{ \Big( \sum_{j=1}^J \big(R_j x_{t-1}^{(j)}\big) + r\Big) \odot (1 - e_{z_{t-1}}) + \Big( \sum_{j=1}^J \big(S_j x_{t-1}^{(j)} \big) + s \Big) \odot e_{z_{t-1}} \bigg\}, \end{align} where, for example, $R_j x_{t-1}^{(j)}$ contains the contribution of population $j$ towards switching to each state. Additionally, we can include a dependency on the previous discrete state. This is included in the code package, but is not used in the example below. \begin{align} p(z_t = i \mid z_{t-1} = j, x_{t-1}) &= \mathrm{softmax}\bigg\{ \log(P_{j,i}) + \big(R x_{t-1}\big) \odot (1 - e_{z_{t-1}}) + \big(S x_{t-1} \big) \odot e_{z_{t-1}} \bigg\}, \end{align} There are sticky multi-population emissions classes that will be loaded in the example below.
**5. Model fitting**. We fit the model with variational laplace EM - see the "Variational Laplace EM for SLDS Tutorial" for more information. +++ {"colab_type": "text", "id": "8OzC8q4bRFQv"} ## Import packages, including multipopulation extensions ```{code-cell} ipython3 --- colab: base_uri: https://localhost:8080/ height: 581 colab_type: code id: ruUnNqi5RZqT outputId: 228b6c8e-c064-46c2-ce57-9ad88daca5c8 --- try: import ssm except: !pip install git+https://github.com/lindermanlab/ssm.git#egg=ssm import ssm ``` ```{code-cell} ipython3 --- colab: base_uri: https://localhost:8080/ height: 71 colab_type: code id: zDn3tEJhRFQv outputId: 2f1ca1d0-8f17-404a-897f-57b8c5d353cb --- #### General packages from matplotlib import pyplot as plt %matplotlib inline import autograd.numpy as np import autograd.numpy.random as npr import seaborn as sns sns.set_style("white") sns.set_context("talk") sns.set_style('ticks',{"xtick.major.size":8, "ytick.major.size":8}) from ssm.plots import gradient_cmap, white_to_color_cmap color_names = [ "purple", "red", "amber", "faded green", "windows blue", "orange" ] colors = sns.xkcd_palette(color_names) cmap = gradient_cmap(colors) ``` ```{code-cell} ipython3 :colab: {} :colab_type: code :id: 0rq19iIQRFQy #### SSM PACKAGES ### import ssm from ssm.variational import SLDSMeanFieldVariationalPosterior, SLDSTriDiagVariationalPosterior, \ SLDSStructuredMeanFieldVariationalPosterior from ssm.util import random_rotation, find_permutation, relu #Load from extensions from ssm.extensions.mp_srslds.emissions_ext import GaussianOrthogonalCompoundEmissions, PoissonOrthogonalCompoundEmissions from ssm.extensions.mp_srslds.transitions_ext import StickyRecurrentOnlyTransitions, StickyRecurrentTransitions ``` +++ {"colab_type": "text", "id": "Ty3EOi8bRFQ1"} ## Simulate (somewhat realistic) data +++ {"colab_type": "text", "id": "QxDoYCRDRFQ2"} ### Set parameters of simulation ```{code-cell} ipython3 :colab: {} :colab_type: code :id: dalqY6zvRFQ2 K=3 #Number of discrete states num_gr=3 #Number of populations num_per_gr=5 #Number of latents per population neur_per_gr=75 #Number of neurons per population t_end=3000 #number of time bins num_trials=1 #number of trials ``` ```{code-cell} ipython3 --- colab: base_uri: https://localhost:8080/ height: 34 colab_type: code id: OHTSTNbTRFQ4 outputId: fd3b833b-8df2-425b-d3ce-103cf42d5153 --- np.random.seed(108) #To create replicable dynamics alphas=.03+.1*np.random.rand(K) #Determines the distribution of values in the dynamics matrix, for each discrete state print('alphas:', alphas) sparsity=.33 #Proportion of non-diagonal blocks in the dynamics matrix that are 0 e1=.1 #Amount of noise in the dynamics ``` +++ {"colab_type": "text", "id": "t91lYSkPRFQ7"} ### Get new emissions and transitions classes for the simulated data ```{code-cell} ipython3 :colab: {} :colab_type: code :id: tQx540b6RFQ8 #Vector containing number of latents per population D_vec=[] for i in range(num_gr): D_vec.append(num_per_gr) #Vector containing number of neurons per population N_vec=[] for i in range(num_gr): N_vec.append(neur_per_gr) D=np.sum(D_vec) num_gr=len(D_vec) D_vec_cumsum = np.concatenate(([0], np.cumsum(D_vec))) #Get new multipopulation emissions class for the simulation # gauss_comp_emissions=GaussianOrthogonalCompoundEmissions(N=np.sum(N_vec),K=1,D=np.sum(D_vec),D_vec=D_vec,N_vec=N_vec) poiss_comp_emissions=PoissonOrthogonalCompoundEmissions(N=np.sum(N_vec),K=1,D=np.sum(D_vec),D_vec=D_vec,N_vec=N_vec,link='softplus') #Get transitions class true_sro_trans=StickyRecurrentOnlyTransitions(K=K,D=np.sum(D_vec)) ``` +++ {"colab_type": "text", "id": "DDu2GnRGRFQ-"} ### Create simulated data ```{code-cell} ipython3 :colab: {} :colab_type: code :id: VLN8FWLLRFQ_ np.random.seed(10) #To create replicable simulations A_masks=[] A_all=np.zeros([K,D,D]) #Initialize dynamics matrix b_all=np.zeros([K,D]) #Initialize dynamics offset #Create initial ground truth model, that we will modify true_slds = ssm.SLDS(N=np.sum(N_vec),K=K,D=int(np.sum(D_vec)), dynamics="gaussian", emissions=poiss_comp_emissions, transitions=true_sro_trans) #Create ground truth transitions v=.2+.2*np.random.rand(1) for k in range(K): inc=np.copy(k) true_slds.transitions.Rs[k,D_vec_cumsum[inc]:D_vec_cumsum[inc]+1]=v true_slds.transitions.Ss[k,D_vec_cumsum[inc]:D_vec_cumsum[inc]+1]=v-.1 true_slds.transitions.r=0*np.ones([K,1]) true_slds.transitions.s=5*np.ones([K,1]) #Create ground truth dynamics for each state for k in range(K): ##Create dynamics## alpha=alphas[k] A_mask=np.random.rand(num_gr,num_gr)>sparsity #Make some blocks of the dynamics matrix 0 A_masks.append(A_mask) for i in range(num_gr): A_mask[i,i]=1 A0=np.zeros([D,D]) for i in range(D-1): A0[i,i+1:]=alpha*np.random.randn(D-1-i) A0=(A0-A0.T) for i in range(num_gr): A0[D_vec_cumsum[i]:D_vec_cumsum[i+1],D_vec_cumsum[i]:D_vec_cumsum[i+1]]=2*A0[D_vec_cumsum[i]:D_vec_cumsum[i+1],D_vec_cumsum[i]:D_vec_cumsum[i+1]] A0=A0+np.identity(D) A=A0*np.kron(A_mask, np.ones((num_per_gr, num_per_gr))) A=A/(np.max(np.abs(np.linalg.eigvals(A)))+.01) #.97 b=1*np.random.rand(D) A_all[k]=A b_all[k]=b true_slds.dynamics.As=A_all true_slds.dynamics.bs=b_all zs, xs, _ = true_slds.sample(t_end) #Sample discrete and continuous latents from model for simulation #Get spike trains that have an average firing rate of 0.25 per bin tmp=np.mean(relu(np.dot(true_slds.emissions.Cs[0],xs.T)+.1*true_slds.emissions.ds[0][:,None]).T) mult=.25/tmp lams=relu(mult*np.dot(true_slds.emissions.Cs[0],xs.T)+.1*true_slds.emissions.ds[0][:,None]).T ys=np.random.poisson(lams) #Get spiking activity based on poisson statistics ``` +++ {"colab_type": "text", "id": "twuPg8wRRFRC"} ## Plot simulated data +++ {"colab_type": "text", "id": "1-VkH7xSRFRC"} ### Dynamics matrices ($A^z$) ```{code-cell} ipython3 --- colab: base_uri: https://localhost:8080/ height: 797 colab_type: code id: qv7hO5fgRFRD outputId: 7b2a3151-a5dc-4ac0-c446-ecbb212ffb61 --- # vmin,vmax=[-1,1] vmin,vmax=[-.5,.5] #zoom in to see colors more clearly for k in range(K): plt.figure(figsize=(4,4)) plt.imshow(true_slds.dynamics.As[k], aspect='auto', interpolation="none", vmin=vmin, vmax=vmax, cmap='RdBu') offset=-.5 for nf in D_vec: plt.plot([-0.5, D-0.5], [offset, offset], '-k') plt.plot([offset, offset], [-0.5, D-0.5], '-k') offset += nf plt.xticks([]) plt.yticks([]) plt.title('Actual State '+str(k)) ``` +++ {"colab_type": "text", "id": "-Y7R3y6TRFRF"} ### Discrete states ($z$) ```{code-cell} ipython3 --- colab: base_uri: https://localhost:8080/ height: 177 colab_type: code id: Wmk2_uI-RFRG outputId: 2f434446-fbaf-4a00-d186-81a557b35011 --- plt.figure(figsize=(8, 4)) plt.subplot(211) plt.imshow(zs[None,:], aspect="auto", cmap=cmap, vmin=0, vmax=len(colors)-1) plt.xlim(0, t_end) plt.ylabel("$z_{\\mathrm{true}}$") plt.yticks([]) ``` +++ {"colab_type": "text", "id": "ufQwDDKTRFRI"} ### Transitions (in a shorter time window) The contribution of population $j$ to staying in a state is $S_j x^{j}$ and the contribution to switching to a state is $R_j x^{j}$ ```{code-cell} ipython3 :colab: {} :colab_type: code :id: fp1QQ_7dRFRI dur=200 st_t=650 end_t=st_t+dur ``` ```{code-cell} ipython3 --- colab: base_uri: https://localhost:8080/ height: 310 colab_type: code id: -FWpoP7-RFRK outputId: bcc60411-fb9e-44d0-bc16-e382fe7919aa --- plt.figure(figsize=(8, 4)) j=0 plt.subplot(211) for g in range(K): plt.plot(np.dot(xs[st_t:end_t,D_vec_cumsum[g]:D_vec_cumsum[g+1]],true_slds.transitions.Rs[j,D_vec_cumsum[g]:D_vec_cumsum[g+1]].T)) plt.xlim(0, dur) plt.ylabel('Contribution towards \n switching to \n purple state',rotation=60) plt.xticks([]) plt.yticks([]) j=1 plt.subplot(212) for g in range(K): plt.plot(np.dot(xs[st_t:end_t,D_vec_cumsum[g]:D_vec_cumsum[g+1]],true_slds.transitions.Ss[j,D_vec_cumsum[g]:D_vec_cumsum[g+1]].T)) plt.xlim(0, dur) plt.ylabel('Contribution towards \n staying in red state',rotation=60) plt.legend(['Pop 1','Pop 2','Pop 3']) plt.yticks([]) ``` +++ {"colab_type": "text", "id": "wOO48DPARFRN"} ### Continuous latents ($x$) and spikes ($y$) for an example population ```{code-cell} ipython3 --- colab: base_uri: https://localhost:8080/ height: 449 colab_type: code id: YllKBbdhRFRN outputId: dc7f9218-0100-4263-8501-9d2fb1696a9e --- plt.figure(figsize=(8, 4)) plt.subplot(211) plt.plot(xs[st_t:end_t,:num_per_gr]) #Show latents of first group plt.xticks([]) plt.subplot(212) plt.plot(ys[st_t:end_t,:10]) #Show first 10 neurons ``` +++ {"colab_type": "text", "id": "OZ_iLjiyRFRV"} ## Fit model to data +++ {"colab_type": "text", "id": "g6QfY0j-RFRV"} ### To create the emissions classes for the multipopulation models, we need vectors containing the number of continuous latents per population ("D_vec") and neurons per population ("N_vec") ```{code-cell} ipython3 :colab: {} :colab_type: code :id: kVZDXyIiRFRW num_gr=3 #Number of populations num_per_gr=5 #Number of latents per population neur_per_gr=75 #Number of neurons per population #Vector containing number of latents per population D_vec=[] for i in range(num_gr): D_vec.append(num_per_gr) #Vector containing number of neurons per population N_vec=[] for i in range(num_gr): N_vec.append(neur_per_gr) ``` +++ {"colab_type": "text", "id": "_I-JFqbHRFRZ"} #### Now create the multipopulation emissions and transitions classes for our model ```{code-cell} ipython3 :colab: {} :colab_type: code :id: dhnavgIgRFRa #Get new multipopulation emissions class poiss_comp_emissions=PoissonOrthogonalCompoundEmissions(N=np.sum(N_vec),K=1,D=np.sum(D_vec),D_vec=D_vec,N_vec=N_vec,link='softplus') #Get new transitions class sro_trans=StickyRecurrentOnlyTransitions(K=K,D=np.sum(D_vec), l2_penalty_similarity=10, l1_penalty=10) #The above l2 penalty is on the similarity between R and S (its assuming the activity to switch into a state is similar to activity to stay in a state) #The L1 penalty is on the entries of R and S ``` Note that another new emissions class is "GaussianOrthogonalCompoundEmissions"
Note that another new transitions class is "StickyRecurrentTransitions" +++ {"colab_type": "text", "id": "syHDAea_RFRc"} #### Now declare and fit the model ```{code-cell} ipython3 --- colab: base_uri: https://localhost:8080/ height: 166 referenced_widgets: [36b4e6c4b2064572a4412f8bc298caeb, 97247a689ff744689528aa9555c13c62, 581b2c1a39eb439581308a5cdea07389, e8e6a6b80ed14b77900e87e1c779f459, 2e200ca5be3141f2b2ee041ec8b08e49, f4e534f4d6ac45c6a7da04be7d341968, ced98407e9ee4ee4bd0baaa24d4765b9, 14de5bf9b23f47d4867e135cf156ac97, 6bd4692a8c95492d84f6353069342d4f, 62f38f8736314a098f91d7486b5b84f6, d7ffd5a5a8874c2d8e593cb29c8f14b2, 2ddbbe21934547db9e365d905ea1f024, 19ad77a391f7459f811262d5bdbeb783, 5da21c8eac7041898c7a468f0388e632, eb1c51978e0f4b79968cf6618a959007, c0204fcad9f64dd48580f7d17767abc6] colab_type: code id: n7WRqQufRFRd outputId: 28ddbaaa-221c-4ca8-db59-b76bbeff1dd0 --- K=3 #Number of discrete states rslds = ssm.SLDS(N=np.sum(N_vec),K=K,D=np.sum(D_vec), dynamics="gaussian", emissions=poiss_comp_emissions, transitions=sro_trans, dynamics_kwargs=dict(l2_penalty_A=100)) #Regularization on the dynamics matrix q_elbos_ar, q_ar = rslds.fit(ys, method="laplace_em", variational_posterior="structured_meanfield", continuous_optimizer='newton', initialize=True, num_init_restarts=10, num_iters=30, alpha=0.25) ``` ```{code-cell} ipython3 :colab: {} :colab_type: code :id: Oi8-3NcxRFRf :outputId: 98075512-a594-4bae-fcf3-550affe796a6 plt.plot(q_elbos_ar[1:]) plt.xlabel("Iteration") plt.ylabel("ELBO") ``` ## Align solution with simulation for plotting ```{code-cell} ipython3 #The recovered discrete states can be permuted in any way. #Find permutation to match the discrete states in the model and the ground truth z_inferred=rslds.most_likely_states(q_ar.mean_continuous_states[0],ys) rslds.permute(find_permutation(zs, z_inferred)) z_inferred2=rslds.most_likely_states(q_ar.mean_continuous_states[0],ys) ``` ```{code-cell} ipython3 #Each population's latents can be multiplied by an arbitrary rotation matrix #Additionally, there may be a change in scaling between the simulation ground truth and recovered latents, #because the simulation didn't constrain the effective emissions (C) matrix to be orthonormal like in the model from sklearn.linear_model import LinearRegression R=np.zeros([D,D]) for g in range(num_gr): lr=LinearRegression(fit_intercept=False) lr.fit(q_ar.mean_continuous_states[0][:,D_vec_cumsum[g]:D_vec_cumsum[g+1]],xs[:,D_vec_cumsum[g]:D_vec_cumsum[g+1]]) R[D_vec_cumsum[g]:D_vec_cumsum[g+1],D_vec_cumsum[g]:D_vec_cumsum[g+1]]=lr.coef_ ``` +++ {"colab_type": "text", "id": "0CjkFH63RFRh"} ## Plot results +++ {"colab_type": "text", "id": "HDlMQx4ZRFRh"} ### Discrete states ($z$) ```{code-cell} ipython3 :colab: {} :colab_type: code :id: TGk3zoBVRFRi :outputId: d1cebd3c-17f7-4e48-a6d3-0b9f34fbd3af plt.figure(figsize=(8, 4)) plt.subplot(211) plt.imshow(zs[None,:], aspect="auto", cmap=cmap, vmin=0, vmax=len(colors)-1) plt.xlim(0, t_end) plt.ylabel("$z_{\\mathrm{true}}$") plt.yticks([]) plt.subplot(212) plt.imshow(z_inferred2[None,:], aspect="auto", cmap=cmap, vmin=0, vmax=len(colors)-1) plt.xlim(0, t_end) plt.ylabel("$z_{\\mathrm{inferred}}$") plt.yticks([]) plt.xlabel("time") plt.tight_layout() ``` ```{code-cell} ipython3 :colab: {} :colab_type: code :id: HRPr3DJlRFRj :outputId: 1d6b402b-fa6d-4040-f4bf-4e4c2188c8b8 print('Discrete state accuracy: ', np.mean(zs==z_inferred2)) ``` +++ {"colab_type": "text", "id": "SX1go1AyRFRl"} #### Shorter time window ```{code-cell} ipython3 :colab: {} :colab_type: code :id: UhjVHiYzRFRm :outputId: 5d4cf636-9055-46fe-af2b-d51531c97317 plt.figure(figsize=(4, 2)) plt.subplot(211) plt.imshow(zs[None,st_t:end_t], aspect="auto", cmap=cmap, vmin=0, vmax=len(colors)-1) plt.xlim(0, dur) plt.ylabel("$z_{\\mathrm{true}}$") plt.yticks([]) plt.xticks([]) plt.subplot(212) plt.imshow(z_inferred2[None,st_t:end_t], aspect="auto", cmap=cmap, vmin=0, vmax=len(colors)-1) plt.xlim(0, dur) plt.ylabel("$z_{\\mathrm{inferred}}$") plt.yticks([]) ``` +++ {"colab_type": "text", "id": "e4Cz1o8GRFRo"} ### Dynamics matrices ($A^z$) +++ We show the A matrix from when the continuous latents are aligned to ground truth, demonstrating the ability to recover the ground truth dynamics. We also show the original recovered A matrix, which demonstrates that we can learn about the block structure, regardless of scaling/rotations. ```{code-cell} ipython3 :colab: {} :colab_type: code :id: 0D2QHaLRRFRo :outputId: 891cf1a1-4aea-4c5f-f2d5-9895478b5361 plt.figure(figsize=(12, 12)) q=1 for k in range(K): plt.subplot(3,3,q) plt.imshow(true_slds.dynamics.As[k], aspect='auto', interpolation="none", vmin=-.5, vmax=.5, cmap='RdBu') offset=-.5 for nf in D_vec: plt.plot([-0.5, D-0.5], [offset, offset], '-k') plt.plot([offset, offset], [-0.5, D-0.5], '-k') offset += nf plt.xticks([]) plt.yticks([]) plt.title('Actual State '+str(k)) q=q+1 plt.subplot(3,3,q) # plt.imshow(rslds.dynamics.As[k], aspect='auto', interpolation="none", vmin=-.5, vmax=.5, cmap='RdBu') plt.imshow(R@rslds.dynamics.As[k]@np.linalg.inv(R), aspect='auto', interpolation="none", vmin=-.5, vmax=.5, cmap='RdBu') offset=-.5 for nf in D_vec: plt.plot([-0.5, D-0.5], [offset, offset], '-k') plt.plot([offset, offset], [-0.5, D-0.5], '-k') offset += nf plt.xticks([]) plt.yticks([]) plt.title('Aligned Predicted State '+str(k)) # plt.savefig(folder+'dyn_est'+str(k)+'.pdf') q=q+1 plt.subplot(3,3,q) # plt.imshow(rslds.dynamics.As[k], aspect='auto', interpolation="none", vmin=-.5, vmax=.5, cmap='RdBu') plt.imshow(rslds.dynamics.As[k], aspect='auto', interpolation="none", vmin=-.5, vmax=.5, cmap='RdBu') offset=-.5 for nf in D_vec: plt.plot([-0.5, D-0.5], [offset, offset], '-k') plt.plot([offset, offset], [-0.5, D-0.5], '-k') offset += nf plt.xticks([]) plt.yticks([]) plt.title('Raw Predicted State '+str(k)) # plt.savefig(folder+'dyn_est'+str(k)+'.pdf') q=q+1 ``` +++ {"colab_type": "text", "id": "b14UYWuMRFRr"} ### Transitions The contribution of population $j$ to staying in a state is $S_j x^{j}$ and the contribution to switching to a state is $R_j x^{j}$ ```{code-cell} ipython3 :colab: {} :colab_type: code :id: hja_9YLPRFRr :outputId: 6eb42921-5277-43ea-b14b-2366d2128509 plt.figure(figsize=(15, 4)) ### Actual j=0 plt.subplot(221) for g in range(K): plt.plot(np.dot(xs[st_t:end_t,D_vec_cumsum[g]:D_vec_cumsum[g+1]],true_slds.transitions.Rs[j,D_vec_cumsum[g]:D_vec_cumsum[g+1]].T)) plt.xlim(0, dur) plt.ylabel('Contribution towards \n switching to \n purple state',rotation=60) plt.xticks([]) plt.yticks([]) plt.title('Actual') j=1 plt.subplot(223) for g in range(K): plt.plot(np.dot(xs[st_t:end_t,D_vec_cumsum[g]:D_vec_cumsum[g+1]],true_slds.transitions.Ss[j,D_vec_cumsum[g]:D_vec_cumsum[g+1]].T)) plt.xlim(0, dur) plt.ylabel('Contribution towards \n staying in red state',rotation=60) plt.legend(['Pop 1','Pop 2','Pop 3']) plt.yticks([]) ### Predicted j=0 plt.subplot(222) for g in range(K): plt.plot(np.dot(q_ar.mean_continuous_states[0][st_t:end_t,D_vec_cumsum[g]:D_vec_cumsum[g+1]],rslds.transitions.Rs[j,D_vec_cumsum[g]:D_vec_cumsum[g+1]].T)) plt.xlim(0, dur) # plt.ylabel('Contribution towards \n switching to \n purple state',rotation=60) plt.xticks([]) plt.yticks([]) plt.title('Predicted') j=1 plt.subplot(224) for g in range(K): plt.plot(np.dot(q_ar.mean_continuous_states[0][st_t:end_t,D_vec_cumsum[g]:D_vec_cumsum[g+1]],rslds.transitions.Ss[j,D_vec_cumsum[g]:D_vec_cumsum[g+1]].T)) plt.xlim(0, dur) # plt.ylabel('Stay in Red') # plt.xticks([]) plt.yticks([]) ``` +++ {"colab_type": "text", "id": "yz4GSxjrRFRt"} ### Example fit of neural activity ($y$) ```{code-cell} ipython3 :colab: {} :colab_type: code :id: cgEffzVKRFRt :outputId: 10d64661-8fbd-4259-96fc-134c35980b88 preds=rslds.smooth(q_ar.mean_continuous_states[0],ys) #get predictions nrn=0 #Example neuron plt.plot(ys[st_t:end_t,nrn],alpha=.5) #true spiking activity plt.plot(lams[st_t:end_t,nrn]) #true firing rate plt.plot(preds[st_t:end_t,nrn]) #predicted firing plt.legend(['True Spikes','True FR','Predicted FR']) ```