Some questions about PyMC Models in online settings

I’m trying to play with a PyMC model inside a reinforcement learning loop, and I’m hitting some pain points related to fitting/refitting/changing data.

The setup is that I have an environment/simulator that generates states (data) given action inputs, and an agent that generates actions given state inputs. The goal is to progressively learn the conditional distribution over actions given states, so that the agent slowly converges to an optimal strategy.

The workflow is something like this:

  1. Initialize the agent’s action model with some priors
  2. Run an “episode” of the simulation, generating a dataset of states, actions, and rewards
  3. Input the states, actions, and rewards into the model and sample a posterior action model
  4. Set the posterior model to be the prior model
  5. Go to 2

It’s not so obvious how to do swap data and parameters in and out without re-compiling the whole model. One thought I had was to use the pm.MiniBatch API for this, because this is basically what I’d do in a pytorch training loop. Looking at the example, it seems I need to supply all the data up front though, so that’s not ideal. Can minibatches be manually supplied inside a training loop?

I also thought about storing a pm.ADVI object somewhere, swap out data with pm.MutableData containers, and then use pm.ADVI.refine to “resume” sampling. Does this have the effect that I think it does? I guess not, because it doesn’t “update” the priors after each. But will it works as I expect if I want it to be a manual minibatch?

On updating priors, this notebook uses stats.gaussian_kde to construct a priors from posteriors using pm.Interpolated. I guess it’s not possible to automatically update these pm.Interpolated priors using MutableData? I see they require numpy inputs, because scipy is used to compute the univariate splines under the hood. In principle though it should be possible to get the spines natively in pytensor though yes? I’m thinking about b-splines in sympy package, which gives back (differentiable!) symbolic expressions for spines. Should be easy enough to code up something like that in pytensor (famous last words). Does anyone know the history of the usage of splines in pm.Interpolate, and why the scipy implementation is called? Curious if I’m missing a technical point before I go trying to brew up something.

Last point is on repeated compilation. I know that by default, it’s not possible to re-use a compiled PyMC logp function for sampling. Short of writing my own sampler, is there any way to hack that? Another angle I was thinking about was that it is possible to re-use a nutpie compiled model. If I use pm.set_data, will that get updated inside nutpie? I guess not, since it’s not magic. But is there a way to update shared variables inside a compiled nutpie mode?

Nutpie allows you to update shared variables, using the with_data method:

That’s what I considered in one project where we wanted to do NUTS inside NUTS.

Otherwise you could consider SMC? It is supposed to be good for online learning because you can use your posterior draws as the initial point once you get more data. Getting the particles close to the posterior is a big chunk of the SMC, so if they already start there it could be faster? That’s the theory at least but I never had the chance to try it out. If the PyMC based one is too slow, you can try the one from blackjax.

Variational inference might be a good candidate.

The default PyMC samplers unfortunately are very black-boxy and there is no easy way to reuse cached function or stop/resume sampling. They are in dire need of being refactored to be more functional and less OOP-like.

1 Like

This sounds close to the .refine approach. But at each restart it will still be evaluating the likelihood with respect to the original priors, not something updated yea? As far as I can tell there’s no way to update priors without re-compiling the model.

Correct. Otherwise you have to somehow represent your posterior as a prior approximation.

I think we can make the Interpolated distribution work with symbolic inputs. There is even an open issue for that: Enable `pm.Interpolated` to accept symbolic inputs · Issue #4767 · pymc-devs/pymc · GitHub

1 Like