Sampling_jax issues

Here are my imports and sample call when I used jax or numpyro in a jupyter notebook

import pymc as pm
import pymc.sampling_jax
import numpyro
import blackjax
import jax
with mdl:
    idata = pm.sampling.jax.sample_numpyro_nuts(postprocessing_backend='cpu',
                                                                       idata_kwargs=dict(log_likelihood=False))

Should also with with pm.sampling.jax.sample_blackjax_nuts