Here are my imports and sample call when I used jax or numpyro in a jupyter notebook
import pymc as pm
import pymc.sampling_jax
import numpyro
import blackjax
import jax
with mdl:
idata = pm.sampling.jax.sample_numpyro_nuts(postprocessing_backend='cpu',
idata_kwargs=dict(log_likelihood=False))
Should also with with pm.sampling.jax.sample_blackjax_nuts