We will try to fix Scan JAX compatibility soon, but in the meantime something like this should work:
import jax
import numpy as np
def jax_scan(outputs_info, sequences):
def scan_update_fn(x, y):
next_state = x + y * 1.2
return next_state, next_state
# If you only need the last state, you can just return it once
# return next_state, ()
last_state, acc = jax.lax.scan(
f=scan_update_fn,
init=outputs_info,
xs=sequences,
)
return acc # If you only need last state, you can return it instead
jax_scan(0, np.arange(5))
Once you have the equivalent Scan function written in JAX, you can use the recipes in this blogpost to wrap them in a PyTensor Op. If you are going to use Numpyro you don’t need to worry about the grad part (just like in the last example with the NeuralNetwork): How to use JAX ODEs and Neural Networks in PyMC - PyMC Labs