JAX Sampling Error with TruncatedNormal Distribution

Hello,

Has anyone had issues with using a TruncatedNormal on the observed data and jax sampling?

I changed a model from using the Normal distribution to a trunated normal and got he following error:

AttributeError: module 'jax.scipy.special' has no attribute 'erfcx'

When I try with just pm.sample…it seems to run fine.

Github issues opened here:

PyMC Throws Error with TruncatedNormal distribution and Jax Sampling · Issue #6244 · pymc-devs/pymc (github.com)

This issue has been close as it’s actually an problem with aesara and has an issue opened up there.

Has there been any development on this issue?

No, we still don’t have a translation of the needed Ops to JAX

If you want to track progress on this front, the relevant issue is: ENH: Missing implementation of Erfcx and Erfcinv · Issue #43 · pymc-devs/pytensor · GitHub

There is a proposed work-around here that you can try: Use optional tensorflow implementation for missing JAX Ops · Issue #256 · pymc-devs/pytensor · GitHub