Support multiple tensor backends via NEP-47 (Python array API standard)

The recently proposed NEP-47 attempts to unify the APIs of various tensor frameworks (NumPy, Tensorflow, PyTorch, Dask, JAX, CuPy, MXNet, etc.), via the Python array API standard.

It is a much more compact version of the original NumPy APIs, removing unnecessary functions that are not friendly to heterogenous hardware like GPUs.

Since PyMC3 is currently using JAX via Aesara, it should be quite painless to adopt NEP-47 for multi-backend support, I guess?

Related topic: NumPy array protocols

References:

Looking at PyMC3 and Aesara source code, it seems relatively easy to switch to NEP-47 (basically numpy).

In Aesara, most of JAX-related code is in aesara\link\jax\dispatch.py

Code snippet:

import jax
import jax.numpy as jnp
import jax.scipy as jsp

...

@jax_funcify.register(Cholesky)
def jax_funcify_Cholesky(op, **kwargs):
    lower = op.lower

    def cholesky(a, lower=lower):
        return jsp.linalg.cholesky(a, lower=lower).astype(a.dtype)

    return cholesky


@jax_funcify.register(Solve)
def jax_funcify_Solve(op, **kwargs):

    if op.assume_a != "gen" and op.lower:
        lower = True
    else:
        lower = False

    def solve(a, b, lower=lower):
        return jsp.linalg.solve(a, b, lower=lower)

    return solve


@jax_funcify.register(Det)
def jax_funcify_Det(op, **kwargs):
    def det(x):
        return jnp.linalg.det(x)

    return det


@jax_funcify.register(Eig)
def jax_funcify_Eig(op, **kwargs):
    def eig(x):
        return jnp.linalg.eig(x)

    return eig

In PyMC3, JAX-related code is pymc3\sampling_jax.py, for exmaple:

def sample_numpyro_nuts(
...
):
    model = modelcontext(model)

    seed = jax.random.PRNGKey(random_seed)

    rv_names = [rv.name for rv in model.value_vars]
    init_state = [model.initial_point[rv_name] for rv_name in rv_names]
    init_state_batched = jax.tree_map(lambda x: np.repeat(x[None, ...], chains, axis=0), init_state)
    init_state_batched_at = [at.as_tensor(v) for v in init_state_batched]

If I understand correctly, JAX-specific code only takes a small fraction of the PyMC source code. Most of code still uses vanilla numpy, which is actually great for NEP-47.

If a tensor framework supports basic array operations (the Array API core, see API specification), as well as higher-level solvers like linalg.cholesky linalg.solve, linalg.det, linalg.eig, then it should be quite straightforward to add this new backend, I guess?

Any important features that the backend framework must support? For example static vs dynamic graph? Automatic differentiation?

That’s an interesting idea and could definitely be done. You already identified that we could probably just copy a lot from the JAX implementation. This is probably better discussed as an aesara issue.