JAX Sampling Error with TruncatedNormal Distribution

Hello,

Has anyone had issues with using a TruncatedNormal on the observed data and jax sampling?

I changed a model from using the Normal distribution to a trunated normal and got he following error:

AttributeError: module 'jax.scipy.special' has no attribute 'erfcx'

When I try with just pm.sample…it seems to run fine.

1 Like

Github issues opened here:

PyMC Throws Error with TruncatedNormal distribution and Jax Sampling · Issue #6244 · pymc-devs/pymc (github.com)

This issue has been close as it’s actually an problem with aesara and has an issue opened up there.

2 Likes

Has there been any development on this issue?

No, we still don’t have a translation of the needed Ops to JAX

If you want to track progress on this front, the relevant issue is: ENH: Missing implementation of Erfcx and Erfcinv · Issue #43 · pymc-devs/pytensor · GitHub

There is a proposed work-around here that you can try: Use optional tensorflow implementation for missing JAX Ops · Issue #256 · pymc-devs/pytensor · GitHub