UnaryScalarOps for JAX?

I have written several custom activation functions for neural networks, and these are written as Ops of type scalar.UnaryScalarOp , for which I have created optimized c_code , and used them to create elemwise functions that run fast on the CPU. Theano would use the CPU functions when the GPU versions did not exist. It looks like PyTensor does not, and requires a Jax version of each Op, is this right? Each time I try to compile a graph in JAX mode, I get an error that my function did not have attribute ‘nfunc_spec’, which I then added:

nfunc_spec = ("tgauss.scalar_tgauss", 1, 1)

but now I get the error:

AttributeError: module ‘jax’ has no attribute ‘tgauss’

Not sure what I need to do here, any help will be appreciated

You need to tell pytensor how to convert an Op to JAX: Adding JAX and Numba support for Ops — PyTensor dev documentation

Also some examples in here How to use JAX ODEs and Neural Networks in PyMC - PyMC Labs

Thanks!!