Is it possible to use Pytorch model as the likelihood distribution in Pymc?

Lots of DDM/SSM variants in psychology have no tractable solutions and therefore constrain the usage of bayesian estimation for these models. Recently there have been several papers that use the deep learning model to approximate the likelihood function, like LAN in HDDM package or Flow-based model SNLE. Although aesera/theano can be used to build a deep learning model but apparently Pytorch is the most popular deep learning package, is it possible to use a pre-trained deep learning model as the target distribution in Pymc? Thanks.


Yes it is, we have an example with JAX, but you should be able to do the same with a Pytorch model: How to wrap a JAX function for use in PyMC — PyMC example gallery


There is also this example using pymc3 and pytorch: PyMC3 + PyTorch | Dan Foreman-Mackey which might be a good complement to the v4 guide about wrapping external functions that uses jax


Thanks a lot!