Dynamic shaping, "round" function, JAX, and a "few" more questions

A lot of great questions! Forgive me if I skip/bundle up some of them:

Yes. The easiest way is to exploit a CustomDist and let PyMC infer the logprob of the rounding operation. Something like this (untested code):

import pytensor.tensor as pt
import pymc as pm

def round_dist(mu, sigma, size):
  raw_dist = pm.Normal.dist(mu, sigma, size=size)
  return pt.round(raw_dist)

data = [1, 2, 3]  # dummy_data
with pm.Model() as m:
  mu = pm.Normal("mu")
  sigma = pm.HalfNormal("sigma")
  llike = pm.CustomDist("llike", mu, sigma, dist=round_dist, observed=data)

 # just to confirm it can figure out the logp 
m.point_logps()  # {'mu': -0.92, 'sigma': -0.73, 'llike': -9.34}

Yes (or numba). You can do this by setting the PyTensor mode (if in PyMC V5, Aesara in V4).
For example, after defining your model if you want to try using JAX in the PyMC Nuts sampler (for example)

with pm.Model() as m:
  ...
  with pytensor.config.change_flags(mode="JAX"):
    pm.sample()  # or model.compile_logp() or whatever you want to do

Now you might find some issues with gradient-based samplers, because the way PyTensor (and PyMC specifically) defines gradient graphs tends to not be very liked by Jitted JAX (shape unravelling and things like that)

I don’t remember the example that well, does it state somewhere that this is not possible? I have a vague idea of seeing custom JAXed Ops fail in multiprocessing but I am not sure. It could very well work out of the box… or maybe JAX doesn’t like how we pickle and unpickle things for multiprocessing. If you test it out and find problems feel free to open an issue in the github repository. Feedback on that area is incredibly valuable!

We are not connected to aeMCMC/aePPL in terms of development, but share many of the goals with those libraries. Most progress in this area is being done in PyMC’s logprob submodule (which is a direct fork of Aeppl, and contains the whole lopgrob inference machinery) and also in pymc-experimental. Some examples of the latter:

Definitely, there has been a focus on establishing the core logic and documentation is lagging, but functionality is slowly coming to the surface. The CustomDist as I answered above is one way we try to offer these capabilities to users (you can find some simple examples about using dist and allowing inference already): pymc.CustomDist — PyMC v5.0.1 documentation

We have some collaborators expanding on the logprob submodule this summer and I have been nudging them to start adding more documentation, so maybe that will improve soon :smiley:

We also use it in the codebase without users knowing about, for instance for pm.RandomWalk and pm.Censored. There has also been a big push to infer probability of timeseries defined with Scan, which now lives in a gist, but one day @jessegrabowski might make a pymc-example out of it :smiley:

Thanks for your input. Let me know if you have more questions / suggestions.

1 Like