Hi, I’m working on a multinomial model whereby not all choices may be available all the time. The standard pymc and native numpyro models and samples look fine. But I see strange traces when using the pymc jaxified models.
import pandas as pd
import numpy as np
import pymc as pm
import pymc.sampling.jax
import arviz as az
Generating a sample dataset:
rng = np.random.default_rng(322)
CHOICES = ['A', 'B', 'C', 'D']
N_CHOICES = len(CHOICES)
true_splits = np.array([0.3, 0.2, 0.1, 0.4])
N_OBS = 100
true_active = rng.binomial(1, 0.5, size=(N_OBS, N_CHOICES)).astype(bool)
for row, col in enumerate(rng.choice(range(N_CHOICES), size=N_OBS)):
true_active[row, col] = True
true_renorm_splits = true_active * true_splits[None, :]
true_renorm_splits /= true_renorm_splits.sum(axis=1)[:, None]
obs = pd.DataFrame(
rng.multinomial(n=N_CHOICES, pvals=true_renorm_splits, size=N_OBS),
columns=CHOICES,
)
And a minimal pymc model:
with pm.Model() as model:
active = pm.MutableData("active", true_active)
conc = 10
prob = pm.Dirichlet("prob", a=conc * np.ones(N_CHOICES),)
masked_prob = pm.Deterministic("masked_prob", prob * active)
normed_prob = pm.Deterministic("normed_prob",
masked_prob / masked_prob.sum(axis=1)[:, None])
# Likelihood
pm.Multinomial("counts",
n=N_CHOICES, p=normed_prob,
observed=obs.values)
The default pymc sampler produces the following results (which look fine):
idata_pymc_nuts = pm.sample(model=model)
az.plot_posterior(idata_pymc_nuts, var_names=['prob'], ref_val=list(true_splits));
And with the jax numpyro nuts sampler:
idata_jaxnp = pm.sampling.jax.sample_numpyro_nuts(model=model)
az.plot_posterior(idata_jaxnp, var_names=['prob'], ref_val=list(true_splits));
Similar results using blackjax:
idata_blackjax = pm.sampling.jax.sample_blackjax_nuts(model=model)
az.plot_posterior(idata_blackjax, var_names=['prob'], ref_val=list(true_splits));
The samplers get stuck at certain numbers.
But interestingly enough, if I re-create the model directly with numpyro, the results are fine:
import numpyro
import numpyro.distributions as dist
import jax
from numpyro.infer import MCMC, NUTS
rng_key = jax.random.PRNGKey(0)
def model_numpyro(obs):
prob = numpyro.sample('prob', dist.Dirichlet(10 * np.ones(N_CHOICES)))
masked_prob = prob * true_active
normed_prob = masked_prob / masked_prob.sum(axis=1)[:, None]
numpyro.sample('obs', dist.Multinomial(N_CHOICES, normed_prob), obs=obs)
nuts_kernel = NUTS(model_numpyro)
mcmc = MCMC(nuts_kernel, num_warmup=500, num_samples=1000)
mcmc.run(rng_key, obs=obs.values)
idata_numpyro_native = az.from_numpyro(mcmc)
az.plot_posterior(idata_numpyro_native, var_names=['prob'], ref_val=list(true_splits));
Since the behavior is present in both pymc’s usage of numpyro nuts and blackjax, my best guess would be that the culprit is in pymc.sampling.jax.get_jaxified_graph
Ran using
(pm.__version__, jax.__version__, numpyro.__version__)
('5.0.1', '0.3.25', '0.11.0')
# EDIT: and also
(pm.__version__, pytensor.__version__, jax.__version__, numpyro.__version__)
('5.3.0', '2.11.1', '0.4.8', '0.11.0')
Any help / pointers would be greatly appreciated! The motive for using the jax samplers is for speed since the real model I’m working with is much more complicated. So I’m also wondering if I’m approaching the modeling part of this correctly too.