Error when using `sample_blackjax_nuts` with `pytensor.scan`

We will try to fix Scan JAX compatibility soon, but in the meantime something like this should work:

import jax
import numpy as np

def jax_scan(outputs_info, sequences):
  def scan_update_fn(x, y):
    next_state = x + y * 1.2
    return next_state, next_state
    # If you only need the last state, you can just return it once
    # return next_state, ()

  last_state, acc = jax.lax.scan(
    f=scan_update_fn,
    init=outputs_info, 
    xs=sequences,
  )
  return acc  # If you only need last state, you can return it instead

jax_scan(0, np.arange(5))

Once you have the equivalent Scan function written in JAX, you can use the recipes in this blogpost to wrap them in a PyTensor Op. If you are going to use Numpyro you don’t need to worry about the grad part (just like in the last example with the NeuralNetwork): How to use JAX ODEs and Neural Networks in PyMC - PyMC Labs

1 Like