diff --git a/pymc_extras/model/marginal/marginal_model.py b/pymc_extras/model/marginal/marginal_model.py index b6ca25bf..6fed7043 100644 --- a/pymc_extras/model/marginal/marginal_model.py +++ b/pymc_extras/model/marginal/marginal_model.py @@ -11,7 +11,7 @@ from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform from pymc.distributions.transforms import Chain from pymc.logprob.transforms import IntervalTransform -from pymc.model import Model +from pymc.model import Model, modelcontext from pymc.model.fgraph import ( ModelFreeRV, ModelValuedVar, @@ -337,8 +337,8 @@ def transform_posterior_pts(model, posterior_pts): def recover_marginals( - model: Model, idata: InferenceData, + model: Model | None = None, var_names: Sequence[str] | None = None, return_samples: bool = True, extend_inferencedata: bool = True, @@ -389,6 +389,11 @@ def recover_marginals( """ + if isinstance(idata, Model): + raise TypeError("The first argument of `recover_marginals` must be an idata") + + model = modelcontext(model) + unmarginal_model = unmarginalize(model) # Find the names of the marginalized variables diff --git a/tests/model/marginal/test_marginal_model.py b/tests/model/marginal/test_marginal_model.py index d9e50569..e7557f38 100644 --- a/tests/model/marginal/test_marginal_model.py +++ b/tests/model/marginal/test_marginal_model.py @@ -837,7 +837,9 @@ def test_basic(self): ) idata = InferenceData(posterior=dict_to_dataset(prior)) - idata = recover_marginals(marginal_m, idata, return_samples=True) + with marginal_m: + idata = recover_marginals(idata, return_samples=True) + post = idata.posterior assert "k" in post assert "lp_k" in post @@ -881,7 +883,8 @@ def test_coords(self): posterior=dict_to_dataset({k: np.expand_dims(prior[k], axis=0) for k in prior}) ) - idata = recover_marginals(marginal_m, idata, return_samples=True) + with marginal_m: + idata = recover_marginals(idata, return_samples=True) post = idata.posterior assert post.idx.dims == ("chain", "draw", "year") assert post.lp_idx.dims == ("chain", "draw", "year", "lp_idx_dim") @@ -907,7 +910,7 @@ def test_batched(self): posterior=dict_to_dataset({k: np.expand_dims(prior[k], axis=0) for k in prior}) ) - idata = recover_marginals(marginal_m, idata, return_samples=True) + idata = recover_marginals(idata, return_samples=True) post = idata.posterior assert post["y"].shape == (1, 20, 2, 3) assert post["idx"].shape == (1, 20, 3, 2) @@ -933,7 +936,7 @@ def test_nested(self): ) idata = InferenceData(posterior=dict_to_dataset(prior)) - idata = recover_marginals(marginal_m, idata, return_samples=True) + idata = recover_marginals(idata, return_samples=True) post = idata.posterior assert "idx" in post assert "lp_idx" in post