In executing MCMC, if you specify nuts_sampler=“numpyro”, the following error occurs:
“NotImplementedError: jax.experimental.host_callback has been deprecated since March 2024 and is now no longer supported. See Deprecate jax.experimental.host_callback in favor of JAX external callbacks · Issue #20385 · jax-ml/jax · GitHub”
It appears that this issue is due to changes in JAX. What should be done to address this?
The code is below. X is iris data.
model1 = pm.Model()
with model1:
mu = pm.Normal('mu', mu=0.0, sigma=10.0)
sigma = pm.HalfNormal('sigma', sigma=10.0)
X_obs = pm.Normal('X_obs', mu=mu, sigma=sigma, observed=X)
with model1:
idata1 = pm.sample(
chains=4,
tune=1000,
draws=1000,
nuts_sampler="numpyro",
random_seed=123)