# ---
# jupyter:
# jupytext:
# text_representation:
# extension: .py
# format_name: light
# format_version: '1.5'
# jupytext_version: 1.14.1
# kernelspec:
# display_name: Python 3
# language: python
# name: python3
# ---
"""
Multi-Population rSLDS
======================
"""
# + [markdown] colab_type="text" id="view-in-github"
#
# -
# ### If you want to quickly see how to fit your own data, jump down to the "Fit model to data" section
#
#
#
# # Multi-population recurrent switching linear dynamical systems overview
#
# This notebook goes through the simulation example shown in our manuscript (Figure 2A,B).
#
# Below, we briefly describe the model. We also recommend looking at the "Recurrent SLDS" notebook, which provides more details on the standard rSLDS.
#
#
#
# **1. Data**.
# Let $y_t^{_{(j)}}$ denote a vector of activity measurements of the $N_j$ neurons in population $j$ in time bin $t$.
#
#
# **2. Emissions**.
# let $x_t^{_{(j)}}$ denote a continuous latent state of population $j$ at time $t$. The population states may differ in dimensionality~$D_j$, since populations may differ in size and complexity. The observed activity of population $j$ is modeled with a generalized linear model,
# \begin{align}
# E[y_t^{(j)}] &= f(C_j x_t^{(j)} + d_j),
# \end{align}
# where each population has its own linear mapping parameterized by $\{C_j, d_j\}$. In this notebook, we use a Poisson GLM. Inputs can also be passed into this GLM, as described in the rSLDS notebook.
#
# There are multi-population emissions classes that will be loaded in the example below.
#
#
# **3. Continuous State Update (Dynamics)**.
# The dynamics of a switching linear dynamical system are piecewise linear, with the linear dynamics at a given time determined by a discrete state, (more on discrete states below).
#
# \begin{align}
# x_t \sim
# A^{(z_t)} x_{t-1} + b^{(z_t)}
# \end{align}
#
# where $z_t$ is the discrete state, $A^{(z_t)}$ is the dynamics for that discrete state, and $x_t$ contains the latents from all populations, $[x_t^1, x_t^2, ..., x_t^J]$. We ignore the noise term here for simplicity.
#
# Having unique continuous latents for each population allows us to decompose the dynamics in an interpretable manner. We model the temporal dynamics of the continuous states as
# \begin{align}
# x_t^{(j)} \sim
# A_{(j \: to \: j)}^{(z_t)} x_{t-1}^{(j)}
# + \sum_{i \neq j} A_{(i \: to \: j)}^{(z_t)} x_{t-1}^{(i)}
# + b_j^{(z_t)}.
# \end{align}
#
# In the full dynamics matrices, $A^{(z_t)}$ we will show in the example below, the on-diagonal blocks represent the internal dynamics, $A_{(j \: to \: j)}^{(z_t)}$ and the off-diagonal blocks represent the external dynamics, $A_{(i \: to \: j)}^{(z_t)}$.
#
#
# **4. Discrete State Update (Transitions)**.
# Recurrent transitions are based on the continuous latent state. Our recurrent transitions have a sticky component, $S$ that determines the probabilities of staying in a state, and a switching component, $R$, that determines the probabilities of switching to states. In the model we use in this notebook:
#
# \begin{align}
# p(z_t = i \mid z_{t-1} = j, x_{t-1}) &= \mathrm{softmax}\bigg\{ \Big( \big(R x_{t-1}\big) + r\Big) \odot (1 - e_{z_{t-1}}) + \Big( \big(S x_{t-1} \big) + s \Big) \odot e_{z_{t-1}} \bigg\},
# \end{align}
#
# where $e_{z_{t-1}} \in \{0,1\}^K$ is a one-hot encoding of $z_{t-1}$.
#
# To understand which populations are contributing to the transitions, we can decompose this equation:
#
#
# \begin{align}
# p(z_t = i \mid z_{t-1} = j, x_{t-1}) &= \mathrm{softmax}\bigg\{ \Big( \sum_{j=1}^J \big(R_j x_{t-1}^{(j)}\big) + r\Big) \odot (1 - e_{z_{t-1}}) + \Big( \sum_{j=1}^J \big(S_j x_{t-1}^{(j)} \big) + s \Big) \odot e_{z_{t-1}} \bigg\},
# \end{align}
# where, for example, $R_j x_{t-1}^{(j)}$ contains the contribution of population $j$ towards switching to each state.
#
#
# Additionally, we can include a dependency on the previous discrete state. This is included in the code package, but is not used in the example below.
#
# \begin{align}
# p(z_t = i \mid z_{t-1} = j, x_{t-1}) &= \mathrm{softmax}\bigg\{ \log(P_{j,i}) + \big(R x_{t-1}\big) \odot (1 - e_{z_{t-1}}) + \big(S x_{t-1} \big) \odot e_{z_{t-1}} \bigg\},
# \end{align}
#
# There are sticky multi-population emissions classes that will be loaded in the example below.
#
#
# **5. Model fitting**.
# We fit the model with variational laplace EM - see the "Variational Laplace EM for SLDS Tutorial" for more information.
#
# + [markdown] colab_type="text" id="8OzC8q4bRFQv"
# ## Import packages, including multipopulation extensions
# + colab={"base_uri": "https://localhost:8080/", "height": 581} colab_type="code" id="ruUnNqi5RZqT" outputId="228b6c8e-c064-46c2-ce57-9ad88daca5c8"
try:
import ssm
except:
# !pip install git+https://github.com/lindermanlab/ssm.git#egg=ssm
import ssm
# + colab={"base_uri": "https://localhost:8080/", "height": 71} colab_type="code" id="zDn3tEJhRFQv" outputId="2f1ca1d0-8f17-404a-897f-57b8c5d353cb"
#### General packages
from matplotlib import pyplot as plt
# %matplotlib inline
import autograd.numpy as np
import autograd.numpy.random as npr
import seaborn as sns
sns.set_style("white")
sns.set_context("talk")
sns.set_style('ticks',{"xtick.major.size":8,
"ytick.major.size":8})
from ssm.plots import gradient_cmap, white_to_color_cmap
color_names = [
"purple",
"red",
"amber",
"faded green",
"windows blue",
"orange"
]
colors = sns.xkcd_palette(color_names)
cmap = gradient_cmap(colors)
# + colab={} colab_type="code" id="0rq19iIQRFQy"
#### SSM PACKAGES ###
import ssm
from ssm.variational import SLDSMeanFieldVariationalPosterior, SLDSTriDiagVariationalPosterior, \
SLDSStructuredMeanFieldVariationalPosterior
from ssm.util import random_rotation, find_permutation, relu
#Load from extensions
from ssm.extensions.mp_srslds.emissions_ext import GaussianOrthogonalCompoundEmissions, PoissonOrthogonalCompoundEmissions
from ssm.extensions.mp_srslds.transitions_ext import StickyRecurrentOnlyTransitions, StickyRecurrentTransitions
# + [markdown] colab_type="text" id="Ty3EOi8bRFQ1"
# ## Simulate (somewhat realistic) data
# + [markdown] colab_type="text" id="QxDoYCRDRFQ2"
# ### Set parameters of simulation
# + colab={} colab_type="code" id="dalqY6zvRFQ2"
K=3 #Number of discrete states
num_gr=3 #Number of populations
num_per_gr=5 #Number of latents per population
neur_per_gr=75 #Number of neurons per population
t_end=3000 #number of time bins
num_trials=1 #number of trials
# + colab={"base_uri": "https://localhost:8080/", "height": 34} colab_type="code" id="OHTSTNbTRFQ4" outputId="fd3b833b-8df2-425b-d3ce-103cf42d5153"
np.random.seed(108) #To create replicable dynamics
alphas=.03+.1*np.random.rand(K) #Determines the distribution of values in the dynamics matrix, for each discrete state
print('alphas:', alphas)
sparsity=.33 #Proportion of non-diagonal blocks in the dynamics matrix that are 0
e1=.1 #Amount of noise in the dynamics
# + [markdown] colab_type="text" id="t91lYSkPRFQ7"
# ### Get new emissions and transitions classes for the simulated data
# + colab={} colab_type="code" id="tQx540b6RFQ8"
#Vector containing number of latents per population
D_vec=[]
for i in range(num_gr):
D_vec.append(num_per_gr)
#Vector containing number of neurons per population
N_vec=[]
for i in range(num_gr):
N_vec.append(neur_per_gr)
D=np.sum(D_vec)
num_gr=len(D_vec)
D_vec_cumsum = np.concatenate(([0], np.cumsum(D_vec)))
#Get new multipopulation emissions class for the simulation
# gauss_comp_emissions=GaussianOrthogonalCompoundEmissions(N=np.sum(N_vec),K=1,D=np.sum(D_vec),D_vec=D_vec,N_vec=N_vec)
poiss_comp_emissions=PoissonOrthogonalCompoundEmissions(N=np.sum(N_vec),K=1,D=np.sum(D_vec),D_vec=D_vec,N_vec=N_vec,link='softplus')
#Get transitions class
true_sro_trans=StickyRecurrentOnlyTransitions(K=K,D=np.sum(D_vec))
# + [markdown] colab_type="text" id="DDu2GnRGRFQ-"
# ### Create simulated data
# + colab={} colab_type="code" id="VLN8FWLLRFQ_"
np.random.seed(10) #To create replicable simulations
A_masks=[]
A_all=np.zeros([K,D,D]) #Initialize dynamics matrix
b_all=np.zeros([K,D]) #Initialize dynamics offset
#Create initial ground truth model, that we will modify
true_slds = ssm.SLDS(N=np.sum(N_vec),K=K,D=int(np.sum(D_vec)),
dynamics="gaussian",
emissions=poiss_comp_emissions,
transitions=true_sro_trans)
#Create ground truth transitions
v=.2+.2*np.random.rand(1)
for k in range(K):
inc=np.copy(k)
true_slds.transitions.Rs[k,D_vec_cumsum[inc]:D_vec_cumsum[inc]+1]=v
true_slds.transitions.Ss[k,D_vec_cumsum[inc]:D_vec_cumsum[inc]+1]=v-.1
true_slds.transitions.r=0*np.ones([K,1])
true_slds.transitions.s=5*np.ones([K,1])
#Create ground truth dynamics for each state
for k in range(K):
##Create dynamics##
alpha=alphas[k]
A_mask=np.random.rand(num_gr,num_gr)>sparsity #Make some blocks of the dynamics matrix 0
A_masks.append(A_mask)
for i in range(num_gr):
A_mask[i,i]=1
A0=np.zeros([D,D])
for i in range(D-1):
A0[i,i+1:]=alpha*np.random.randn(D-1-i)
A0=(A0-A0.T)
for i in range(num_gr):
A0[D_vec_cumsum[i]:D_vec_cumsum[i+1],D_vec_cumsum[i]:D_vec_cumsum[i+1]]=2*A0[D_vec_cumsum[i]:D_vec_cumsum[i+1],D_vec_cumsum[i]:D_vec_cumsum[i+1]]
A0=A0+np.identity(D)
A=A0*np.kron(A_mask, np.ones((num_per_gr, num_per_gr)))
A=A/(np.max(np.abs(np.linalg.eigvals(A)))+.01) #.97
b=1*np.random.rand(D)
A_all[k]=A
b_all[k]=b
true_slds.dynamics.As=A_all
true_slds.dynamics.bs=b_all
zs, xs, _ = true_slds.sample(t_end) #Sample discrete and continuous latents from model for simulation
#Get spike trains that have an average firing rate of 0.25 per bin
tmp=np.mean(relu(np.dot(true_slds.emissions.Cs[0],xs.T)+.1*true_slds.emissions.ds[0][:,None]).T)
mult=.25/tmp
lams=relu(mult*np.dot(true_slds.emissions.Cs[0],xs.T)+.1*true_slds.emissions.ds[0][:,None]).T
ys=np.random.poisson(lams) #Get spiking activity based on poisson statistics
# + [markdown] colab_type="text" id="twuPg8wRRFRC"
# ## Plot simulated data
# + [markdown] colab_type="text" id="1-VkH7xSRFRC"
# ### Dynamics matrices ($A^z$)
# + colab={"base_uri": "https://localhost:8080/", "height": 797} colab_type="code" id="qv7hO5fgRFRD" outputId="7b2a3151-a5dc-4ac0-c446-ecbb212ffb61"
# vmin,vmax=[-1,1]
vmin,vmax=[-.5,.5] #zoom in to see colors more clearly
for k in range(K):
plt.figure(figsize=(4,4))
plt.imshow(true_slds.dynamics.As[k], aspect='auto', interpolation="none", vmin=vmin, vmax=vmax, cmap='RdBu')
offset=-.5
for nf in D_vec:
plt.plot([-0.5, D-0.5], [offset, offset], '-k')
plt.plot([offset, offset], [-0.5, D-0.5], '-k')
offset += nf
plt.xticks([])
plt.yticks([])
plt.title('Actual State '+str(k))
# + [markdown] colab_type="text" id="-Y7R3y6TRFRF"
# ### Discrete states ($z$)
# + colab={"base_uri": "https://localhost:8080/", "height": 177} colab_type="code" id="Wmk2_uI-RFRG" outputId="2f434446-fbaf-4a00-d186-81a557b35011"
plt.figure(figsize=(8, 4))
plt.subplot(211)
plt.imshow(zs[None,:], aspect="auto", cmap=cmap, vmin=0, vmax=len(colors)-1)
plt.xlim(0, t_end)
plt.ylabel("$z_{\\mathrm{true}}$")
plt.yticks([])
# + [markdown] colab_type="text" id="ufQwDDKTRFRI"
# ### Transitions (in a shorter time window)
# The contribution of population $j$ to staying in a state is $S_j x^{j}$ and the contribution to switching to a state is $R_j x^{j}$
# + colab={} colab_type="code" id="fp1QQ_7dRFRI"
dur=200
st_t=650
end_t=st_t+dur
# + colab={"base_uri": "https://localhost:8080/", "height": 310} colab_type="code" id="-FWpoP7-RFRK" outputId="bcc60411-fb9e-44d0-bc16-e382fe7919aa"
plt.figure(figsize=(8, 4))
j=0
plt.subplot(211)
for g in range(K):
plt.plot(np.dot(xs[st_t:end_t,D_vec_cumsum[g]:D_vec_cumsum[g+1]],true_slds.transitions.Rs[j,D_vec_cumsum[g]:D_vec_cumsum[g+1]].T))
plt.xlim(0, dur)
plt.ylabel('Contribution towards \n switching to \n purple state',rotation=60)
plt.xticks([])
plt.yticks([])
j=1
plt.subplot(212)
for g in range(K):
plt.plot(np.dot(xs[st_t:end_t,D_vec_cumsum[g]:D_vec_cumsum[g+1]],true_slds.transitions.Ss[j,D_vec_cumsum[g]:D_vec_cumsum[g+1]].T))
plt.xlim(0, dur)
plt.ylabel('Contribution towards \n staying in red state',rotation=60)
plt.legend(['Pop 1','Pop 2','Pop 3'])
plt.yticks([])
# + [markdown] colab_type="text" id="wOO48DPARFRN"
# ### Continuous latents ($x$) and spikes ($y$) for an example population
# + colab={"base_uri": "https://localhost:8080/", "height": 449} colab_type="code" id="YllKBbdhRFRN" outputId="dc7f9218-0100-4263-8501-9d2fb1696a9e"
plt.figure(figsize=(8, 4))
plt.subplot(211)
plt.plot(xs[st_t:end_t,:num_per_gr]) #Show latents of first group
plt.xticks([])
plt.subplot(212)
plt.plot(ys[st_t:end_t,:10]) #Show first 10 neurons
# + [markdown] colab_type="text" id="OZ_iLjiyRFRV"
# ## Fit model to data
# + [markdown] colab_type="text" id="g6QfY0j-RFRV"
# #### To create the emissions classes for the multipopulation models, we need vectors containing the number of continuous latents per population ("D_vec") and neurons per population ("N_vec")
#
# + colab={} colab_type="code" id="kVZDXyIiRFRW"
num_gr=3 #Number of populations
num_per_gr=5 #Number of latents per population
neur_per_gr=75 #Number of neurons per population
#Vector containing number of latents per population
D_vec=[]
for i in range(num_gr):
D_vec.append(num_per_gr)
#Vector containing number of neurons per population
N_vec=[]
for i in range(num_gr):
N_vec.append(neur_per_gr)
# + [markdown] colab_type="text" id="_I-JFqbHRFRZ"
# #### Now create the multipopulation emissions and transitions classes for our model
# + colab={} colab_type="code" id="dhnavgIgRFRa"
#Get new multipopulation emissions class
poiss_comp_emissions=PoissonOrthogonalCompoundEmissions(N=np.sum(N_vec),K=1,D=np.sum(D_vec),D_vec=D_vec,N_vec=N_vec,link='softplus')
#Get new transitions class
sro_trans=StickyRecurrentOnlyTransitions(K=K,D=np.sum(D_vec), l2_penalty_similarity=10, l1_penalty=10)
#The above l2 penalty is on the similarity between R and S (its assuming the activity to switch into a state is similar to activity to stay in a state)
#The L1 penalty is on the entries of R and S
# -
# Note that another new emissions class is "GaussianOrthogonalCompoundEmissions"
#
# Note that another new transitions class is "StickyRecurrentTransitions"
# + [markdown] colab_type="text" id="syHDAea_RFRc"
# #### Now declare and fit the model
# + colab={"base_uri": "https://localhost:8080/", "height": 166, "referenced_widgets": ["36b4e6c4b2064572a4412f8bc298caeb", "97247a689ff744689528aa9555c13c62", "581b2c1a39eb439581308a5cdea07389", "e8e6a6b80ed14b77900e87e1c779f459", "2e200ca5be3141f2b2ee041ec8b08e49", "f4e534f4d6ac45c6a7da04be7d341968", "ced98407e9ee4ee4bd0baaa24d4765b9", "14de5bf9b23f47d4867e135cf156ac97", "6bd4692a8c95492d84f6353069342d4f", "62f38f8736314a098f91d7486b5b84f6", "d7ffd5a5a8874c2d8e593cb29c8f14b2", "2ddbbe21934547db9e365d905ea1f024", "19ad77a391f7459f811262d5bdbeb783", "5da21c8eac7041898c7a468f0388e632", "eb1c51978e0f4b79968cf6618a959007", "c0204fcad9f64dd48580f7d17767abc6"]} colab_type="code" id="n7WRqQufRFRd" outputId="28ddbaaa-221c-4ca8-db59-b76bbeff1dd0"
K=3 #Number of discrete states
rslds = ssm.SLDS(N=np.sum(N_vec),K=K,D=np.sum(D_vec),
dynamics="gaussian",
emissions=poiss_comp_emissions,
transitions=sro_trans,
dynamics_kwargs=dict(l2_penalty_A=100)) #Regularization on the dynamics matrix
q_elbos_ar, q_ar = rslds.fit(ys, method="laplace_em",
variational_posterior="structured_meanfield",
continuous_optimizer='newton',
initialize=True,
num_init_restarts=10,
num_iters=30,
alpha=0.25)
# + colab={} colab_type="code" id="Oi8-3NcxRFRf" outputId="98075512-a594-4bae-fcf3-550affe796a6"
plt.plot(q_elbos_ar[1:])
plt.xlabel("Iteration")
plt.ylabel("ELBO")
# -
# ## Align solution with simulation for plotting
#The recovered discrete states can be permuted in any way.
#Find permutation to match the discrete states in the model and the ground truth
z_inferred=rslds.most_likely_states(q_ar.mean_continuous_states[0],ys)
rslds.permute(find_permutation(zs, z_inferred))
z_inferred2=rslds.most_likely_states(q_ar.mean_continuous_states[0],ys)
# +
#Each population's latents can be multiplied by an arbitrary rotation matrix
#Additionally, there may be a change in scaling between the simulation ground truth and recovered latents,
#because the simulation didn't constrain the effective emissions (C) matrix to be orthonormal like in the model
from sklearn.linear_model import LinearRegression
R=np.zeros([D,D])
for g in range(num_gr):
lr=LinearRegression(fit_intercept=False)
lr.fit(q_ar.mean_continuous_states[0][:,D_vec_cumsum[g]:D_vec_cumsum[g+1]],xs[:,D_vec_cumsum[g]:D_vec_cumsum[g+1]])
R[D_vec_cumsum[g]:D_vec_cumsum[g+1],D_vec_cumsum[g]:D_vec_cumsum[g+1]]=lr.coef_
# + [markdown] colab_type="text" id="0CjkFH63RFRh"
# ## Plot results
# + [markdown] colab_type="text" id="HDlMQx4ZRFRh"
# ### Discrete states ($z$)
# + colab={} colab_type="code" id="TGk3zoBVRFRi" outputId="d1cebd3c-17f7-4e48-a6d3-0b9f34fbd3af"
plt.figure(figsize=(8, 4))
plt.subplot(211)
plt.imshow(zs[None,:], aspect="auto", cmap=cmap, vmin=0, vmax=len(colors)-1)
plt.xlim(0, t_end)
plt.ylabel("$z_{\\mathrm{true}}$")
plt.yticks([])
plt.subplot(212)
plt.imshow(z_inferred2[None,:], aspect="auto", cmap=cmap, vmin=0, vmax=len(colors)-1)
plt.xlim(0, t_end)
plt.ylabel("$z_{\\mathrm{inferred}}$")
plt.yticks([])
plt.xlabel("time")
plt.tight_layout()
# + colab={} colab_type="code" id="HRPr3DJlRFRj" outputId="1d6b402b-fa6d-4040-f4bf-4e4c2188c8b8"
print('Discrete state accuracy: ', np.mean(zs==z_inferred2))
# + [markdown] colab_type="text" id="SX1go1AyRFRl"
# #### Shorter time window
# + colab={} colab_type="code" id="UhjVHiYzRFRm" outputId="5d4cf636-9055-46fe-af2b-d51531c97317"
plt.figure(figsize=(4, 2))
plt.subplot(211)
plt.imshow(zs[None,st_t:end_t], aspect="auto", cmap=cmap, vmin=0, vmax=len(colors)-1)
plt.xlim(0, dur)
plt.ylabel("$z_{\\mathrm{true}}$")
plt.yticks([])
plt.xticks([])
plt.subplot(212)
plt.imshow(z_inferred2[None,st_t:end_t], aspect="auto", cmap=cmap, vmin=0, vmax=len(colors)-1)
plt.xlim(0, dur)
plt.ylabel("$z_{\\mathrm{inferred}}$")
plt.yticks([])
# + [markdown] colab_type="text" id="e4Cz1o8GRFRo"
# ### Dynamics matrices ($A^z$)
# -
# We show the A matrix from when the continuous latents are aligned to ground truth, demonstrating the ability to recover the ground truth dynamics.
#
# We also show the original recovered A matrix, which demonstrates that we can learn about the block structure, regardless of scaling/rotations.
# + colab={} colab_type="code" id="0D2QHaLRRFRo" outputId="891cf1a1-4aea-4c5f-f2d5-9895478b5361"
plt.figure(figsize=(12, 12))
q=1
for k in range(K):
plt.subplot(3,3,q)
plt.imshow(true_slds.dynamics.As[k], aspect='auto', interpolation="none", vmin=-.5, vmax=.5, cmap='RdBu')
offset=-.5
for nf in D_vec:
plt.plot([-0.5, D-0.5], [offset, offset], '-k')
plt.plot([offset, offset], [-0.5, D-0.5], '-k')
offset += nf
plt.xticks([])
plt.yticks([])
plt.title('Actual State '+str(k))
q=q+1
plt.subplot(3,3,q)
# plt.imshow(rslds.dynamics.As[k], aspect='auto', interpolation="none", vmin=-.5, vmax=.5, cmap='RdBu')
plt.imshow(R@rslds.dynamics.As[k]@np.linalg.inv(R), aspect='auto', interpolation="none", vmin=-.5, vmax=.5, cmap='RdBu')
offset=-.5
for nf in D_vec:
plt.plot([-0.5, D-0.5], [offset, offset], '-k')
plt.plot([offset, offset], [-0.5, D-0.5], '-k')
offset += nf
plt.xticks([])
plt.yticks([])
plt.title('Aligned Predicted State '+str(k))
# plt.savefig(folder+'dyn_est'+str(k)+'.pdf')
q=q+1
plt.subplot(3,3,q)
# plt.imshow(rslds.dynamics.As[k], aspect='auto', interpolation="none", vmin=-.5, vmax=.5, cmap='RdBu')
plt.imshow(rslds.dynamics.As[k], aspect='auto', interpolation="none", vmin=-.5, vmax=.5, cmap='RdBu')
offset=-.5
for nf in D_vec:
plt.plot([-0.5, D-0.5], [offset, offset], '-k')
plt.plot([offset, offset], [-0.5, D-0.5], '-k')
offset += nf
plt.xticks([])
plt.yticks([])
plt.title('Raw Predicted State '+str(k))
# plt.savefig(folder+'dyn_est'+str(k)+'.pdf')
q=q+1
# + [markdown] colab_type="text" id="b14UYWuMRFRr"
# ### Transitions
# The contribution of population $j$ to staying in a state is $S_j x^{j}$ and the contribution to switching to a state is $R_j x^{j}$
#
# + colab={} colab_type="code" id="hja_9YLPRFRr" outputId="6eb42921-5277-43ea-b14b-2366d2128509"
plt.figure(figsize=(15, 4))
### Actual
j=0
plt.subplot(221)
for g in range(K):
plt.plot(np.dot(xs[st_t:end_t,D_vec_cumsum[g]:D_vec_cumsum[g+1]],true_slds.transitions.Rs[j,D_vec_cumsum[g]:D_vec_cumsum[g+1]].T))
plt.xlim(0, dur)
plt.ylabel('Contribution towards \n switching to \n purple state',rotation=60)
plt.xticks([])
plt.yticks([])
plt.title('Actual')
j=1
plt.subplot(223)
for g in range(K):
plt.plot(np.dot(xs[st_t:end_t,D_vec_cumsum[g]:D_vec_cumsum[g+1]],true_slds.transitions.Ss[j,D_vec_cumsum[g]:D_vec_cumsum[g+1]].T))
plt.xlim(0, dur)
plt.ylabel('Contribution towards \n staying in red state',rotation=60)
plt.legend(['Pop 1','Pop 2','Pop 3'])
plt.yticks([])
### Predicted
j=0
plt.subplot(222)
for g in range(K):
plt.plot(np.dot(q_ar.mean_continuous_states[0][st_t:end_t,D_vec_cumsum[g]:D_vec_cumsum[g+1]],rslds.transitions.Rs[j,D_vec_cumsum[g]:D_vec_cumsum[g+1]].T))
plt.xlim(0, dur)
# plt.ylabel('Contribution towards \n switching to \n purple state',rotation=60)
plt.xticks([])
plt.yticks([])
plt.title('Predicted')
j=1
plt.subplot(224)
for g in range(K):
plt.plot(np.dot(q_ar.mean_continuous_states[0][st_t:end_t,D_vec_cumsum[g]:D_vec_cumsum[g+1]],rslds.transitions.Ss[j,D_vec_cumsum[g]:D_vec_cumsum[g+1]].T))
plt.xlim(0, dur)
# plt.ylabel('Stay in Red')
# plt.xticks([])
plt.yticks([])
# + [markdown] colab_type="text" id="yz4GSxjrRFRt"
# ### Example fit of neural activity ($y$)
# + colab={} colab_type="code" id="cgEffzVKRFRt" outputId="10d64661-8fbd-4259-96fc-134c35980b88"
preds=rslds.smooth(q_ar.mean_continuous_states[0],ys) #get predictions
nrn=0 #Example neuron
plt.plot(ys[st_t:end_t,nrn],alpha=.5) #true spiking activity
plt.plot(lams[st_t:end_t,nrn]) #true firing rate
plt.plot(preds[st_t:end_t,nrn]) #predicted firing
plt.legend(['True Spikes','True FR','Predicted FR'])