Hello,
I have set up a model in pymc4 with a custom likelihood function, defined using aesara.tensor operations. This likelihood function is presented to the model via pymc.DensityDist. I am able to draw samples using the pymc.sampling_jax.sample_numpyro_nuts sample on GPU. To speed things up even more, I am trying to see if I can create a jitted version of the likelihood function, somewhat similar to how it’s done here:
https://jax.readthedocs.io/en/latest/jax-101/02-jitting.html
as selu_jit = jax.jit(selu).
I am lost as to how I can manage that (if at all possible), mainly because I cannot find relevant examples on the web where people have done something similar. In the above example, the original function is defined using jax.numpy operations and so jax.jit works smoothly, but in my case, I’m using objects that are generated by pymc (the sampling parameters) and then perform matrix operations on them using aesara.tensor, so I am far removed from jax.numpy.
The error I get is as follows:
TypeError: Argument 'alphas' of type <class 'aesara.tensor.var.TensorVariable'> is not a valid JAX type.
alphas being one of the sampling parameters which appears as the first argument in my pymc.DensityDist function.
I can provide more details if it helps.
Any help will be appreciated. Thanks!