@ferrine made very little difference for me.
i’m limited to running very small chains.
an option to use jax.lax.map — JAX documentation should save memory usage and stop the memory spike
@ferrine made very little difference for me.
i’m limited to running very small chains.
an option to use jax.lax.map — JAX documentation should save memory usage and stop the memory spike