@@ -837,7 +837,9 @@ def test_basic(self):
837837 )
838838 idata = InferenceData (posterior = dict_to_dataset (prior ))
839839
840- idata = recover_marginals (marginal_m , idata , return_samples = True )
840+ with marginal_m :
841+ idata = recover_marginals (idata , return_samples = True )
842+
841843 post = idata .posterior
842844 assert "k" in post
843845 assert "lp_k" in post
@@ -881,7 +883,8 @@ def test_coords(self):
881883 posterior = dict_to_dataset ({k : np .expand_dims (prior [k ], axis = 0 ) for k in prior })
882884 )
883885
884- idata = recover_marginals (marginal_m , idata , return_samples = True )
886+ with marginal_m :
887+ idata = recover_marginals (idata , return_samples = True )
885888 post = idata .posterior
886889 assert post .idx .dims == ("chain" , "draw" , "year" )
887890 assert post .lp_idx .dims == ("chain" , "draw" , "year" , "lp_idx_dim" )
@@ -907,7 +910,7 @@ def test_batched(self):
907910 posterior = dict_to_dataset ({k : np .expand_dims (prior [k ], axis = 0 ) for k in prior })
908911 )
909912
910- idata = recover_marginals (marginal_m , idata , return_samples = True )
913+ idata = recover_marginals (idata , return_samples = True )
911914 post = idata .posterior
912915 assert post ["y" ].shape == (1 , 20 , 2 , 3 )
913916 assert post ["idx" ].shape == (1 , 20 , 3 , 2 )
@@ -933,7 +936,7 @@ def test_nested(self):
933936 )
934937 idata = InferenceData (posterior = dict_to_dataset (prior ))
935938
936- idata = recover_marginals (marginal_m , idata , return_samples = True )
939+ idata = recover_marginals (idata , return_samples = True )
937940 post = idata .posterior
938941 assert "idx" in post
939942 assert "lp_idx" in post
0 commit comments