Numpyro sampler issue

In executing MCMC, if you specify nuts_sampler=“numpyro”, the following error occurs:
“NotImplementedError: jax.experimental.host_callback has been deprecated since March 2024 and is now no longer supported. See Deprecate jax.experimental.host_callback in favor of JAX external callbacks · Issue #20385 · jax-ml/jax · GitHub
It appears that this issue is due to changes in JAX. What should be done to address this?

The code is below. X is iris data.

model1 = pm.Model()
with model1:
    mu = pm.Normal('mu', mu=0.0, sigma=10.0)
    sigma = pm.HalfNormal('sigma', sigma=10.0)
    X_obs = pm.Normal('X_obs', mu=mu, sigma=sigma, observed=X)

with model1:
    idata1 = pm.sample(
    chains=4,
    tune=1000,
    draws=1000,
    nuts_sampler="numpyro",
    random_seed=123)

Try to install an older version of jax and jaxlib? Probably numpyro have to update to the Jax changes