.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/Multi-Population-rSLDS.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note Click :ref:`here ` to download the full example code .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_Multi-Population-rSLDS.py: Multi-Population rSLDS ====================== .. GENERATED FROM PYTHON SOURCE LINES 19-620 .. rst-class:: sphx-glr-horizontal * .. image-sg:: /auto_examples/images/sphx_glr_Multi-Population-rSLDS_001.png :alt: Actual State 0 :srcset: /auto_examples/images/sphx_glr_Multi-Population-rSLDS_001.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/images/sphx_glr_Multi-Population-rSLDS_002.png :alt: Actual State 1 :srcset: /auto_examples/images/sphx_glr_Multi-Population-rSLDS_002.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/images/sphx_glr_Multi-Population-rSLDS_003.png :alt: Actual State 2 :srcset: /auto_examples/images/sphx_glr_Multi-Population-rSLDS_003.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/images/sphx_glr_Multi-Population-rSLDS_004.png :alt: Multi Population rSLDS :srcset: /auto_examples/images/sphx_glr_Multi-Population-rSLDS_004.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/images/sphx_glr_Multi-Population-rSLDS_005.png :alt: Multi Population rSLDS :srcset: /auto_examples/images/sphx_glr_Multi-Population-rSLDS_005.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/images/sphx_glr_Multi-Population-rSLDS_006.png :alt: Multi Population rSLDS :srcset: /auto_examples/images/sphx_glr_Multi-Population-rSLDS_006.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/images/sphx_glr_Multi-Population-rSLDS_007.png :alt: Multi Population rSLDS :srcset: /auto_examples/images/sphx_glr_Multi-Population-rSLDS_007.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/images/sphx_glr_Multi-Population-rSLDS_008.png :alt: Multi Population rSLDS :srcset: /auto_examples/images/sphx_glr_Multi-Population-rSLDS_008.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/images/sphx_glr_Multi-Population-rSLDS_009.png :alt: Actual State 0, Aligned Predicted State 0, Raw Predicted State 0, Actual State 1, Aligned Predicted State 1, Raw Predicted State 1, Actual State 2, Aligned Predicted State 2, Raw Predicted State 2 :srcset: /auto_examples/images/sphx_glr_Multi-Population-rSLDS_009.png :class: sphx-glr-multi-img * .. image-sg:: /auto_examples/images/sphx_glr_Multi-Population-rSLDS_010.png :alt: Actual, Predicted :srcset: /auto_examples/images/sphx_glr_Multi-Population-rSLDS_010.png :class: sphx-glr-multi-img .. rst-class:: sphx-glr-script-out .. code-block:: none alphas: [0.05336111 0.03343876 0.08800889] 0%| | 0/10 [00:00 | .. code-block:: default # + [markdown] colab_type="text" id="view-in-github" # Open In Colab # - # ### If you want to quickly see how to fit your own data, jump down to the "Fit model to data" section #
#
# # # Multi-population recurrent switching linear dynamical systems overview # # This notebook goes through the simulation example shown in our manuscript (Figure 2A,B). # # Below, we briefly describe the model. We also recommend looking at the "Recurrent SLDS" notebook, which provides more details on the standard rSLDS. #
#
# # **1. Data**. # Let $y_t^{_{(j)}}$ denote a vector of activity measurements of the $N_j$ neurons in population $j$ in time bin $t$. #
# # **2. Emissions**. # let $x_t^{_{(j)}}$ denote a continuous latent state of population $j$ at time $t$. The population states may differ in dimensionality~$D_j$, since populations may differ in size and complexity. The observed activity of population $j$ is modeled with a generalized linear model, # \begin{align} # E[y_t^{(j)}] &= f(C_j x_t^{(j)} + d_j), # \end{align} # where each population has its own linear mapping parameterized by $\{C_j, d_j\}$. In this notebook, we use a Poisson GLM. Inputs can also be passed into this GLM, as described in the rSLDS notebook. # # There are multi-population emissions classes that will be loaded in the example below. #
# # **3. Continuous State Update (Dynamics)**. # The dynamics of a switching linear dynamical system are piecewise linear, with the linear dynamics at a given time determined by a discrete state, (more on discrete states below). # # \begin{align} # x_t \sim # A^{(z_t)} x_{t-1} + b^{(z_t)} # \end{align} # # where $z_t$ is the discrete state, $A^{(z_t)}$ is the dynamics for that discrete state, and $x_t$ contains the latents from all populations, $[x_t^1, x_t^2, ..., x_t^J]$. We ignore the noise term here for simplicity. # # Having unique continuous latents for each population allows us to decompose the dynamics in an interpretable manner. We model the temporal dynamics of the continuous states as # \begin{align} # x_t^{(j)} \sim # A_{(j \: to \: j)}^{(z_t)} x_{t-1}^{(j)} # + \sum_{i \neq j} A_{(i \: to \: j)}^{(z_t)} x_{t-1}^{(i)} # + b_j^{(z_t)}. # \end{align} # # In the full dynamics matrices, $A^{(z_t)}$ we will show in the example below, the on-diagonal blocks represent the internal dynamics, $A_{(j \: to \: j)}^{(z_t)}$ and the off-diagonal blocks represent the external dynamics, $A_{(i \: to \: j)}^{(z_t)}$. # # # **4. Discrete State Update (Transitions)**. # Recurrent transitions are based on the continuous latent state. Our recurrent transitions have a sticky component, $S$ that determines the probabilities of staying in a state, and a switching component, $R$, that determines the probabilities of switching to states. In the model we use in this notebook: # # \begin{align} # p(z_t = i \mid z_{t-1} = j, x_{t-1}) &= \mathrm{softmax}\bigg\{ \Big( \big(R x_{t-1}\big) + r\Big) \odot (1 - e_{z_{t-1}}) + \Big( \big(S x_{t-1} \big) + s \Big) \odot e_{z_{t-1}} \bigg\}, # \end{align} # # where $e_{z_{t-1}} \in \{0,1\}^K$ is a one-hot encoding of $z_{t-1}$. # # To understand which populations are contributing to the transitions, we can decompose this equation: # # # \begin{align} # p(z_t = i \mid z_{t-1} = j, x_{t-1}) &= \mathrm{softmax}\bigg\{ \Big( \sum_{j=1}^J \big(R_j x_{t-1}^{(j)}\big) + r\Big) \odot (1 - e_{z_{t-1}}) + \Big( \sum_{j=1}^J \big(S_j x_{t-1}^{(j)} \big) + s \Big) \odot e_{z_{t-1}} \bigg\}, # \end{align} # where, for example, $R_j x_{t-1}^{(j)}$ contains the contribution of population $j$ towards switching to each state. # # # Additionally, we can include a dependency on the previous discrete state. This is included in the code package, but is not used in the example below. # # \begin{align} # p(z_t = i \mid z_{t-1} = j, x_{t-1}) &= \mathrm{softmax}\bigg\{ \log(P_{j,i}) + \big(R x_{t-1}\big) \odot (1 - e_{z_{t-1}}) + \big(S x_{t-1} \big) \odot e_{z_{t-1}} \bigg\}, # \end{align} # # There are sticky multi-population emissions classes that will be loaded in the example below. #
# # **5. Model fitting**. # We fit the model with variational laplace EM - see the "Variational Laplace EM for SLDS Tutorial" for more information. # # + [markdown] colab_type="text" id="8OzC8q4bRFQv" # ## Import packages, including multipopulation extensions # + colab={"base_uri": "https://localhost:8080/", "height": 581} colab_type="code" id="ruUnNqi5RZqT" outputId="228b6c8e-c064-46c2-ce57-9ad88daca5c8" try: import ssm except: # !pip install git+https://github.com/lindermanlab/ssm.git#egg=ssm import ssm # + colab={"base_uri": "https://localhost:8080/", "height": 71} colab_type="code" id="zDn3tEJhRFQv" outputId="2f1ca1d0-8f17-404a-897f-57b8c5d353cb" #### General packages from matplotlib import pyplot as plt # %matplotlib inline import autograd.numpy as np import autograd.numpy.random as npr import seaborn as sns sns.set_style("white") sns.set_context("talk") sns.set_style('ticks',{"xtick.major.size":8, "ytick.major.size":8}) from ssm.plots import gradient_cmap, white_to_color_cmap color_names = [ "purple", "red", "amber", "faded green", "windows blue", "orange" ] colors = sns.xkcd_palette(color_names) cmap = gradient_cmap(colors) # + colab={} colab_type="code" id="0rq19iIQRFQy" #### SSM PACKAGES ### import ssm from ssm.variational import SLDSMeanFieldVariationalPosterior, SLDSTriDiagVariationalPosterior, \ SLDSStructuredMeanFieldVariationalPosterior from ssm.util import random_rotation, find_permutation, relu #Load from extensions from ssm.extensions.mp_srslds.emissions_ext import GaussianOrthogonalCompoundEmissions, PoissonOrthogonalCompoundEmissions from ssm.extensions.mp_srslds.transitions_ext import StickyRecurrentOnlyTransitions, StickyRecurrentTransitions # + [markdown] colab_type="text" id="Ty3EOi8bRFQ1" # ## Simulate (somewhat realistic) data # + [markdown] colab_type="text" id="QxDoYCRDRFQ2" # ### Set parameters of simulation # + colab={} colab_type="code" id="dalqY6zvRFQ2" K=3 #Number of discrete states num_gr=3 #Number of populations num_per_gr=5 #Number of latents per population neur_per_gr=75 #Number of neurons per population t_end=3000 #number of time bins num_trials=1 #number of trials # + colab={"base_uri": "https://localhost:8080/", "height": 34} colab_type="code" id="OHTSTNbTRFQ4" outputId="fd3b833b-8df2-425b-d3ce-103cf42d5153" np.random.seed(108) #To create replicable dynamics alphas=.03+.1*np.random.rand(K) #Determines the distribution of values in the dynamics matrix, for each discrete state print('alphas:', alphas) sparsity=.33 #Proportion of non-diagonal blocks in the dynamics matrix that are 0 e1=.1 #Amount of noise in the dynamics # + [markdown] colab_type="text" id="t91lYSkPRFQ7" # ### Get new emissions and transitions classes for the simulated data # + colab={} colab_type="code" id="tQx540b6RFQ8" #Vector containing number of latents per population D_vec=[] for i in range(num_gr): D_vec.append(num_per_gr) #Vector containing number of neurons per population N_vec=[] for i in range(num_gr): N_vec.append(neur_per_gr) D=np.sum(D_vec) num_gr=len(D_vec) D_vec_cumsum = np.concatenate(([0], np.cumsum(D_vec))) #Get new multipopulation emissions class for the simulation # gauss_comp_emissions=GaussianOrthogonalCompoundEmissions(N=np.sum(N_vec),K=1,D=np.sum(D_vec),D_vec=D_vec,N_vec=N_vec) poiss_comp_emissions=PoissonOrthogonalCompoundEmissions(N=np.sum(N_vec),K=1,D=np.sum(D_vec),D_vec=D_vec,N_vec=N_vec,link='softplus') #Get transitions class true_sro_trans=StickyRecurrentOnlyTransitions(K=K,D=np.sum(D_vec)) # + [markdown] colab_type="text" id="DDu2GnRGRFQ-" # ### Create simulated data # + colab={} colab_type="code" id="VLN8FWLLRFQ_" np.random.seed(10) #To create replicable simulations A_masks=[] A_all=np.zeros([K,D,D]) #Initialize dynamics matrix b_all=np.zeros([K,D]) #Initialize dynamics offset #Create initial ground truth model, that we will modify true_slds = ssm.SLDS(N=np.sum(N_vec),K=K,D=int(np.sum(D_vec)), dynamics="gaussian", emissions=poiss_comp_emissions, transitions=true_sro_trans) #Create ground truth transitions v=.2+.2*np.random.rand(1) for k in range(K): inc=np.copy(k) true_slds.transitions.Rs[k,D_vec_cumsum[inc]:D_vec_cumsum[inc]+1]=v true_slds.transitions.Ss[k,D_vec_cumsum[inc]:D_vec_cumsum[inc]+1]=v-.1 true_slds.transitions.r=0*np.ones([K,1]) true_slds.transitions.s=5*np.ones([K,1]) #Create ground truth dynamics for each state for k in range(K): ##Create dynamics## alpha=alphas[k] A_mask=np.random.rand(num_gr,num_gr)>sparsity #Make some blocks of the dynamics matrix 0 A_masks.append(A_mask) for i in range(num_gr): A_mask[i,i]=1 A0=np.zeros([D,D]) for i in range(D-1): A0[i,i+1:]=alpha*np.random.randn(D-1-i) A0=(A0-A0.T) for i in range(num_gr): A0[D_vec_cumsum[i]:D_vec_cumsum[i+1],D_vec_cumsum[i]:D_vec_cumsum[i+1]]=2*A0[D_vec_cumsum[i]:D_vec_cumsum[i+1],D_vec_cumsum[i]:D_vec_cumsum[i+1]] A0=A0+np.identity(D) A=A0*np.kron(A_mask, np.ones((num_per_gr, num_per_gr))) A=A/(np.max(np.abs(np.linalg.eigvals(A)))+.01) #.97 b=1*np.random.rand(D) A_all[k]=A b_all[k]=b true_slds.dynamics.As=A_all true_slds.dynamics.bs=b_all zs, xs, _ = true_slds.sample(t_end) #Sample discrete and continuous latents from model for simulation #Get spike trains that have an average firing rate of 0.25 per bin tmp=np.mean(relu(np.dot(true_slds.emissions.Cs[0],xs.T)+.1*true_slds.emissions.ds[0][:,None]).T) mult=.25/tmp lams=relu(mult*np.dot(true_slds.emissions.Cs[0],xs.T)+.1*true_slds.emissions.ds[0][:,None]).T ys=np.random.poisson(lams) #Get spiking activity based on poisson statistics # + [markdown] colab_type="text" id="twuPg8wRRFRC" # ## Plot simulated data # + [markdown] colab_type="text" id="1-VkH7xSRFRC" # ### Dynamics matrices ($A^z$) # + colab={"base_uri": "https://localhost:8080/", "height": 797} colab_type="code" id="qv7hO5fgRFRD" outputId="7b2a3151-a5dc-4ac0-c446-ecbb212ffb61" # vmin,vmax=[-1,1] vmin,vmax=[-.5,.5] #zoom in to see colors more clearly for k in range(K): plt.figure(figsize=(4,4)) plt.imshow(true_slds.dynamics.As[k], aspect='auto', interpolation="none", vmin=vmin, vmax=vmax, cmap='RdBu') offset=-.5 for nf in D_vec: plt.plot([-0.5, D-0.5], [offset, offset], '-k') plt.plot([offset, offset], [-0.5, D-0.5], '-k') offset += nf plt.xticks([]) plt.yticks([]) plt.title('Actual State '+str(k)) # + [markdown] colab_type="text" id="-Y7R3y6TRFRF" # ### Discrete states ($z$) # + colab={"base_uri": "https://localhost:8080/", "height": 177} colab_type="code" id="Wmk2_uI-RFRG" outputId="2f434446-fbaf-4a00-d186-81a557b35011" plt.figure(figsize=(8, 4)) plt.subplot(211) plt.imshow(zs[None,:], aspect="auto", cmap=cmap, vmin=0, vmax=len(colors)-1) plt.xlim(0, t_end) plt.ylabel("$z_{\\mathrm{true}}$") plt.yticks([]) # + [markdown] colab_type="text" id="ufQwDDKTRFRI" # ### Transitions (in a shorter time window) # The contribution of population $j$ to staying in a state is $S_j x^{j}$ and the contribution to switching to a state is $R_j x^{j}$ # + colab={} colab_type="code" id="fp1QQ_7dRFRI" dur=200 st_t=650 end_t=st_t+dur # + colab={"base_uri": "https://localhost:8080/", "height": 310} colab_type="code" id="-FWpoP7-RFRK" outputId="bcc60411-fb9e-44d0-bc16-e382fe7919aa" plt.figure(figsize=(8, 4)) j=0 plt.subplot(211) for g in range(K): plt.plot(np.dot(xs[st_t:end_t,D_vec_cumsum[g]:D_vec_cumsum[g+1]],true_slds.transitions.Rs[j,D_vec_cumsum[g]:D_vec_cumsum[g+1]].T)) plt.xlim(0, dur) plt.ylabel('Contribution towards \n switching to \n purple state',rotation=60) plt.xticks([]) plt.yticks([]) j=1 plt.subplot(212) for g in range(K): plt.plot(np.dot(xs[st_t:end_t,D_vec_cumsum[g]:D_vec_cumsum[g+1]],true_slds.transitions.Ss[j,D_vec_cumsum[g]:D_vec_cumsum[g+1]].T)) plt.xlim(0, dur) plt.ylabel('Contribution towards \n staying in red state',rotation=60) plt.legend(['Pop 1','Pop 2','Pop 3']) plt.yticks([]) # + [markdown] colab_type="text" id="wOO48DPARFRN" # ### Continuous latents ($x$) and spikes ($y$) for an example population # + colab={"base_uri": "https://localhost:8080/", "height": 449} colab_type="code" id="YllKBbdhRFRN" outputId="dc7f9218-0100-4263-8501-9d2fb1696a9e" plt.figure(figsize=(8, 4)) plt.subplot(211) plt.plot(xs[st_t:end_t,:num_per_gr]) #Show latents of first group plt.xticks([]) plt.subplot(212) plt.plot(ys[st_t:end_t,:10]) #Show first 10 neurons # + [markdown] colab_type="text" id="OZ_iLjiyRFRV" # ## Fit model to data # + [markdown] colab_type="text" id="g6QfY0j-RFRV" # #### To create the emissions classes for the multipopulation models, we need vectors containing the number of continuous latents per population ("D_vec") and neurons per population ("N_vec") # # + colab={} colab_type="code" id="kVZDXyIiRFRW" num_gr=3 #Number of populations num_per_gr=5 #Number of latents per population neur_per_gr=75 #Number of neurons per population #Vector containing number of latents per population D_vec=[] for i in range(num_gr): D_vec.append(num_per_gr) #Vector containing number of neurons per population N_vec=[] for i in range(num_gr): N_vec.append(neur_per_gr) # + [markdown] colab_type="text" id="_I-JFqbHRFRZ" # #### Now create the multipopulation emissions and transitions classes for our model # + colab={} colab_type="code" id="dhnavgIgRFRa" #Get new multipopulation emissions class poiss_comp_emissions=PoissonOrthogonalCompoundEmissions(N=np.sum(N_vec),K=1,D=np.sum(D_vec),D_vec=D_vec,N_vec=N_vec,link='softplus') #Get new transitions class sro_trans=StickyRecurrentOnlyTransitions(K=K,D=np.sum(D_vec), l2_penalty_similarity=10, l1_penalty=10) #The above l2 penalty is on the similarity between R and S (its assuming the activity to switch into a state is similar to activity to stay in a state) #The L1 penalty is on the entries of R and S # - # Note that another new emissions class is "GaussianOrthogonalCompoundEmissions"
# # Note that another new transitions class is "StickyRecurrentTransitions" # + [markdown] colab_type="text" id="syHDAea_RFRc" # #### Now declare and fit the model # + colab={"base_uri": "https://localhost:8080/", "height": 166, "referenced_widgets": ["36b4e6c4b2064572a4412f8bc298caeb", "97247a689ff744689528aa9555c13c62", "581b2c1a39eb439581308a5cdea07389", "e8e6a6b80ed14b77900e87e1c779f459", "2e200ca5be3141f2b2ee041ec8b08e49", "f4e534f4d6ac45c6a7da04be7d341968", "ced98407e9ee4ee4bd0baaa24d4765b9", "14de5bf9b23f47d4867e135cf156ac97", "6bd4692a8c95492d84f6353069342d4f", "62f38f8736314a098f91d7486b5b84f6", "d7ffd5a5a8874c2d8e593cb29c8f14b2", "2ddbbe21934547db9e365d905ea1f024", "19ad77a391f7459f811262d5bdbeb783", "5da21c8eac7041898c7a468f0388e632", "eb1c51978e0f4b79968cf6618a959007", "c0204fcad9f64dd48580f7d17767abc6"]} colab_type="code" id="n7WRqQufRFRd" outputId="28ddbaaa-221c-4ca8-db59-b76bbeff1dd0" K=3 #Number of discrete states rslds = ssm.SLDS(N=np.sum(N_vec),K=K,D=np.sum(D_vec), dynamics="gaussian", emissions=poiss_comp_emissions, transitions=sro_trans, dynamics_kwargs=dict(l2_penalty_A=100)) #Regularization on the dynamics matrix q_elbos_ar, q_ar = rslds.fit(ys, method="laplace_em", variational_posterior="structured_meanfield", continuous_optimizer='newton', initialize=True, num_init_restarts=10, num_iters=30, alpha=0.25) # + colab={} colab_type="code" id="Oi8-3NcxRFRf" outputId="98075512-a594-4bae-fcf3-550affe796a6" plt.plot(q_elbos_ar[1:]) plt.xlabel("Iteration") plt.ylabel("ELBO") # - # ## Align solution with simulation for plotting #The recovered discrete states can be permuted in any way. #Find permutation to match the discrete states in the model and the ground truth z_inferred=rslds.most_likely_states(q_ar.mean_continuous_states[0],ys) rslds.permute(find_permutation(zs, z_inferred)) z_inferred2=rslds.most_likely_states(q_ar.mean_continuous_states[0],ys) # + #Each population's latents can be multiplied by an arbitrary rotation matrix #Additionally, there may be a change in scaling between the simulation ground truth and recovered latents, #because the simulation didn't constrain the effective emissions (C) matrix to be orthonormal like in the model from sklearn.linear_model import LinearRegression R=np.zeros([D,D]) for g in range(num_gr): lr=LinearRegression(fit_intercept=False) lr.fit(q_ar.mean_continuous_states[0][:,D_vec_cumsum[g]:D_vec_cumsum[g+1]],xs[:,D_vec_cumsum[g]:D_vec_cumsum[g+1]]) R[D_vec_cumsum[g]:D_vec_cumsum[g+1],D_vec_cumsum[g]:D_vec_cumsum[g+1]]=lr.coef_ # + [markdown] colab_type="text" id="0CjkFH63RFRh" # ## Plot results # + [markdown] colab_type="text" id="HDlMQx4ZRFRh" # ### Discrete states ($z$) # + colab={} colab_type="code" id="TGk3zoBVRFRi" outputId="d1cebd3c-17f7-4e48-a6d3-0b9f34fbd3af" plt.figure(figsize=(8, 4)) plt.subplot(211) plt.imshow(zs[None,:], aspect="auto", cmap=cmap, vmin=0, vmax=len(colors)-1) plt.xlim(0, t_end) plt.ylabel("$z_{\\mathrm{true}}$") plt.yticks([]) plt.subplot(212) plt.imshow(z_inferred2[None,:], aspect="auto", cmap=cmap, vmin=0, vmax=len(colors)-1) plt.xlim(0, t_end) plt.ylabel("$z_{\\mathrm{inferred}}$") plt.yticks([]) plt.xlabel("time") plt.tight_layout() # + colab={} colab_type="code" id="HRPr3DJlRFRj" outputId="1d6b402b-fa6d-4040-f4bf-4e4c2188c8b8" print('Discrete state accuracy: ', np.mean(zs==z_inferred2)) # + [markdown] colab_type="text" id="SX1go1AyRFRl" # #### Shorter time window # + colab={} colab_type="code" id="UhjVHiYzRFRm" outputId="5d4cf636-9055-46fe-af2b-d51531c97317" plt.figure(figsize=(4, 2)) plt.subplot(211) plt.imshow(zs[None,st_t:end_t], aspect="auto", cmap=cmap, vmin=0, vmax=len(colors)-1) plt.xlim(0, dur) plt.ylabel("$z_{\\mathrm{true}}$") plt.yticks([]) plt.xticks([]) plt.subplot(212) plt.imshow(z_inferred2[None,st_t:end_t], aspect="auto", cmap=cmap, vmin=0, vmax=len(colors)-1) plt.xlim(0, dur) plt.ylabel("$z_{\\mathrm{inferred}}$") plt.yticks([]) # + [markdown] colab_type="text" id="e4Cz1o8GRFRo" # ### Dynamics matrices ($A^z$) # - # We show the A matrix from when the continuous latents are aligned to ground truth, demonstrating the ability to recover the ground truth dynamics. # # We also show the original recovered A matrix, which demonstrates that we can learn about the block structure, regardless of scaling/rotations. # + colab={} colab_type="code" id="0D2QHaLRRFRo" outputId="891cf1a1-4aea-4c5f-f2d5-9895478b5361" plt.figure(figsize=(12, 12)) q=1 for k in range(K): plt.subplot(3,3,q) plt.imshow(true_slds.dynamics.As[k], aspect='auto', interpolation="none", vmin=-.5, vmax=.5, cmap='RdBu') offset=-.5 for nf in D_vec: plt.plot([-0.5, D-0.5], [offset, offset], '-k') plt.plot([offset, offset], [-0.5, D-0.5], '-k') offset += nf plt.xticks([]) plt.yticks([]) plt.title('Actual State '+str(k)) q=q+1 plt.subplot(3,3,q) # plt.imshow(rslds.dynamics.As[k], aspect='auto', interpolation="none", vmin=-.5, vmax=.5, cmap='RdBu') plt.imshow(R@rslds.dynamics.As[k]@np.linalg.inv(R), aspect='auto', interpolation="none", vmin=-.5, vmax=.5, cmap='RdBu') offset=-.5 for nf in D_vec: plt.plot([-0.5, D-0.5], [offset, offset], '-k') plt.plot([offset, offset], [-0.5, D-0.5], '-k') offset += nf plt.xticks([]) plt.yticks([]) plt.title('Aligned Predicted State '+str(k)) # plt.savefig(folder+'dyn_est'+str(k)+'.pdf') q=q+1 plt.subplot(3,3,q) # plt.imshow(rslds.dynamics.As[k], aspect='auto', interpolation="none", vmin=-.5, vmax=.5, cmap='RdBu') plt.imshow(rslds.dynamics.As[k], aspect='auto', interpolation="none", vmin=-.5, vmax=.5, cmap='RdBu') offset=-.5 for nf in D_vec: plt.plot([-0.5, D-0.5], [offset, offset], '-k') plt.plot([offset, offset], [-0.5, D-0.5], '-k') offset += nf plt.xticks([]) plt.yticks([]) plt.title('Raw Predicted State '+str(k)) # plt.savefig(folder+'dyn_est'+str(k)+'.pdf') q=q+1 # + [markdown] colab_type="text" id="b14UYWuMRFRr" # ### Transitions # The contribution of population $j$ to staying in a state is $S_j x^{j}$ and the contribution to switching to a state is $R_j x^{j}$ # # + colab={} colab_type="code" id="hja_9YLPRFRr" outputId="6eb42921-5277-43ea-b14b-2366d2128509" plt.figure(figsize=(15, 4)) ### Actual j=0 plt.subplot(221) for g in range(K): plt.plot(np.dot(xs[st_t:end_t,D_vec_cumsum[g]:D_vec_cumsum[g+1]],true_slds.transitions.Rs[j,D_vec_cumsum[g]:D_vec_cumsum[g+1]].T)) plt.xlim(0, dur) plt.ylabel('Contribution towards \n switching to \n purple state',rotation=60) plt.xticks([]) plt.yticks([]) plt.title('Actual') j=1 plt.subplot(223) for g in range(K): plt.plot(np.dot(xs[st_t:end_t,D_vec_cumsum[g]:D_vec_cumsum[g+1]],true_slds.transitions.Ss[j,D_vec_cumsum[g]:D_vec_cumsum[g+1]].T)) plt.xlim(0, dur) plt.ylabel('Contribution towards \n staying in red state',rotation=60) plt.legend(['Pop 1','Pop 2','Pop 3']) plt.yticks([]) ### Predicted j=0 plt.subplot(222) for g in range(K): plt.plot(np.dot(q_ar.mean_continuous_states[0][st_t:end_t,D_vec_cumsum[g]:D_vec_cumsum[g+1]],rslds.transitions.Rs[j,D_vec_cumsum[g]:D_vec_cumsum[g+1]].T)) plt.xlim(0, dur) # plt.ylabel('Contribution towards \n switching to \n purple state',rotation=60) plt.xticks([]) plt.yticks([]) plt.title('Predicted') j=1 plt.subplot(224) for g in range(K): plt.plot(np.dot(q_ar.mean_continuous_states[0][st_t:end_t,D_vec_cumsum[g]:D_vec_cumsum[g+1]],rslds.transitions.Ss[j,D_vec_cumsum[g]:D_vec_cumsum[g+1]].T)) plt.xlim(0, dur) # plt.ylabel('Stay in Red') # plt.xticks([]) plt.yticks([]) # + [markdown] colab_type="text" id="yz4GSxjrRFRt" # ### Example fit of neural activity ($y$) # + colab={} colab_type="code" id="cgEffzVKRFRt" outputId="10d64661-8fbd-4259-96fc-134c35980b88" preds=rslds.smooth(q_ar.mean_continuous_states[0],ys) #get predictions nrn=0 #Example neuron plt.plot(ys[st_t:end_t,nrn],alpha=.5) #true spiking activity plt.plot(lams[st_t:end_t,nrn]) #true firing rate plt.plot(preds[st_t:end_t,nrn]) #predicted firing plt.legend(['True Spikes','True FR','Predicted FR']) .. rst-class:: sphx-glr-timing **Total running time of the script:** ( 4 minutes 8.973 seconds) .. _sphx_glr_download_auto_examples_Multi-Population-rSLDS.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: Multi-Population-rSLDS.py ` .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: Multi-Population-rSLDS.ipynb ` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_