How to control random seed for RVs in V 4.0?

Congrats on the new release! I’m eager to get my MUSE inference package working with the release. It had been working on 4.0.0b6, the only new thing in the release I can’t figure out is how to handle random seeding.

Essentially what my library needs is to compile some custom sampling functions for the RVs, as a simple example consider:

with pm.Model() as model:
    x = pm.Normal("x")
    y = pm.Normal("y")

sample_x_y = aesara.function([], model.basic_RVs)

then it needs to call sample_x_y and have control over the random seed (and fwiw the real thing does something slightly more complex so using eg sample_prior_predictive isn’t an option).

The 4.0.0b6 solution based on model.rng_seq which seems to have dissapeared looked like this, but it was hacky and slow, so I’m happy to be rid of it guessing that the new thing is better.

So what’s the appropriate way to do this now? Thanks.

2 Likes

You can call pymc.aesaraf.compile_pymc which accepts a random seed and will reseed your variables before compiling a aesara function, as well as set automatic updates so that values vary between calls.

Something like this (writing from memory):

from pymc.aesara import compile_pymc

# I think seed can be a seedsequence, but also just an integer
seed = np.random.SeedSequence(123)

sample_x_y = compile_pymc([], model.basic_RVs, random_seed=seed)

If you need to reseed the seeds between calls you can also do that (there are some utilities for that in the same aesaraf module) but that would slow you down.

2 Likes

Thanks, hmm, yea that’s exactly what I need to do, and it looks like I can do that with those utilities, this seems to give me exactly it:

sample_x_y = aesara.function([], model.basic_RVs)
pymc.aesaraf.reseed_rngs(pymc.aesaraf.find_rng_nodes(model.basic_RVs), seed)
sample_x_y()

now I can control which samples I get with seed, without having to recompile everytime. Glancing at the code I think its kind of doing what my hacky thing was doing before anyway, but seems to work. Thanks!

2 Likes

You should probably cache the result of find_rngs in that case. We do something like this for model.intial_point here: pymc/initial_point.py at da1f63b95f64d02c958302ddc44ee6d8b838a39d · pymc-devs/pymc · GitHub

1 Like

Out of curiosity, why do you need to reseed between calls? The values will still follow a deterministic sequence across calls.

The algorithm is basically solving an equation like \langle f(x,\theta) \rangle_{x\sim\mathcal{P}(x\,|\,\theta)} = 0 for \theta where f is some function and the (Monte-Carlo-computed) average is over a bunch of samples of x from a likelihood \mathcal{P}(x\,|\,\theta). It really helps the convergence of the solver if these x are the same seeds at each iteratoin of the solver, since it makes the MC average vary smoothly with \theta, rather than having a new random MC error for each test value of \theta.

So basically you seed, generate a bunch of x given the current \theta, use these to compute the next \theta, reseed, generate same-seeded x's but with the new \theta, etc…

I didn’t know muse existed – looks cool! Did you announce it here?

1 Like

Thanks! Probably in the next couple of days, just need to get these last tweaks to get it working on the final V4 release and some final API tweaks.

1 Like

Hmm… sounds a bit funny. I am not sure what the properties of the new draws with new theta but same seed will be…

Btw if you are running this algorithm in parallel you will want to actually swap the rng variables so that reseeding them in one process will not reseed them in another.

In the intial point code I sent you that’s done some lines above.

1 Like

Did you try and integrate the JAX implementation of MUSE with a pymc model running on the new JAX backend?

I don’t think we support RandomVariables (updates) in the JAX backend

No, but its vaguely on my todo list to look into it. If you glanced at the docs (which sounds like you did since you saw the Jax interface) you saw the PyMC version has some overhead, since the algorithm itself is pure Python, only the calls to various posteriors gradients and transformation are compiled with aesara. Is that what this might help with? Or are posterior gradients themselves just faster with Jax? (any resources you might point me to about this new backend for a library author would be helpful!)

If the algorithm itself is written in Python there won’t be any speed-improvements. Only if you implemented the algorithm in JAX (or better yet: Aesara directly like aehmc does GitHub - aesara-devs/aehmc: An experimental HMC implementation in Aesara) would you get these benefits.

Thanks, good to know. A JAX / JIT-able version is close, although no plans for an Aesara version. Its maybe not the highest priority since MUSE isn’t really for super cheap posteriors anyway. But how transparent is the “Jax backend” to PyMC. Like can I just put a jax.jit around some function which internally is calling into aesara.function-complied code and :rocket: ? Or its not that simple? (sorry for the barrage of questions, no hurry)

Pretty much that simple, you can look at pymc/sampling_jax.py at main · pymc-devs/pymc · GitHub for some inspiration.

Still recommend you to upstream MUSE to BlackJAX, then PyMC can use MUSE through BlackJAX directly.

1 Like

I understand the excitement about JAX, but I would not recommend it for this algorithm, simply because we don’t have a way to sample from the prior of a PyMC model with the jax backend and this algorithm requires it.

3 Likes

Thanks, BlackJAX is definitely on the horizon still. I suppose based on what @ricardoV94 mentions that would mean automatic PyMC integration would then just work, but it could still be used in BlackJAX alone? Or does blackjax not specify an interface for problems to define how to sample from the prior, only to evaluate the posterior? That indeed would be a limitation for using MUSE in BlackJAX. (also sorry this question might be better suited for the BlackJAX repo, I can ask there, at least start looking around)

Oh I didint know that forward sampling is not working yet in JAX mode (sorry @ricardoV94 I missed your earlier reply), in that case the automatic PyMC integration wont just work as you will hit a bug in the JAX backend when doing prior sampling.

I also didn’t know that, I thought that whole code was now aesara-fied? Should we open an issue?