Running BlackJAX with pmap

Hi!
I’m trying to sample an external model using NUTS. Externally provided are:

  • a JAX log-likelihood function and its gradient which I’ll wrap in an Op
  • the list of priors and starting values for its parameters

I want to parallelize the sampling on a CPU cluster (~100 chains), so I opted for BlackJAX with pmap. However, I am not entirely clear on how to actually set it in PyMC. Is pm.sample(nuts_sampler="blackjax") as used in the tutorial calling the pymc.sampling.jax.sample_blackjax_nuts function (and so chain_method='parallel' is the default)?

Yes, you can override them with nuts_sampler_kwargs

1 Like

Why do you want 100 chains?

Well, the current MCMC impementation requires running 200-500 chains for several weeks when dealing with a complex cosmology and large datasets. I anticipate an order of magnitude improvement from decreased number of steps needed after switching to NUTS and further improvement from switching to a differentiable cosmology solver at each step, but I would like to still be able to run at least a few dozen chains in parallel.

1 Like

Ah ha, thank you. I’ve got another question about the BlackJAX-PyMC interaction: the manual on wrapping JAX functions mentions that when using NumPyro there is no need to define a gradient Op since the function is “unwrapped” into JAX before the gradient is taken, is the same true for BlackJAX?

yes

1 Like