That’s beyond the scope of PyMC. We simple rely on Numpyro to create a JAX graph from the model logp and sample it. You would have to ask the Numpyro folks for more details, but even they might only forward you to the JAX folks.
Related topics
| Topic | Replies | Views | Activity | |
|---|---|---|---|---|
| Pymc, numpyro GPs and Transform RVs memory behaviour | 9 | 896 | December 21, 2022 | |
| Reduce memory requirements on the GPU when sampling with pm.sampling_jax.sample_numpyro_nuts() | 3 | 1249 | March 15, 2023 | |
| MemeoryError when sampling | 2 | 551 | July 12, 2018 | |
| Has anyone had memory issues with Jax/GPU specifically? | 7 | 5732 | March 22, 2023 | |
| Batch process capability for pymc.sampling_jax.sample_numpyro_nuts() with GPU? | 3 | 583 | September 12, 2022 |