Looking at PyMC3 and Aesara source code, it seems relatively easy to switch to NEP-47 (basically `numpy`

).

In Aesara, most of JAX-related code is in aesara\link\jax\dispatch.py

Code snippet:

```
import jax
import jax.numpy as jnp
import jax.scipy as jsp
...
@jax_funcify.register(Cholesky)
def jax_funcify_Cholesky(op, **kwargs):
lower = op.lower
def cholesky(a, lower=lower):
return jsp.linalg.cholesky(a, lower=lower).astype(a.dtype)
return cholesky
@jax_funcify.register(Solve)
def jax_funcify_Solve(op, **kwargs):
if op.assume_a != "gen" and op.lower:
lower = True
else:
lower = False
def solve(a, b, lower=lower):
return jsp.linalg.solve(a, b, lower=lower)
return solve
@jax_funcify.register(Det)
def jax_funcify_Det(op, **kwargs):
def det(x):
return jnp.linalg.det(x)
return det
@jax_funcify.register(Eig)
def jax_funcify_Eig(op, **kwargs):
def eig(x):
return jnp.linalg.eig(x)
return eig
```

In PyMC3, JAX-related code is pymc3\sampling_jax.py, for exmaple:

```
def sample_numpyro_nuts(
...
):
model = modelcontext(model)
seed = jax.random.PRNGKey(random_seed)
rv_names = [rv.name for rv in model.value_vars]
init_state = [model.initial_point[rv_name] for rv_name in rv_names]
init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)
init_state_batched_at = [at.as_tensor(v) for v in init_state_batched]
```

If I understand correctly, JAX-specific code only takes a small fraction of the PyMC source code. Most of code still uses vanilla `numpy`

, which is actually great for NEP-47.