I am trying to set up a model whose loglikelihood function uses a pre-trained flax MLP neural network. I would like to use PyMC in sampling from the posterior, instead of migrating entirely to Blackjax.
Do you have any suggestions about how to make this happen?