Trying to speed up a model with custom likelihood

Hi folks,

I am modelling the data from an experiment where participants learn a simple language over several (200) trials. In each trial, the participant sees four images and a sentence in an unknown language (always 3 words which could mean e.g. “Blue hits circle”) that describes one of the scenes (e.g. the scene where a blue object with an arm hits a circle). There is a total of 4 objects and 3 actions. After seeing the four scenes and the sentence, the participant picks a scene and they get feedback about which of the scenes the sentence actually referred to. Over the experiment, they are meant to learn the meaning of each word and also the word order of the language (e.g. subject verb object).

I am trying to model the actual learning process of each participant, in a hierarchical fashion. As you can imagine, the model includes quite a complex likelihood function, including a big scan over trials with a set_subtensor inside it. Here’s a gist with the model (excuse the references to theano). The code in the gist runs parameter recovery with a single prior sample with a much smaller size than the final dataset, which has 200 trials and more than 200 participants. At the moment, fitting with the full dataset takes more than 120 hours, which unfortunately is the hard limit on the server I am using. Here’s what I have tried:

  • Rewriting the scan as a loop and running with sample_blackjax_nuts using the GPU. This didn’t seem to make it any faster (in fact it made is slower). I’m no JAX expert though so I might have missed something.
  • Stop and restarted the run in a different server job. This is difficult for reasons treated recently in this post.
  • Running variational inference instead. Unfortunately doing parameter recovery with this showed pretty poor recovery (in contrast to some tests with NUTS).

I was wondering if there is any way to simplify the likelihood function, or even getting rid of the scan. Any help would be much appreciated. Thank you!

I am still fighting with this model. Since the lastest version of PyMC it is possible to use sample_numpyro_nuts with scan, which is awesome. However, now I am getting a different error during compilation when running the code in the gist above, namely:

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[48], line 2
     1 with model:
----> 2     trace = sample_numpyro_nuts(chains=1)

File ~/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/sampling/jax.py:632, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progressbar, keep_untransformed, chain_method, postprocessing_backend, postprocessing_chunks, idata_kwargs, nuts_kwargs, **kwargs)
   623 print("Compiling...", file=sys.stdout)
   625 init_params = _get_batched_jittered_initial_points(
   626     model=model,
   627     chains=chains,
   628     initvals=initvals,
   629     random_seed=random_seed,
   630 )
--> 632 logp_fn = get_jaxified_logp(model, negative_logp=False)
   634 nuts_kwargs = _update_numpyro_nuts_kwargs(nuts_kwargs)
   635 nuts_kernel = NUTS(
   636     potential_fn=logp_fn,
   637     target_accept_prob=target_accept,
   638     **nuts_kwargs,
   639 )

File ~/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/sampling/jax.py:136, in get_jaxified_logp(model, negative_logp)
   134 if not negative_logp:
   135     model_logp = -model_logp
--> 136 logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])
   138 def logp_fn_wrap(x):
   139     return logp_fn(*x)[0]

File ~/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pymc/sampling/jax.py:129, in get_jaxified_graph(inputs, outputs)
   126 mode.JAX.optimizer.rewrite(fgraph)
   128 # We now jaxify the optimized fgraph
--> 129 return jax_funcify(fgraph)

File ~/mambaforge/envs/pymc_latest/lib/python3.11/functools.py:909, in singledispatch.<locals>.wrapper(*args, **kw)
   905 if not args:
   906     raise TypeError(f'{funcname} requires at least '
   907                     '1 positional argument')
--> 909 return dispatch(args[0].__class__)(*args, **kw)

File ~/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pytensor/link/jax/dispatch/basic.py:51, in jax_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
    44 @jax_funcify.register(FunctionGraph)
    45 def jax_funcify_FunctionGraph(
    46     fgraph,
  (...)
    49     **kwargs,
    50 ):
---> 51     return fgraph_to_python(
    52         fgraph,
    53         jax_funcify,
    54         type_conversion_fn=jax_typify,
    55         fgraph_name=fgraph_name,
    56         **kwargs,
    57     )

File ~/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pytensor/link/utils.py:738, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, **kwargs)
   736 body_assigns = []
   737 for node in order:
--> 738     compiled_func = op_conversion_fn(
   739         node.op, node=node, storage_map=storage_map, **kwargs
   740     )
   742     # Create a local alias with a unique name
   743     local_compiled_func_name = unique_name(compiled_func)

File ~/mambaforge/envs/pymc_latest/lib/python3.11/functools.py:909, in singledispatch.<locals>.wrapper(*args, **kw)
   905 if not args:
   906     raise TypeError(f'{funcname} requires at least '
   907                     '1 positional argument')
--> 909 return dispatch(args[0].__class__)(*args, **kw)

File ~/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pytensor/link/jax/dispatch/scan.py:23, in jax_funcify_Scan(op, **kwargs)
    21 rewriter = op.mode_instance.optimizer
    22 rewriter(op.fgraph)
---> 23 scan_inner_func = jax_funcify(op.fgraph, **kwargs)
    25 def scan(*outer_inputs):
    26     # Extract JAX scan inputs
    27     outer_inputs = list(outer_inputs)

File ~/mambaforge/envs/pymc_latest/lib/python3.11/functools.py:909, in singledispatch.<locals>.wrapper(*args, **kw)
   905 if not args:
   906     raise TypeError(f'{funcname} requires at least '
   907                     '1 positional argument')
--> 909 return dispatch(args[0].__class__)(*args, **kw)

File ~/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pytensor/link/jax/dispatch/basic.py:51, in jax_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
    44 @jax_funcify.register(FunctionGraph)
    45 def jax_funcify_FunctionGraph(
    46     fgraph,
  (...)
    49     **kwargs,
    50 ):
---> 51     return fgraph_to_python(
    52         fgraph,
    53         jax_funcify,
    54         type_conversion_fn=jax_typify,
    55         fgraph_name=fgraph_name,
    56         **kwargs,
    57     )

File ~/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pytensor/link/utils.py:738, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, **kwargs)
   736 body_assigns = []
   737 for node in order:
--> 738     compiled_func = op_conversion_fn(
   739         node.op, node=node, storage_map=storage_map, **kwargs
   740     )
   742     # Create a local alias with a unique name
   743     local_compiled_func_name = unique_name(compiled_func)

File ~/mambaforge/envs/pymc_latest/lib/python3.11/functools.py:909, in singledispatch.<locals>.wrapper(*args, **kw)
   905 if not args:
   906     raise TypeError(f'{funcname} requires at least '
   907                     '1 positional argument')
--> 909 return dispatch(args[0].__class__)(*args, **kw)

File ~/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pytensor/link/jax/dispatch/shape.py:69, in jax_funcify_Reshape(op, node, **kwargs)
    66         return jnp.reshape(x, constant_shape)
    68 else:
---> 69     assert_shape_argument_jax_compatible(shape)
    71     def reshape(x, shape):
    72         return jnp.reshape(x, shape)

File ~/mambaforge/envs/pymc_latest/lib/python3.11/site-packages/pytensor/link/jax/dispatch/shape.py:55, in assert_shape_argument_jax_compatible(shape)
    53 shape_op = shape.owner.op
    54 if not isinstance(shape_op, (Shape, Shape_i, JAXShapeTuple)):
---> 55     raise NotImplementedError(SHAPE_NOT_COMPATIBLE)

NotImplementedError: JAX requires concrete values for the `shape` parameter of `jax.numpy.reshape`.
Concrete values are either constants:

>>> import pytensor.tensor as at
>>> x = at.ones(6)
>>> y = x.reshape((2, 3))

Or the shape of an array:

>>> mat = at.matrix('mat')
>>> y = x.reshape(mat.shape)

Since I am not using reshape explicitly anywhere, I assume some of the operations I am using are using reshape under the hood. But I find this hard to debug. Anyone is familiar with this error / has any hints about what could be going wrong here?

It looks like you are using scan, which raises some unique complications when compiling to JAX/numba. In particular, it can sometimes be the case that the compile mode flag doesn’t make it down to the inner-graph used by the scan. In plain english: the function that your scan loops over might not be being correctly converted to JAX.

To work around this, you can manually pass mode="JAX" as an argument to scan (or better yet, add a scan_mode argument to your model factory function and pass that. Pass None for the default C backend).

Otherwise, in general when debugging, always start by using pytensor.dprint on your logp graph and visually inspect what is going on. Since your graph is quite complicated it will give you a ton of output, but your first step from there will be to ctrl+F for reshape Ops and see where they are coming in.

1 Like

In addition to what @jessegrabowski wrote, note that JAX is very restrictive about shapes, and specially dynamic shapes (or shapes it thinks are dynamic). This can happen easily with indexing/slicing/masking/arange operations.

You may consider writing your intended scan in JAX directly, just to see if you can get it to compile at all. If not, you might need to look for an alternative algorithm that is “jax-compatible”.

Another option is to try the numba backend and nutpie sampler, which does not pose the same kind of shape restrictions as JAX. This one is a bit less well supported at the moment, so no promises there either. But if it works for your model, it could provide nice speedups as well.