The recently proposed NEP-47 attempts to unify the APIs of various tensor frameworks (NumPy, Tensorflow, PyTorch, Dask, JAX, CuPy, MXNet, etc.), via the Python array API standard.
It is a much more compact version of the original NumPy APIs, removing unnecessary functions that are not friendly to heterogenous hardware like GPUs.
Since PyMC3 is currently using JAX via Aesara, it should be quite painless to adopt NEP-47 for multi-backend support, I guess?
Related topic: NumPy array protocols
Looking at PyMC3 and Aesara source code, it seems relatively easy to switch to NEP-47 (basically numpy
In Aesara, most of JAX-related code is in aesara\link\jax\
Code snippet:
import jax
import jax.numpy as jnp
import jax.scipy as jsp
def jax_funcify_Cholesky(op, **kwargs):
lower = op.lower
def cholesky(a, lower=lower):
return jsp.linalg.cholesky(a, lower=lower).astype(a.dtype)
return cholesky
def jax_funcify_Solve(op, **kwargs):
if op.assume_a != "gen" and op.lower:
lower = True
lower = False
def solve(a, b, lower=lower):
return jsp.linalg.solve(a, b, lower=lower)
return solve
def jax_funcify_Det(op, **kwargs):
def det(x):
return jnp.linalg.det(x)
return det
def jax_funcify_Eig(op, **kwargs):
def eig(x):
return jnp.linalg.eig(x)
return eig
In PyMC3, JAX-related code is pymc3\, for exmaple:
def sample_numpyro_nuts(
model = modelcontext(model)
seed = jax.random.PRNGKey(random_seed)
rv_names = [ for rv in model.value_vars]
init_state = [model.initial_point[rv_name] for rv_name in rv_names]
init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)
init_state_batched_at = [at.as_tensor(v) for v in init_state_batched]
If I understand correctly, JAX-specific code only takes a small fraction of the PyMC source code. Most of code still uses vanilla numpy
, which is actually great for NEP-47.
If a tensor framework supports basic array operations (the Array API core, see API specification), as well as higher-level solvers like linalg.cholesky
, linalg.det
, linalg.eig
, then it should be quite straightforward to add this new backend, I guess?
Any important features that the backend framework must support? For example static vs dynamic graph? Automatic differentiation?
That’s an interesting idea and could definitely be done. You already identified that we could probably just copy a lot from the JAX implementation. This is probably better discussed as an aesara issue.