I am still fighting with this model. Since the lastest version of PyMC it is possible to use sample_numpyro_nuts
with scan
, which is awesome. However, now I am getting a different error during compilation when running the code in the gist above, namely:
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
Cell In[48], line 2
1 with model:
----> 2 trace = sample_numpyro_nuts(chains=1)
File ~/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/sampling/jax.py:632, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_chunks, idata_kwargs, nuts_kwargs, **kwargs)
623 print("Compiling...", file=sys.stdout)
625 init_params = _get_batched_jittered_initial_points(
626 model=model,
627 chains=chains,
628 initvals=initvals,
629 random_seed=random_seed,
630 )
--> 632 logp_fn = get_jaxified_logp(model, negative_logp=False)
634 nuts_kwargs = _update_numpyro_nuts_kwargs(nuts_kwargs)
635 nuts_kernel = NUTS(
636 potential_fn=logp_fn,
637 target_accept_prob=target_accept,
638 **nuts_kwargs,
639 )
File ~/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/sampling/jax.py:136, in get_jaxified_logp(model, negative_logp)
134 if not negative_logp:
135 model_logp = -model_logp
--> 136 logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])
138 def logp_fn_wrap(x):
139 return logp_fn(*x)[0]
File ~/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/sampling/jax.py:129, in get_jaxified_graph(inputs, outputs)
126 mode.JAX.optimizer.rewrite(fgraph)
128 # We now jaxify the optimized fgraph
--> 129 return jax_funcify(fgraph)
File ~/mambaforge/envs/pymc_latest/lib/python3.11/functools.py:909, in singledispatch.<locals>.wrapper(*args, **kw)
905 if not args:
906 raise TypeError(f'{funcname} requires at least '
907 '1 positional argument')
--> 909 return dispatch(args[0].__class__)(*args, **kw)
File ~/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pytensor/link/jax/dispatch/basic.py:51, in jax_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
44 @jax_funcify.register(FunctionGraph)
45 def jax_funcify_FunctionGraph(
46 fgraph,
(...)
49 **kwargs,
50 ):
---> 51 return fgraph_to_python(
52 fgraph,
53 jax_funcify,
54 type_conversion_fn=jax_typify,
55 fgraph_name=fgraph_name,
56 **kwargs,
57 )
File ~/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pytensor/link/utils.py:738, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, **kwargs)
736 body_assigns = []
737 for node in order:
--> 738 compiled_func = op_conversion_fn(
739 node.op, node=node, storage_map=storage_map, **kwargs
740 )
742 # Create a local alias with a unique name
743 local_compiled_func_name = unique_name(compiled_func)
File ~/mambaforge/envs/pymc_latest/lib/python3.11/functools.py:909, in singledispatch.<locals>.wrapper(*args, **kw)
905 if not args:
906 raise TypeError(f'{funcname} requires at least '
907 '1 positional argument')
--> 909 return dispatch(args[0].__class__)(*args, **kw)
File ~/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pytensor/link/jax/dispatch/scan.py:23, in jax_funcify_Scan(op, **kwargs)
21 rewriter = op.mode_instance.optimizer
22 rewriter(op.fgraph)
---> 23 scan_inner_func = jax_funcify(op.fgraph, **kwargs)
25 def scan(*outer_inputs):
26 # Extract JAX scan inputs
27 outer_inputs = list(outer_inputs)
File ~/mambaforge/envs/pymc_latest/lib/python3.11/functools.py:909, in singledispatch.<locals>.wrapper(*args, **kw)
905 if not args:
906 raise TypeError(f'{funcname} requires at least '
907 '1 positional argument')
--> 909 return dispatch(args[0].__class__)(*args, **kw)
File ~/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pytensor/link/jax/dispatch/basic.py:51, in jax_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
44 @jax_funcify.register(FunctionGraph)
45 def jax_funcify_FunctionGraph(
46 fgraph,
(...)
49 **kwargs,
50 ):
---> 51 return fgraph_to_python(
52 fgraph,
53 jax_funcify,
54 type_conversion_fn=jax_typify,
55 fgraph_name=fgraph_name,
56 **kwargs,
57 )
File ~/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pytensor/link/utils.py:738, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, **kwargs)
736 body_assigns = []
737 for node in order:
--> 738 compiled_func = op_conversion_fn(
739 node.op, node=node, storage_map=storage_map, **kwargs
740 )
742 # Create a local alias with a unique name
743 local_compiled_func_name = unique_name(compiled_func)
File ~/mambaforge/envs/pymc_latest/lib/python3.11/functools.py:909, in singledispatch.<locals>.wrapper(*args, **kw)
905 if not args:
906 raise TypeError(f'{funcname} requires at least '
907 '1 positional argument')
--> 909 return dispatch(args[0].__class__)(*args, **kw)
File ~/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pytensor/link/jax/dispatch/shape.py:69, in jax_funcify_Reshape(op, node, **kwargs)
66 return jnp.reshape(x, constant_shape)
68 else:
---> 69 assert_shape_argument_jax_compatible(shape)
71 def reshape(x, shape):
72 return jnp.reshape(x, shape)
File ~/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pytensor/link/jax/dispatch/shape.py:55, in assert_shape_argument_jax_compatible(shape)
53 shape_op = shape.owner.op
54 if not isinstance(shape_op, (Shape, Shape_i, JAXShapeTuple)):
---> 55 raise NotImplementedError(SHAPE_NOT_COMPATIBLE)
NotImplementedError: JAX requires concrete values for the `shape` parameter of `jax.numpy.reshape`.
Concrete values are either constants:
>>> import pytensor.tensor as at
>>> x = at.ones(6)
>>> y = x.reshape((2, 3))
Or the shape of an array:
>>> mat = at.matrix('mat')
>>> y = x.reshape(mat.shape)
Since I am not using reshape explicitly anywhere, I assume some of the operations I am using are using reshape under the hood. But I find this hard to debug. Anyone is familiar with this error / has any hints about what could be going wrong here?