Input Driven Observations (GLM-HMM)#

  • Generative weights, Generative transition matrix
  • 2b Input Driven Observations (GLM HMM)
  • Weight recovery
  • generative, recovered
  • 2b Input Driven Observations (GLM HMM)
  • 2b Input Driven Observations (GLM HMM)
  • 2b Input Driven Observations (GLM HMM)
  • 2b Input Driven Observations (GLM HMM)
  • MLE, MAP
  • generative, recovered - MLE, recovered - MAP
  • example data (multinomial GLM-HMM)
  • 2b Input Driven Observations (GLM HMM)
  • Generative weights; class 1, Generative weights; class 2, Generative weights; class 3, Generative transition matrix, Recovered weights; class 1, Recovered weights; class 2, Recovered weights; class 3, Recovered transition matrix
true ll = -900.7834782398646

  0%|          | 0/200 [00:00<?, ?it/s]
LP: -1169.6:   0%|          | 0/200 [00:00<?, ?it/s]
LP: -1169.6:   0%|          | 0/200 [00:00<?, ?it/s]
LP: -1132.6:   0%|          | 0/200 [00:00<?, ?it/s]
LP: -1029.1:   0%|          | 0/200 [00:00<?, ?it/s]
LP: -979.9:   0%|          | 0/200 [00:00<?, ?it/s]
LP: -979.9:   2%|2         | 4/200 [00:00<00:06, 31.66it/s]
LP: -964.4:   2%|2         | 4/200 [00:00<00:06, 31.66it/s]
LP: -954.5:   2%|2         | 4/200 [00:00<00:06, 31.66it/s]
LP: -945.2:   2%|2         | 4/200 [00:00<00:06, 31.66it/s]
LP: -937.1:   2%|2         | 4/200 [00:00<00:06, 31.66it/s]
LP: -929.4:   2%|2         | 4/200 [00:00<00:06, 31.66it/s]
LP: -929.4:   4%|4         | 9/200 [00:00<00:04, 38.38it/s]
LP: -921.0:   4%|4         | 9/200 [00:00<00:04, 38.38it/s]
LP: -911.7:   4%|4         | 9/200 [00:00<00:04, 38.38it/s]
LP: -902.9:   4%|4         | 9/200 [00:00<00:04, 38.38it/s]
LP: -898.1:   4%|4         | 9/200 [00:00<00:04, 38.38it/s]
LP: -896.3:   4%|4         | 9/200 [00:00<00:04, 38.38it/s]
LP: -896.3:   7%|7         | 14/200 [00:00<00:04, 41.91it/s]
LP: -895.7:   7%|7         | 14/200 [00:00<00:04, 41.91it/s]
LP: -895.5:   7%|7         | 14/200 [00:00<00:04, 41.91it/s]
LP: -895.4:   7%|7         | 14/200 [00:00<00:04, 41.91it/s]
LP: -895.3:   7%|7         | 14/200 [00:00<00:04, 41.91it/s]
LP: -895.3:   7%|7         | 14/200 [00:00<00:04, 41.91it/s]
LP: -895.3:   7%|7         | 14/200 [00:00<00:04, 41.91it/s]
LP: -895.3:   7%|7         | 14/200 [00:00<00:04, 41.91it/s]
LP: -895.3:  10%|#         | 21/200 [00:00<00:03, 50.01it/s]
LP: -895.3:  10%|#         | 21/200 [00:00<00:03, 50.01it/s]
LP: -895.3:  10%|#         | 21/200 [00:00<00:03, 50.01it/s]
LP: -895.3:  10%|#         | 21/200 [00:00<00:03, 50.01it/s]
LP: -895.3:  10%|#         | 21/200 [00:00<00:03, 50.01it/s]
LP: -895.3:  10%|#         | 21/200 [00:00<00:03, 50.01it/s]
LP: -895.3:  10%|#         | 21/200 [00:00<00:03, 50.01it/s]
Converged to LP: -895.3:  10%|#         | 21/200 [00:00<00:03, 50.01it/s]
Converged to LP: -895.3:  13%|#3        | 26/200 [00:00<00:03, 48.85it/s]

  0%|          | 0/200 [00:00<?, ?it/s]
LP: -1147.8:   0%|          | 0/200 [00:00<?, ?it/s]
LP: -1148.5:   0%|          | 0/200 [00:00<?, ?it/s]
LP: -1084.4:   0%|          | 0/200 [00:00<?, ?it/s]
LP: -1002.1:   0%|          | 0/200 [00:00<?, ?it/s]
LP: -947.0:   0%|          | 0/200 [00:00<?, ?it/s]
LP: -947.0:   2%|2         | 4/200 [00:00<00:05, 32.89it/s]
LP: -921.2:   2%|2         | 4/200 [00:00<00:05, 32.89it/s]
LP: -908.7:   2%|2         | 4/200 [00:00<00:05, 32.89it/s]
LP: -900.2:   2%|2         | 4/200 [00:00<00:05, 32.89it/s]
LP: -894.5:   2%|2         | 4/200 [00:00<00:05, 32.89it/s]
LP: -891.6:   2%|2         | 4/200 [00:00<00:05, 32.89it/s]
LP: -891.6:   4%|4         | 9/200 [00:00<00:04, 39.45it/s]
LP: -890.4:   4%|4         | 9/200 [00:00<00:04, 39.45it/s]
LP: -889.9:   4%|4         | 9/200 [00:00<00:04, 39.45it/s]
LP: -889.7:   4%|4         | 9/200 [00:00<00:04, 39.45it/s]
LP: -889.6:   4%|4         | 9/200 [00:00<00:04, 39.45it/s]
LP: -889.6:   4%|4         | 9/200 [00:00<00:04, 39.45it/s]
LP: -889.6:   4%|4         | 9/200 [00:00<00:04, 39.45it/s]
LP: -889.6:   8%|7         | 15/200 [00:00<00:03, 47.84it/s]
LP: -889.6:   8%|7         | 15/200 [00:00<00:03, 47.84it/s]
LP: -889.6:   8%|7         | 15/200 [00:00<00:03, 47.84it/s]
LP: -889.6:   8%|7         | 15/200 [00:00<00:03, 47.84it/s]
LP: -889.6:   8%|7         | 15/200 [00:00<00:03, 47.84it/s]
LP: -889.6:   8%|7         | 15/200 [00:00<00:03, 47.84it/s]
LP: -889.6:   8%|7         | 15/200 [00:00<00:03, 47.84it/s]
LP: -889.6:   8%|7         | 15/200 [00:00<00:03, 47.84it/s]
Converged to LP: -889.6:   8%|7         | 15/200 [00:00<00:03, 47.84it/s]
Converged to LP: -889.6:  10%|#         | 21/200 [00:00<00:03, 50.00it/s]
(4, 2, 2)
true ll = -16835.397370367293

  0%|          | 0/500 [00:00<?, ?it/s]
LP: -19291.6:   0%|          | 0/500 [00:00<?, ?it/s]
LP: -19291.6:   0%|          | 0/500 [00:00<?, ?it/s]
LP: -19250.8:   0%|          | 0/500 [00:00<?, ?it/s]
LP: -19250.8:   0%|          | 2/500 [00:00<00:52,  9.45it/s]
LP: -18879.3:   0%|          | 2/500 [00:00<00:52,  9.45it/s]
LP: -18879.3:   1%|          | 3/500 [00:00<00:57,  8.63it/s]
LP: -18052.3:   1%|          | 3/500 [00:00<00:57,  8.63it/s]
LP: -18052.3:   1%|          | 4/500 [00:00<00:59,  8.34it/s]
LP: -17552.4:   1%|          | 4/500 [00:00<00:59,  8.34it/s]
LP: -17552.4:   1%|1         | 5/500 [00:00<01:01,  8.10it/s]
LP: -17278.2:   1%|1         | 5/500 [00:00<01:01,  8.10it/s]
LP: -17278.2:   1%|1         | 6/500 [00:00<01:00,  8.10it/s]
LP: -17141.8:   1%|1         | 6/500 [00:00<01:00,  8.10it/s]
LP: -17141.8:   1%|1         | 7/500 [00:00<00:59,  8.34it/s]
LP: -17077.7:   1%|1         | 7/500 [00:00<00:59,  8.34it/s]
LP: -17077.7:   2%|1         | 8/500 [00:00<00:58,  8.48it/s]
LP: -17037.6:   2%|1         | 8/500 [00:01<00:58,  8.48it/s]
LP: -17037.6:   2%|1         | 9/500 [00:01<00:57,  8.59it/s]
LP: -17004.9:   2%|1         | 9/500 [00:01<00:57,  8.59it/s]
LP: -17004.9:   2%|2         | 10/500 [00:01<00:55,  8.88it/s]
LP: -16975.2:   2%|2         | 10/500 [00:01<00:55,  8.88it/s]
LP: -16948.2:   2%|2         | 10/500 [00:01<00:55,  8.88it/s]
LP: -16948.2:   2%|2         | 12/500 [00:01<00:50,  9.60it/s]
LP: -16924.5:   2%|2         | 12/500 [00:01<00:50,  9.60it/s]
LP: -16904.3:   2%|2         | 12/500 [00:01<00:50,  9.60it/s]
LP: -16904.3:   3%|2         | 14/500 [00:01<00:48, 10.12it/s]
LP: -16887.7:   3%|2         | 14/500 [00:01<00:48, 10.12it/s]
LP: -16874.2:   3%|2         | 14/500 [00:01<00:48, 10.12it/s]
LP: -16874.2:   3%|3         | 16/500 [00:01<00:45, 10.60it/s]
LP: -16863.3:   3%|3         | 16/500 [00:01<00:45, 10.60it/s]
LP: -16854.6:   3%|3         | 16/500 [00:01<00:45, 10.60it/s]
LP: -16854.6:   4%|3         | 18/500 [00:01<00:44, 10.85it/s]
LP: -16847.9:   4%|3         | 18/500 [00:01<00:44, 10.85it/s]
LP: -16842.7:   4%|3         | 18/500 [00:02<00:44, 10.85it/s]
LP: -16842.7:   4%|4         | 20/500 [00:02<00:42, 11.17it/s]
LP: -16838.6:   4%|4         | 20/500 [00:02<00:42, 11.17it/s]
LP: -16835.5:   4%|4         | 20/500 [00:02<00:42, 11.17it/s]
LP: -16835.5:   4%|4         | 22/500 [00:02<00:40, 11.73it/s]
LP: -16833.1:   4%|4         | 22/500 [00:02<00:40, 11.73it/s]
LP: -16831.2:   4%|4         | 22/500 [00:02<00:40, 11.73it/s]
LP: -16831.2:   5%|4         | 24/500 [00:02<00:39, 12.15it/s]
LP: -16829.8:   5%|4         | 24/500 [00:02<00:39, 12.15it/s]
LP: -16828.6:   5%|4         | 24/500 [00:02<00:39, 12.15it/s]
LP: -16828.6:   5%|5         | 26/500 [00:02<00:37, 12.67it/s]
LP: -16827.7:   5%|5         | 26/500 [00:02<00:37, 12.67it/s]
LP: -16826.9:   5%|5         | 26/500 [00:02<00:37, 12.67it/s]
LP: -16826.9:   6%|5         | 28/500 [00:02<00:37, 12.61it/s]
LP: -16826.3:   6%|5         | 28/500 [00:02<00:37, 12.61it/s]
LP: -16825.8:   6%|5         | 28/500 [00:02<00:37, 12.61it/s]
LP: -16825.8:   6%|6         | 30/500 [00:02<00:35, 13.12it/s]
LP: -16825.5:   6%|6         | 30/500 [00:02<00:35, 13.12it/s]
LP: -16825.1:   6%|6         | 30/500 [00:02<00:35, 13.12it/s]
LP: -16825.1:   6%|6         | 32/500 [00:02<00:34, 13.57it/s]
LP: -16824.8:   6%|6         | 32/500 [00:03<00:34, 13.57it/s]
LP: -16824.6:   6%|6         | 32/500 [00:03<00:34, 13.57it/s]
LP: -16824.6:   7%|6         | 34/500 [00:03<00:34, 13.70it/s]
LP: -16824.3:   7%|6         | 34/500 [00:03<00:34, 13.70it/s]
LP: -16824.2:   7%|6         | 34/500 [00:03<00:34, 13.70it/s]
LP: -16824.2:   7%|7         | 36/500 [00:03<00:32, 14.24it/s]
LP: -16824.0:   7%|7         | 36/500 [00:03<00:32, 14.24it/s]
LP: -16823.9:   7%|7         | 36/500 [00:03<00:32, 14.24it/s]
LP: -16823.9:   8%|7         | 38/500 [00:03<00:30, 15.21it/s]
LP: -16823.8:   8%|7         | 38/500 [00:03<00:30, 15.21it/s]
LP: -16823.7:   8%|7         | 38/500 [00:03<00:30, 15.21it/s]
LP: -16823.7:   8%|8         | 40/500 [00:03<00:29, 15.34it/s]
LP: -16823.6:   8%|8         | 40/500 [00:03<00:29, 15.34it/s]
LP: -16823.5:   8%|8         | 40/500 [00:03<00:29, 15.34it/s]
LP: -16823.5:   8%|8         | 42/500 [00:03<00:29, 15.69it/s]
LP: -16823.4:   8%|8         | 42/500 [00:03<00:29, 15.69it/s]
LP: -16823.4:   8%|8         | 42/500 [00:03<00:29, 15.69it/s]
LP: -16823.4:   9%|8         | 44/500 [00:03<00:27, 16.34it/s]
LP: -16823.3:   9%|8         | 44/500 [00:03<00:27, 16.34it/s]
LP: -16823.3:   9%|8         | 44/500 [00:03<00:27, 16.34it/s]
LP: -16823.3:   9%|9         | 46/500 [00:03<00:27, 16.38it/s]
LP: -16823.2:   9%|9         | 46/500 [00:03<00:27, 16.38it/s]
LP: -16823.2:   9%|9         | 46/500 [00:03<00:27, 16.38it/s]
LP: -16823.2:  10%|9         | 48/500 [00:03<00:27, 16.46it/s]
LP: -16823.2:  10%|9         | 48/500 [00:03<00:27, 16.46it/s]
LP: -16823.1:  10%|9         | 48/500 [00:04<00:27, 16.46it/s]
LP: -16823.1:  10%|#         | 50/500 [00:04<00:26, 16.89it/s]
LP: -16823.1:  10%|#         | 50/500 [00:04<00:26, 16.89it/s]
LP: -16823.1:  10%|#         | 50/500 [00:04<00:26, 16.89it/s]
LP: -16823.1:  10%|#         | 52/500 [00:04<00:26, 16.84it/s]
LP: -16823.0:  10%|#         | 52/500 [00:04<00:26, 16.84it/s]
LP: -16823.0:  10%|#         | 52/500 [00:04<00:26, 16.84it/s]
LP: -16823.0:  11%|#         | 54/500 [00:04<00:25, 17.18it/s]
LP: -16823.0:  11%|#         | 54/500 [00:04<00:25, 17.18it/s]
LP: -16823.0:  11%|#         | 54/500 [00:04<00:25, 17.18it/s]
LP: -16823.0:  11%|#1        | 56/500 [00:04<00:26, 17.06it/s]
LP: -16823.0:  11%|#1        | 56/500 [00:04<00:26, 17.06it/s]
LP: -16822.9:  11%|#1        | 56/500 [00:04<00:26, 17.06it/s]
LP: -16822.9:  12%|#1        | 58/500 [00:04<00:25, 17.01it/s]
LP: -16822.9:  12%|#1        | 58/500 [00:04<00:25, 17.01it/s]
LP: -16822.9:  12%|#1        | 58/500 [00:04<00:25, 17.01it/s]
LP: -16822.9:  12%|#2        | 60/500 [00:04<00:25, 17.36it/s]
LP: -16822.9:  12%|#2        | 60/500 [00:04<00:25, 17.36it/s]
LP: -16822.9:  12%|#2        | 60/500 [00:04<00:25, 17.36it/s]
LP: -16822.9:  12%|#2        | 62/500 [00:04<00:24, 17.60it/s]
LP: -16822.9:  12%|#2        | 62/500 [00:04<00:24, 17.60it/s]
LP: -16822.8:  12%|#2        | 62/500 [00:04<00:24, 17.60it/s]
LP: -16822.8:  13%|#2        | 64/500 [00:04<00:25, 17.35it/s]
LP: -16822.8:  13%|#2        | 64/500 [00:04<00:25, 17.35it/s]
LP: -16822.8:  13%|#2        | 64/500 [00:04<00:25, 17.35it/s]
LP: -16822.8:  13%|#3        | 66/500 [00:04<00:24, 17.57it/s]
LP: -16822.8:  13%|#3        | 66/500 [00:04<00:24, 17.57it/s]
LP: -16822.8:  13%|#3        | 66/500 [00:05<00:24, 17.57it/s]
LP: -16822.8:  14%|#3        | 68/500 [00:05<00:24, 17.71it/s]
LP: -16822.8:  14%|#3        | 68/500 [00:05<00:24, 17.71it/s]
LP: -16822.8:  14%|#3        | 68/500 [00:05<00:24, 17.71it/s]
LP: -16822.8:  14%|#4        | 70/500 [00:05<00:24, 17.42it/s]
LP: -16822.8:  14%|#4        | 70/500 [00:05<00:24, 17.42it/s]
LP: -16822.8:  14%|#4        | 70/500 [00:05<00:24, 17.42it/s]
LP: -16822.8:  14%|#4        | 72/500 [00:05<00:24, 17.62it/s]
LP: -16822.8:  14%|#4        | 72/500 [00:05<00:24, 17.62it/s]
LP: -16822.7:  14%|#4        | 72/500 [00:05<00:24, 17.62it/s]
LP: -16822.7:  15%|#4        | 74/500 [00:05<00:23, 17.75it/s]
LP: -16822.7:  15%|#4        | 74/500 [00:05<00:23, 17.75it/s]
LP: -16822.7:  15%|#4        | 74/500 [00:05<00:23, 17.75it/s]
LP: -16822.7:  15%|#5        | 76/500 [00:05<00:23, 17.85it/s]
LP: -16822.7:  15%|#5        | 76/500 [00:05<00:23, 17.85it/s]
LP: -16822.7:  15%|#5        | 76/500 [00:05<00:23, 17.85it/s]
LP: -16822.7:  16%|#5        | 78/500 [00:05<00:23, 17.93it/s]
LP: -16822.7:  16%|#5        | 78/500 [00:05<00:23, 17.93it/s]
LP: -16822.7:  16%|#5        | 78/500 [00:05<00:23, 17.93it/s]
LP: -16822.7:  16%|#6        | 80/500 [00:05<00:23, 17.57it/s]
LP: -16822.7:  16%|#6        | 80/500 [00:05<00:23, 17.57it/s]
LP: -16822.7:  16%|#6        | 80/500 [00:05<00:23, 17.57it/s]
LP: -16822.7:  16%|#6        | 82/500 [00:05<00:24, 17.32it/s]
LP: -16822.7:  16%|#6        | 82/500 [00:05<00:24, 17.32it/s]
LP: -16822.7:  16%|#6        | 82/500 [00:05<00:24, 17.32it/s]
LP: -16822.7:  17%|#6        | 84/500 [00:05<00:23, 17.57it/s]
LP: -16822.7:  17%|#6        | 84/500 [00:06<00:23, 17.57it/s]
LP: -16822.7:  17%|#6        | 84/500 [00:06<00:23, 17.57it/s]
LP: -16822.7:  17%|#7        | 86/500 [00:06<00:23, 17.73it/s]
LP: -16822.7:  17%|#7        | 86/500 [00:06<00:23, 17.73it/s]
LP: -16822.7:  17%|#7        | 86/500 [00:06<00:23, 17.73it/s]
LP: -16822.7:  18%|#7        | 88/500 [00:06<00:23, 17.78it/s]
LP: -16822.7:  18%|#7        | 88/500 [00:06<00:23, 17.78it/s]
LP: -16822.7:  18%|#7        | 88/500 [00:06<00:23, 17.78it/s]
LP: -16822.7:  18%|#8        | 90/500 [00:06<00:22, 17.84it/s]
LP: -16822.7:  18%|#8        | 90/500 [00:06<00:22, 17.84it/s]
LP: -16822.7:  18%|#8        | 90/500 [00:06<00:22, 17.84it/s]
LP: -16822.7:  18%|#8        | 92/500 [00:06<00:22, 17.91it/s]
LP: -16822.7:  18%|#8        | 92/500 [00:06<00:22, 17.91it/s]
LP: -16822.7:  18%|#8        | 92/500 [00:06<00:22, 17.91it/s]
LP: -16822.7:  19%|#8        | 94/500 [00:06<00:22, 17.95it/s]
LP: -16822.7:  19%|#8        | 94/500 [00:06<00:22, 17.95it/s]
LP: -16822.7:  19%|#8        | 94/500 [00:06<00:22, 17.95it/s]
LP: -16822.7:  19%|#9        | 96/500 [00:06<00:22, 17.98it/s]
LP: -16822.7:  19%|#9        | 96/500 [00:06<00:22, 17.98it/s]
LP: -16822.6:  19%|#9        | 96/500 [00:06<00:22, 17.98it/s]
LP: -16822.6:  20%|#9        | 98/500 [00:06<00:22, 17.59it/s]
LP: -16822.6:  20%|#9        | 98/500 [00:06<00:22, 17.59it/s]
LP: -16822.6:  20%|#9        | 98/500 [00:06<00:22, 17.59it/s]
LP: -16822.6:  20%|##        | 100/500 [00:06<00:22, 17.70it/s]
LP: -16822.6:  20%|##        | 100/500 [00:06<00:22, 17.70it/s]
LP: -16822.6:  20%|##        | 100/500 [00:06<00:22, 17.70it/s]
LP: -16822.6:  20%|##        | 102/500 [00:06<00:22, 17.83it/s]
LP: -16822.6:  20%|##        | 102/500 [00:07<00:22, 17.83it/s]
LP: -16822.6:  20%|##        | 102/500 [00:07<00:22, 17.83it/s]
LP: -16822.6:  21%|##        | 104/500 [00:07<00:22, 17.90it/s]
LP: -16822.6:  21%|##        | 104/500 [00:07<00:22, 17.90it/s]
LP: -16822.6:  21%|##        | 104/500 [00:07<00:22, 17.90it/s]
LP: -16822.6:  21%|##1       | 106/500 [00:07<00:21, 18.01it/s]
LP: -16822.6:  21%|##1       | 106/500 [00:07<00:21, 18.01it/s]
LP: -16822.6:  21%|##1       | 106/500 [00:07<00:21, 18.01it/s]
LP: -16822.6:  22%|##1       | 108/500 [00:07<00:21, 18.07it/s]
LP: -16822.6:  22%|##1       | 108/500 [00:07<00:21, 18.07it/s]
LP: -16822.6:  22%|##1       | 108/500 [00:07<00:21, 18.07it/s]
LP: -16822.6:  22%|##2       | 110/500 [00:07<00:21, 18.07it/s]
LP: -16822.6:  22%|##2       | 110/500 [00:07<00:21, 18.07it/s]
LP: -16822.6:  22%|##2       | 110/500 [00:07<00:21, 18.07it/s]
LP: -16822.6:  22%|##2       | 112/500 [00:07<00:21, 18.11it/s]
LP: -16822.6:  22%|##2       | 112/500 [00:07<00:21, 18.11it/s]
LP: -16822.6:  22%|##2       | 112/500 [00:07<00:21, 18.11it/s]
LP: -16822.6:  23%|##2       | 114/500 [00:07<00:21, 18.12it/s]
LP: -16822.6:  23%|##2       | 114/500 [00:07<00:21, 18.12it/s]
LP: -16822.6:  23%|##2       | 114/500 [00:07<00:21, 18.12it/s]
LP: -16822.6:  23%|##3       | 116/500 [00:07<00:21, 18.18it/s]
LP: -16822.6:  23%|##3       | 116/500 [00:07<00:21, 18.18it/s]
LP: -16822.6:  23%|##3       | 116/500 [00:07<00:21, 18.18it/s]
LP: -16822.6:  24%|##3       | 118/500 [00:07<00:21, 18.10it/s]
LP: -16822.6:  24%|##3       | 118/500 [00:07<00:21, 18.10it/s]
LP: -16822.6:  24%|##3       | 118/500 [00:07<00:21, 18.10it/s]
LP: -16822.6:  24%|##4       | 120/500 [00:07<00:21, 18.06it/s]
LP: -16822.6:  24%|##4       | 120/500 [00:08<00:21, 18.06it/s]
LP: -16822.6:  24%|##4       | 120/500 [00:08<00:21, 18.06it/s]
LP: -16822.6:  24%|##4       | 122/500 [00:08<00:20, 18.08it/s]
LP: -16822.6:  24%|##4       | 122/500 [00:08<00:20, 18.08it/s]
LP: -16822.6:  24%|##4       | 122/500 [00:08<00:20, 18.08it/s]
LP: -16822.6:  25%|##4       | 124/500 [00:08<00:20, 18.09it/s]
LP: -16822.6:  25%|##4       | 124/500 [00:08<00:20, 18.09it/s]
LP: -16822.6:  25%|##4       | 124/500 [00:08<00:20, 18.09it/s]
LP: -16822.6:  25%|##5       | 126/500 [00:08<00:20, 18.07it/s]
LP: -16822.6:  25%|##5       | 126/500 [00:08<00:20, 18.07it/s]
LP: -16822.6:  25%|##5       | 126/500 [00:08<00:20, 18.07it/s]
LP: -16822.6:  26%|##5       | 128/500 [00:08<00:20, 18.10it/s]
LP: -16822.6:  26%|##5       | 128/500 [00:08<00:20, 18.10it/s]
LP: -16822.6:  26%|##5       | 128/500 [00:08<00:20, 18.10it/s]
LP: -16822.6:  26%|##6       | 130/500 [00:08<00:20, 17.96it/s]
LP: -16822.6:  26%|##6       | 130/500 [00:08<00:20, 17.96it/s]
LP: -16822.6:  26%|##6       | 130/500 [00:08<00:20, 17.96it/s]
LP: -16822.6:  26%|##6       | 132/500 [00:08<00:20, 17.93it/s]
LP: -16822.6:  26%|##6       | 132/500 [00:08<00:20, 17.93it/s]
LP: -16822.6:  26%|##6       | 132/500 [00:08<00:20, 17.93it/s]
LP: -16822.6:  27%|##6       | 134/500 [00:08<00:20, 17.95it/s]
LP: -16822.6:  27%|##6       | 134/500 [00:08<00:20, 17.95it/s]
LP: -16822.6:  27%|##6       | 134/500 [00:08<00:20, 17.95it/s]
LP: -16822.6:  27%|##7       | 136/500 [00:08<00:20, 17.97it/s]
LP: -16822.6:  27%|##7       | 136/500 [00:08<00:20, 17.97it/s]
LP: -16822.6:  27%|##7       | 136/500 [00:08<00:20, 17.97it/s]
LP: -16822.6:  28%|##7       | 138/500 [00:08<00:20, 17.93it/s]
LP: -16822.6:  28%|##7       | 138/500 [00:09<00:20, 17.93it/s]
LP: -16822.6:  28%|##7       | 138/500 [00:09<00:20, 17.93it/s]
LP: -16822.6:  28%|##8       | 140/500 [00:09<00:20, 17.94it/s]
LP: -16822.6:  28%|##8       | 140/500 [00:09<00:20, 17.94it/s]
LP: -16822.6:  28%|##8       | 140/500 [00:09<00:20, 17.94it/s]
LP: -16822.6:  28%|##8       | 142/500 [00:09<00:20, 17.57it/s]
LP: -16822.6:  28%|##8       | 142/500 [00:09<00:20, 17.57it/s]
LP: -16822.6:  28%|##8       | 142/500 [00:09<00:20, 17.57it/s]
LP: -16822.6:  29%|##8       | 144/500 [00:09<00:20, 17.72it/s]
LP: -16822.6:  29%|##8       | 144/500 [00:09<00:20, 17.72it/s]
LP: -16822.6:  29%|##8       | 144/500 [00:09<00:20, 17.72it/s]
LP: -16822.6:  29%|##9       | 146/500 [00:09<00:19, 17.77it/s]
LP: -16822.6:  29%|##9       | 146/500 [00:09<00:19, 17.77it/s]
LP: -16822.6:  29%|##9       | 146/500 [00:09<00:19, 17.77it/s]
LP: -16822.6:  30%|##9       | 148/500 [00:09<00:19, 17.82it/s]
LP: -16822.6:  30%|##9       | 148/500 [00:09<00:19, 17.82it/s]
LP: -16822.6:  30%|##9       | 148/500 [00:09<00:19, 17.82it/s]
LP: -16822.6:  30%|###       | 150/500 [00:09<00:19, 17.90it/s]
LP: -16822.6:  30%|###       | 150/500 [00:09<00:19, 17.90it/s]
LP: -16822.6:  30%|###       | 150/500 [00:09<00:19, 17.90it/s]
LP: -16822.6:  30%|###       | 152/500 [00:09<00:19, 17.90it/s]
LP: -16822.6:  30%|###       | 152/500 [00:09<00:19, 17.90it/s]
LP: -16822.6:  30%|###       | 152/500 [00:09<00:19, 17.90it/s]
LP: -16822.6:  31%|###       | 154/500 [00:09<00:19, 17.87it/s]
LP: -16822.6:  31%|###       | 154/500 [00:09<00:19, 17.87it/s]
LP: -16822.6:  31%|###       | 154/500 [00:09<00:19, 17.87it/s]
LP: -16822.6:  31%|###1      | 156/500 [00:09<00:19, 17.91it/s]
LP: -16822.6:  31%|###1      | 156/500 [00:10<00:19, 17.91it/s]
LP: -16822.6:  31%|###1      | 156/500 [00:10<00:19, 17.91it/s]
LP: -16822.6:  32%|###1      | 158/500 [00:10<00:19, 17.96it/s]
LP: -16822.6:  32%|###1      | 158/500 [00:10<00:19, 17.96it/s]
LP: -16822.6:  32%|###1      | 158/500 [00:10<00:19, 17.96it/s]
LP: -16822.6:  32%|###2      | 160/500 [00:10<00:18, 17.96it/s]
LP: -16822.6:  32%|###2      | 160/500 [00:10<00:18, 17.96it/s]
LP: -16822.6:  32%|###2      | 160/500 [00:10<00:18, 17.96it/s]
LP: -16822.6:  32%|###2      | 162/500 [00:10<00:18, 17.97it/s]
LP: -16822.6:  32%|###2      | 162/500 [00:10<00:18, 17.97it/s]
LP: -16822.6:  32%|###2      | 162/500 [00:10<00:18, 17.97it/s]
LP: -16822.6:  33%|###2      | 164/500 [00:10<00:18, 17.92it/s]
LP: -16822.6:  33%|###2      | 164/500 [00:10<00:18, 17.92it/s]
LP: -16822.6:  33%|###2      | 164/500 [00:10<00:18, 17.92it/s]
LP: -16822.6:  33%|###3      | 166/500 [00:10<00:18, 17.99it/s]
LP: -16822.6:  33%|###3      | 166/500 [00:10<00:18, 17.99it/s]
LP: -16822.6:  33%|###3      | 166/500 [00:10<00:18, 17.99it/s]
LP: -16822.6:  34%|###3      | 168/500 [00:10<00:18, 18.05it/s]
LP: -16822.6:  34%|###3      | 168/500 [00:10<00:18, 18.05it/s]
LP: -16822.6:  34%|###3      | 168/500 [00:10<00:18, 18.05it/s]
LP: -16822.6:  34%|###4      | 170/500 [00:10<00:18, 18.05it/s]
LP: -16822.6:  34%|###4      | 170/500 [00:10<00:18, 18.05it/s]
LP: -16822.6:  34%|###4      | 170/500 [00:10<00:18, 18.05it/s]
LP: -16822.6:  34%|###4      | 172/500 [00:10<00:18, 18.02it/s]
LP: -16822.6:  34%|###4      | 172/500 [00:10<00:18, 18.02it/s]
LP: -16822.6:  34%|###4      | 172/500 [00:10<00:18, 18.02it/s]
LP: -16822.6:  35%|###4      | 174/500 [00:10<00:18, 18.11it/s]
LP: -16822.6:  35%|###4      | 174/500 [00:11<00:18, 18.11it/s]
LP: -16822.6:  35%|###4      | 174/500 [00:11<00:18, 18.11it/s]
LP: -16822.6:  35%|###5      | 176/500 [00:11<00:17, 18.10it/s]
LP: -16822.6:  35%|###5      | 176/500 [00:11<00:17, 18.10it/s]
LP: -16822.6:  35%|###5      | 176/500 [00:11<00:17, 18.10it/s]
LP: -16822.6:  36%|###5      | 178/500 [00:11<00:17, 18.06it/s]
LP: -16822.6:  36%|###5      | 178/500 [00:11<00:17, 18.06it/s]
LP: -16822.6:  36%|###5      | 178/500 [00:11<00:17, 18.06it/s]
LP: -16822.6:  36%|###6      | 180/500 [00:11<00:17, 18.05it/s]
LP: -16822.6:  36%|###6      | 180/500 [00:11<00:17, 18.05it/s]
LP: -16822.6:  36%|###6      | 180/500 [00:11<00:17, 18.05it/s]
LP: -16822.6:  36%|###6      | 182/500 [00:11<00:17, 18.09it/s]
LP: -16822.6:  36%|###6      | 182/500 [00:11<00:17, 18.09it/s]
LP: -16822.6:  36%|###6      | 182/500 [00:11<00:17, 18.09it/s]
LP: -16822.6:  37%|###6      | 184/500 [00:11<00:17, 18.10it/s]
LP: -16822.6:  37%|###6      | 184/500 [00:11<00:17, 18.10it/s]
LP: -16822.6:  37%|###6      | 184/500 [00:11<00:17, 18.10it/s]
LP: -16822.6:  37%|###7      | 186/500 [00:11<00:17, 18.08it/s]
LP: -16822.6:  37%|###7      | 186/500 [00:11<00:17, 18.08it/s]
LP: -16822.6:  37%|###7      | 186/500 [00:11<00:17, 18.08it/s]
LP: -16822.6:  38%|###7      | 188/500 [00:11<00:17, 18.12it/s]
LP: -16822.6:  38%|###7      | 188/500 [00:11<00:17, 18.12it/s]
LP: -16822.6:  38%|###7      | 188/500 [00:11<00:17, 18.12it/s]
LP: -16822.6:  38%|###8      | 190/500 [00:11<00:17, 18.06it/s]
LP: -16822.6:  38%|###8      | 190/500 [00:11<00:17, 18.06it/s]
LP: -16822.6:  38%|###8      | 190/500 [00:11<00:17, 18.06it/s]
LP: -16822.6:  38%|###8      | 192/500 [00:11<00:17, 18.05it/s]
LP: -16822.6:  38%|###8      | 192/500 [00:12<00:17, 18.05it/s]
Converged to LP: -16822.6:  38%|###8      | 192/500 [00:12<00:17, 18.05it/s]
Converged to LP: -16822.6:  38%|###8      | 192/500 [00:12<00:19, 15.97it/s]
/Users/emdupre/Desktop/ssm/notebooks/2b-Input-Driven-Observations-(GLM-HMM).py:581: MatplotlibDeprecationWarning: Auto-removal of overlapping axes is deprecated since 3.6 and will be removed two minor releases later; explicitly call ax.remove() as needed.
  plt.subplot(2, num_categories+1, c+1)

Text(0.5, 1.0, 'Recovered transition matrix')

# # Input Driven Observations ("GLM-HMM")
#
# Notebook prepared by Zoe Ashwood: feel free to email me with feedback or questions (zashwood at cs dot princeton dot edu).
#
# This notebook demonstrates the "InputDrivenObservations" class, and illustrates its use in the context of modeling decision-making data as in Ashwood et al. (2020) ([Mice alternate between discrete strategies during perceptual
# decision-making](https://www.biorxiv.org/content/10.1101/2020.10.19.346353v1.full.pdf)).
#
# Compared to the model considered in the notebook ["2 Input Driven HMM"](https://github.com/lindermanlab/ssm/blob/master/notebooks/2%20Input%20Driven%20HMM.ipynb), Ashwood et al. (2020) assumes a stationary transition matrix where transition probabilities *do not* depend on external inputs. However, observation probabilities now *do* depend on external covariates according to:
#
#
# for $c \neq C$:
# $$
# \begin{align}
# \Pr(y_t = c \mid z_{t} = k, u_t, w_{kc}) =
# \frac{\exp\{w_{kc}^\mathsf{T} u_t\}}
# {1+\sum_{c'=1}^{C-1} \exp\{w_{kc'}^\mathsf{T} u_t\}}
# \end{align}
# $$
#
# and for $c = C$:
# $$
# \begin{align}
# \Pr(y_t = c \mid z_{t} = k, u_t, w_{kc}) =
# \frac{1}
# {1+\sum_{c'=1}^{C-1} \exp\{w_{kc'}^\mathsf{T} u_t\}}
# \end{align}
# $$
#
# where $c \in \{1, ..., C\}$ indicates the categorical class for the observation, $u_{t} \in \mathbb{R}^{M}$ is the set of input covariates, and $w_{kc} \in \mathbb{R}^{M}$ is the set of input weights associated with state $k$ and class $c$. These weights, along with the transition matrix and initial state probabilities, will be learned.
#
# In Ashwood et al. (2020), $C = 2$ as $y_{t}$ represents the binary choice made by an animal during a 2AFC (2-Alternative Forced Choice) task. The above equations then reduce to:
#
# $$
# \begin{align}
# \Pr(y_t = 0 \mid z_{t} = k, u_t, w_{k}) =
# \frac{\exp\{w_{k}^\mathsf{T} u_t\}}
# {1 + \exp\{w_{k}^\mathsf{T} u_t\}} = \frac{1}
# {1 + \exp\{-w_{k}^\mathsf{T} u_t\}}.
# \end{align}
# $$
#
# $$
# \begin{align}
# \Pr(y_t = 1 \mid z_{t} = k, u_t, w_{k}) =
# \frac{1}
# {1 + \exp\{w_{k}^\mathsf{T} u_t\}}.
# \end{align}
# $$
#
# and only a single weight vector, $w_{k}$, is associated with each state.

# ## 1. Setup
# The line `import ssm` imports the package for use. Here, we have also imported a few other packages for plotting.

# +
import numpy as np
import numpy.random as npr
import matplotlib.pyplot as plt
import ssm
from ssm.util import find_permutation

npr.seed(0)
# -

# ## 2. Input Driven Observations
# We create a HMM with input-driven observations and 'standard' (stationary) transitions with the following line:
# ```python
#         ssm.HMM(num_states, obs_dim, input_dim, observations="input_driven_obs", observation_kwargs=dict(C=num_categories), transitions="standard")
# ```
#
# As in Ashwood et al. (2020), we are going to model an animal's binary choice data during a decision-making task, so we will set `num_categories=2` because the animal only has two options available to it. We will also set `obs_dim = 1` because the dimensionality of the observation data is 1 (if we were also modeling, for example, the binned reaction time of the animal, we could set `obs_dim = 2`).  For the sake of simplicity, we will assume that an animal's choice in a particular state is only affected by the external stimulus associated with that particular trial, and its innate choice bias. Thus, we will set `input_dim = 2` and we will simulate input data that resembles sequences of stimuli in what follows.  In Ashwood et al. (2020), they found that many mice used 3 decision-making states when performing 2AFC tasks. We will, thus, set `num_states = 3`.

# ### 2a. Initialize GLM-HMM

# +
# Set the parameters of the GLM-HMM
num_states = 3        # number of discrete states
obs_dim = 1           # number of observed dimensions
num_categories = 2    # number of categories for output
input_dim = 2         # input dimensions

# Make a GLM-HMM
true_glmhmm = ssm.HMM(num_states, obs_dim, input_dim, observations="input_driven_obs",
                   observation_kwargs=dict(C=num_categories), transitions="standard")
# -

# ### 2b. Specify parameters of generative GLM-HMM

# Let's update the weights and transition matrix for the true GLM-HMM so as to bring the GLM-HMM to the parameter regime that real animals use (according to Ashwood et al. (2020)):

gen_weights = np.array([[[6, 1]], [[2, -3]], [[2, 3]]])
gen_log_trans_mat = np.log(np.array([[[0.98, 0.01, 0.01], [0.05, 0.92, 0.03], [0.03, 0.03, 0.94]]]))
true_glmhmm.observations.params = gen_weights
true_glmhmm.transitions.params = gen_log_trans_mat

# +
# Plot generative parameters:
fig = plt.figure(figsize=(8, 3), dpi=80, facecolor='w', edgecolor='k')
plt.subplot(1, 2, 1)
cols = ['#ff7f00', '#4daf4a', '#377eb8']
for k in range(num_states):
    plt.plot(range(input_dim), gen_weights[k][0], marker='o',
             color=cols[k], linestyle='-',
             lw=1.5, label="state " + str(k+1))
plt.yticks(fontsize=10)
plt.ylabel("GLM weight", fontsize=15)
plt.xlabel("covariate", fontsize=15)
plt.xticks([0, 1], ['stimulus', 'bias'], fontsize=12, rotation=45)
plt.axhline(y=0, color="k", alpha=0.5, ls="--")
plt.legend()
plt.title("Generative weights", fontsize = 15)

plt.subplot(1, 2, 2)
gen_trans_mat = np.exp(gen_log_trans_mat)[0]
plt.imshow(gen_trans_mat, vmin=-0.8, vmax=1, cmap='bone')
for i in range(gen_trans_mat.shape[0]):
    for j in range(gen_trans_mat.shape[1]):
        text = plt.text(j, i, str(np.around(gen_trans_mat[i, j], decimals=2)), ha="center", va="center",
                        color="k", fontsize=12)
plt.xlim(-0.5, num_states - 0.5)
plt.xticks(range(0, num_states), ('1', '2', '3'), fontsize=10)
plt.yticks(range(0, num_states), ('1', '2', '3'), fontsize=10)
plt.ylim(num_states - 0.5, -0.5)
plt.ylabel("state t", fontsize = 15)
plt.xlabel("state t+1", fontsize = 15)
plt.title("Generative transition matrix", fontsize = 15)
# -

# ### 2c. Create external input sequences

# Simulate an example set of external inputs for each trial in a session. We will create an array of size `(num_sess x num_trials_per_sess x num_covariates)`. As in Ashwood et al. (2020), for each trial in a session we will include the stimulus presented to the animal at that trial, as well as a '1' as the second covariate (so as to capture the animal's innate bias for one of the two options available to it). We will simulate stimuli sequences so as to resemble the sequences of stimuli in the International Brain Laboratory et al. (2020) task.

num_sess = 20 # number of example sessions
num_trials_per_sess = 100 # number of trials in a session
inpts = np.ones((num_sess, num_trials_per_sess, input_dim)) # initialize inpts array
stim_vals = [-1, -0.5, -0.25, -0.125, -0.0625, 0, 0.0625, 0.125, 0.25, 0.5, 1]
inpts[:,:,0] = np.random.choice(stim_vals, (num_sess, num_trials_per_sess)) # generate random sequence of stimuli
inpts = list(inpts) #convert inpts to correct format

# ### 2d. Simulate states and observations with generative model

# Generate a sequence of latents and choices for each session
true_latents, true_choices = [], []
for sess in range(num_sess):
    true_z, true_y = true_glmhmm.sample(num_trials_per_sess, input=inpts[sess])
    true_latents.append(true_z)
    true_choices.append(true_y)

# Calculate true loglikelihood
true_ll = true_glmhmm.log_probability(true_choices, inputs=inpts)
print("true ll = " + str(true_ll))

# ## 3. Fit GLM-HMM and perform recovery analysis

# ### 3a. Maximum Likelihood Estimation

# Now we instantiate a new GLM-HMM and check that we can recover the generative parameters in simulated data:

# +
new_glmhmm = ssm.HMM(num_states, obs_dim, input_dim, observations="input_driven_obs",
                   observation_kwargs=dict(C=num_categories), transitions="standard")

N_iters = 200 # maximum number of EM iterations. Fitting with stop earlier if increase in LL is below tolerance specified by tolerance parameter
fit_ll = new_glmhmm.fit(true_choices, inputs=inpts, method="em", num_iters=N_iters, tolerance=10**-4)
# -

# Plot the log probabilities of the true and fit models. Fit model final LL should be greater
# than or equal to true LL.
fig = plt.figure(figsize=(4, 3), dpi=80, facecolor='w', edgecolor='k')
plt.plot(fit_ll, label="EM")
plt.plot([0, len(fit_ll)], true_ll * np.ones(2), ':k', label="True")
plt.legend(loc="lower right")
plt.xlabel("EM Iteration")
plt.xlim(0, len(fit_ll))
plt.ylabel("Log Probability")
plt.show()

# ### 3b. Retrieved parameters

# Compare retrieved weights and transition matrices to generative parameters. To do this, we may first need to permute the states of the fit GLM-HMM relative to the
# generative model.  One way to do this uses the `find_permutation` function from `ssm`:

new_glmhmm.permute(find_permutation(true_latents[0], new_glmhmm.most_likely_states(true_choices[0], input=inpts[0])))

# Now plot generative and retrieved weights for GLMs (analogous plot to Figure S1c in
# Ashwood et al. (2020)):

fig = plt.figure(figsize=(4, 3), dpi=80, facecolor='w', edgecolor='k')
cols = ['#ff7f00', '#4daf4a', '#377eb8']
recovered_weights = new_glmhmm.observations.params
for k in range(num_states):
    if k ==0:
        plt.plot(range(input_dim), gen_weights[k][0], marker='o',
                 color=cols[k], linestyle='-',
                 lw=1.5, label="generative")
        plt.plot(range(input_dim), recovered_weights[k][0], color=cols[k],
                     lw=1.5,  label = "recovered", linestyle = '--')
    else:
        plt.plot(range(input_dim), gen_weights[k][0], marker='o',
                 color=cols[k], linestyle='-',
                 lw=1.5, label="")
        plt.plot(range(input_dim), recovered_weights[k][0], color=cols[k],
                     lw=1.5,  label = '', linestyle = '--')
plt.yticks(fontsize=10)
plt.ylabel("GLM weight", fontsize=15)
plt.xlabel("covariate", fontsize=15)
plt.xticks([0, 1], ['stimulus', 'bias'], fontsize=12, rotation=45)
plt.axhline(y=0, color="k", alpha=0.5, ls="--")
plt.legend()
plt.title("Weight recovery", fontsize=15)

# Now plot generative and retrieved transition matrices (analogous plot to Figure S1c in
# Ashwood et al. (2020)):

# +
fig = plt.figure(figsize=(5, 2.5), dpi=80, facecolor='w', edgecolor='k')
plt.subplot(1, 2, 1)
gen_trans_mat = np.exp(gen_log_trans_mat)[0]
plt.imshow(gen_trans_mat, vmin=-0.8, vmax=1, cmap='bone')
for i in range(gen_trans_mat.shape[0]):
    for j in range(gen_trans_mat.shape[1]):
        text = plt.text(j, i, str(np.around(gen_trans_mat[i, j], decimals=2)), ha="center", va="center",
                        color="k", fontsize=12)
plt.xlim(-0.5, num_states - 0.5)
plt.xticks(range(0, num_states), ('1', '2', '3'), fontsize=10)
plt.yticks(range(0, num_states), ('1', '2', '3'), fontsize=10)
plt.ylim(num_states - 0.5, -0.5)
plt.ylabel("state t", fontsize = 15)
plt.xlabel("state t+1", fontsize = 15)
plt.title("generative", fontsize = 15)


plt.subplot(1, 2, 2)
recovered_trans_mat = np.exp(new_glmhmm.transitions.log_Ps)
plt.imshow(recovered_trans_mat, vmin=-0.8, vmax=1, cmap='bone')
for i in range(recovered_trans_mat.shape[0]):
    for j in range(recovered_trans_mat.shape[1]):
        text = plt.text(j, i, str(np.around(recovered_trans_mat[i, j], decimals=2)), ha="center", va="center",
                        color="k", fontsize=12)
plt.xlim(-0.5, num_states - 0.5)
plt.xticks(range(0, num_states), ('1', '2', '3'), fontsize=10)
plt.yticks(range(0, num_states), ('1', '2', '3'), fontsize=10)
plt.ylim(num_states - 0.5, -0.5)
plt.title("recovered", fontsize = 15)
plt.subplots_adjust(0, 0, 1, 1)

# -

# ### 3c. Posterior State Probabilities

# Let's now plot $p(z_{t} = k|\mathbf{y}, \{u_{t}\}_{t=1}^{T})$, the posterior state probabilities, which give the probability of the animal being in state k at trial t.

# Get expected states:
posterior_probs = [new_glmhmm.expected_states(data=data, input=inpt)[0]
                for data, inpt
                in zip(true_choices, inpts)]

fig = plt.figure(figsize=(5, 2.5), dpi=80, facecolor='w', edgecolor='k')
sess_id = 0 #session id; can choose any index between 0 and num_sess-1
for k in range(num_states):
    plt.plot(posterior_probs[sess_id][:, k], label="State " + str(k + 1), lw=2,
             color=cols[k])
plt.ylim((-0.01, 1.01))
plt.yticks([0, 0.5, 1], fontsize = 10)
plt.xlabel("trial #", fontsize = 15)
plt.ylabel("p(state)", fontsize = 15)

# With these posterior state probabilities, we can assign trials to states and then plot the fractional occupancy of each state:

# concatenate posterior probabilities across sessions
posterior_probs_concat = np.concatenate(posterior_probs)
# get state with maximum posterior probability at particular trial:
state_max_posterior = np.argmax(posterior_probs_concat, axis = 1)
# now obtain state fractional occupancies:
_, state_occupancies = np.unique(state_max_posterior, return_counts=True)
state_occupancies = state_occupancies/np.sum(state_occupancies)

fig = plt.figure(figsize=(2, 2.5), dpi=80, facecolor='w', edgecolor='k')
for z, occ in enumerate(state_occupancies):
    plt.bar(z, occ, width = 0.8, color = cols[z])
plt.ylim((0, 1))
plt.xticks([0, 1, 2], ['1', '2', '3'], fontsize = 10)
plt.yticks([0, 0.5, 1], ['0', '0.5', '1'], fontsize=10)
plt.xlabel('state', fontsize = 15)
plt.ylabel('frac. occupancy', fontsize=15)

# ## 4. Fit GLM-HMM and perform recovery analysis: Maximum A Priori Estimation

# Above, we performed Maximum Likelihood Estimation to retrieve the generative parameters of the GLM-HMM in simulated data. In the small data regime, where we do not have many trials available to us, we may instead want to perform Maximum A Priori (MAP) Estimation in order to incorporate a prior term and restrict the range for the best fitting parameters. Unfortunately, what is meant by 'small data regime' is problem dependent and will be affected by the number of states in the generative GLM-HMM, and the specific parameters of the generative model, amongst other things. In practice, we may perform both Maximum Likelihood Estimation and MAP estimation and compare the ability of the fit models to make predictions on held-out data (see Section 5 on Cross-Validation below).
#
# The prior we consider for the GLM-HMM is the product of a Gaussian prior on the GLM weights, $W$, and a Dirichlet prior on the transition matrix, $A$:
#
# $$
# \begin{align}
# \Pr(W, A) &= \mathcal{N}(W|0, \Sigma) \Pr(A|\alpha) \\&= \mathcal{N}(W|0, diag(\sigma^{2}, \cdots, \sigma^{2})) \prod_{j=1}^{K} \dfrac{1}{B(\alpha)} \prod_{k=1}^{K} A_{jk}^{\alpha -1}
# \end{align}
# $$
#
# There are two hyperparameters controlling the strength of the prior: $\sigma$ and $\alpha$.  The larger the value of $\sigma$ and if $\alpha = 1$, the more similar MAP estimation will become to Maximum Likelihood Estimation, and the prior term will become an additive offset to the objective function of the GLM-HMM that is independent of the values of $W$ and $A$.  In comparison, setting $\sigma = 2$ and $\alpha = 2$ will result in the prior no longer being independent of $W$ and $\alpha$.
#
# In order to perform MAP estimation for the GLM-HMM with `ssm`, the new syntax is:
#
# ```python
# ssm.HMM(num_states, obs_dim, input_dim, observations="input_driven_obs",
#              observation_kwargs=dict(C=num_categories,prior_sigma=prior_sigma),
#              transitions="sticky", transition_kwargs=dict(alpha=prior_alpha,kappa=0))
# ```
#
# where `prior_sigma` is the $\sigma$ parameter from above, and `prior_alpha` is the $\alpha$ parameter.

# Instantiate GLM-HMM and set prior hyperparameters
prior_sigma = 2
prior_alpha = 2
map_glmhmm = ssm.HMM(num_states, obs_dim, input_dim, observations="input_driven_obs",
             observation_kwargs=dict(C=num_categories,prior_sigma=prior_sigma),
             transitions="sticky", transition_kwargs=dict(alpha=prior_alpha,kappa=0))

# Fit GLM-HMM with MAP estimation:
_ = map_glmhmm.fit(true_choices, inputs=inpts, method="em", num_iters=N_iters, tolerance=10**-4)

# Compare final likelihood of data with MAP estimation and MLE to likelihood under generative model (note: we cannot use log_probability that is output of `fit` function as this incorporates prior term, which is not comparable between generative and MAP models). We want to check that MAP and MLE likelihood values are higher than true likelihood; if they are not, this may indicate a poor initialization and that we should refit these models.

true_likelihood = true_glmhmm.log_likelihood(true_choices, inputs=inpts)
mle_final_ll = new_glmhmm.log_likelihood(true_choices, inputs=inpts)
map_final_ll = map_glmhmm.log_likelihood(true_choices, inputs=inpts)

# Plot these values
fig = plt.figure(figsize=(2, 2.5), dpi=80, facecolor='w', edgecolor='k')
loglikelihood_vals = [true_likelihood, mle_final_ll, map_final_ll]
colors = ['Red', 'Navy', 'Purple']
for z, occ in enumerate(loglikelihood_vals):
    plt.bar(z, occ, width = 0.8, color = colors[z])
plt.ylim((true_likelihood-5, true_likelihood+15))
plt.xticks([0, 1, 2], ['true', 'mle', 'map'], fontsize = 10)
plt.xlabel('model', fontsize = 15)
plt.ylabel('loglikelihood', fontsize=15)

# ## 5. Cross Validation

# To assess which model is better - the model fit via Maximum Likelihood Estimation, or the model fit via MAP estimation - we can investigate the predictive power of these fit models on held-out test data sets.

# Create additional input sequences to be used as held-out test data
num_test_sess = 10
test_inpts = np.ones((num_test_sess, num_trials_per_sess, input_dim))
test_inpts[:,:,0] = np.random.choice(stim_vals, (num_test_sess, num_trials_per_sess))
test_inpts = list(test_inpts)

# Create set of test latents and choices to accompany input sequences:
test_latents, test_choices = [], []
for sess in range(num_test_sess):
    test_z, test_y = true_glmhmm.sample(num_trials_per_sess, input=test_inpts[sess])
    test_latents.append(test_z)
    test_choices.append(test_y)

# Compare likelihood of test_choices for model fit with MLE and MAP:
mle_test_ll = new_glmhmm.log_likelihood(test_choices, inputs=test_inpts)
map_test_ll = map_glmhmm.log_likelihood(test_choices, inputs=test_inpts)

fig = plt.figure(figsize=(2, 2.5), dpi=80, facecolor='w', edgecolor='k')
loglikelihood_vals = [mle_test_ll, map_test_ll]
colors = ['Navy', 'Purple']
for z, occ in enumerate(loglikelihood_vals):
    plt.bar(z, occ, width = 0.8, color = colors[z])
plt.ylim((mle_test_ll-2, mle_test_ll+5))
plt.xticks([0, 1], ['mle', 'map'], fontsize = 10)
plt.xlabel('model', fontsize = 15)
plt.ylabel('loglikelihood', fontsize=15)

# Here we see that the model fit with MAP estimation achieves higher likelihood on the held-out dataset than the model fit with MLE, so we would choose this model as the best model of animal decision-making behavior (although we'd probably want to perform multiple fold cross-validation to be sure that this is the case in all instantiations of test data).
#
# Let's finish by comparing the retrieved weights and transition matrices from MAP estimation to the generative parameters.

map_glmhmm.permute(find_permutation(true_latents[0], map_glmhmm.most_likely_states(true_choices[0], input=inpts[0])))

# +
fig = plt.figure(figsize=(6, 3), dpi=80, facecolor='w', edgecolor='k')
cols = ['#ff7f00', '#4daf4a', '#377eb8']
plt.subplot(1,2,1)
recovered_weights = new_glmhmm.observations.params
for k in range(num_states):
    if k ==0: # show labels only for first state
        plt.plot(range(input_dim), gen_weights[k][0], marker='o',
                 color=cols[k],
                 lw=1.5, label="generative")
        plt.plot(range(input_dim), recovered_weights[k][0], color=cols[k],
                     lw=1.5,  label = 'recovered', linestyle='--')
    else:
        plt.plot(range(input_dim), gen_weights[k][0], marker='o',
                 color=cols[k],
                 lw=1.5, label="")
        plt.plot(range(input_dim), recovered_weights[k][0], color=cols[k],
                     lw=1.5,  label = '', linestyle='--')
plt.yticks(fontsize=10)
plt.ylabel("GLM weight", fontsize=15)
plt.xlabel("covariate", fontsize=15)
plt.xticks([0, 1], ['stimulus', 'bias'], fontsize=12, rotation=45)
plt.axhline(y=0, color="k", alpha=0.5, ls="--")
plt.title("MLE", fontsize = 15)
plt.legend()

plt.subplot(1,2,2)
recovered_weights = map_glmhmm.observations.params
for k in range(num_states):
    plt.plot(range(input_dim), gen_weights[k][0], marker='o',
             color=cols[k],
             lw=1.5, label="", linestyle = '-')
    plt.plot(range(input_dim), recovered_weights[k][0], color=cols[k],
                 lw=1.5,  label = '', linestyle='--')
plt.yticks(fontsize=10)
plt.xticks([0, 1], ['', ''], fontsize=12, rotation=45)
plt.axhline(y=0, color="k", alpha=0.5, ls="--")
plt.title("MAP", fontsize = 15)

# +
fig = plt.figure(figsize=(7, 2.5), dpi=80, facecolor='w', edgecolor='k')
plt.subplot(1, 3, 1)
gen_trans_mat = np.exp(gen_log_trans_mat)[0]
plt.imshow(gen_trans_mat, vmin=-0.8, vmax=1, cmap='bone')
for i in range(gen_trans_mat.shape[0]):
    for j in range(gen_trans_mat.shape[1]):
        text = plt.text(j, i, str(np.around(gen_trans_mat[i, j], decimals=2)), ha="center", va="center",
                        color="k", fontsize=12)
plt.xlim(-0.5, num_states - 0.5)
plt.xticks(range(0, num_states), ('1', '2', '3'), fontsize=10)
plt.yticks(range(0, num_states), ('1', '2', '3'), fontsize=10)
plt.ylim(num_states - 0.5, -0.5)
plt.ylabel("state t", fontsize = 15)
plt.xlabel("state t+1", fontsize = 15)
plt.title("generative", fontsize = 15)


plt.subplot(1, 3, 2)
recovered_trans_mat = np.exp(new_glmhmm.transitions.log_Ps)
plt.imshow(recovered_trans_mat, vmin=-0.8, vmax=1, cmap='bone')
for i in range(recovered_trans_mat.shape[0]):
    for j in range(recovered_trans_mat.shape[1]):
        text = plt.text(j, i, str(np.around(recovered_trans_mat[i, j], decimals=2)), ha="center", va="center",
                        color="k", fontsize=12)
plt.xlim(-0.5, num_states - 0.5)
plt.xticks(range(0, num_states), ('1', '2', '3'), fontsize=10)
plt.yticks(range(0, num_states), ('1', '2', '3'), fontsize=10)
plt.ylim(num_states - 0.5, -0.5)
plt.title("recovered - MLE", fontsize = 15)
plt.subplots_adjust(0, 0, 1, 1)


plt.subplot(1, 3, 3)
recovered_trans_mat = np.exp(map_glmhmm.transitions.log_Ps)
plt.imshow(recovered_trans_mat, vmin=-0.8, vmax=1, cmap='bone')
for i in range(recovered_trans_mat.shape[0]):
    for j in range(recovered_trans_mat.shape[1]):
        text = plt.text(j, i, str(np.around(recovered_trans_mat[i, j], decimals=2)), ha="center", va="center",
                        color="k", fontsize=12)
plt.xlim(-0.5, num_states - 0.5)
plt.xticks(range(0, num_states), ('1', '2', '3'), fontsize=10)
plt.yticks(range(0, num_states), ('1', '2', '3'), fontsize=10)
plt.ylim(num_states - 0.5, -0.5)
plt.title("recovered - MAP", fontsize = 15)
plt.subplots_adjust(0, 0, 1, 1)
# -

# ## 6. Multinomial GLM-HMM

# Until now, we have only considered the case where there are 2 output classes (the Bernoulli GLM-HMM corresponding to `C=num_categories=2`), yet the `ssm` framework is sufficiently general to allow us to fit the multinomial GLM-HMM described in Equations 1 and 2. Here we demonstrate a recovery analysis for the multinomial GLM-HMM.

# +
# Set the parameters of the GLM-HMM
num_states = 4        # number of discrete states
obs_dim = 1           # number of observed dimensions
num_categories = 3    # number of categories for output
input_dim = 2         # input dimensions

# Make a GLM-HMM
true_glmhmm = ssm.HMM(num_states, obs_dim, input_dim, observations="input_driven_obs",
                   observation_kwargs=dict(C=num_categories), transitions="standard")
# -

# Set weights of multinomial GLM-HMM
gen_weights = np.array([[[0.6,3], [2,3]], [[6,1], [6,-2]], [[1,1], [3,1]], [[2,2], [0,5]]])
print(gen_weights.shape)
true_glmhmm.observations.params = gen_weights

# In the above, notice that the shape of the weights for the multinomial GLM-HMM is `(num_states, num_categories-1, input_dim)`.  Specifically, we only learn `num_categories-1` weight vectors (of size `input_dim`) for a given state, and we set the weights for the other observation class to zero. Constraining the weight vectors for one class is important if that we want to be able to identify generative weights in simulated data. If we didn't do this, it is easy to see that one could generate the same observation probabilities with a set of weight vectors that are offset by a constant vector $w_{k}$ (the index k indicates that a different offset vector could exist per state):
# $$
# \begin{align}
# \Pr(y_t = c \mid z_{t} = k, u_t, w_{kc}) =
# \frac{\exp\{w_{kc}^\mathsf{T} u_t\}}
# {\sum_{c'=1}^C \exp\{w_{kc'}^\mathsf{T} u_t\}} = \frac{\exp\{(w_{kc}-w_{k})^\mathsf{T} u_t\}}
# {\sum_{c'=1}^C \exp\{(w_{kc'}-w_{k})^\mathsf{T} u_t\}}
# \end{align}
# $$
#
# Equations 1 and 2 at the top of this notebook already take into account the fact that the weights for a particular class for a given state are fixed to zero (this is why $c = C$ is handled differently).

# Set transition matrix of multinomial GLM-HMM
gen_log_trans_mat = np.log(np.array([[[0.90, 0.04, 0.05, 0.01], [0.05, 0.92, 0.01, 0.02], [0.03, 0.02, 0.94, 0.01], [0.09, 0.01, 0.01, 0.89]]]))
true_glmhmm.transitions.params = gen_log_trans_mat

# Create external inputs sequence; compared to the example above, we will increase the number of examples
# (through the "num_trials_per_session" paramater) since the number of parameters has increased
num_sess = 20 # number of example sessions
num_trials_per_sess = 1000 # number of trials in a session
inpts = np.ones((num_sess, num_trials_per_sess, input_dim)) # initialize inpts array
stim_vals = [-1, -0.5, -0.25, -0.125, -0.0625, 0, 0.0625, 0.125, 0.25, 0.5, 1]
inpts[:,:,0] = np.random.choice(stim_vals, (num_sess, num_trials_per_sess)) # generate random sequence of stimuli
inpts = list(inpts)

# Generate a sequence of latents and choices for each session
true_latents, true_choices = [], []
for sess in range(num_sess):
    true_z, true_y = true_glmhmm.sample(num_trials_per_sess, input=inpts[sess])
    true_latents.append(true_z)
    true_choices.append(true_y)

# plot example data:
fig = plt.figure(figsize=(8, 3), dpi=80, facecolor='w', edgecolor='k')
plt.step(range(100),true_choices[0][range(100)], color = "red")
plt.yticks([0, 1, 2])
plt.title("example data (multinomial GLM-HMM)")
plt.xlabel("trial #", fontsize = 15)
plt.ylabel("observation class", fontsize = 15)

# Calculate true loglikelihood
true_ll = true_glmhmm.log_probability(true_choices, inputs=inpts)
print("true ll = " + str(true_ll))

# +
# fit GLM-HMM
new_glmhmm = ssm.HMM(num_states, obs_dim, input_dim, observations="input_driven_obs",
                   observation_kwargs=dict(C=num_categories), transitions="standard")

N_iters = 500 # maximum number of EM iterations. Fitting with stop earlier if increase in LL is below tolerance specified by tolerance parameter
fit_ll = new_glmhmm.fit(true_choices, inputs=inpts, method="em", num_iters=N_iters, tolerance=10**-4)
# -

# Plot the log probabilities of the true and fit models. Fit model final LL should be greater
# than or equal to true LL.
fig = plt.figure(figsize=(4, 3), dpi=80, facecolor='w', edgecolor='k')
plt.plot(fit_ll, label="EM")
plt.plot([0, len(fit_ll)], true_ll * np.ones(2), ':k', label="True")
plt.legend(loc="lower right")
plt.xlabel("EM Iteration")
plt.xlim(0, len(fit_ll))
plt.ylabel("Log Probability")
plt.show()

# permute recovered state identities to match state identities of generative model
new_glmhmm.permute(find_permutation(true_latents[0], new_glmhmm.most_likely_states(true_choices[0], input=inpts[0])))

# +
# Plot recovered parameters:
recovered_weights = new_glmhmm.observations.params
recovered_transitions = new_glmhmm.transitions.params

fig = plt.figure(figsize=(16, 8), dpi=80, facecolor='w', edgecolor='k')
plt.subplots_adjust(wspace=0.3, hspace=0.6)

plt.subplot(2, 2, 1)
cols = ['#ff7f00', '#4daf4a', '#377eb8', '#f781bf', '#a65628', '#984ea3', '#999999', '#e41a1c', '#dede00']

for c in range(num_categories):
    plt.subplot(2, num_categories+1, c+1)
    if c < num_categories-1:
        for k in range(num_states):
            plt.plot(range(input_dim), gen_weights[k,c], marker='o',
                 color=cols[k], lw=1.5, label="state " + str(k+1) + "; class " + str(c+1))
    else:
        for k in range(num_states):
            plt.plot(range(input_dim), np.zeros(input_dim), marker='o',
                     color=cols[k], lw=1.5, label="state " + str(k+1) + "; class " + str(c+1), alpha = 0.5)

    plt.axhline(y=0, color="k", alpha=0.5, ls="--")
    plt.yticks(fontsize=10)
    plt.xticks([0, 1], ['', ''])
    if c == 0:
        plt.ylabel("GLM weight", fontsize=15)
    plt.legend()
    plt.title("Generative weights; class " + str(c+1), fontsize = 15)
    plt.ylim((-3, 10))

plt.subplot(2, num_categories+1, num_categories+1)
gen_trans_mat = np.exp(gen_log_trans_mat)[0]
plt.imshow(gen_trans_mat, vmin=-0.8, vmax=1, cmap='bone')
for i in range(gen_trans_mat.shape[0]):
    for j in range(gen_trans_mat.shape[1]):
        text = plt.text(j, i, str(np.around(gen_trans_mat[i, j], decimals=2)), ha="center", va="center",
                        color="k", fontsize=12)
plt.xlim(-0.5, num_states - 0.5)
plt.xticks(range(0, num_states), ('1', '2', '3', '4'), fontsize=10)
plt.yticks(range(0, num_states), ('1', '2', '3', '4'), fontsize=10)
plt.ylim(num_states - 0.5, -0.5)
plt.ylabel("state t", fontsize = 15)
plt.xlabel("state t+1", fontsize = 15)
plt.title("Generative transition matrix", fontsize = 15)


cols = ['#ff7f00', '#4daf4a', '#377eb8', '#f781bf', '#a65628', '#984ea3', '#999999', '#e41a1c', '#dede00']

for c in range(num_categories):
    plt.subplot(2, num_categories+1, num_categories + c + 2)
    if c < num_categories-1:
        for k in range(num_states):
            plt.plot(range(input_dim), recovered_weights[k,c], marker='o', linestyle = '--',
                 color=cols[k], lw=1.5, label="state " + str(k+1) + "; class " + str(c+1))
    else:
        for k in range(num_states):
            plt.plot(range(input_dim), np.zeros(input_dim), marker='o', linestyle = '--',
                     color=cols[k], lw=1.5, label="state " + str(k+1) + "; class " + str(c+1), alpha = 0.5)

    plt.axhline(y=0, color="k", alpha=0.5, ls="--")
    plt.yticks(fontsize=10)
    plt.xlabel("covariate", fontsize=15)
    if c == 0:
        plt.ylabel("GLM weight", fontsize=15)
    plt.xticks([0, 1], ['stimulus', 'bias'], fontsize=12, rotation=45)
    plt.legend()
    plt.title("Recovered weights; class " + str(c+1), fontsize = 15)
    plt.ylim((-3,10))

plt.subplot(2, num_categories+1, 2*num_categories+2)
recovered_trans_mat = np.exp(recovered_transitions)[0]
plt.imshow(recovered_trans_mat, vmin=-0.8, vmax=1, cmap='bone')
for i in range(recovered_trans_mat.shape[0]):
    for j in range(recovered_trans_mat.shape[1]):
        text = plt.text(j, i, str(np.around(recovered_trans_mat[i, j], decimals=2)), ha="center", va="center",
                        color="k", fontsize=12)
plt.xlim(-0.5, num_states - 0.5)
plt.xticks(range(0, num_states), ('1', '2', '3', '4'), fontsize=10)
plt.yticks(range(0, num_states), ('1', '2', '3', '4'), fontsize=10)
plt.ylim(num_states - 0.5, -0.5)
plt.ylabel("state t", fontsize = 15)
plt.xlabel("state t+1", fontsize = 15)
plt.title("Recovered transition matrix", fontsize = 15)

Total running time of the script: ( 0 minutes 15.195 seconds)

Gallery generated by Sphinx-Gallery