Sampling using a flax model in the loglikelihood

Hi all

I am trying to set up a model whose loglikelihood function uses a pre-trained flax MLP neural network. I would like to use PyMC in sampling from the posterior, instead of migrating entirely to Blackjax.

Do you have any suggestions about how to make this happen?

Thank you in advance.

There was some recent discussion about this here that might be interesting to you.
Basically, you can

  • Write a wrapper pytensor Op yourself
  • Use the package mentioned in the thread to automate this wrapping op
  • Wait for a PR in pytensor with that functionality
1 Like