I would like to define a custom likelihood function with a logp function similar to the v3 example here.
The logp output in turn depends on the output of a pytorch model function which can provide gradients. How can I integrate the gradients/tensors produced by the pytorch model? Does v4 make any of this easier? e.g. in v3 we must define a custom theano op as in here, but this becomes tedious for a complex model. Are there new simplified ways to achieve this in v4?
TIA
It’s gonna be pretty much exactly like the example you linked to: A custom Aesara Op with a perform method that calculates the gradient already, followed by a grad method that calls self(*inputs) to get the gradient.
OK Thanks. If I were to write the equivalent pytorch model so the gradients can be picked up natively in pymc4, should I re-write this model using aesara or JAX? Not sure if I write it in JAX, whether pymc4 can operate on it or it would require code wrapped in aesara. I haven’t found any code snippets to define a simple neural network classification model in aesara, but JAX has several examples - so if pymc4 can natively work with JAX that would be easy to re-write the model.
You could do either: Write everything in Aesara Ops that already have JAX implementations, or write one Aesara Op and its JAX implementation.
See Adding JAX and Numba support for Ops — Aesara 2.7.3+2.gd09e222b0.dirty documentation
Whether to reproduce the neural net in Aesara Ops, or just wrap one Op around an existing JAX implementation depends on the complexity of the NN and what you want to achieve.
If you implement the NN in Aesara, you could also compile with numba or C, but if you already know that you’ll sample with a Jax sampler, its probably less work to write just that one wrapper Op.