Sampling using a flax model in the loglikelihood

There was some recent discussion about this here that might be interesting to you.
Basically, you can

  • Write a wrapper pytensor Op yourself
  • Use the package mentioned in the thread to automate this wrapping op
  • Wait for a PR in pytensor with that functionality
1 Like