Help navigating implementation of new inference algorithm

Nice!

Yes, you can do x = [model.rvs_to_values[v].data for v in model.observed_RVs]

Nothing obvious from a quick glance of the implementation (although as you said there is still the concating and transforming the variables for the implementation to work on a general PyMC model). Usually for small functions Aesara should be faster than Jax, so if you could profile the function call and we can find out which is the slow one and try to find solution.

1 Like