I’m interested in Bayesian parameter inference for partially-observed stochastic differential equations, and have been building a JAX-based repo that computes the marginal log probability of given time series data under a specified parametric model (exactly for linear SDEs, approximately for non-linear SDEs…using Kalman filters and their non-linear variants). I’ve written demos for how to integrate with numpyro via a simple numpyro.factor()
call. I’m curious what it would take to write a similar type of demo for PyMC. Can PyMC use/benefit-from JAX-based likelihood?
Check How to use JAX ODEs and Neural Networks in PyMC
And soon more easily: Implement as_jax_op by aseyboldt · Pull Request #1614 · pymc-devs/pytensor · GitHub
Thanks! The PR you linked looks a lot easier to work with—maybe I should wait until that is out to make my demo? Also, does this look like the correct way to use a) your as_jax_op syntax and b) the pm.Potential feature (sorry, I’m new to PyMC) ?
import numpy as np
import jax.numpy as jnp
import pymc as pm
from pytensor.link.jax import as_jax_op
# ---- Example data ----
rng = np.random.default_rng(0)
n, p = 50, 3
X = rng.normal(size=(n, p))
true_beta = np.array([2.0, -1.0, 0.5])
y = X @ true_beta + rng.normal(scale=1.0, size=n)
# ---- Define JAX log-likelihood for regression ----
def loglike_fn(beta, X_obs, y_obs):
mu = X_obs @ beta
return jnp.sum(-0.5 * (y_obs - mu) ** 2)
# ---- Wrap JAX function as pytensor Op ----
@as_jax_op
def loglike(beta, X_obs, y_obs):
return loglike_fn(beta, X_obs, y_obs)
# ---- Build and sample PyMC model ----
with pm.Model() as model:
beta = pm.Uniform("beta", lower=-5, upper=5, shape=p) # <-- only real change!
pm.Potential("likelihood", loglike(beta, X, y))
idata = pm.sample_numpyro_nuts(draws=1000, tune=500)
print(idata.posterior)