May need to disable multi core on sample_smc but it would be fun to try. You can compile the pymc simulator to jax with mode="JAX" when you call compile_pymc which gives you jitted gpu friendly code without any extra work.
In a few weeks you’ll also be able to compile to numba automatically, but only cpu. For gpu you have to write it yourself which doesn’t seem to be a problem for you anyway ![]()