Hi!
I’m trying to sample an external model using NUTS. Externally provided are:
- a JAX log-likelihood function and its gradient which I’ll wrap in an Op
- the list of priors and starting values for its parameters
I want to parallelize the sampling on a CPU cluster (~100 chains), so I opted for BlackJAX with pmap. However, I am not entirely clear on how to actually set it in PyMC. Is pm.sample(nuts_sampler="blackjax")
as used in the tutorial calling the pymc.sampling.jax.sample_blackjax_nuts
function (and so chain_method='parallel'
is the default)?
Yes, you can override them with nuts_sampler_kwargs
1 Like
Why do you want 100 chains?
Well, the current MCMC impementation requires running 200-500 chains for several weeks when dealing with a complex cosmology and large datasets. I anticipate an order of magnitude improvement from decreased number of steps needed after switching to NUTS and further improvement from switching to a differentiable cosmology solver at each step, but I would like to still be able to run at least a few dozen chains in parallel.
1 Like
Ah ha, thank you. I’ve got another question about the BlackJAX-PyMC interaction: the manual on wrapping JAX functions mentions that when using NumPyro there is no need to define a gradient Op since the function is “unwrapped” into JAX before the gradient is taken, is the same true for BlackJAX?