Automatic imputation is compatible with JAX, it’s even tested in our CI:
You may however have a model based on mutable data / coords which is dynamic in shape by default and which JAX can’t handle. In that case you can:
- Pass an explicit shape to
pm.Dataand your observed variable (can do this alongside the dims) - Call
freeze_data_and_dimsprior to JAX sampling: pymc.model.transform.optimization.freeze_dims_and_data — PyMC v5.16.1 documentation
You’ll probably need the latest version of PyMC.
The reason it fails is probably specific to your model, hence why I asked for more details