Integrating custom JAX-based likelihoods (e.g. auto differentiable stochastic filters) w/ PyMC sampling workflow

I’m interested in Bayesian parameter inference for partially-observed stochastic differential equations, and have been building a JAX-based repo that computes the marginal log probability of given time series data under a specified parametric model (exactly for linear SDEs, approximately for non-linear SDEs…using Kalman filters and their non-linear variants). I’ve written demos for how to integrate with numpyro via a simple numpyro.factor()call. I’m curious what it would take to write a similar type of demo for PyMC. Can PyMC use/benefit-from JAX-based likelihood?

Check How to use JAX ODEs and Neural Networks in PyMC

And soon more easily: Implement as_jax_op by aseyboldt · Pull Request #1614 · pymc-devs/pytensor · GitHub

Thanks! The PR you linked looks a lot easier to work with—maybe I should wait until that is out to make my demo? Also, does this look like the correct way to use a) your as_jax_op syntax and b) the pm.Potential feature (sorry, I’m new to PyMC) ?

import numpy as np
import jax.numpy as jnp
import pymc as pm
from pytensor.link.jax import as_jax_op

# ---- Example data ----
rng = np.random.default_rng(0)
n, p = 50, 3
X = rng.normal(size=(n, p))
true_beta = np.array([2.0, -1.0, 0.5])
y = X @ true_beta + rng.normal(scale=1.0, size=n)

# ---- Define JAX log-likelihood for regression ----
def loglike_fn(beta, X_obs, y_obs):
    mu = X_obs @ beta
    return jnp.sum(-0.5 * (y_obs - mu) ** 2)

# ---- Wrap JAX function as pytensor Op ----
@as_jax_op
def loglike(beta, X_obs, y_obs):
    return loglike_fn(beta, X_obs, y_obs)

# ---- Build and sample PyMC model ----
with pm.Model() as model:
    beta = pm.Uniform("beta", lower=-5, upper=5, shape=p)  # <-- only real change!
    pm.Potential("likelihood", loglike(beta, X, y))
    idata = pm.sample_numpyro_nuts(draws=1000, tune=500)

print(idata.posterior)