Gradientless sampler on the GPU through PyMc?

Is there any way to use a gradientless sampler (e.g. Slice, ESS, or Randomwalk) on the GPU through PyMC? As best as I can tell, the least-friction option is to build the model using PyMC, extract it, and then run numpyro.infer.

This is frustrating for two reasons.
First, for my application, slice sampling is likely the best option, and NumPyro doesn’t support that.
Second, I would like to keep everything “within” PyMc so other collaborators can more easily understand my code.

Any ideas are appreciated.

Yes, you can pass compile_kwargs={'mode':'JAX'} to pm.sample and/or the actual step methods. IIRC you might have to do both. This will compile the delta_logp function in jax and run ti on GPU if you have that configured. You will still pay a device transfer overhead, though, because the actual slice algorithm is still in pure python. If the logp function is expensive enough that might be find. If you want end-to-end GPU, we don’t offer that – you’ll have to get your logp and take it to numpyro or blackjax, as you are already doing.