Nothing else should be needed for forward diff in numpyro as that’s done by JAX (not PyTensor) from the logp graph returned by PyMC. You just have to try and pass that flag to the numpyro sampler
Related topics
| Topic | Replies | Views | Activity | |
|---|---|---|---|---|
|
Native JAX support in PyMC without wrappers?
|
3 | 260 | March 21, 2024 | |
| PyMC3, Jax, and analytically computing gradients for novel model architectures | 3 | 1354 | November 16, 2020 | |
| Auto-differentiation of user defined functions? | 2 | 924 | December 2, 2020 | |
| Auto-diff when using NUTS sampler with blackbox likelihoods | 1 | 853 | September 11, 2020 | |
| Does JAX backend Speedup also apply to variational inference? | 2 | 761 | January 20, 2023 |