How to sample with posterior_predictive using a nutpie compiled model?

Hopefully a trivial question regarding nutpie. Suppose I have the model:

with pm.Model() as gp_model:
    # data containers
    X = pm.Data("X", times[:, None])
    trials = pm.Data("trials", trials)
    sends = pm.Data("sends", sends)
    # priors
    ell1 = pm.HalfNormal("ell1", sigma=1.25)
    ell2 = pm.HalfNormal("ell2", sigma=1.25)
    eta1 = pm.HalfNormal("eta1", sigma=1.0)
    eta2 = pm.HalfNormal("eta2", sigma=1.0)

    # define the kernel
    cov = eta1 * pm.gp.cov.Matern12(1, ell1) + eta2 * pm.gp.cov.Periodic(1, 12, ls=ell2)

    gp = pm.gp.Latent(cov_func=cov)
    f = gp.prior("f", X=X)

    # logit link and Binomial likelihood
    p = pm.Deterministic("p", pm.math.invlogit(f))
    lik = pm.Binomial("lik", n=trials, logit_p=f, observed=sends)

in which case I run

    compiled_gp_model = nutpie.compile_pymc_model(gp_model)
    idata = nutpie.sample(compiled_gp_model, chains=4, target_accept=0.9, maxdepth=13)

Now I want to sample the posterior predictive values at new times;

with gp_model:
    # add the GP conditional to the model, given the new X values
    f_pred = gp.conditional("f_pred", new_times[:, None], jitter=1e-4)
    # Sample from the GP conditional distribution
    idata.extend(pm.sample_posterior_predictive(idata, var_names=["f_pred"]))

First question: is the the way to do it?

If so, then how do I replicate this for the compiled model on new data? e.g.

compiled_gp_model.with_data(X=times1, sends=sends1, trials=trials1)
idata = nutpie.sample(compiled_gp_model, chains=4, target_accept=0.9, maxdepth=13)

since there is no nutpie.sample_posterior_predictive()?

Do I need to place the gp.conditional() inside the model prior to compiling?

Nutpie doesn’t provide forward sampling, you should use the original pymc model and the pymc function for that, just set the data to what you want before sampling and pass the idata from nutpie

If you want to use nutpie for the forward because pymc forward sampling is going really slow, you can compile the forward sampling functions (all pymc functions actually) to either jax or numba with the compile_kwargs argument. For example:

with gp_model:
    f_pred = gp.conditional("f_pred", new_times[:, None], jitter=1e-4)
    idata = pm.sample_posterior_predictive(idata, 
                                           var_names=["f_pred"], 
                                           compile_kwargs={'mode':'NUMBA'},
                                           extend_inferencedata=True)

I do this often for time series models, as they’re much snappier in the alternative backends.

2 Likes