JAX Sampling Error with TruncatedNormal Distribution


Has anyone had issues with using a TruncatedNormal on the observed data and jax sampling?

I changed a model from using the Normal distribution to a trunated normal and got he following error:

AttributeError: module 'jax.scipy.special' has no attribute 'erfcx'

When I try with just pm.sample…it seems to run fine.

Github issues opened here:

PyMC Throws Error with TruncatedNormal distribution and Jax Sampling · Issue #6244 · pymc-devs/pymc (github.com)

This issue has been close as it’s actually an problem with aesara and has an issue opened up there.

