Implementing rounding (by manual integration) more efficiently

The logp functions are saved in the model itself, so you can directly use mod.logp() and mod.dlogp(). But it is more useful to look at the compiled graph, because that’s what you’re actually going to be timing. You can dprint compiled pytensor functions, but PyMC hides them from you a bit like this:

wrapped_f = mod.compile_dlogp()

As I tried to suggest by the naming, PyMC wraps the underlying pytensor functions it creates with some logic to make passing the outputs of MCMC steps more convenient. For benchmarking/debugging though you need the “raw” function, which is saved in the .f attribute.

One other note about timing, if you are compiling to a non-standard backend (like jax or numba) make sure you time the jitted function. You actually can’t do this with mod.compile_dlogp, you have to use mod.compile_fn:

wrapped_f = mod.compile_fn(mod.dlogp(), mode='JAX')
jax_f = wrapped_f.f.vm.jit_fn

jax_f will be the raw jitted JAX function that you can then use for timings (make sure you run it once before you %timeit to trigger the JIT compilation).

Another useful thing to do is to enable the profiler, you can do this with profile = True in mod.compile_fn(mod.logp(), profile=True), run %timeit, then look at f.profile.summary(). That will show you which operations are consuming the most time.

Variable transformations are only for the logp graph, so I’m not sure what you mean.