As stated in the title, pytensor.scan
causes an error when sampling with sample_blackjax_nuts
. The code is as follows:
import pymc as pm
import pytensor
import pytensor.tensor as pt
from pymc import sampling_jax
def theano_fn(x, y):
# x is the t-th value of a
# y is the output of previous iteration
# (which is the sum of all the previous values)
return x + y*1.2
with pm.Model() as model:
a = pm.Normal(
'a',
shape=10,
mu=1
)
acc, updates = pytensor.scan(
theano_fn,
sequences=[a],
outputs_info=[np.float64(0.)]
)
dets = pm.Deterministic(
'outputs',
acc
)
pm.Normal(
'observed',
dets[-1]
)
with model:
trace = pm.sampling_jax.sample_blackjax_nuts()
And the error is:
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Cell In[6], line 2
1 with model:
----> 2 trace = pm.sampling_jax.sample_blackjax_nuts()
File ~/anaconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/sampling/jax.py:396, in sample_blackjax_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, keep_untransformed, chain_method, postprocessing_backend, postprocessing_chunks, idata_kwargs)
393 if chains == 1:
394 init_params = [np.stack(init_state) for init_state in zip(init_params)]
--> 396 logprob_fn = get_jaxified_logp(model)
398 seed = jax.random.PRNGKey(random_seed)
399 keys = jax.random.split(seed, chains)
File ~/anaconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/sampling/jax.py:118, in get_jaxified_logp(model, negative_logp)
116 if not negative_logp:
117 model_logp = -model_logp
--> 118 logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])
120 def logp_fn_wrap(x):
121 return logp_fn(*x)[0]
File ~/anaconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/sampling/jax.py:111, in get_jaxified_graph(inputs, outputs)
108 mode.JAX.optimizer.rewrite(fgraph)
110 # We now jaxify the optimized fgraph
--> 111 return jax_funcify(fgraph)
File ~/anaconda3/envs/pymc_env/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw)
885 if not args:
886 raise TypeError(f'{funcname} requires at least '
887 '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)
File ~/anaconda3/envs/pymc_env/lib/python3.10/site-packages/pytensor/link/jax/dispatch/basic.py:49, in jax_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
42 @jax_funcify.register(FunctionGraph)
43 def jax_funcify_FunctionGraph(
44 fgraph,
(...)
47 **kwargs,
48 ):
---> 49 return fgraph_to_python(
50 fgraph,
51 jax_funcify,
52 type_conversion_fn=jax_typify,
53 fgraph_name=fgraph_name,
54 **kwargs,
55 )
File ~/anaconda3/envs/pymc_env/lib/python3.10/site-packages/pytensor/link/utils.py:740, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, **kwargs)
738 body_assigns = []
739 for node in order:
--> 740 compiled_func = op_conversion_fn(
741 node.op, node=node, storage_map=storage_map, **kwargs
742 )
744 # Create a local alias with a unique name
745 local_compiled_func_name = unique_name(compiled_func)
File ~/anaconda3/envs/pymc_env/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw)
885 if not args:
886 raise TypeError(f'{funcname} requires at least '
887 '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)
File ~/anaconda3/envs/pymc_env/lib/python3.10/site-packages/pytensor/link/jax/dispatch/scan.py:12, in jax_funcify_Scan(op, **kwargs)
10 @jax_funcify.register(Scan)
11 def jax_funcify_Scan(op, **kwargs):
---> 12 inner_fg = FunctionGraph(op.inputs, op.outputs)
13 jax_at_inner_func = jax_funcify(inner_fg, **kwargs)
15 def scan(*outer_inputs):
AttributeError: 'Scan' object has no attribute 'inputs'
Is there any obvious fix I am missing? While the specific example above is easily rewritten without scan, I have a complex model with loads of datapoints that uses a scan and it would be great if it could fit faster using the gpu.