Hi @rlouf,
Love your work on
transformer - are you looking to integrate it with some Bayesian source?
As for your question itself, I think these are certainly functionary that could be added, but lower level control probably makes more sense to do it in TF/TFP itself. FWIW, you can use PyMC4 to generate the log_prob/loss function, and plug into an inference workflow of your choice.
As for iterative sampling, do you mean being able to do a for loop that train and fetch metrics at the same time?