Input Driven Observations (“GLM-HMM”)
Input Driven Observations (“GLM-HMM”)#
Notebook prepared by Zoe Ashwood: feel free to email me with feedback or questions (zashwood at cs dot princeton dot edu).
This notebook demonstrates the “InputDrivenObservations” class, and illustrates its use in the context of modeling decision-making data as in Ashwood et al. (2020) (Mice alternate between discrete strategies during perceptual decision-making).
Compared to the model considered in the notebook “2 Input Driven HMM”, Ashwood et al. (2020) assumes a stationary transition matrix where transition probabilities do not depend on external inputs. However, observation probabilities now do depend on external covariates according to:
for
$$
Equations 1 and 2 at the top of this notebook already take into account the fact that the weights for a particular class for a given state are fixed to zero (this is why
# Set transition matrix of multinomial GLM-HMM
gen_log_trans_mat = np.log(np.array([[[0.90, 0.04, 0.05, 0.01], [0.05, 0.92, 0.01, 0.02], [0.03, 0.02, 0.94, 0.01], [0.09, 0.01, 0.01, 0.89]]]))
true_glmhmm.transitions.params = gen_log_trans_mat
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In [1], line 2
1 # Set transition matrix of multinomial GLM-HMM
----> 2 gen_log_trans_mat = np.log(np.array([[[0.90, 0.04, 0.05, 0.01], [0.05, 0.92, 0.01, 0.02], [0.03, 0.02, 0.94, 0.01], [0.09, 0.01, 0.01, 0.89]]]))
3 true_glmhmm.transitions.params = gen_log_trans_mat
NameError: name 'np' is not defined
# Create external inputs sequence; compared to the example above, we will increase the number of examples
# (through the "num_trials_per_session" paramater) since the number of parameters has increased
num_sess = 20 # number of example sessions
num_trials_per_sess = 1000 # number of trials in a session
inpts = np.ones((num_sess, num_trials_per_sess, input_dim)) # initialize inpts array
stim_vals = [-1, -0.5, -0.25, -0.125, -0.0625, 0, 0.0625, 0.125, 0.25, 0.5, 1]
inpts[:,:,0] = np.random.choice(stim_vals, (num_sess, num_trials_per_sess)) # generate random sequence of stimuli
inpts = list(inpts)
# Generate a sequence of latents and choices for each session
true_latents, true_choices = [], []
for sess in range(num_sess):
true_z, true_y = true_glmhmm.sample(num_trials_per_sess, input=inpts[sess])
true_latents.append(true_z)
true_choices.append(true_y)
# plot example data:
fig = plt.figure(figsize=(8, 3), dpi=80, facecolor='w', edgecolor='k')
plt.step(range(100),true_choices[0][range(100)], color = "red")
plt.yticks([0, 1, 2])
plt.title("example data (multinomial GLM-HMM)")
plt.xlabel("trial #", fontsize = 15)
plt.ylabel("observation class", fontsize = 15)
# Calculate true loglikelihood
true_ll = true_glmhmm.log_probability(true_choices, inputs=inpts)
print("true ll = " + str(true_ll))
# fit GLM-HMM
new_glmhmm = ssm.HMM(num_states, obs_dim, input_dim, observations="input_driven_obs",
observation_kwargs=dict(C=num_categories), transitions="standard")
N_iters = 500 # maximum number of EM iterations. Fitting with stop earlier if increase in LL is below tolerance specified by tolerance parameter
fit_ll = new_glmhmm.fit(true_choices, inputs=inpts, method="em", num_iters=N_iters, tolerance=10**-4)
# Plot the log probabilities of the true and fit models. Fit model final LL should be greater
# than or equal to true LL.
fig = plt.figure(figsize=(4, 3), dpi=80, facecolor='w', edgecolor='k')
plt.plot(fit_ll, label="EM")
plt.plot([0, len(fit_ll)], true_ll * np.ones(2), ':k', label="True")
plt.legend(loc="lower right")
plt.xlabel("EM Iteration")
plt.xlim(0, len(fit_ll))
plt.ylabel("Log Probability")
plt.show()
# permute recovered state identities to match state identities of generative model
new_glmhmm.permute(find_permutation(true_latents[0], new_glmhmm.most_likely_states(true_choices[0], input=inpts[0])))
# Plot recovered parameters:
recovered_weights = new_glmhmm.observations.params
recovered_transitions = new_glmhmm.transitions.params
fig = plt.figure(figsize=(16, 8), dpi=80, facecolor='w', edgecolor='k')
plt.subplots_adjust(wspace=0.3, hspace=0.6)
plt.subplot(2, 2, 1)
cols = ['#ff7f00', '#4daf4a', '#377eb8', '#f781bf', '#a65628', '#984ea3', '#999999', '#e41a1c', '#dede00']
for c in range(num_categories):
plt.subplot(2, num_categories+1, c+1)
if c < num_categories-1:
for k in range(num_states):
plt.plot(range(input_dim), gen_weights[k,c], marker='o',
color=cols[k], lw=1.5, label="state " + str(k+1) + "; class " + str(c+1))
else:
for k in range(num_states):
plt.plot(range(input_dim), np.zeros(input_dim), marker='o',
color=cols[k], lw=1.5, label="state " + str(k+1) + "; class " + str(c+1), alpha = 0.5)
plt.axhline(y=0, color="k", alpha=0.5, ls="--")
plt.yticks(fontsize=10)
plt.xticks([0, 1], ['', ''])
if c == 0:
plt.ylabel("GLM weight", fontsize=15)
plt.legend()
plt.title("Generative weights; class " + str(c+1), fontsize = 15)
plt.ylim((-3, 10))
plt.subplot(2, num_categories+1, num_categories+1)
gen_trans_mat = np.exp(gen_log_trans_mat)[0]
plt.imshow(gen_trans_mat, vmin=-0.8, vmax=1, cmap='bone')
for i in range(gen_trans_mat.shape[0]):
for j in range(gen_trans_mat.shape[1]):
text = plt.text(j, i, str(np.around(gen_trans_mat[i, j], decimals=2)), ha="center", va="center",
color="k", fontsize=12)
plt.xlim(-0.5, num_states - 0.5)
plt.xticks(range(0, num_states), ('1', '2', '3', '4'), fontsize=10)
plt.yticks(range(0, num_states), ('1', '2', '3', '4'), fontsize=10)
plt.ylim(num_states - 0.5, -0.5)
plt.ylabel("state t", fontsize = 15)
plt.xlabel("state t+1", fontsize = 15)
plt.title("Generative transition matrix", fontsize = 15)
cols = ['#ff7f00', '#4daf4a', '#377eb8', '#f781bf', '#a65628', '#984ea3', '#999999', '#e41a1c', '#dede00']
for c in range(num_categories):
plt.subplot(2, num_categories+1, num_categories + c + 2)
if c < num_categories-1:
for k in range(num_states):
plt.plot(range(input_dim), recovered_weights[k,c], marker='o', linestyle = '--',
color=cols[k], lw=1.5, label="state " + str(k+1) + "; class " + str(c+1))
else:
for k in range(num_states):
plt.plot(range(input_dim), np.zeros(input_dim), marker='o', linestyle = '--',
color=cols[k], lw=1.5, label="state " + str(k+1) + "; class " + str(c+1), alpha = 0.5)
plt.axhline(y=0, color="k", alpha=0.5, ls="--")
plt.yticks(fontsize=10)
plt.xlabel("covariate", fontsize=15)
if c == 0:
plt.ylabel("GLM weight", fontsize=15)
plt.xticks([0, 1], ['stimulus', 'bias'], fontsize=12, rotation=45)
plt.legend()
plt.title("Recovered weights; class " + str(c+1), fontsize = 15)
plt.ylim((-3,10))
plt.subplot(2, num_categories+1, 2*num_categories+2)
recovered_trans_mat = np.exp(recovered_transitions)[0]
plt.imshow(recovered_trans_mat, vmin=-0.8, vmax=1, cmap='bone')
for i in range(recovered_trans_mat.shape[0]):
for j in range(recovered_trans_mat.shape[1]):
text = plt.text(j, i, str(np.around(recovered_trans_mat[i, j], decimals=2)), ha="center", va="center",
color="k", fontsize=12)
plt.xlim(-0.5, num_states - 0.5)
plt.xticks(range(0, num_states), ('1', '2', '3', '4'), fontsize=10)
plt.yticks(range(0, num_states), ('1', '2', '3', '4'), fontsize=10)
plt.ylim(num_states - 0.5, -0.5)
plt.ylabel("state t", fontsize = 15)
plt.xlabel("state t+1", fontsize = 15)
plt.title("Recovered transition matrix", fontsize = 15)