Multi-Population rSLDS#

  • Actual State 0
  • Actual State 1
  • Actual State 2
  • Multi Population rSLDS
  • Multi Population rSLDS
  • Multi Population rSLDS
  • Multi Population rSLDS
  • Multi Population rSLDS
  • Actual State 0, Aligned Predicted State 0, Raw Predicted State 0, Actual State 1, Aligned Predicted State 1, Raw Predicted State 1, Actual State 2, Aligned Predicted State 2, Raw Predicted State 2
  • Actual, Predicted
alphas: [0.05336111 0.03343876 0.08800889]

  0%|          | 0/10 [00:00<?, ?it/s]
ARHMM Initialization restarts:   0%|          | 0/10 [00:00<?, ?it/s]Initializing with an ARHMM using 25 steps of EM.


  0%|          | 0/25 [00:00<?, ?it/s]

LP: -88550.7:   0%|          | 0/25 [00:00<?, ?it/s]

LP: -88553.0:   0%|          | 0/25 [00:00<?, ?it/s]

LP: -88356.8:   0%|          | 0/25 [00:00<?, ?it/s]

LP: -88050.6:   0%|          | 0/25 [00:00<?, ?it/s]

LP: -88050.6:  12%|#2        | 3/25 [00:00<00:01, 18.87it/s]

LP: -87651.4:  12%|#2        | 3/25 [00:00<00:01, 18.87it/s]

LP: -87209.9:  12%|#2        | 3/25 [00:00<00:01, 18.87it/s]

LP: -87209.9:  20%|##        | 5/25 [00:00<00:01, 16.44it/s]

LP: -86835.6:  20%|##        | 5/25 [00:00<00:01, 16.44it/s]

LP: -86617.6:  20%|##        | 5/25 [00:00<00:01, 16.44it/s]

LP: -86617.6:  28%|##8       | 7/25 [00:00<00:01, 13.07it/s]

LP: -86495.2:  28%|##8       | 7/25 [00:00<00:01, 13.07it/s]

LP: -86415.0:  28%|##8       | 7/25 [00:00<00:01, 13.07it/s]

LP: -86415.0:  36%|###6      | 9/25 [00:00<00:01, 11.50it/s]

LP: -86349.3:  36%|###6      | 9/25 [00:00<00:01, 11.50it/s]

LP: -86313.8:  36%|###6      | 9/25 [00:00<00:01, 11.50it/s]

LP: -86313.8:  44%|####4     | 11/25 [00:00<00:01, 12.41it/s]

LP: -86280.3:  44%|####4     | 11/25 [00:00<00:01, 12.41it/s]

LP: -86262.2:  44%|####4     | 11/25 [00:00<00:01, 12.41it/s]

LP: -86262.2:  52%|#####2    | 13/25 [00:00<00:00, 12.71it/s]

LP: -86241.7:  52%|#####2    | 13/25 [00:01<00:00, 12.71it/s]

LP: -86228.6:  52%|#####2    | 13/25 [00:01<00:00, 12.71it/s]

LP: -86228.6:  60%|######    | 15/25 [00:01<00:00, 14.21it/s]

LP: -86217.4:  60%|######    | 15/25 [00:01<00:00, 14.21it/s]

LP: -86209.1:  60%|######    | 15/25 [00:01<00:00, 14.21it/s]

LP: -86209.1:  68%|######8   | 17/25 [00:01<00:00, 14.24it/s]

LP: -86193.5:  68%|######8   | 17/25 [00:01<00:00, 14.24it/s]

LP: -86183.7:  68%|######8   | 17/25 [00:01<00:00, 14.24it/s]

LP: -86175.3:  68%|######8   | 17/25 [00:01<00:00, 14.24it/s]

LP: -86175.3:  80%|########  | 20/25 [00:01<00:00, 16.16it/s]

LP: -86165.9:  80%|########  | 20/25 [00:01<00:00, 16.16it/s]

LP: -86156.8:  80%|########  | 20/25 [00:01<00:00, 16.16it/s]

LP: -86156.8:  88%|########8 | 22/25 [00:01<00:00, 15.37it/s]

LP: -86140.8:  88%|########8 | 22/25 [00:01<00:00, 15.37it/s]

LP: -86128.8:  88%|########8 | 22/25 [00:01<00:00, 15.37it/s]

LP: -86128.8:  96%|#########6| 24/25 [00:01<00:00, 15.56it/s]

LP: -86116.7:  96%|#########6| 24/25 [00:01<00:00, 15.56it/s]
LP: -86116.7: 100%|##########| 25/25 [00:01<00:00, 14.59it/s]

ARHMM Initialization restarts:  10%|#         | 1/10 [00:01<00:15,  1.76s/it]Initializing with an ARHMM using 25 steps of EM.


  0%|          | 0/25 [00:00<?, ?it/s]

LP: -88543.6:   0%|          | 0/25 [00:00<?, ?it/s]

LP: -88545.9:   0%|          | 0/25 [00:00<?, ?it/s]

LP: -88363.0:   0%|          | 0/25 [00:00<?, ?it/s]

LP: -88127.2:   0%|          | 0/25 [00:00<?, ?it/s]

LP: -88127.2:  12%|#2        | 3/25 [00:00<00:00, 24.31it/s]

LP: -87884.1:  12%|#2        | 3/25 [00:00<00:00, 24.31it/s]

LP: -87642.4:  12%|#2        | 3/25 [00:00<00:00, 24.31it/s]

LP: -87395.3:  12%|#2        | 3/25 [00:00<00:00, 24.31it/s]

LP: -87395.3:  24%|##4       | 6/25 [00:00<00:01, 17.90it/s]

LP: -87136.5:  24%|##4       | 6/25 [00:00<00:01, 17.90it/s]

LP: -86886.3:  24%|##4       | 6/25 [00:00<00:01, 17.90it/s]

LP: -86886.3:  32%|###2      | 8/25 [00:00<00:01, 15.70it/s]

LP: -86675.0:  32%|###2      | 8/25 [00:00<00:01, 15.70it/s]

LP: -86534.1:  32%|###2      | 8/25 [00:00<00:01, 15.70it/s]

LP: -86534.1:  40%|####      | 10/25 [00:00<00:01, 14.62it/s]

LP: -86427.5:  40%|####      | 10/25 [00:00<00:01, 14.62it/s]

LP: -86343.3:  40%|####      | 10/25 [00:00<00:01, 14.62it/s]

LP: -86343.3:  48%|####8     | 12/25 [00:00<00:00, 13.91it/s]

LP: -86280.3:  48%|####8     | 12/25 [00:00<00:00, 13.91it/s]

LP: -86238.2:  48%|####8     | 12/25 [00:00<00:00, 13.91it/s]

LP: -86238.2:  56%|#####6    | 14/25 [00:00<00:00, 13.90it/s]

LP: -86199.4:  56%|#####6    | 14/25 [00:00<00:00, 13.90it/s]

LP: -86174.6:  56%|#####6    | 14/25 [00:01<00:00, 13.90it/s]

LP: -86174.6:  64%|######4   | 16/25 [00:01<00:00, 13.87it/s]

LP: -86146.4:  64%|######4   | 16/25 [00:01<00:00, 13.87it/s]

LP: -86126.4:  64%|######4   | 16/25 [00:01<00:00, 13.87it/s]

LP: -86126.4:  72%|#######2  | 18/25 [00:01<00:00, 14.56it/s]

LP: -86101.5:  72%|#######2  | 18/25 [00:01<00:00, 14.56it/s]

LP: -86087.0:  72%|#######2  | 18/25 [00:01<00:00, 14.56it/s]

LP: -86087.0:  80%|########  | 20/25 [00:01<00:00, 14.02it/s]

LP: -86066.6:  80%|########  | 20/25 [00:01<00:00, 14.02it/s]

LP: -86054.1:  80%|########  | 20/25 [00:01<00:00, 14.02it/s]

LP: -86054.1:  88%|########8 | 22/25 [00:01<00:00, 10.98it/s]

LP: -86034.5:  88%|########8 | 22/25 [00:01<00:00, 10.98it/s]

LP: -86021.3:  88%|########8 | 22/25 [00:01<00:00, 10.98it/s]

LP: -86021.3:  96%|#########6| 24/25 [00:01<00:00, 11.26it/s]

LP: -85998.0:  96%|#########6| 24/25 [00:01<00:00, 11.26it/s]
LP: -85998.0: 100%|##########| 25/25 [00:01<00:00, 13.53it/s]

ARHMM Initialization restarts:  20%|##        | 2/10 [00:03<00:14,  1.84s/it]Initializing with an ARHMM using 25 steps of EM.


  0%|          | 0/25 [00:00<?, ?it/s]

LP: -88563.6:   0%|          | 0/25 [00:00<?, ?it/s]

LP: -88566.0:   0%|          | 0/25 [00:00<?, ?it/s]

LP: -88406.6:   0%|          | 0/25 [00:00<?, ?it/s]

LP: -88406.6:   8%|8         | 2/25 [00:00<00:01, 19.71it/s]

LP: -88175.4:   8%|8         | 2/25 [00:00<00:01, 19.71it/s]

LP: -87874.9:   8%|8         | 2/25 [00:00<00:01, 19.71it/s]

LP: -87874.9:  16%|#6        | 4/25 [00:00<00:01, 18.43it/s]

LP: -87541.0:  16%|#6        | 4/25 [00:00<00:01, 18.43it/s]

LP: -87217.5:  16%|#6        | 4/25 [00:00<00:01, 18.43it/s]

LP: -87217.5:  24%|##4       | 6/25 [00:00<00:01, 17.23it/s]

LP: -86977.8:  24%|##4       | 6/25 [00:00<00:01, 17.23it/s]

LP: -86811.6:  24%|##4       | 6/25 [00:00<00:01, 17.23it/s]

LP: -86811.6:  32%|###2      | 8/25 [00:00<00:01, 14.70it/s]

LP: -86691.1:  32%|###2      | 8/25 [00:00<00:01, 14.70it/s]

LP: -86604.0:  32%|###2      | 8/25 [00:00<00:01, 14.70it/s]

LP: -86604.0:  40%|####      | 10/25 [00:00<00:01, 14.28it/s]

LP: -86536.7:  40%|####      | 10/25 [00:00<00:01, 14.28it/s]

LP: -86480.9:  40%|####      | 10/25 [00:00<00:01, 14.28it/s]

LP: -86480.9:  48%|####8     | 12/25 [00:00<00:00, 14.97it/s]

LP: -86439.7:  48%|####8     | 12/25 [00:00<00:00, 14.97it/s]

LP: -86400.3:  48%|####8     | 12/25 [00:00<00:00, 14.97it/s]

LP: -86400.3:  56%|#####6    | 14/25 [00:00<00:00, 14.86it/s]

LP: -86372.9:  56%|#####6    | 14/25 [00:00<00:00, 14.86it/s]

LP: -86346.6:  56%|#####6    | 14/25 [00:01<00:00, 14.86it/s]

LP: -86326.7:  56%|#####6    | 14/25 [00:01<00:00, 14.86it/s]

LP: -86326.7:  68%|######8   | 17/25 [00:01<00:00, 16.50it/s]

LP: -86308.0:  68%|######8   | 17/25 [00:01<00:00, 16.50it/s]

LP: -86297.3:  68%|######8   | 17/25 [00:01<00:00, 16.50it/s]

LP: -86289.5:  68%|######8   | 17/25 [00:01<00:00, 16.50it/s]

LP: -86289.5:  80%|########  | 20/25 [00:01<00:00, 16.68it/s]

LP: -86278.8:  80%|########  | 20/25 [00:01<00:00, 16.68it/s]

LP: -86273.6:  80%|########  | 20/25 [00:01<00:00, 16.68it/s]

LP: -86269.4:  80%|########  | 20/25 [00:01<00:00, 16.68it/s]

LP: -86269.4:  92%|#########2| 23/25 [00:01<00:00, 18.19it/s]

LP: -86266.0:  92%|#########2| 23/25 [00:01<00:00, 18.19it/s]

LP: -86263.5:  92%|#########2| 23/25 [00:01<00:00, 18.19it/s]
LP: -86263.5: 100%|##########| 25/25 [00:01<00:00, 16.92it/s]

ARHMM Initialization restarts:  30%|###       | 3/10 [00:05<00:11,  1.69s/it]Initializing with an ARHMM using 25 steps of EM.


  0%|          | 0/25 [00:00<?, ?it/s]

LP: -88547.5:   0%|          | 0/25 [00:00<?, ?it/s]

LP: -88549.6:   0%|          | 0/25 [00:00<?, ?it/s]

LP: -88363.1:   0%|          | 0/25 [00:00<?, ?it/s]

LP: -88060.3:   0%|          | 0/25 [00:00<?, ?it/s]

LP: -88060.3:  12%|#2        | 3/25 [00:00<00:01, 20.71it/s]

LP: -87632.6:  12%|#2        | 3/25 [00:00<00:01, 20.71it/s]

LP: -87159.0:  12%|#2        | 3/25 [00:00<00:01, 20.71it/s]

LP: -86797.0:  12%|#2        | 3/25 [00:00<00:01, 20.71it/s]

LP: -86797.0:  24%|##4       | 6/25 [00:00<00:01, 17.36it/s]

LP: -86604.4:  24%|##4       | 6/25 [00:00<00:01, 17.36it/s]

LP: -86488.7:  24%|##4       | 6/25 [00:00<00:01, 17.36it/s]

LP: -86488.7:  32%|###2      | 8/25 [00:00<00:01, 15.86it/s]

LP: -86406.2:  32%|###2      | 8/25 [00:00<00:01, 15.86it/s]

LP: -86342.3:  32%|###2      | 8/25 [00:00<00:01, 15.86it/s]

LP: -86342.3:  40%|####      | 10/25 [00:00<00:00, 15.30it/s]

LP: -86296.8:  40%|####      | 10/25 [00:00<00:00, 15.30it/s]

LP: -86247.3:  40%|####      | 10/25 [00:00<00:00, 15.30it/s]

LP: -86247.3:  48%|####8     | 12/25 [00:00<00:00, 14.88it/s]

LP: -86209.6:  48%|####8     | 12/25 [00:00<00:00, 14.88it/s]

LP: -86171.0:  48%|####8     | 12/25 [00:00<00:00, 14.88it/s]

LP: -86171.0:  56%|#####6    | 14/25 [00:00<00:00, 14.95it/s]

LP: -86143.4:  56%|#####6    | 14/25 [00:00<00:00, 14.95it/s]

LP: -86111.4:  56%|#####6    | 14/25 [00:01<00:00, 14.95it/s]

LP: -86111.4:  64%|######4   | 16/25 [00:01<00:00, 15.66it/s]

LP: -86082.2:  64%|######4   | 16/25 [00:01<00:00, 15.66it/s]

LP: -86042.8:  64%|######4   | 16/25 [00:01<00:00, 15.66it/s]

LP: -86042.8:  72%|#######2  | 18/25 [00:01<00:00, 14.62it/s]

LP: -86006.3:  72%|#######2  | 18/25 [00:01<00:00, 14.62it/s]

LP: -85957.5:  72%|#######2  | 18/25 [00:01<00:00, 14.62it/s]

LP: -85957.5:  80%|########  | 20/25 [00:01<00:00, 14.24it/s]

LP: -85891.3:  80%|########  | 20/25 [00:01<00:00, 14.24it/s]

LP: -85807.6:  80%|########  | 20/25 [00:01<00:00, 14.24it/s]

LP: -85807.6:  88%|########8 | 22/25 [00:01<00:00, 13.35it/s]

LP: -85725.3:  88%|########8 | 22/25 [00:01<00:00, 13.35it/s]

LP: -85645.0:  88%|########8 | 22/25 [00:01<00:00, 13.35it/s]

LP: -85645.0:  96%|#########6| 24/25 [00:01<00:00, 12.83it/s]

LP: -85570.7:  96%|#########6| 24/25 [00:01<00:00, 12.83it/s]
LP: -85570.7: 100%|##########| 25/25 [00:01<00:00, 14.23it/s]

ARHMM Initialization restarts:  40%|####      | 4/10 [00:06<00:10,  1.73s/it]Initializing with an ARHMM using 25 steps of EM.


  0%|          | 0/25 [00:00<?, ?it/s]

LP: -88564.2:   0%|          | 0/25 [00:00<?, ?it/s]

LP: -88565.8:   0%|          | 0/25 [00:00<?, ?it/s]

LP: -88415.3:   0%|          | 0/25 [00:00<?, ?it/s]

LP: -88211.4:   0%|          | 0/25 [00:00<?, ?it/s]

LP: -88211.4:  12%|#2        | 3/25 [00:00<00:00, 24.52it/s]

LP: -87988.9:  12%|#2        | 3/25 [00:00<00:00, 24.52it/s]

LP: -87765.8:  12%|#2        | 3/25 [00:00<00:00, 24.52it/s]

LP: -87522.8:  12%|#2        | 3/25 [00:00<00:00, 24.52it/s]

LP: -87522.8:  24%|##4       | 6/25 [00:00<00:01, 17.80it/s]

LP: -87245.8:  24%|##4       | 6/25 [00:00<00:01, 17.80it/s]

LP: -86996.6:  24%|##4       | 6/25 [00:00<00:01, 17.80it/s]

LP: -86996.6:  32%|###2      | 8/25 [00:00<00:01, 16.57it/s]

LP: -86810.2:  32%|###2      | 8/25 [00:00<00:01, 16.57it/s]

LP: -86680.2:  32%|###2      | 8/25 [00:00<00:01, 16.57it/s]

LP: -86680.2:  40%|####      | 10/25 [00:00<00:00, 15.95it/s]

LP: -86585.3:  40%|####      | 10/25 [00:00<00:00, 15.95it/s]

LP: -86507.5:  40%|####      | 10/25 [00:00<00:00, 15.95it/s]

LP: -86507.5:  48%|####8     | 12/25 [00:00<00:00, 15.51it/s]

LP: -86446.2:  48%|####8     | 12/25 [00:00<00:00, 15.51it/s]

LP: -86405.9:  48%|####8     | 12/25 [00:00<00:00, 15.51it/s]

LP: -86372.4:  48%|####8     | 12/25 [00:00<00:00, 15.51it/s]

LP: -86372.4:  60%|######    | 15/25 [00:00<00:00, 16.91it/s]

LP: -86338.4:  60%|######    | 15/25 [00:00<00:00, 16.91it/s]

LP: -86310.8:  60%|######    | 15/25 [00:01<00:00, 16.91it/s]

LP: -86310.8:  68%|######8   | 17/25 [00:01<00:00, 16.83it/s]

LP: -86280.0:  68%|######8   | 17/25 [00:01<00:00, 16.83it/s]

LP: -86247.6:  68%|######8   | 17/25 [00:01<00:00, 16.83it/s]

LP: -86247.6:  76%|#######6  | 19/25 [00:01<00:00, 16.67it/s]

LP: -86218.4:  76%|#######6  | 19/25 [00:01<00:00, 16.67it/s]

LP: -86181.1:  76%|#######6  | 19/25 [00:01<00:00, 16.67it/s]

LP: -86181.1:  84%|########4 | 21/25 [00:01<00:00, 16.00it/s]

LP: -86149.1:  84%|########4 | 21/25 [00:01<00:00, 16.00it/s]

LP: -86110.1:  84%|########4 | 21/25 [00:01<00:00, 16.00it/s]

LP: -86110.1:  92%|#########2| 23/25 [00:01<00:00, 14.35it/s]

LP: -86062.6:  92%|#########2| 23/25 [00:01<00:00, 14.35it/s]

LP: -86011.9:  92%|#########2| 23/25 [00:01<00:00, 14.35it/s]

LP: -86011.9: 100%|##########| 25/25 [00:01<00:00, 13.62it/s]
LP: -86011.9: 100%|##########| 25/25 [00:01<00:00, 15.59it/s]

ARHMM Initialization restarts:  50%|#####     | 5/10 [00:08<00:08,  1.70s/it]Initializing with an ARHMM using 25 steps of EM.


  0%|          | 0/25 [00:00<?, ?it/s]

LP: -88560.5:   0%|          | 0/25 [00:00<?, ?it/s]

LP: -88562.2:   0%|          | 0/25 [00:00<?, ?it/s]

LP: -88412.2:   0%|          | 0/25 [00:00<?, ?it/s]

LP: -88412.2:   8%|8         | 2/25 [00:00<00:01, 18.06it/s]

LP: -88212.4:   8%|8         | 2/25 [00:00<00:01, 18.06it/s]

LP: -87975.0:   8%|8         | 2/25 [00:00<00:01, 18.06it/s]

LP: -87975.0:  16%|#6        | 4/25 [00:00<00:01, 15.75it/s]

LP: -87693.5:  16%|#6        | 4/25 [00:00<00:01, 15.75it/s]

LP: -87390.3:  16%|#6        | 4/25 [00:00<00:01, 15.75it/s]

LP: -87390.3:  24%|##4       | 6/25 [00:00<00:01, 14.41it/s]

LP: -87118.9:  24%|##4       | 6/25 [00:00<00:01, 14.41it/s]

LP: -86881.5:  24%|##4       | 6/25 [00:00<00:01, 14.41it/s]

LP: -86881.5:  32%|###2      | 8/25 [00:00<00:01, 14.40it/s]

LP: -86666.9:  32%|###2      | 8/25 [00:00<00:01, 14.40it/s]

LP: -86471.1:  32%|###2      | 8/25 [00:00<00:01, 14.40it/s]

LP: -86471.1:  40%|####      | 10/25 [00:00<00:01, 12.97it/s]

LP: -86320.8:  40%|####      | 10/25 [00:00<00:01, 12.97it/s]

LP: -86211.7:  40%|####      | 10/25 [00:00<00:01, 12.97it/s]

LP: -86211.7:  48%|####8     | 12/25 [00:00<00:01, 12.15it/s]

LP: -86121.5:  48%|####8     | 12/25 [00:01<00:01, 12.15it/s]

LP: -86034.0:  48%|####8     | 12/25 [00:01<00:01, 12.15it/s]

LP: -86034.0:  56%|#####6    | 14/25 [00:01<00:00, 12.15it/s]

LP: -85977.4:  56%|#####6    | 14/25 [00:01<00:00, 12.15it/s]

LP: -85942.5:  56%|#####6    | 14/25 [00:01<00:00, 12.15it/s]

LP: -85942.5:  64%|######4   | 16/25 [00:01<00:00, 12.38it/s]

LP: -85926.8:  64%|######4   | 16/25 [00:01<00:00, 12.38it/s]

LP: -85916.4:  64%|######4   | 16/25 [00:01<00:00, 12.38it/s]

LP: -85916.4:  72%|#######2  | 18/25 [00:01<00:00, 13.37it/s]

LP: -85903.2:  72%|#######2  | 18/25 [00:01<00:00, 13.37it/s]

LP: -85891.4:  72%|#######2  | 18/25 [00:01<00:00, 13.37it/s]

LP: -85879.8:  72%|#######2  | 18/25 [00:01<00:00, 13.37it/s]

LP: -85879.8:  84%|########4 | 21/25 [00:01<00:00, 15.88it/s]

LP: -85869.2:  84%|########4 | 21/25 [00:01<00:00, 15.88it/s]

LP: -85859.8:  84%|########4 | 21/25 [00:01<00:00, 15.88it/s]

LP: -85859.8:  92%|#########2| 23/25 [00:01<00:00, 16.57it/s]

LP: -85850.2:  92%|#########2| 23/25 [00:01<00:00, 16.57it/s]

LP: -85838.7:  92%|#########2| 23/25 [00:01<00:00, 16.57it/s]

LP: -85838.7: 100%|##########| 25/25 [00:01<00:00, 17.13it/s]
LP: -85838.7: 100%|##########| 25/25 [00:01<00:00, 14.70it/s]

ARHMM Initialization restarts:  60%|######    | 6/10 [00:10<00:06,  1.71s/it]Initializing with an ARHMM using 25 steps of EM.


  0%|          | 0/25 [00:00<?, ?it/s]

LP: -88544.4:   0%|          | 0/25 [00:00<?, ?it/s]

LP: -88547.0:   0%|          | 0/25 [00:00<?, ?it/s]

LP: -88350.3:   0%|          | 0/25 [00:00<?, ?it/s]

LP: -88350.3:   8%|8         | 2/25 [00:00<00:01, 19.46it/s]

LP: -88063.9:   8%|8         | 2/25 [00:00<00:01, 19.46it/s]

LP: -87746.8:   8%|8         | 2/25 [00:00<00:01, 19.46it/s]

LP: -87746.8:  16%|#6        | 4/25 [00:00<00:01, 16.51it/s]

LP: -87458.9:  16%|#6        | 4/25 [00:00<00:01, 16.51it/s]

LP: -87208.1:  16%|#6        | 4/25 [00:00<00:01, 16.51it/s]

LP: -87208.1:  24%|##4       | 6/25 [00:00<00:01, 14.89it/s]

LP: -87006.6:  24%|##4       | 6/25 [00:00<00:01, 14.89it/s]

LP: -86849.3:  24%|##4       | 6/25 [00:00<00:01, 14.89it/s]

LP: -86849.3:  32%|###2      | 8/25 [00:00<00:01, 14.25it/s]

LP: -86703.1:  32%|###2      | 8/25 [00:00<00:01, 14.25it/s]

LP: -86553.2:  32%|###2      | 8/25 [00:00<00:01, 14.25it/s]

LP: -86553.2:  40%|####      | 10/25 [00:00<00:01, 14.13it/s]

LP: -86412.4:  40%|####      | 10/25 [00:00<00:01, 14.13it/s]

LP: -86281.9:  40%|####      | 10/25 [00:00<00:01, 14.13it/s]

LP: -86281.9:  48%|####8     | 12/25 [00:00<00:00, 15.38it/s]

LP: -86180.1:  48%|####8     | 12/25 [00:00<00:00, 15.38it/s]

LP: -86095.5:  48%|####8     | 12/25 [00:00<00:00, 15.38it/s]

LP: -86095.5:  56%|#####6    | 14/25 [00:00<00:00, 15.44it/s]

LP: -86024.2:  56%|#####6    | 14/25 [00:00<00:00, 15.44it/s]

LP: -85961.5:  56%|#####6    | 14/25 [00:01<00:00, 15.44it/s]

LP: -85961.5:  64%|######4   | 16/25 [00:01<00:00, 14.64it/s]

LP: -85893.3:  64%|######4   | 16/25 [00:01<00:00, 14.64it/s]

LP: -85817.9:  64%|######4   | 16/25 [00:01<00:00, 14.64it/s]

LP: -85817.9:  72%|#######2  | 18/25 [00:01<00:00, 13.73it/s]

LP: -85733.9:  72%|#######2  | 18/25 [00:01<00:00, 13.73it/s]

LP: -85643.0:  72%|#######2  | 18/25 [00:01<00:00, 13.73it/s]

LP: -85643.0:  80%|########  | 20/25 [00:01<00:00, 13.32it/s]

LP: -85549.4:  80%|########  | 20/25 [00:01<00:00, 13.32it/s]

LP: -85460.6:  80%|########  | 20/25 [00:01<00:00, 13.32it/s]

LP: -85460.6:  88%|########8 | 22/25 [00:01<00:00, 12.34it/s]

LP: -85379.9:  88%|########8 | 22/25 [00:01<00:00, 12.34it/s]

LP: -85320.4:  88%|########8 | 22/25 [00:01<00:00, 12.34it/s]

LP: -85320.4:  96%|#########6| 24/25 [00:01<00:00, 11.78it/s]

LP: -85282.0:  96%|#########6| 24/25 [00:01<00:00, 11.78it/s]
LP: -85282.0: 100%|##########| 25/25 [00:01<00:00, 13.43it/s]

ARHMM Initialization restarts:  70%|#######   | 7/10 [00:12<00:05,  1.78s/it]Initializing with an ARHMM using 25 steps of EM.


  0%|          | 0/25 [00:00<?, ?it/s]

LP: -88560.7:   0%|          | 0/25 [00:00<?, ?it/s]

LP: -88562.7:   0%|          | 0/25 [00:00<?, ?it/s]

LP: -88398.9:   0%|          | 0/25 [00:00<?, ?it/s]

LP: -88167.0:   0%|          | 0/25 [00:00<?, ?it/s]

LP: -88167.0:  12%|#2        | 3/25 [00:00<00:01, 20.85it/s]

LP: -87870.5:  12%|#2        | 3/25 [00:00<00:01, 20.85it/s]

LP: -87526.6:  12%|#2        | 3/25 [00:00<00:01, 20.85it/s]

LP: -87179.8:  12%|#2        | 3/25 [00:00<00:01, 20.85it/s]

LP: -87179.8:  24%|##4       | 6/25 [00:00<00:01, 16.47it/s]

LP: -86899.6:  24%|##4       | 6/25 [00:00<00:01, 16.47it/s]

LP: -86678.3:  24%|##4       | 6/25 [00:00<00:01, 16.47it/s]

LP: -86678.3:  32%|###2      | 8/25 [00:00<00:01, 14.55it/s]

LP: -86472.0:  32%|###2      | 8/25 [00:00<00:01, 14.55it/s]

LP: -86277.4:  32%|###2      | 8/25 [00:00<00:01, 14.55it/s]

LP: -86277.4:  40%|####      | 10/25 [00:00<00:01, 13.45it/s]

LP: -86134.2:  40%|####      | 10/25 [00:00<00:01, 13.45it/s]

LP: -86035.3:  40%|####      | 10/25 [00:00<00:01, 13.45it/s]

LP: -86035.3:  48%|####8     | 12/25 [00:00<00:01, 12.51it/s]

LP: -85939.7:  48%|####8     | 12/25 [00:00<00:01, 12.51it/s]

LP: -85840.8:  48%|####8     | 12/25 [00:01<00:01, 12.51it/s]

LP: -85840.8:  56%|#####6    | 14/25 [00:01<00:00, 12.29it/s]

LP: -85735.6:  56%|#####6    | 14/25 [00:01<00:00, 12.29it/s]

LP: -85622.8:  56%|#####6    | 14/25 [00:01<00:00, 12.29it/s]

LP: -85622.8:  64%|######4   | 16/25 [00:01<00:00, 11.90it/s]

LP: -85520.5:  64%|######4   | 16/25 [00:01<00:00, 11.90it/s]

LP: -85426.0:  64%|######4   | 16/25 [00:01<00:00, 11.90it/s]

LP: -85426.0:  72%|#######2  | 18/25 [00:01<00:00, 11.25it/s]

LP: -85350.5:  72%|#######2  | 18/25 [00:01<00:00, 11.25it/s]

LP: -85295.8:  72%|#######2  | 18/25 [00:01<00:00, 11.25it/s]

LP: -85295.8:  80%|########  | 20/25 [00:01<00:00, 11.20it/s]

LP: -85261.0:  80%|########  | 20/25 [00:01<00:00, 11.20it/s]

LP: -85244.3:  80%|########  | 20/25 [00:01<00:00, 11.20it/s]

LP: -85244.3:  88%|########8 | 22/25 [00:01<00:00, 12.50it/s]

LP: -85229.6:  88%|########8 | 22/25 [00:01<00:00, 12.50it/s]

LP: -85223.3:  88%|########8 | 22/25 [00:01<00:00, 12.50it/s]

LP: -85223.3:  96%|#########6| 24/25 [00:01<00:00, 13.15it/s]

LP: -85218.0:  96%|#########6| 24/25 [00:01<00:00, 13.15it/s]
LP: -85218.0: 100%|##########| 25/25 [00:01<00:00, 13.03it/s]

ARHMM Initialization restarts:  80%|########  | 8/10 [00:14<00:03,  1.84s/it]Initializing with an ARHMM using 25 steps of EM.


  0%|          | 0/25 [00:00<?, ?it/s]

LP: -88549.6:   0%|          | 0/25 [00:00<?, ?it/s]

LP: -88552.3:   0%|          | 0/25 [00:00<?, ?it/s]

LP: -88405.6:   0%|          | 0/25 [00:00<?, ?it/s]

LP: -88235.5:   0%|          | 0/25 [00:00<?, ?it/s]

LP: -88235.5:  12%|#2        | 3/25 [00:00<00:01, 21.22it/s]

LP: -88062.3:  12%|#2        | 3/25 [00:00<00:01, 21.22it/s]

LP: -87890.7:  12%|#2        | 3/25 [00:00<00:01, 21.22it/s]

LP: -87722.7:  12%|#2        | 3/25 [00:00<00:01, 21.22it/s]

LP: -87722.7:  24%|##4       | 6/25 [00:00<00:01, 17.57it/s]

LP: -87531.4:  24%|##4       | 6/25 [00:00<00:01, 17.57it/s]

LP: -87319.1:  24%|##4       | 6/25 [00:00<00:01, 17.57it/s]

LP: -87319.1:  32%|###2      | 8/25 [00:00<00:01, 16.89it/s]

LP: -87120.1:  32%|###2      | 8/25 [00:00<00:01, 16.89it/s]

LP: -86918.3:  32%|###2      | 8/25 [00:00<00:01, 16.89it/s]

LP: -86918.3:  40%|####      | 10/25 [00:00<00:01, 14.94it/s]

LP: -86729.1:  40%|####      | 10/25 [00:00<00:01, 14.94it/s]

LP: -86584.5:  40%|####      | 10/25 [00:00<00:01, 14.94it/s]

LP: -86584.5:  48%|####8     | 12/25 [00:00<00:00, 14.33it/s]

LP: -86487.4:  48%|####8     | 12/25 [00:00<00:00, 14.33it/s]

LP: -86430.4:  48%|####8     | 12/25 [00:00<00:00, 14.33it/s]

LP: -86430.4:  56%|#####6    | 14/25 [00:00<00:00, 15.23it/s]

LP: -86390.5:  56%|#####6    | 14/25 [00:00<00:00, 15.23it/s]

LP: -86365.0:  56%|#####6    | 14/25 [00:00<00:00, 15.23it/s]

LP: -86344.5:  56%|#####6    | 14/25 [00:01<00:00, 15.23it/s]

LP: -86344.5:  68%|######8   | 17/25 [00:01<00:00, 17.16it/s]

LP: -86328.2:  68%|######8   | 17/25 [00:01<00:00, 17.16it/s]

LP: -86313.2:  68%|######8   | 17/25 [00:01<00:00, 17.16it/s]

LP: -86299.4:  68%|######8   | 17/25 [00:01<00:00, 17.16it/s]

LP: -86299.4:  80%|########  | 20/25 [00:01<00:00, 17.44it/s]

LP: -86285.6:  80%|########  | 20/25 [00:01<00:00, 17.44it/s]

LP: -86274.3:  80%|########  | 20/25 [00:01<00:00, 17.44it/s]

LP: -86274.3:  88%|########8 | 22/25 [00:01<00:00, 17.40it/s]

LP: -86262.7:  88%|########8 | 22/25 [00:01<00:00, 17.40it/s]

LP: -86250.4:  88%|########8 | 22/25 [00:01<00:00, 17.40it/s]

LP: -86250.4:  96%|#########6| 24/25 [00:01<00:00, 18.01it/s]

LP: -86239.3:  96%|#########6| 24/25 [00:01<00:00, 18.01it/s]
LP: -86239.3: 100%|##########| 25/25 [00:01<00:00, 17.14it/s]

ARHMM Initialization restarts:  90%|######### | 9/10 [00:15<00:01,  1.73s/it]Initializing with an ARHMM using 25 steps of EM.


  0%|          | 0/25 [00:00<?, ?it/s]

LP: -88554.1:   0%|          | 0/25 [00:00<?, ?it/s]

LP: -88556.8:   0%|          | 0/25 [00:00<?, ?it/s]

LP: -88381.9:   0%|          | 0/25 [00:00<?, ?it/s]

LP: -88381.9:   8%|8         | 2/25 [00:00<00:01, 17.35it/s]

LP: -88134.9:   8%|8         | 2/25 [00:00<00:01, 17.35it/s]

LP: -87867.2:   8%|8         | 2/25 [00:00<00:01, 17.35it/s]

LP: -87643.9:   8%|8         | 2/25 [00:00<00:01, 17.35it/s]

LP: -87643.9:  20%|##        | 5/25 [00:00<00:01, 19.68it/s]

LP: -87471.3:  20%|##        | 5/25 [00:00<00:01, 19.68it/s]

LP: -87335.7:  20%|##        | 5/25 [00:00<00:01, 19.68it/s]

LP: -87335.7:  28%|##8       | 7/25 [00:00<00:01, 17.12it/s]

LP: -87210.5:  28%|##8       | 7/25 [00:00<00:01, 17.12it/s]

LP: -87077.2:  28%|##8       | 7/25 [00:00<00:01, 17.12it/s]

LP: -87077.2:  36%|###6      | 9/25 [00:00<00:01, 15.54it/s]

LP: -86936.6:  36%|###6      | 9/25 [00:00<00:01, 15.54it/s]

LP: -86772.2:  36%|###6      | 9/25 [00:00<00:01, 15.54it/s]

LP: -86772.2:  44%|####4     | 11/25 [00:00<00:01, 13.56it/s]

LP: -86591.4:  44%|####4     | 11/25 [00:00<00:01, 13.56it/s]

LP: -86405.8:  44%|####4     | 11/25 [00:00<00:01, 13.56it/s]

LP: -86405.8:  52%|#####2    | 13/25 [00:00<00:00, 13.27it/s]

LP: -86206.6:  52%|#####2    | 13/25 [00:00<00:00, 13.27it/s]

LP: -86008.2:  52%|#####2    | 13/25 [00:01<00:00, 13.27it/s]

LP: -86008.2:  60%|######    | 15/25 [00:01<00:00, 12.60it/s]

LP: -85809.3:  60%|######    | 15/25 [00:01<00:00, 12.60it/s]

LP: -85625.4:  60%|######    | 15/25 [00:01<00:00, 12.60it/s]

LP: -85625.4:  68%|######8   | 17/25 [00:01<00:00, 12.12it/s]

LP: -85458.3:  68%|######8   | 17/25 [00:01<00:00, 12.12it/s]

LP: -85344.3:  68%|######8   | 17/25 [00:01<00:00, 12.12it/s]

LP: -85344.3:  76%|#######6  | 19/25 [00:01<00:00, 11.09it/s]

LP: -85286.3:  76%|#######6  | 19/25 [00:01<00:00, 11.09it/s]

LP: -85245.4:  76%|#######6  | 19/25 [00:01<00:00, 11.09it/s]

LP: -85245.4:  84%|########4 | 21/25 [00:01<00:00, 11.13it/s]

LP: -85229.2:  84%|########4 | 21/25 [00:01<00:00, 11.13it/s]

LP: -85224.2:  84%|########4 | 21/25 [00:01<00:00, 11.13it/s]

LP: -85224.2:  92%|#########2| 23/25 [00:01<00:00, 12.81it/s]

LP: -85217.3:  92%|#########2| 23/25 [00:01<00:00, 12.81it/s]

LP: -85214.4:  92%|#########2| 23/25 [00:01<00:00, 12.81it/s]

LP: -85214.4: 100%|##########| 25/25 [00:01<00:00, 13.77it/s]
LP: -85214.4: 100%|##########| 25/25 [00:01<00:00, 13.44it/s]

ARHMM Initialization restarts: 100%|##########| 10/10 [00:17<00:00,  1.79s/it]
ARHMM Initialization restarts: 100%|##########| 10/10 [00:17<00:00,  1.76s/it]

  0%|          | 0/30 [00:00<?, ?it/s]
ELBO: -464532.6:   0%|          | 0/30 [00:00<?, ?it/s]
ELBO: -310624.2:   0%|          | 0/30 [00:11<?, ?it/s]
ELBO: -310624.2:   3%|3         | 1/30 [00:11<05:24, 11.19s/it]
ELBO: -302823.3:   3%|3         | 1/30 [00:20<05:24, 11.19s/it]
ELBO: -302823.3:   7%|6         | 2/30 [00:20<04:41, 10.07s/it]
ELBO: -299983.3:   7%|6         | 2/30 [00:29<04:41, 10.07s/it]
ELBO: -299983.3:  10%|#         | 3/30 [00:29<04:13,  9.38s/it]
ELBO: -298108.0:  10%|#         | 3/30 [00:37<04:13,  9.38s/it]
ELBO: -298108.0:  13%|#3        | 4/30 [00:37<03:55,  9.06s/it]
ELBO: -297458.8:  13%|#3        | 4/30 [00:45<03:55,  9.06s/it]
ELBO: -297458.8:  17%|#6        | 5/30 [00:45<03:33,  8.54s/it]
ELBO: -297252.5:  17%|#6        | 5/30 [00:52<03:33,  8.54s/it]
ELBO: -297252.5:  20%|##        | 6/30 [00:52<03:15,  8.14s/it]
ELBO: -296877.2:  20%|##        | 6/30 [00:59<03:15,  8.14s/it]
ELBO: -296877.2:  23%|##3       | 7/30 [00:59<03:00,  7.85s/it]
ELBO: -297119.1:  23%|##3       | 7/30 [01:07<03:00,  7.85s/it]
ELBO: -297119.1:  27%|##6       | 8/30 [01:07<02:49,  7.70s/it]
ELBO: -296639.7:  27%|##6       | 8/30 [01:14<02:49,  7.70s/it]
ELBO: -296639.7:  30%|###       | 9/30 [01:14<02:38,  7.55s/it]
ELBO: -296420.9:  30%|###       | 9/30 [01:21<02:38,  7.55s/it]
ELBO: -296420.9:  33%|###3      | 10/30 [01:21<02:29,  7.45s/it]
ELBO: -296467.9:  33%|###3      | 10/30 [01:28<02:29,  7.45s/it]
ELBO: -296467.9:  37%|###6      | 11/30 [01:28<02:20,  7.40s/it]
ELBO: -296510.5:  37%|###6      | 11/30 [01:36<02:20,  7.40s/it]
ELBO: -296510.5:  40%|####      | 12/30 [01:36<02:12,  7.34s/it]
ELBO: -296649.0:  40%|####      | 12/30 [01:43<02:12,  7.34s/it]
ELBO: -296649.0:  43%|####3     | 13/30 [01:43<02:05,  7.36s/it]
ELBO: -296362.8:  43%|####3     | 13/30 [01:50<02:05,  7.36s/it]
ELBO: -296362.8:  47%|####6     | 14/30 [01:50<01:57,  7.34s/it]
ELBO: -296321.8:  47%|####6     | 14/30 [01:58<01:57,  7.34s/it]
ELBO: -296321.8:  50%|#####     | 15/30 [01:58<01:50,  7.34s/it]
ELBO: -296221.9:  50%|#####     | 15/30 [02:05<01:50,  7.34s/it]
ELBO: -296221.9:  53%|#####3    | 16/30 [02:05<01:42,  7.35s/it]
ELBO: -296085.0:  53%|#####3    | 16/30 [02:12<01:42,  7.35s/it]
ELBO: -296085.0:  57%|#####6    | 17/30 [02:12<01:35,  7.33s/it]
ELBO: -295992.1:  57%|#####6    | 17/30 [02:20<01:35,  7.33s/it]
ELBO: -295992.1:  60%|######    | 18/30 [02:20<01:27,  7.31s/it]
ELBO: -296281.7:  60%|######    | 18/30 [02:27<01:27,  7.31s/it]
ELBO: -296281.7:  63%|######3   | 19/30 [02:27<01:20,  7.33s/it]
ELBO: -295891.4:  63%|######3   | 19/30 [02:34<01:20,  7.33s/it]
ELBO: -295891.4:  67%|######6   | 20/30 [02:34<01:13,  7.31s/it]
ELBO: -296296.8:  67%|######6   | 20/30 [02:42<01:13,  7.31s/it]
ELBO: -296296.8:  70%|#######   | 21/30 [02:42<01:05,  7.31s/it]
ELBO: -296196.2:  70%|#######   | 21/30 [02:49<01:05,  7.31s/it]
ELBO: -296196.2:  73%|#######3  | 22/30 [02:49<00:58,  7.31s/it]
ELBO: -296300.0:  73%|#######3  | 22/30 [02:56<00:58,  7.31s/it]
ELBO: -296300.0:  77%|#######6  | 23/30 [02:56<00:51,  7.32s/it]
ELBO: -296139.2:  77%|#######6  | 23/30 [03:03<00:51,  7.32s/it]
ELBO: -296139.2:  80%|########  | 24/30 [03:03<00:43,  7.27s/it]
ELBO: -296013.9:  80%|########  | 24/30 [03:11<00:43,  7.27s/it]
ELBO: -296013.9:  83%|########3 | 25/30 [03:11<00:36,  7.25s/it]
ELBO: -296084.1:  83%|########3 | 25/30 [03:18<00:36,  7.25s/it]
ELBO: -296084.1:  87%|########6 | 26/30 [03:18<00:29,  7.28s/it]
ELBO: -296109.2:  87%|########6 | 26/30 [03:25<00:29,  7.28s/it]
ELBO: -296109.2:  90%|######### | 27/30 [03:25<00:21,  7.28s/it]
ELBO: -296167.3:  90%|######### | 27/30 [03:32<00:21,  7.28s/it]
ELBO: -296167.3:  93%|#########3| 28/30 [03:32<00:14,  7.22s/it]
ELBO: -296169.1:  93%|#########3| 28/30 [03:39<00:14,  7.22s/it]
ELBO: -296169.1:  97%|#########6| 29/30 [03:39<00:07,  7.22s/it]
ELBO: -295957.2:  97%|#########6| 29/30 [03:47<00:07,  7.22s/it]
ELBO: -295957.2: 100%|##########| 30/30 [03:47<00:00,  7.20s/it]
ELBO: -295957.2: 100%|##########| 30/30 [03:47<00:00,  7.57s/it]
Discrete state accuracy:  0.978

<matplotlib.legend.Legend object at 0x28fc0f880>

# + [markdown] colab_type="text" id="view-in-github"
# <a href="https://colab.research.google.com/github/lindermanlab/ssm/blob/master/notebooks/Multi-Population%20rSLDS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>
# -

# ### If you want to quickly see how to fit your own data, jump down to the "Fit model to data" section
# <br />
# <br />
#
# # Multi-population recurrent switching linear dynamical systems overview
#
# This notebook goes through the simulation example shown in our manuscript (Figure 2A,B).
#
# Below, we briefly describe the model. We also recommend looking at the "Recurrent SLDS" notebook, which provides more details on the standard rSLDS.
# <br />
# <br />
#
# **1. Data**.
# Let $y_t^{_{(j)}}$ denote a vector of activity measurements of the $N_j$ neurons in population $j$ in time bin $t$.
# <br />
#
# **2. Emissions**.
# let $x_t^{_{(j)}}$ denote a continuous latent state of population $j$ at time $t$. The population states may differ in dimensionality~$D_j$, since populations may differ in size and complexity. The observed activity of population $j$ is modeled with a generalized linear model,
# \begin{align}
#     E[y_t^{(j)}] &= f(C_j x_t^{(j)} + d_j),
# \end{align}
# where each population has its own linear mapping parameterized by $\{C_j, d_j\}$. In this notebook, we use a Poisson GLM. Inputs can also be passed into this GLM, as described in the rSLDS notebook.
#
# There are multi-population emissions classes that will be loaded in the example below.
# <br />
#
# **3. Continuous State Update (Dynamics)**.
# The dynamics of a switching linear dynamical system are piecewise linear, with the linear dynamics at a given time determined by a discrete state, (more on discrete states below).
#
# \begin{align}
#     x_t \sim
#     A^{(z_t)} x_{t-1} + b^{(z_t)}
# \end{align}
#
# where $z_t$ is the discrete state, $A^{(z_t)}$ is the dynamics for that discrete state, and $x_t$ contains the latents from all populations, $[x_t^1, x_t^2, ..., x_t^J]$. We ignore the noise term here for simplicity.
#
# Having unique continuous latents for each population allows us to decompose the dynamics in an interpretable manner. We model the temporal dynamics of the continuous states as
# \begin{align}
#     x_t^{(j)} \sim
#     A_{(j \: to \: j)}^{(z_t)} x_{t-1}^{(j)}
#     + \sum_{i \neq j} A_{(i \: to \: j)}^{(z_t)} x_{t-1}^{(i)}
#     + b_j^{(z_t)}.
# \end{align}
#
# In the full dynamics matrices, $A^{(z_t)}$ we will show in the example below, the on-diagonal blocks represent the internal dynamics, $A_{(j \: to \: j)}^{(z_t)}$ and the off-diagonal blocks represent the external dynamics, $A_{(i \: to \: j)}^{(z_t)}$.
#
#
# **4. Discrete State Update (Transitions)**.
# Recurrent transitions are based on the continuous latent state. Our recurrent transitions have a sticky component, $S$ that determines the probabilities of staying in a state, and a switching component, $R$, that determines the probabilities of switching to states. In the model we use in this notebook:
#
# \begin{align}
#     p(z_t = i \mid z_{t-1} = j, x_{t-1}) &= \mathrm{softmax}\bigg\{ \Big( \big(R x_{t-1}\big) + r\Big) \odot (1 - e_{z_{t-1}})  + \Big( \big(S x_{t-1} \big) + s \Big) \odot e_{z_{t-1}}  \bigg\},
# \end{align}
#
# where $e_{z_{t-1}} \in \{0,1\}^K$ is a one-hot encoding of $z_{t-1}$.
#
# To understand which populations are contributing to the transitions, we can decompose this equation:
#
#
# \begin{align}
#     p(z_t = i \mid z_{t-1} = j, x_{t-1}) &= \mathrm{softmax}\bigg\{ \Big( \sum_{j=1}^J \big(R_j x_{t-1}^{(j)}\big) + r\Big) \odot (1 - e_{z_{t-1}})  + \Big( \sum_{j=1}^J \big(S_j x_{t-1}^{(j)} \big) + s \Big) \odot e_{z_{t-1}}  \bigg\},
# \end{align}
# where, for example, $R_j x_{t-1}^{(j)}$ contains the contribution of population $j$ towards switching to each state.
#
#
# Additionally, we can include a dependency on the previous discrete state. This is included in the code package, but is not used in the example below.
#
# \begin{align}
#     p(z_t = i \mid z_{t-1} = j, x_{t-1}) &= \mathrm{softmax}\bigg\{ \log(P_{j,i}) +  \big(R x_{t-1}\big) \odot (1 - e_{z_{t-1}})  + \big(S x_{t-1} \big) \odot e_{z_{t-1}}  \bigg\},
# \end{align}
#
# There are sticky multi-population emissions classes that will be loaded in the example below.
# <br />
#
# **5. Model fitting**.
# We fit the model with variational laplace EM - see the "Variational Laplace EM for SLDS Tutorial" for more information.
#

# + [markdown] colab_type="text" id="8OzC8q4bRFQv"
# ## Import packages, including multipopulation extensions

# + colab={"base_uri": "https://localhost:8080/", "height": 581} colab_type="code" id="ruUnNqi5RZqT" outputId="228b6c8e-c064-46c2-ce57-9ad88daca5c8"
try:
    import ssm
except:
    # !pip install git+https://github.com/lindermanlab/ssm.git#egg=ssm
    import ssm

# + colab={"base_uri": "https://localhost:8080/", "height": 71} colab_type="code" id="zDn3tEJhRFQv" outputId="2f1ca1d0-8f17-404a-897f-57b8c5d353cb"
#### General packages

from matplotlib import pyplot as plt
# %matplotlib inline
import autograd.numpy as np
import autograd.numpy.random as npr

import seaborn as sns
sns.set_style("white")
sns.set_context("talk")
sns.set_style('ticks',{"xtick.major.size":8,
"ytick.major.size":8})
from ssm.plots import gradient_cmap, white_to_color_cmap

color_names = [
    "purple",
    "red",
    "amber",
    "faded green",
    "windows blue",
    "orange"
    ]

colors = sns.xkcd_palette(color_names)
cmap = gradient_cmap(colors)

# + colab={} colab_type="code" id="0rq19iIQRFQy"
#### SSM PACKAGES ###

import ssm
from ssm.variational import SLDSMeanFieldVariationalPosterior, SLDSTriDiagVariationalPosterior, \
    SLDSStructuredMeanFieldVariationalPosterior
from ssm.util import random_rotation, find_permutation, relu

#Load from extensions
from ssm.extensions.mp_srslds.emissions_ext import GaussianOrthogonalCompoundEmissions, PoissonOrthogonalCompoundEmissions
from ssm.extensions.mp_srslds.transitions_ext import StickyRecurrentOnlyTransitions, StickyRecurrentTransitions

# + [markdown] colab_type="text" id="Ty3EOi8bRFQ1"
# ## Simulate (somewhat realistic) data

# + [markdown] colab_type="text" id="QxDoYCRDRFQ2"
# ### Set parameters of simulation

# + colab={} colab_type="code" id="dalqY6zvRFQ2"
K=3 #Number of discrete states

num_gr=3 #Number of populations
num_per_gr=5 #Number of latents per population
neur_per_gr=75 #Number of neurons per population

t_end=3000 #number of time bins
num_trials=1 #number of trials

# + colab={"base_uri": "https://localhost:8080/", "height": 34} colab_type="code" id="OHTSTNbTRFQ4" outputId="fd3b833b-8df2-425b-d3ce-103cf42d5153"
np.random.seed(108) #To create replicable dynamics

alphas=.03+.1*np.random.rand(K) #Determines the distribution of values in the dynamics matrix, for each discrete state
print('alphas:', alphas)

sparsity=.33 #Proportion of non-diagonal blocks in the dynamics matrix that are 0

e1=.1 #Amount of noise in the dynamics

# + [markdown] colab_type="text" id="t91lYSkPRFQ7"
# ### Get new emissions and transitions classes for the simulated data

# + colab={} colab_type="code" id="tQx540b6RFQ8"
#Vector containing number of latents per population
D_vec=[]
for i in range(num_gr):
    D_vec.append(num_per_gr)

#Vector containing number of neurons per population
N_vec=[]
for i in range(num_gr):
    N_vec.append(neur_per_gr)

D=np.sum(D_vec)
num_gr=len(D_vec)
D_vec_cumsum = np.concatenate(([0], np.cumsum(D_vec)))

#Get new multipopulation emissions class for the simulation

# gauss_comp_emissions=GaussianOrthogonalCompoundEmissions(N=np.sum(N_vec),K=1,D=np.sum(D_vec),D_vec=D_vec,N_vec=N_vec)
poiss_comp_emissions=PoissonOrthogonalCompoundEmissions(N=np.sum(N_vec),K=1,D=np.sum(D_vec),D_vec=D_vec,N_vec=N_vec,link='softplus')

#Get transitions class
true_sro_trans=StickyRecurrentOnlyTransitions(K=K,D=np.sum(D_vec))

# + [markdown] colab_type="text" id="DDu2GnRGRFQ-"
# ### Create simulated data

# + colab={} colab_type="code" id="VLN8FWLLRFQ_"
np.random.seed(10) #To create replicable simulations

A_masks=[]

A_all=np.zeros([K,D,D]) #Initialize dynamics matrix
b_all=np.zeros([K,D]) #Initialize dynamics offset


#Create initial ground truth model, that we will modify
true_slds = ssm.SLDS(N=np.sum(N_vec),K=K,D=int(np.sum(D_vec)),
             dynamics="gaussian",
             emissions=poiss_comp_emissions,
             transitions=true_sro_trans)

#Create ground truth transitions
v=.2+.2*np.random.rand(1)
for k in range(K):
    inc=np.copy(k)
    true_slds.transitions.Rs[k,D_vec_cumsum[inc]:D_vec_cumsum[inc]+1]=v
    true_slds.transitions.Ss[k,D_vec_cumsum[inc]:D_vec_cumsum[inc]+1]=v-.1

true_slds.transitions.r=0*np.ones([K,1])
true_slds.transitions.s=5*np.ones([K,1])

#Create ground truth dynamics for each state
for k in range(K):

    ##Create dynamics##
    alpha=alphas[k]

    A_mask=np.random.rand(num_gr,num_gr)>sparsity #Make some blocks of the dynamics matrix 0

    A_masks.append(A_mask)

    for i in range(num_gr):
        A_mask[i,i]=1

    A0=np.zeros([D,D])
    for i in range(D-1):
        A0[i,i+1:]=alpha*np.random.randn(D-1-i)
    A0=(A0-A0.T)

    for i in range(num_gr):
        A0[D_vec_cumsum[i]:D_vec_cumsum[i+1],D_vec_cumsum[i]:D_vec_cumsum[i+1]]=2*A0[D_vec_cumsum[i]:D_vec_cumsum[i+1],D_vec_cumsum[i]:D_vec_cumsum[i+1]]


    A0=A0+np.identity(D)
    A=A0*np.kron(A_mask, np.ones((num_per_gr, num_per_gr)))

    A=A/(np.max(np.abs(np.linalg.eigvals(A)))+.01) #.97

    b=1*np.random.rand(D)

    A_all[k]=A
    b_all[k]=b

true_slds.dynamics.As=A_all
true_slds.dynamics.bs=b_all


zs, xs, _ = true_slds.sample(t_end) #Sample discrete and continuous latents from model for simulation

#Get spike trains that have an average firing rate of 0.25 per bin
tmp=np.mean(relu(np.dot(true_slds.emissions.Cs[0],xs.T)+.1*true_slds.emissions.ds[0][:,None]).T)
mult=.25/tmp
lams=relu(mult*np.dot(true_slds.emissions.Cs[0],xs.T)+.1*true_slds.emissions.ds[0][:,None]).T
ys=np.random.poisson(lams) #Get spiking activity based on poisson statistics

# + [markdown] colab_type="text" id="twuPg8wRRFRC"
# ## Plot simulated data

# + [markdown] colab_type="text" id="1-VkH7xSRFRC"
# ### Dynamics matrices ($A^z$)

# + colab={"base_uri": "https://localhost:8080/", "height": 797} colab_type="code" id="qv7hO5fgRFRD" outputId="7b2a3151-a5dc-4ac0-c446-ecbb212ffb61"
# vmin,vmax=[-1,1]
vmin,vmax=[-.5,.5] #zoom in to see colors more clearly


for k in range(K):

    plt.figure(figsize=(4,4))
    plt.imshow(true_slds.dynamics.As[k], aspect='auto', interpolation="none", vmin=vmin, vmax=vmax, cmap='RdBu')
    offset=-.5
    for nf in D_vec:
        plt.plot([-0.5, D-0.5], [offset, offset], '-k')
        plt.plot([offset, offset], [-0.5, D-0.5], '-k')
        offset += nf
    plt.xticks([])
    plt.yticks([])
    plt.title('Actual State '+str(k))

# + [markdown] colab_type="text" id="-Y7R3y6TRFRF"
# ### Discrete states ($z$)

# + colab={"base_uri": "https://localhost:8080/", "height": 177} colab_type="code" id="Wmk2_uI-RFRG" outputId="2f434446-fbaf-4a00-d186-81a557b35011"
plt.figure(figsize=(8, 4))
plt.subplot(211)
plt.imshow(zs[None,:], aspect="auto", cmap=cmap, vmin=0, vmax=len(colors)-1)
plt.xlim(0, t_end)
plt.ylabel("$z_{\\mathrm{true}}$")
plt.yticks([])

# + [markdown] colab_type="text" id="ufQwDDKTRFRI"
# ### Transitions (in a shorter time window)
# The contribution of population $j$ to staying in a state is $S_j x^{j}$ and the contribution to switching to a state is $R_j x^{j}$

# + colab={} colab_type="code" id="fp1QQ_7dRFRI"
dur=200
st_t=650
end_t=st_t+dur

# + colab={"base_uri": "https://localhost:8080/", "height": 310} colab_type="code" id="-FWpoP7-RFRK" outputId="bcc60411-fb9e-44d0-bc16-e382fe7919aa"
plt.figure(figsize=(8, 4))

j=0

plt.subplot(211)
for g in range(K):
    plt.plot(np.dot(xs[st_t:end_t,D_vec_cumsum[g]:D_vec_cumsum[g+1]],true_slds.transitions.Rs[j,D_vec_cumsum[g]:D_vec_cumsum[g+1]].T))
plt.xlim(0, dur)
plt.ylabel('Contribution towards \n switching to \n purple state',rotation=60)
plt.xticks([])
plt.yticks([])

j=1
plt.subplot(212)
for g in range(K):
    plt.plot(np.dot(xs[st_t:end_t,D_vec_cumsum[g]:D_vec_cumsum[g+1]],true_slds.transitions.Ss[j,D_vec_cumsum[g]:D_vec_cumsum[g+1]].T))
plt.xlim(0, dur)
plt.ylabel('Contribution towards \n staying in red state',rotation=60)
plt.legend(['Pop 1','Pop 2','Pop 3'])
plt.yticks([])

# + [markdown] colab_type="text" id="wOO48DPARFRN"
# ### Continuous latents ($x$) and spikes ($y$) for an example population

# + colab={"base_uri": "https://localhost:8080/", "height": 449} colab_type="code" id="YllKBbdhRFRN" outputId="dc7f9218-0100-4263-8501-9d2fb1696a9e"
plt.figure(figsize=(8, 4))

plt.subplot(211)
plt.plot(xs[st_t:end_t,:num_per_gr]) #Show latents of first group
plt.xticks([])

plt.subplot(212)
plt.plot(ys[st_t:end_t,:10]) #Show first 10 neurons


# + [markdown] colab_type="text" id="OZ_iLjiyRFRV"
# ## Fit model to data

# + [markdown] colab_type="text" id="g6QfY0j-RFRV"
# #### To create the emissions classes for the multipopulation models, we need vectors containing the number of continuous latents per population ("D_vec") and neurons per population ("N_vec")
#

# + colab={} colab_type="code" id="kVZDXyIiRFRW"
num_gr=3 #Number of populations
num_per_gr=5 #Number of latents per population
neur_per_gr=75 #Number of neurons per population

#Vector containing number of latents per population
D_vec=[]
for i in range(num_gr):
    D_vec.append(num_per_gr)

#Vector containing number of neurons per population
N_vec=[]
for i in range(num_gr):
    N_vec.append(neur_per_gr)

# + [markdown] colab_type="text" id="_I-JFqbHRFRZ"
# #### Now create the multipopulation emissions and transitions classes for our model

# + colab={} colab_type="code" id="dhnavgIgRFRa"
#Get new multipopulation emissions class
poiss_comp_emissions=PoissonOrthogonalCompoundEmissions(N=np.sum(N_vec),K=1,D=np.sum(D_vec),D_vec=D_vec,N_vec=N_vec,link='softplus')

#Get new transitions class
sro_trans=StickyRecurrentOnlyTransitions(K=K,D=np.sum(D_vec), l2_penalty_similarity=10, l1_penalty=10)
#The above l2 penalty is on the similarity between R and S (its assuming the activity to switch into a state is similar to activity to stay in a state)
#The L1 penalty is on the entries of R and S
# -

# Note that another new emissions class is "GaussianOrthogonalCompoundEmissions" <br />
#
# Note that another new transitions class is "StickyRecurrentTransitions"

# + [markdown] colab_type="text" id="syHDAea_RFRc"
# #### Now declare and fit the model

# + colab={"base_uri": "https://localhost:8080/", "height": 166, "referenced_widgets": ["36b4e6c4b2064572a4412f8bc298caeb", "97247a689ff744689528aa9555c13c62", "581b2c1a39eb439581308a5cdea07389", "e8e6a6b80ed14b77900e87e1c779f459", "2e200ca5be3141f2b2ee041ec8b08e49", "f4e534f4d6ac45c6a7da04be7d341968", "ced98407e9ee4ee4bd0baaa24d4765b9", "14de5bf9b23f47d4867e135cf156ac97", "6bd4692a8c95492d84f6353069342d4f", "62f38f8736314a098f91d7486b5b84f6", "d7ffd5a5a8874c2d8e593cb29c8f14b2", "2ddbbe21934547db9e365d905ea1f024", "19ad77a391f7459f811262d5bdbeb783", "5da21c8eac7041898c7a468f0388e632", "eb1c51978e0f4b79968cf6618a959007", "c0204fcad9f64dd48580f7d17767abc6"]} colab_type="code" id="n7WRqQufRFRd" outputId="28ddbaaa-221c-4ca8-db59-b76bbeff1dd0"
K=3 #Number of discrete states

rslds = ssm.SLDS(N=np.sum(N_vec),K=K,D=np.sum(D_vec),
             dynamics="gaussian",
             emissions=poiss_comp_emissions,
             transitions=sro_trans,
             dynamics_kwargs=dict(l2_penalty_A=100)) #Regularization on the dynamics matrix

q_elbos_ar, q_ar = rslds.fit(ys, method="laplace_em",
                             variational_posterior="structured_meanfield",
                             continuous_optimizer='newton',
                             initialize=True,
                             num_init_restarts=10,
                             num_iters=30,
                             alpha=0.25)

# + colab={} colab_type="code" id="Oi8-3NcxRFRf" outputId="98075512-a594-4bae-fcf3-550affe796a6"
plt.plot(q_elbos_ar[1:])
plt.xlabel("Iteration")
plt.ylabel("ELBO")
# -

# ## Align solution with simulation for plotting

#The recovered discrete states can be permuted in any way.
#Find permutation to match the discrete states in the model and the ground truth
z_inferred=rslds.most_likely_states(q_ar.mean_continuous_states[0],ys)
rslds.permute(find_permutation(zs, z_inferred))
z_inferred2=rslds.most_likely_states(q_ar.mean_continuous_states[0],ys)

# +
#Each population's latents can be multiplied by an arbitrary rotation matrix
#Additionally, there may be a change in scaling between the simulation ground truth and recovered latents,
#because the simulation didn't constrain the effective emissions (C) matrix to be orthonormal like in the model

from sklearn.linear_model import LinearRegression

R=np.zeros([D,D])
for g in range(num_gr):
    lr=LinearRegression(fit_intercept=False)
    lr.fit(q_ar.mean_continuous_states[0][:,D_vec_cumsum[g]:D_vec_cumsum[g+1]],xs[:,D_vec_cumsum[g]:D_vec_cumsum[g+1]])
    R[D_vec_cumsum[g]:D_vec_cumsum[g+1],D_vec_cumsum[g]:D_vec_cumsum[g+1]]=lr.coef_

# + [markdown] colab_type="text" id="0CjkFH63RFRh"
# ## Plot results

# + [markdown] colab_type="text" id="HDlMQx4ZRFRh"
# ### Discrete states ($z$)

# + colab={} colab_type="code" id="TGk3zoBVRFRi" outputId="d1cebd3c-17f7-4e48-a6d3-0b9f34fbd3af"
plt.figure(figsize=(8, 4))
plt.subplot(211)
plt.imshow(zs[None,:], aspect="auto", cmap=cmap, vmin=0, vmax=len(colors)-1)
plt.xlim(0, t_end)
plt.ylabel("$z_{\\mathrm{true}}$")
plt.yticks([])

plt.subplot(212)
plt.imshow(z_inferred2[None,:], aspect="auto", cmap=cmap, vmin=0, vmax=len(colors)-1)
plt.xlim(0, t_end)
plt.ylabel("$z_{\\mathrm{inferred}}$")
plt.yticks([])
plt.xlabel("time")

plt.tight_layout()

# + colab={} colab_type="code" id="HRPr3DJlRFRj" outputId="1d6b402b-fa6d-4040-f4bf-4e4c2188c8b8"
print('Discrete state accuracy: ', np.mean(zs==z_inferred2))

# + [markdown] colab_type="text" id="SX1go1AyRFRl"
# #### Shorter time window

# + colab={} colab_type="code" id="UhjVHiYzRFRm" outputId="5d4cf636-9055-46fe-af2b-d51531c97317"
plt.figure(figsize=(4, 2))
plt.subplot(211)
plt.imshow(zs[None,st_t:end_t], aspect="auto", cmap=cmap, vmin=0, vmax=len(colors)-1)
plt.xlim(0, dur)
plt.ylabel("$z_{\\mathrm{true}}$")
plt.yticks([])
plt.xticks([])

plt.subplot(212)
plt.imshow(z_inferred2[None,st_t:end_t], aspect="auto", cmap=cmap, vmin=0, vmax=len(colors)-1)
plt.xlim(0, dur)
plt.ylabel("$z_{\\mathrm{inferred}}$")
plt.yticks([])

# + [markdown] colab_type="text" id="e4Cz1o8GRFRo"
# ### Dynamics matrices ($A^z$)
# -

# We show the A matrix from when the continuous latents are aligned to ground truth, demonstrating the ability to recover the ground truth dynamics.
#
# We also show the original recovered A matrix, which demonstrates that we can learn about the block structure, regardless of scaling/rotations.

# + colab={} colab_type="code" id="0D2QHaLRRFRo" outputId="891cf1a1-4aea-4c5f-f2d5-9895478b5361"
plt.figure(figsize=(12, 12))

q=1

for k in range(K):

    plt.subplot(3,3,q)
    plt.imshow(true_slds.dynamics.As[k], aspect='auto', interpolation="none", vmin=-.5, vmax=.5, cmap='RdBu')
    offset=-.5
    for nf in D_vec:
        plt.plot([-0.5, D-0.5], [offset, offset], '-k')
        plt.plot([offset, offset], [-0.5, D-0.5], '-k')
        offset += nf
    plt.xticks([])
    plt.yticks([])
    plt.title('Actual State '+str(k))

    q=q+1

    plt.subplot(3,3,q)
#     plt.imshow(rslds.dynamics.As[k], aspect='auto', interpolation="none", vmin=-.5, vmax=.5, cmap='RdBu')
    plt.imshow(R@rslds.dynamics.As[k]@np.linalg.inv(R), aspect='auto', interpolation="none", vmin=-.5, vmax=.5, cmap='RdBu')

    offset=-.5
    for nf in D_vec:
        plt.plot([-0.5, D-0.5], [offset, offset], '-k')
        plt.plot([offset, offset], [-0.5, D-0.5], '-k')
        offset += nf
    plt.xticks([])
    plt.yticks([])
    plt.title('Aligned Predicted State '+str(k))
#     plt.savefig(folder+'dyn_est'+str(k)+'.pdf')

    q=q+1


    plt.subplot(3,3,q)
#     plt.imshow(rslds.dynamics.As[k], aspect='auto', interpolation="none", vmin=-.5, vmax=.5, cmap='RdBu')
    plt.imshow(rslds.dynamics.As[k], aspect='auto', interpolation="none", vmin=-.5, vmax=.5, cmap='RdBu')

    offset=-.5
    for nf in D_vec:
        plt.plot([-0.5, D-0.5], [offset, offset], '-k')
        plt.plot([offset, offset], [-0.5, D-0.5], '-k')
        offset += nf
    plt.xticks([])
    plt.yticks([])
    plt.title('Raw Predicted State '+str(k))
#     plt.savefig(folder+'dyn_est'+str(k)+'.pdf')

    q=q+1






# + [markdown] colab_type="text" id="b14UYWuMRFRr"
# ### Transitions
# The contribution of population $j$ to staying in a state is $S_j x^{j}$ and the contribution to switching to a state is $R_j x^{j}$
#

# + colab={} colab_type="code" id="hja_9YLPRFRr" outputId="6eb42921-5277-43ea-b14b-2366d2128509"
plt.figure(figsize=(15, 4))


### Actual

j=0

plt.subplot(221)
for g in range(K):
    plt.plot(np.dot(xs[st_t:end_t,D_vec_cumsum[g]:D_vec_cumsum[g+1]],true_slds.transitions.Rs[j,D_vec_cumsum[g]:D_vec_cumsum[g+1]].T))
plt.xlim(0, dur)
plt.ylabel('Contribution towards \n switching to \n purple state',rotation=60)
plt.xticks([])
plt.yticks([])
plt.title('Actual')

j=1
plt.subplot(223)
for g in range(K):
    plt.plot(np.dot(xs[st_t:end_t,D_vec_cumsum[g]:D_vec_cumsum[g+1]],true_slds.transitions.Ss[j,D_vec_cumsum[g]:D_vec_cumsum[g+1]].T))
plt.xlim(0, dur)
plt.ylabel('Contribution towards \n staying in red state',rotation=60)
plt.legend(['Pop 1','Pop 2','Pop 3'])
plt.yticks([])




### Predicted

j=0

plt.subplot(222)
for g in range(K):
    plt.plot(np.dot(q_ar.mean_continuous_states[0][st_t:end_t,D_vec_cumsum[g]:D_vec_cumsum[g+1]],rslds.transitions.Rs[j,D_vec_cumsum[g]:D_vec_cumsum[g+1]].T))
plt.xlim(0, dur)
# plt.ylabel('Contribution towards \n switching to \n purple state',rotation=60)
plt.xticks([])
plt.yticks([])
plt.title('Predicted')

j=1
plt.subplot(224)
for g in range(K):
    plt.plot(np.dot(q_ar.mean_continuous_states[0][st_t:end_t,D_vec_cumsum[g]:D_vec_cumsum[g+1]],rslds.transitions.Ss[j,D_vec_cumsum[g]:D_vec_cumsum[g+1]].T))
plt.xlim(0, dur)
# plt.ylabel('Stay in Red')
# plt.xticks([])
plt.yticks([])

# + [markdown] colab_type="text" id="yz4GSxjrRFRt"
# ### Example fit of neural activity ($y$)

# + colab={} colab_type="code" id="cgEffzVKRFRt" outputId="10d64661-8fbd-4259-96fc-134c35980b88"
preds=rslds.smooth(q_ar.mean_continuous_states[0],ys) #get predictions

nrn=0 #Example neuron
plt.plot(ys[st_t:end_t,nrn],alpha=.5) #true spiking activity
plt.plot(lams[st_t:end_t,nrn]) #true firing rate
plt.plot(preds[st_t:end_t,nrn]) #predicted firing

plt.legend(['True Spikes','True FR','Predicted FR'])

Total running time of the script: ( 4 minutes 8.973 seconds)

Gallery generated by Sphinx-Gallery