I have written several custom activation functions for neural networks, and these are written as Ops of type scalar.UnaryScalarOp , for which I have created optimized c_code , and used them to create elemwise functions that run fast on the CPU. Theano would use the CPU functions when the GPU versions did not exist. It looks like PyTensor does not, and requires a Jax version of each Op, is this right? Each time I try to compile a graph in JAX mode, I get an error that my function did not have attribute ‘nfunc_spec’, which I then added:
nfunc_spec = ("tgauss.scalar_tgauss", 1, 1)
but now I get the error:
AttributeError: module ‘jax’ has no attribute ‘tgauss’
Not sure what I need to do here, any help will be appreciated