Wild results for masked multinomial with jax sampler(s)

Hi, I’m working on a multinomial model whereby not all choices may be available all the time. The standard pymc and native numpyro models and samples look fine. But I see strange traces when using the pymc jaxified models.

import pandas as pd
import numpy as np
import pymc as pm
import pymc.sampling.jax
import arviz as az

Generating a sample dataset:

rng = np.random.default_rng(322)

CHOICES = ['A', 'B', 'C', 'D']

true_splits = np.array([0.3, 0.2, 0.1, 0.4])
N_OBS = 100

true_active = rng.binomial(1, 0.5, size=(N_OBS, N_CHOICES)).astype(bool)
for row, col in enumerate(rng.choice(range(N_CHOICES), size=N_OBS)):
    true_active[row, col] = True

true_renorm_splits = true_active * true_splits[None, :]
true_renorm_splits /= true_renorm_splits.sum(axis=1)[:, None]

obs = pd.DataFrame(
    rng.multinomial(n=N_CHOICES, pvals=true_renorm_splits, size=N_OBS),

And a minimal pymc model:

with pm.Model() as model:
    active = pm.MutableData("active", true_active)
    conc = 10
    prob = pm.Dirichlet("prob", a=conc * np.ones(N_CHOICES),)
    masked_prob = pm.Deterministic("masked_prob", prob * active)
    normed_prob = pm.Deterministic("normed_prob",
                                   masked_prob / masked_prob.sum(axis=1)[:, None])
    # Likelihood
                   n=N_CHOICES, p=normed_prob, 

The default pymc sampler produces the following results (which look fine):

idata_pymc_nuts = pm.sample(model=model)
az.plot_posterior(idata_pymc_nuts, var_names=['prob'], ref_val=list(true_splits));

And with the jax numpyro nuts sampler:

idata_jaxnp = pm.sampling.jax.sample_numpyro_nuts(model=model)
az.plot_posterior(idata_jaxnp, var_names=['prob'], ref_val=list(true_splits));

Similar results using blackjax:

idata_blackjax = pm.sampling.jax.sample_blackjax_nuts(model=model)
az.plot_posterior(idata_blackjax, var_names=['prob'], ref_val=list(true_splits));

The samplers get stuck at certain numbers.

But interestingly enough, if I re-create the model directly with numpyro, the results are fine:

import numpyro
import numpyro.distributions as dist
import jax
from numpyro.infer import MCMC, NUTS
rng_key = jax.random.PRNGKey(0)

def model_numpyro(obs):
    prob = numpyro.sample('prob', dist.Dirichlet(10 * np.ones(N_CHOICES)))
    masked_prob = prob * true_active
    normed_prob = masked_prob / masked_prob.sum(axis=1)[:, None]
    numpyro.sample('obs', dist.Multinomial(N_CHOICES, normed_prob), obs=obs)

nuts_kernel = NUTS(model_numpyro)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
mcmc.run(rng_key, obs=obs.values)

idata_numpyro_native = az.from_numpyro(mcmc)
az.plot_posterior(idata_numpyro_native, var_names=['prob'], ref_val=list(true_splits));

Since the behavior is present in both pymc’s usage of numpyro nuts and blackjax, my best guess would be that the culprit is in pymc.sampling.jax.get_jaxified_graph

Ran using

(pm.__version__, jax.__version__, numpyro.__version__)
('5.0.1', '0.3.25', '0.11.0')
# EDIT: and also
(pm.__version__, pytensor.__version__, jax.__version__, numpyro.__version__)
('5.3.0', '2.11.1', '0.4.8', '0.11.0')

Any help / pointers would be greatly appreciated! The motive for using the jax samplers is for speed since the real model I’m working with is much more complicated. So I’m also wondering if I’m approaching the modeling part of this correctly too.

@jason do you mind trying with pymc>=5.3? Just because 5.0 is already pretty old, although I suspect you will find the same strange behavior.

@ricardoV94 I’ve updated to:

(pm.__version__, pytensor.__version__, jax.__version__, numpyro.__version__)
('5.3.0', '2.11.1', '0.4.8', '0.11.0')

with the same results unfortunately