How to deal with missing values

Automatic imputation is compatible with JAX, it’s even tested in our CI:

You may however have a model based on mutable data / coords which is dynamic in shape by default and which JAX can’t handle. In that case you can:

  1. Pass an explicit shape to pm.Data and your observed variable (can do this alongside the dims)
  2. Call freeze_data_and_dims prior to JAX sampling: pymc.model.transform.optimization.freeze_dims_and_data — PyMC v5.16.1 documentation

You’ll probably need the latest version of PyMC.

The reason it fails is probably specific to your model, hence why I asked for more details