{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"%matplotlib inline"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"\n# Multi-Population rSLDS\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# + [markdown] colab_type=\"text\" id=\"view-in-github\"\n#
\n# -\n\n# ### If you want to quickly see how to fit your own data, jump down to the \"Fit model to data\" section\n#
\n#
\n#\n# # Multi-population recurrent switching linear dynamical systems overview\n#\n# This notebook goes through the simulation example shown in our manuscript (Figure 2A,B).\n#\n# Below, we briefly describe the model. We also recommend looking at the \"Recurrent SLDS\" notebook, which provides more details on the standard rSLDS.\n#
\n#
\n#\n# **1. Data**.\n# Let $y_t^{_{(j)}}$ denote a vector of activity measurements of the $N_j$ neurons in population $j$ in time bin $t$.\n#
\n#\n# **2. Emissions**. \n# let $x_t^{_{(j)}}$ denote a continuous latent state of population $j$ at time $t$. The population states may differ in dimensionality~$D_j$, since populations may differ in size and complexity. The observed activity of population $j$ is modeled with a generalized linear model,\n# \\begin{align}\n# E[y_t^{(j)}] &= f(C_j x_t^{(j)} + d_j),\n# \\end{align}\n# where each population has its own linear mapping parameterized by $\\{C_j, d_j\\}$. In this notebook, we use a Poisson GLM. Inputs can also be passed into this GLM, as described in the rSLDS notebook.\n#\n# There are multi-population emissions classes that will be loaded in the example below.\n#
\n#\n# **3. Continuous State Update (Dynamics)**. \n# The dynamics of a switching linear dynamical system are piecewise linear, with the linear dynamics at a given time determined by a discrete state, (more on discrete states below).\n#\n# \\begin{align}\n# x_t \\sim \n# A^{(z_t)} x_{t-1} + b^{(z_t)}\n# \\end{align} \n#\n# where $z_t$ is the discrete state, $A^{(z_t)}$ is the dynamics for that discrete state, and $x_t$ contains the latents from all populations, $[x_t^1, x_t^2, ..., x_t^J]$. We ignore the noise term here for simplicity.\n#\n# Having unique continuous latents for each population allows us to decompose the dynamics in an interpretable manner. We model the temporal dynamics of the continuous states as\n# \\begin{align}\n# x_t^{(j)} \\sim\n# A_{(j \\: to \\: j)}^{(z_t)} x_{t-1}^{(j)} \n# + \\sum_{i \\neq j} A_{(i \\: to \\: j)}^{(z_t)} x_{t-1}^{(i)} \n# + b_j^{(z_t)}.\n# \\end{align} \n#\n# In the full dynamics matrices, $A^{(z_t)}$ we will show in the example below, the on-diagonal blocks represent the internal dynamics, $A_{(j \\: to \\: j)}^{(z_t)}$ and the off-diagonal blocks represent the external dynamics, $A_{(i \\: to \\: j)}^{(z_t)}$.\n#\n#\n# **4. Discrete State Update (Transitions)**. \n# Recurrent transitions are based on the continuous latent state. Our recurrent transitions have a sticky component, $S$ that determines the probabilities of staying in a state, and a switching component, $R$, that determines the probabilities of switching to states. In the model we use in this notebook:\n#\n# \\begin{align}\n# p(z_t = i \\mid z_{t-1} = j, x_{t-1}) &= \\mathrm{softmax}\\bigg\\{ \\Big( \\big(R x_{t-1}\\big) + r\\Big) \\odot (1 - e_{z_{t-1}}) + \\Big( \\big(S x_{t-1} \\big) + s \\Big) \\odot e_{z_{t-1}} \\bigg\\},\n# \\end{align}\n#\n# where $e_{z_{t-1}} \\in \\{0,1\\}^K$ is a one-hot encoding of $z_{t-1}$.\n#\n# To understand which populations are contributing to the transitions, we can decompose this equation:\n#\n#\n# \\begin{align}\n# p(z_t = i \\mid z_{t-1} = j, x_{t-1}) &= \\mathrm{softmax}\\bigg\\{ \\Big( \\sum_{j=1}^J \\big(R_j x_{t-1}^{(j)}\\big) + r\\Big) \\odot (1 - e_{z_{t-1}}) + \\Big( \\sum_{j=1}^J \\big(S_j x_{t-1}^{(j)} \\big) + s \\Big) \\odot e_{z_{t-1}} \\bigg\\},\n# \\end{align}\n# where, for example, $R_j x_{t-1}^{(j)}$ contains the contribution of population $j$ towards switching to each state.\n#\n#\n# Additionally, we can include a dependency on the previous discrete state. This is included in the code package, but is not used in the example below.\n#\n# \\begin{align}\n# p(z_t = i \\mid z_{t-1} = j, x_{t-1}) &= \\mathrm{softmax}\\bigg\\{ \\log(P_{j,i}) + \\big(R x_{t-1}\\big) \\odot (1 - e_{z_{t-1}}) + \\big(S x_{t-1} \\big) \\odot e_{z_{t-1}} \\bigg\\},\n# \\end{align}\n#\n# There are sticky multi-population emissions classes that will be loaded in the example below.\n#
\n#\n# **5. Model fitting**. \n# We fit the model with variational laplace EM - see the \"Variational Laplace EM for SLDS Tutorial\" for more information.\n#\n\n# + [markdown] colab_type=\"text\" id=\"8OzC8q4bRFQv\"\n# ## Import packages, including multipopulation extensions\n\n# + colab={\"base_uri\": \"https://localhost:8080/\", \"height\": 581} colab_type=\"code\" id=\"ruUnNqi5RZqT\" outputId=\"228b6c8e-c064-46c2-ce57-9ad88daca5c8\"\ntry:\n import ssm\nexcept:\n # !pip install git+https://github.com/lindermanlab/ssm.git#egg=ssm\n import ssm\n\n# + colab={\"base_uri\": \"https://localhost:8080/\", \"height\": 71} colab_type=\"code\" id=\"zDn3tEJhRFQv\" outputId=\"2f1ca1d0-8f17-404a-897f-57b8c5d353cb\"\n#### General packages\n\nfrom matplotlib import pyplot as plt\n# %matplotlib inline\nimport autograd.numpy as np\nimport autograd.numpy.random as npr\n\nimport seaborn as sns\nsns.set_style(\"white\")\nsns.set_context(\"talk\")\nsns.set_style('ticks',{\"xtick.major.size\":8,\n\"ytick.major.size\":8})\nfrom ssm.plots import gradient_cmap, white_to_color_cmap\n\ncolor_names = [\n \"purple\",\n \"red\",\n \"amber\",\n \"faded green\",\n \"windows blue\",\n \"orange\"\n ]\n\ncolors = sns.xkcd_palette(color_names)\ncmap = gradient_cmap(colors)\n\n# + colab={} colab_type=\"code\" id=\"0rq19iIQRFQy\"\n#### SSM PACKAGES ###\n\nimport ssm\nfrom ssm.variational import SLDSMeanFieldVariationalPosterior, SLDSTriDiagVariationalPosterior, \\\n SLDSStructuredMeanFieldVariationalPosterior\nfrom ssm.util import random_rotation, find_permutation, relu\n\n#Load from extensions\nfrom ssm.extensions.mp_srslds.emissions_ext import GaussianOrthogonalCompoundEmissions, PoissonOrthogonalCompoundEmissions\nfrom ssm.extensions.mp_srslds.transitions_ext import StickyRecurrentOnlyTransitions, StickyRecurrentTransitions\n\n# + [markdown] colab_type=\"text\" id=\"Ty3EOi8bRFQ1\"\n# ## Simulate (somewhat realistic) data\n\n# + [markdown] colab_type=\"text\" id=\"QxDoYCRDRFQ2\"\n# ### Set parameters of simulation\n\n# + colab={} colab_type=\"code\" id=\"dalqY6zvRFQ2\"\nK=3 #Number of discrete states\n\nnum_gr=3 #Number of populations\nnum_per_gr=5 #Number of latents per population\nneur_per_gr=75 #Number of neurons per population\n\nt_end=3000 #number of time bins\nnum_trials=1 #number of trials\n\n# + colab={\"base_uri\": \"https://localhost:8080/\", \"height\": 34} colab_type=\"code\" id=\"OHTSTNbTRFQ4\" outputId=\"fd3b833b-8df2-425b-d3ce-103cf42d5153\"\nnp.random.seed(108) #To create replicable dynamics\n\nalphas=.03+.1*np.random.rand(K) #Determines the distribution of values in the dynamics matrix, for each discrete state\nprint('alphas:', alphas)\n\nsparsity=.33 #Proportion of non-diagonal blocks in the dynamics matrix that are 0\n\ne1=.1 #Amount of noise in the dynamics\n\n# + [markdown] colab_type=\"text\" id=\"t91lYSkPRFQ7\"\n# ### Get new emissions and transitions classes for the simulated data\n\n# + colab={} colab_type=\"code\" id=\"tQx540b6RFQ8\"\n#Vector containing number of latents per population\nD_vec=[]\nfor i in range(num_gr):\n D_vec.append(num_per_gr) \n\n#Vector containing number of neurons per population\nN_vec=[]\nfor i in range(num_gr):\n N_vec.append(neur_per_gr)\n\nD=np.sum(D_vec)\nnum_gr=len(D_vec)\nD_vec_cumsum = np.concatenate(([0], np.cumsum(D_vec)))\n\n#Get new multipopulation emissions class for the simulation\n\n# gauss_comp_emissions=GaussianOrthogonalCompoundEmissions(N=np.sum(N_vec),K=1,D=np.sum(D_vec),D_vec=D_vec,N_vec=N_vec)\npoiss_comp_emissions=PoissonOrthogonalCompoundEmissions(N=np.sum(N_vec),K=1,D=np.sum(D_vec),D_vec=D_vec,N_vec=N_vec,link='softplus')\n\n#Get transitions class\ntrue_sro_trans=StickyRecurrentOnlyTransitions(K=K,D=np.sum(D_vec)) \n\n# + [markdown] colab_type=\"text\" id=\"DDu2GnRGRFQ-\"\n# ### Create simulated data\n\n# + colab={} colab_type=\"code\" id=\"VLN8FWLLRFQ_\"\nnp.random.seed(10) #To create replicable simulations\n\nA_masks=[]\n\nA_all=np.zeros([K,D,D]) #Initialize dynamics matrix\nb_all=np.zeros([K,D]) #Initialize dynamics offset\n\n\n#Create initial ground truth model, that we will modify\ntrue_slds = ssm.SLDS(N=np.sum(N_vec),K=K,D=int(np.sum(D_vec)),\n dynamics=\"gaussian\",\n emissions=poiss_comp_emissions,\n transitions=true_sro_trans)\n\n#Create ground truth transitions\nv=.2+.2*np.random.rand(1)\nfor k in range(K):\n inc=np.copy(k)\n true_slds.transitions.Rs[k,D_vec_cumsum[inc]:D_vec_cumsum[inc]+1]=v\n true_slds.transitions.Ss[k,D_vec_cumsum[inc]:D_vec_cumsum[inc]+1]=v-.1\n\ntrue_slds.transitions.r=0*np.ones([K,1])\ntrue_slds.transitions.s=5*np.ones([K,1])\n\n#Create ground truth dynamics for each state\nfor k in range(K):\n\n ##Create dynamics##\n alpha=alphas[k]\n\n A_mask=np.random.rand(num_gr,num_gr)>sparsity #Make some blocks of the dynamics matrix 0\n\n A_masks.append(A_mask)\n\n for i in range(num_gr): \n A_mask[i,i]=1\n\n A0=np.zeros([D,D])\n for i in range(D-1):\n A0[i,i+1:]=alpha*np.random.randn(D-1-i)\n A0=(A0-A0.T)\n\n for i in range(num_gr):\n A0[D_vec_cumsum[i]:D_vec_cumsum[i+1],D_vec_cumsum[i]:D_vec_cumsum[i+1]]=2*A0[D_vec_cumsum[i]:D_vec_cumsum[i+1],D_vec_cumsum[i]:D_vec_cumsum[i+1]]\n\n\n A0=A0+np.identity(D)\n A=A0*np.kron(A_mask, np.ones((num_per_gr, num_per_gr)))\n\n A=A/(np.max(np.abs(np.linalg.eigvals(A)))+.01) #.97\n\n b=1*np.random.rand(D)\n\n A_all[k]=A\n b_all[k]=b\n\ntrue_slds.dynamics.As=A_all\ntrue_slds.dynamics.bs=b_all\n\n\nzs, xs, _ = true_slds.sample(t_end) #Sample discrete and continuous latents from model for simulation\n\n#Get spike trains that have an average firing rate of 0.25 per bin\ntmp=np.mean(relu(np.dot(true_slds.emissions.Cs[0],xs.T)+.1*true_slds.emissions.ds[0][:,None]).T)\nmult=.25/tmp\nlams=relu(mult*np.dot(true_slds.emissions.Cs[0],xs.T)+.1*true_slds.emissions.ds[0][:,None]).T\nys=np.random.poisson(lams) #Get spiking activity based on poisson statistics\n\n# + [markdown] colab_type=\"text\" id=\"twuPg8wRRFRC\"\n# ## Plot simulated data\n\n# + [markdown] colab_type=\"text\" id=\"1-VkH7xSRFRC\"\n# ### Dynamics matrices ($A^z$)\n\n# + colab={\"base_uri\": \"https://localhost:8080/\", \"height\": 797} colab_type=\"code\" id=\"qv7hO5fgRFRD\" outputId=\"7b2a3151-a5dc-4ac0-c446-ecbb212ffb61\"\n# vmin,vmax=[-1,1]\nvmin,vmax=[-.5,.5] #zoom in to see colors more clearly\n\n\nfor k in range(K):\n \n plt.figure(figsize=(4,4))\n plt.imshow(true_slds.dynamics.As[k], aspect='auto', interpolation=\"none\", vmin=vmin, vmax=vmax, cmap='RdBu')\n offset=-.5\n for nf in D_vec: \n plt.plot([-0.5, D-0.5], [offset, offset], '-k')\n plt.plot([offset, offset], [-0.5, D-0.5], '-k')\n offset += nf\n plt.xticks([])\n plt.yticks([])\n plt.title('Actual State '+str(k))\n\n# + [markdown] colab_type=\"text\" id=\"-Y7R3y6TRFRF\"\n# ### Discrete states ($z$)\n\n# + colab={\"base_uri\": \"https://localhost:8080/\", \"height\": 177} colab_type=\"code\" id=\"Wmk2_uI-RFRG\" outputId=\"2f434446-fbaf-4a00-d186-81a557b35011\"\nplt.figure(figsize=(8, 4))\nplt.subplot(211)\nplt.imshow(zs[None,:], aspect=\"auto\", cmap=cmap, vmin=0, vmax=len(colors)-1)\nplt.xlim(0, t_end)\nplt.ylabel(\"$z_{\\\\mathrm{true}}$\")\nplt.yticks([])\n\n# + [markdown] colab_type=\"text\" id=\"ufQwDDKTRFRI\"\n# ### Transitions (in a shorter time window)\n# The contribution of population $j$ to staying in a state is $S_j x^{j}$ and the contribution to switching to a state is $R_j x^{j}$\n\n# + colab={} colab_type=\"code\" id=\"fp1QQ_7dRFRI\"\ndur=200\nst_t=650\nend_t=st_t+dur\n\n# + colab={\"base_uri\": \"https://localhost:8080/\", \"height\": 310} colab_type=\"code\" id=\"-FWpoP7-RFRK\" outputId=\"bcc60411-fb9e-44d0-bc16-e382fe7919aa\"\nplt.figure(figsize=(8, 4))\n\nj=0\n\nplt.subplot(211)\nfor g in range(K):\n plt.plot(np.dot(xs[st_t:end_t,D_vec_cumsum[g]:D_vec_cumsum[g+1]],true_slds.transitions.Rs[j,D_vec_cumsum[g]:D_vec_cumsum[g+1]].T))\nplt.xlim(0, dur)\nplt.ylabel('Contribution towards \\n switching to \\n purple state',rotation=60)\nplt.xticks([])\nplt.yticks([])\n\nj=1\nplt.subplot(212)\nfor g in range(K):\n plt.plot(np.dot(xs[st_t:end_t,D_vec_cumsum[g]:D_vec_cumsum[g+1]],true_slds.transitions.Ss[j,D_vec_cumsum[g]:D_vec_cumsum[g+1]].T))\nplt.xlim(0, dur)\nplt.ylabel('Contribution towards \\n staying in red state',rotation=60)\nplt.legend(['Pop 1','Pop 2','Pop 3'])\nplt.yticks([])\n\n# + [markdown] colab_type=\"text\" id=\"wOO48DPARFRN\"\n# ### Continuous latents ($x$) and spikes ($y$) for an example population\n\n# + colab={\"base_uri\": \"https://localhost:8080/\", \"height\": 449} colab_type=\"code\" id=\"YllKBbdhRFRN\" outputId=\"dc7f9218-0100-4263-8501-9d2fb1696a9e\"\nplt.figure(figsize=(8, 4))\n\nplt.subplot(211)\nplt.plot(xs[st_t:end_t,:num_per_gr]) #Show latents of first group\nplt.xticks([])\n\nplt.subplot(212)\nplt.plot(ys[st_t:end_t,:10]) #Show first 10 neurons\n\n\n# + [markdown] colab_type=\"text\" id=\"OZ_iLjiyRFRV\"\n# ## Fit model to data\n\n# + [markdown] colab_type=\"text\" id=\"g6QfY0j-RFRV\"\n# #### To create the emissions classes for the multipopulation models, we need vectors containing the number of continuous latents per population (\"D_vec\") and neurons per population (\"N_vec\")\n#\n\n# + colab={} colab_type=\"code\" id=\"kVZDXyIiRFRW\"\nnum_gr=3 #Number of populations\nnum_per_gr=5 #Number of latents per population\nneur_per_gr=75 #Number of neurons per population\n\n#Vector containing number of latents per population\nD_vec=[]\nfor i in range(num_gr):\n D_vec.append(num_per_gr) \n\n#Vector containing number of neurons per population\nN_vec=[]\nfor i in range(num_gr):\n N_vec.append(neur_per_gr)\n\n# + [markdown] colab_type=\"text\" id=\"_I-JFqbHRFRZ\"\n# #### Now create the multipopulation emissions and transitions classes for our model\n\n# + colab={} colab_type=\"code\" id=\"dhnavgIgRFRa\"\n#Get new multipopulation emissions class\npoiss_comp_emissions=PoissonOrthogonalCompoundEmissions(N=np.sum(N_vec),K=1,D=np.sum(D_vec),D_vec=D_vec,N_vec=N_vec,link='softplus')\n\n#Get new transitions class\nsro_trans=StickyRecurrentOnlyTransitions(K=K,D=np.sum(D_vec), l2_penalty_similarity=10, l1_penalty=10) \n#The above l2 penalty is on the similarity between R and S (its assuming the activity to switch into a state is similar to activity to stay in a state)\n#The L1 penalty is on the entries of R and S\n# -\n\n# Note that another new emissions class is \"GaussianOrthogonalCompoundEmissions\"
\n#\n# Note that another new transitions class is \"StickyRecurrentTransitions\"\n\n# + [markdown] colab_type=\"text\" id=\"syHDAea_RFRc\"\n# #### Now declare and fit the model\n\n# + colab={\"base_uri\": \"https://localhost:8080/\", \"height\": 166, \"referenced_widgets\": [\"36b4e6c4b2064572a4412f8bc298caeb\", \"97247a689ff744689528aa9555c13c62\", \"581b2c1a39eb439581308a5cdea07389\", \"e8e6a6b80ed14b77900e87e1c779f459\", \"2e200ca5be3141f2b2ee041ec8b08e49\", \"f4e534f4d6ac45c6a7da04be7d341968\", \"ced98407e9ee4ee4bd0baaa24d4765b9\", \"14de5bf9b23f47d4867e135cf156ac97\", \"6bd4692a8c95492d84f6353069342d4f\", \"62f38f8736314a098f91d7486b5b84f6\", \"d7ffd5a5a8874c2d8e593cb29c8f14b2\", \"2ddbbe21934547db9e365d905ea1f024\", \"19ad77a391f7459f811262d5bdbeb783\", \"5da21c8eac7041898c7a468f0388e632\", \"eb1c51978e0f4b79968cf6618a959007\", \"c0204fcad9f64dd48580f7d17767abc6\"]} colab_type=\"code\" id=\"n7WRqQufRFRd\" outputId=\"28ddbaaa-221c-4ca8-db59-b76bbeff1dd0\"\nK=3 #Number of discrete states\n\nrslds = ssm.SLDS(N=np.sum(N_vec),K=K,D=np.sum(D_vec),\n dynamics=\"gaussian\",\n emissions=poiss_comp_emissions,\n transitions=sro_trans,\n dynamics_kwargs=dict(l2_penalty_A=100)) #Regularization on the dynamics matrix\n\nq_elbos_ar, q_ar = rslds.fit(ys, method=\"laplace_em\",\n variational_posterior=\"structured_meanfield\", \n continuous_optimizer='newton',\n initialize=True, \n num_init_restarts=10,\n num_iters=30, \n alpha=0.25)\n\n# + colab={} colab_type=\"code\" id=\"Oi8-3NcxRFRf\" outputId=\"98075512-a594-4bae-fcf3-550affe796a6\"\nplt.plot(q_elbos_ar[1:])\nplt.xlabel(\"Iteration\")\nplt.ylabel(\"ELBO\")\n# -\n\n# ## Align solution with simulation for plotting\n\n#The recovered discrete states can be permuted in any way. \n#Find permutation to match the discrete states in the model and the ground truth\nz_inferred=rslds.most_likely_states(q_ar.mean_continuous_states[0],ys)\nrslds.permute(find_permutation(zs, z_inferred))\nz_inferred2=rslds.most_likely_states(q_ar.mean_continuous_states[0],ys)\n\n# +\n#Each population's latents can be multiplied by an arbitrary rotation matrix\n#Additionally, there may be a change in scaling between the simulation ground truth and recovered latents,\n#because the simulation didn't constrain the effective emissions (C) matrix to be orthonormal like in the model\n\nfrom sklearn.linear_model import LinearRegression\n\nR=np.zeros([D,D])\nfor g in range(num_gr):\n lr=LinearRegression(fit_intercept=False)\n lr.fit(q_ar.mean_continuous_states[0][:,D_vec_cumsum[g]:D_vec_cumsum[g+1]],xs[:,D_vec_cumsum[g]:D_vec_cumsum[g+1]])\n R[D_vec_cumsum[g]:D_vec_cumsum[g+1],D_vec_cumsum[g]:D_vec_cumsum[g+1]]=lr.coef_\n\n# + [markdown] colab_type=\"text\" id=\"0CjkFH63RFRh\"\n# ## Plot results\n\n# + [markdown] colab_type=\"text\" id=\"HDlMQx4ZRFRh\"\n# ### Discrete states ($z$)\n\n# + colab={} colab_type=\"code\" id=\"TGk3zoBVRFRi\" outputId=\"d1cebd3c-17f7-4e48-a6d3-0b9f34fbd3af\"\nplt.figure(figsize=(8, 4))\nplt.subplot(211)\nplt.imshow(zs[None,:], aspect=\"auto\", cmap=cmap, vmin=0, vmax=len(colors)-1)\nplt.xlim(0, t_end)\nplt.ylabel(\"$z_{\\\\mathrm{true}}$\")\nplt.yticks([])\n\nplt.subplot(212)\nplt.imshow(z_inferred2[None,:], aspect=\"auto\", cmap=cmap, vmin=0, vmax=len(colors)-1)\nplt.xlim(0, t_end)\nplt.ylabel(\"$z_{\\\\mathrm{inferred}}$\")\nplt.yticks([])\nplt.xlabel(\"time\")\n\nplt.tight_layout()\n\n# + colab={} colab_type=\"code\" id=\"HRPr3DJlRFRj\" outputId=\"1d6b402b-fa6d-4040-f4bf-4e4c2188c8b8\"\nprint('Discrete state accuracy: ', np.mean(zs==z_inferred2))\n\n# + [markdown] colab_type=\"text\" id=\"SX1go1AyRFRl\"\n# #### Shorter time window\n\n# + colab={} colab_type=\"code\" id=\"UhjVHiYzRFRm\" outputId=\"5d4cf636-9055-46fe-af2b-d51531c97317\"\nplt.figure(figsize=(4, 2))\nplt.subplot(211)\nplt.imshow(zs[None,st_t:end_t], aspect=\"auto\", cmap=cmap, vmin=0, vmax=len(colors)-1)\nplt.xlim(0, dur)\nplt.ylabel(\"$z_{\\\\mathrm{true}}$\")\nplt.yticks([])\nplt.xticks([])\n\nplt.subplot(212)\nplt.imshow(z_inferred2[None,st_t:end_t], aspect=\"auto\", cmap=cmap, vmin=0, vmax=len(colors)-1)\nplt.xlim(0, dur)\nplt.ylabel(\"$z_{\\\\mathrm{inferred}}$\")\nplt.yticks([])\n\n# + [markdown] colab_type=\"text\" id=\"e4Cz1o8GRFRo\"\n# ### Dynamics matrices ($A^z$)\n# -\n\n# We show the A matrix from when the continuous latents are aligned to ground truth, demonstrating the ability to recover the ground truth dynamics.\n#\n# We also show the original recovered A matrix, which demonstrates that we can learn about the block structure, regardless of scaling/rotations.\n\n# + colab={} colab_type=\"code\" id=\"0D2QHaLRRFRo\" outputId=\"891cf1a1-4aea-4c5f-f2d5-9895478b5361\"\nplt.figure(figsize=(12, 12))\n\nq=1\n\nfor k in range(K):\n \n plt.subplot(3,3,q)\n plt.imshow(true_slds.dynamics.As[k], aspect='auto', interpolation=\"none\", vmin=-.5, vmax=.5, cmap='RdBu')\n offset=-.5\n for nf in D_vec: \n plt.plot([-0.5, D-0.5], [offset, offset], '-k')\n plt.plot([offset, offset], [-0.5, D-0.5], '-k')\n offset += nf\n plt.xticks([])\n plt.yticks([])\n plt.title('Actual State '+str(k))\n \n q=q+1\n\n plt.subplot(3,3,q)\n# plt.imshow(rslds.dynamics.As[k], aspect='auto', interpolation=\"none\", vmin=-.5, vmax=.5, cmap='RdBu')\n plt.imshow(R@rslds.dynamics.As[k]@np.linalg.inv(R), aspect='auto', interpolation=\"none\", vmin=-.5, vmax=.5, cmap='RdBu')\n\n offset=-.5\n for nf in D_vec: \n plt.plot([-0.5, D-0.5], [offset, offset], '-k')\n plt.plot([offset, offset], [-0.5, D-0.5], '-k')\n offset += nf\n plt.xticks([])\n plt.yticks([])\n plt.title('Aligned Predicted State '+str(k))\n# plt.savefig(folder+'dyn_est'+str(k)+'.pdf') \n\n q=q+1\n\n\n plt.subplot(3,3,q)\n# plt.imshow(rslds.dynamics.As[k], aspect='auto', interpolation=\"none\", vmin=-.5, vmax=.5, cmap='RdBu')\n plt.imshow(rslds.dynamics.As[k], aspect='auto', interpolation=\"none\", vmin=-.5, vmax=.5, cmap='RdBu')\n\n offset=-.5\n for nf in D_vec: \n plt.plot([-0.5, D-0.5], [offset, offset], '-k')\n plt.plot([offset, offset], [-0.5, D-0.5], '-k')\n offset += nf\n plt.xticks([])\n plt.yticks([])\n plt.title('Raw Predicted State '+str(k))\n# plt.savefig(folder+'dyn_est'+str(k)+'.pdf')\n \n q=q+1\n \n \n\n \n\n\n# + [markdown] colab_type=\"text\" id=\"b14UYWuMRFRr\"\n# ### Transitions\n# The contribution of population $j$ to staying in a state is $S_j x^{j}$ and the contribution to switching to a state is $R_j x^{j}$\n#\n\n# + colab={} colab_type=\"code\" id=\"hja_9YLPRFRr\" outputId=\"6eb42921-5277-43ea-b14b-2366d2128509\"\nplt.figure(figsize=(15, 4))\n\n\n### Actual\n\nj=0\n\nplt.subplot(221)\nfor g in range(K):\n plt.plot(np.dot(xs[st_t:end_t,D_vec_cumsum[g]:D_vec_cumsum[g+1]],true_slds.transitions.Rs[j,D_vec_cumsum[g]:D_vec_cumsum[g+1]].T))\nplt.xlim(0, dur)\nplt.ylabel('Contribution towards \\n switching to \\n purple state',rotation=60)\nplt.xticks([])\nplt.yticks([])\nplt.title('Actual')\n\nj=1\nplt.subplot(223)\nfor g in range(K):\n plt.plot(np.dot(xs[st_t:end_t,D_vec_cumsum[g]:D_vec_cumsum[g+1]],true_slds.transitions.Ss[j,D_vec_cumsum[g]:D_vec_cumsum[g+1]].T))\nplt.xlim(0, dur)\nplt.ylabel('Contribution towards \\n staying in red state',rotation=60)\nplt.legend(['Pop 1','Pop 2','Pop 3'])\nplt.yticks([])\n\n\n\n\n### Predicted\n\nj=0\n\nplt.subplot(222)\nfor g in range(K):\n plt.plot(np.dot(q_ar.mean_continuous_states[0][st_t:end_t,D_vec_cumsum[g]:D_vec_cumsum[g+1]],rslds.transitions.Rs[j,D_vec_cumsum[g]:D_vec_cumsum[g+1]].T))\nplt.xlim(0, dur)\n# plt.ylabel('Contribution towards \\n switching to \\n purple state',rotation=60)\nplt.xticks([])\nplt.yticks([])\nplt.title('Predicted')\n\nj=1\nplt.subplot(224)\nfor g in range(K):\n plt.plot(np.dot(q_ar.mean_continuous_states[0][st_t:end_t,D_vec_cumsum[g]:D_vec_cumsum[g+1]],rslds.transitions.Ss[j,D_vec_cumsum[g]:D_vec_cumsum[g+1]].T))\nplt.xlim(0, dur)\n# plt.ylabel('Stay in Red')\n# plt.xticks([])\nplt.yticks([])\n\n# + [markdown] colab_type=\"text\" id=\"yz4GSxjrRFRt\"\n# ### Example fit of neural activity ($y$)\n\n# + colab={} colab_type=\"code\" id=\"cgEffzVKRFRt\" outputId=\"10d64661-8fbd-4259-96fc-134c35980b88\"\npreds=rslds.smooth(q_ar.mean_continuous_states[0],ys) #get predictions\n\nnrn=0 #Example neuron\nplt.plot(ys[st_t:end_t,nrn],alpha=.5) #true spiking activity\nplt.plot(lams[st_t:end_t,nrn]) #true firing rate\nplt.plot(preds[st_t:end_t,nrn]) #predicted firing\n\nplt.legend(['True Spikes','True FR','Predicted FR'])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.13"
}
},
"nbformat": 4,
"nbformat_minor": 0
}