Error when using `sample_blackjax_nuts` with `pytensor.scan`

As stated in the title, pytensor.scan causes an error when sampling with sample_blackjax_nuts. The code is as follows:

import pymc as pm
import pytensor 
import pytensor.tensor as pt
from pymc import sampling_jax

def theano_fn(x, y):
    # x is the t-th value of a
    # y is the output of previous iteration 
    # (which is the sum of all the previous values)
    return x + y*1.2

with pm.Model() as model:
    
    a = pm.Normal(
        'a',
        shape=10,
        mu=1
    )
    
    acc, updates = pytensor.scan(
        theano_fn,
        sequences=[a],
        outputs_info=[np.float64(0.)]
    )
    
    dets = pm.Deterministic(
        'outputs',
        acc
    )
    
    pm.Normal(
        'observed',
        dets[-1]
    )   

with model:
    trace = pm.sampling_jax.sample_blackjax_nuts()

And the error is:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[6], line 2
      1 with model:
----> 2     trace = pm.sampling_jax.sample_blackjax_nuts()

File ~/anaconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/sampling/jax.py:396, in sample_blackjax_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, keep_untransformed, chain_method, postprocessing_backend, postprocessing_chunks, idata_kwargs)
    393 if chains == 1:
    394     init_params = [np.stack(init_state) for init_state in zip(init_params)]
--> 396 logprob_fn = get_jaxified_logp(model)
    398 seed = jax.random.PRNGKey(random_seed)
    399 keys = jax.random.split(seed, chains)

File ~/anaconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/sampling/jax.py:118, in get_jaxified_logp(model, negative_logp)
    116 if not negative_logp:
    117     model_logp = -model_logp
--> 118 logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])
    120 def logp_fn_wrap(x):
    121     return logp_fn(*x)[0]

File ~/anaconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/sampling/jax.py:111, in get_jaxified_graph(inputs, outputs)
    108 mode.JAX.optimizer.rewrite(fgraph)
    110 # We now jaxify the optimized fgraph
--> 111 return jax_funcify(fgraph)

File ~/anaconda3/envs/pymc_env/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw)
    885 if not args:
    886     raise TypeError(f'{funcname} requires at least '
    887                     '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)

File ~/anaconda3/envs/pymc_env/lib/python3.10/site-packages/pytensor/link/jax/dispatch/basic.py:49, in jax_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
     42 @jax_funcify.register(FunctionGraph)
     43 def jax_funcify_FunctionGraph(
     44     fgraph,
   (...)
     47     **kwargs,
     48 ):
---> 49     return fgraph_to_python(
     50         fgraph,
     51         jax_funcify,
     52         type_conversion_fn=jax_typify,
     53         fgraph_name=fgraph_name,
     54         **kwargs,
     55     )

File ~/anaconda3/envs/pymc_env/lib/python3.10/site-packages/pytensor/link/utils.py:740, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, **kwargs)
    738 body_assigns = []
    739 for node in order:
--> 740     compiled_func = op_conversion_fn(
    741         node.op, node=node, storage_map=storage_map, **kwargs
    742     )
    744     # Create a local alias with a unique name
    745     local_compiled_func_name = unique_name(compiled_func)

File ~/anaconda3/envs/pymc_env/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw)
    885 if not args:
    886     raise TypeError(f'{funcname} requires at least '
    887                     '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)

File ~/anaconda3/envs/pymc_env/lib/python3.10/site-packages/pytensor/link/jax/dispatch/scan.py:12, in jax_funcify_Scan(op, **kwargs)
     10 @jax_funcify.register(Scan)
     11 def jax_funcify_Scan(op, **kwargs):
---> 12     inner_fg = FunctionGraph(op.inputs, op.outputs)
     13     jax_at_inner_func = jax_funcify(inner_fg, **kwargs)
     15     def scan(*outer_inputs):

AttributeError: 'Scan' object has no attribute 'inputs'

Is there any obvious fix I am missing? While the specific example above is easily rewritten without scan, I have a complex model with loads of datapoints that uses a scan and it would be great if it could fit faster using the gpu.

At the moment we don’t support scan in the JAX backend. We hope to fix this soon

1 Like

I see - so I assume it is not straightfoward to implement. Are there alternatives to scan? E.g., would rewriting the model as an explicit loop and letting jax optimize the loop work?

No you can’t write the PyMC model with a loop. You can perhaps write the scan function in JAX and wrap it in a Pytensor Op as a quick alternative. There are some recipes on how to do something like that here: How to use JAX ODEs and Neural Networks in PyMC - PyMC Labs

@juststarted did you ever get this to work?

No, sorry - I couldn’t find an easy way to write the pytensor scan into a jax scan (turns out it’s not trivial). I tried to rewrite the scan as a loop and it compiled fine, but it was super slow. Let me know if you solve this!

We will try to fix Scan JAX compatibility soon, but in the meantime something like this should work:

import jax
import numpy as np

def jax_scan(outputs_info, sequences):
  def scan_update_fn(x, y):
    next_state = x + y * 1.2
    return next_state, next_state
    # If you only need the last state, you can just return it once
    # return next_state, ()

  last_state, acc = jax.lax.scan(
    f=scan_update_fn,
    init=outputs_info, 
    xs=sequences,
  )
  return acc  # If you only need last state, you can return it instead

jax_scan(0, np.arange(5))

Once you have the equivalent Scan function written in JAX, you can use the recipes in this blogpost to wrap them in a PyTensor Op. If you are going to use Numpyro you don’t need to worry about the grad part (just like in the last example with the NeuralNetwork): How to use JAX ODEs and Neural Networks in PyMC - PyMC Labs