Integration, ODE and Lambert W function in PyMC4

Hello everyone, I have a few concerns about using PyMC4 for my MCMC needs, and I would like to outline them here in hopes of getting some guidance.

Brief summary of the model

I have a model that gives me a function u'(t | \theta), which I have to solve using an ODE solver to obtain u(t | \theta), which I will then have to integrate and compute \int_0^T u(t | \theta) dt, where T also depends on the parameters.

ODE solvers

In my search I came across sunode, a Python package that allows me to solve an IVP for ODEs. I haven’t tried it out, but its usage looks somewhat straightforward and I’m currently learning how to apply it to my model.

Integration routines

What concerns me is that there seems to be no integration routine, or at least one that is obvious to me that is compatible with PyMC4.

I say this because I came across this post on StackOverflow, named “Custom Theano Op to do numerical integration”, which suggest a particularly complicated solution based on Theano, which was deprecated in favor of Aesara. The hyperlink that I use to go to the repository of the latter actually leads me to PyTensor, which I imagine is the replacement for Aesara.

Posts such as “Custom theano Op to do numerical integration”, " PyMC3 : Using a parameter as a limit of integral" and “Using Pymc3 to do forecasting and numerical integration” (which actually leads to the first forum post) all work based on Theano.

Lambert function

Additionally, in order to compute one of the quantities in my model, I need to have the Lambert W function in PyStan.

This function is available in SciPy, however, I can’t use it directly in PyMC. I imagine the problem is the same as any other external function that I want to use, and in order to have it work correctly I will most likely face the same challenges as the people trying to implement the integration routines in the posts I have mentioned in the previous section.

Questions

My questions are then as follows:

  1. Has anybody used this ODE solver successfully in PyMC4?
  2. What integration routine should I use in PyMC4?
  3. How can I use the Lambert W function in PyMC4?

If both questions 2. and 3. do not have an answer, is it recommended to downgrade to a previous version of PyMC? If so, which one? And which integration routine and Lambert W functions would I use?


If you are interested in the model itself, I will share it with you here in further detail.
Feel free to skip this section if you don’t want to hear about it.

First, I must solve the ODE
\left.u^{\prime}(t | \theta) = \frac{e^{-\lambda u^2}}{\lambda u\left(u^{-2}-2 \lambda\right)-u^{-3}}\left[\frac{3}{2}\Omega_m(1+t)^2+2 \Omega_r(1+t)^3\right)\right],

where \Omega_m and \Omega_r are the two parameters in this model, and \lambda is derived from the parameters by relying on the main branch of the Lambert function

\lambda = \frac{1}{2} + W_0\left( - \frac{\Omega_m + \Omega_r}{2e^{1/2}} \right) .

The final step is to compute

\int_0^T u(t | \theta) dt,

where T is a function of the parameters \Omega_m and \Omega_r with a rather complicated expression which I will omit here for brevity, as I believe it only matters that it depends on the parameters.

This integral can then be directly compared with the observations.


I would also like to share with you the reason why I’m using PyMC4: because Stan has a bug, which I found, that doesn’t have an ETA on its solution.
Apparently, using an integration routine on top of an ODE solver when applied to this system somehow gives a segmentation fault.
If you are interested, I opened a bug report in the Stan math repository: Segmentation fault on making use of `ode_rk45` with `integrate_1d` · Issue #2848 · stan-dev/math · GitHub


If something isn’t clear, do let me know.
Thank you in advance!

While I can’t help with the ODE side of the question (@aseyboldt @michaelosthege), I can help with the LambertW part. You can add a new function to PyMC by writing a LambertW Op for Aesara, which is the computational backend for PyMC (it handles building the computation graph and getting gradients of the likelihood function).

Since LambertW is a scalar function with a single input you will need to subclass aesara.scalar.basic.UniaryScalarOp. It’s not as bad as it looks (although there are a lot of pitfalls, luckily the community here is always around to help when you face them. Thanks to @ricardoV94 for helping me work out of problems I had coming up with this implementation)

import aesara.scalar as aes
from aesara.tensor.elemwise import Elemwise
from scipy.special import lambertw

class LambertW(aes.basic.UnaryScalarOp):
        
    def impl(self, x):
        return lambertw(x).real
    
    def grad(self, inputs, output_grads):
        (x,) = inputs
        (gz,) = output_grads
        w = self(x)
        return [gz * w / (x * (1 + w))] #edit: fixed typo

impl handles the forward computation, which in this case is applying the lambertw function. I am using .real on the output because PyMC can’t really handle complex variables anyway. You will need to choose priors so that you stay in the principal branch (I think – correct me if my choice of words here is wrong w.r.t. ensuring outputs are strictly real).

The second bit is the grad method, which takes 1) the value at which the derivative is to be evaluated (inputs), and 2) the gradient of the objective computed so far (output_grads). Compute the gradient of lambert W, then multiply it by the gradient computed so far. This function will be called during the backward pass of the computational graph, so you’re doing the chain rule here.

So this gives you an abstract class that represents computation of the lambert W function. To use it, you need to make an actual instance of the class:

lambert_w_scalar = LambertW(aes.upgrade_to_float, name='lambert_w_scalar')

aes.upgrade_to_float is a helper function that make the function less fussy about datatypes (it will upgrade anything you give it to a float). Now that you have a usable Aesara function, Aesara provides a helper function to check that the gradient is written correctly using numeric approximation:

import aesara
import aesara.tensor as at
import numpy as np

x = at.dscalar('x')
rng = np.random.default_rng(1337)

assert lambertw(1.5).real == lambert_w_scalar(x).eval({x:1.5})
aesara.gradient.verify_grad(lambert_w_scalar, pt=[1.5], rng=rng)

You’re almost done, but this function will only take scalar valued data, which might cause problems if you put it into a PyMC model. You can transform a scalar function to a function that broadcasts elementwise to an arbitrary tensor using the Elementwise wrapper:

lambert_w = Elementwise(lambert_w_scalar)

And now you’re good to use this in a PyMC model. You have a defined gradient, so you can sample your final model with NUTS. Here’s a code snippet just to show that things work as expected:

import pymc as pm

with pm.Model():
    x = pm.Uniform('x', 1 / pm.math.exp(1), 1000)
    z = pm.Deterministic('z', lambert_w(x))
    
    idata = pm.sample_prior_predictive(var_names=['z'])
3 Likes

Many thanks for your detailed answer!

Just one minor detail, your verification of the gradient fails because the analytical solution provided has a mistake somewhere.
The gradient of the W_0(x) is actually given by

\frac{d W_0(x)}{dx} = \frac{1}{e^{W_0(x)} + x}

and therefore the class should read

class LambertW(aes.basic.UnaryScalarOp):
    def impl(self, x):
        return lambertw(x).real

    def grad(self, inputs, output_grads):
        (x,) = inputs
        (gz,) = output_grads
        w = self(x)
        return [gz * 1/(2.718281828459045**w + x)]

The full Python script is therefore the following

# imports
from aesara.tensor.elemwise import Elemwise
import aesara.scalar as aes
import aesara.tensor as at
import aesara
from scipy.special import lambertw
import numpy as np
import pymc as pm

# creates a class that works with Aesara (PyMC computational backend)
# details: https://discourse.pymc.io/t/integration-ode-and-lambert-w-function-in-pymc4/10897/2
class LambertW(aes.basic.UnaryScalarOp):
    def impl(self, x):
        return lambertw(x).real

    def grad(self, inputs, output_grads):
        (x,) = inputs
        (gz,) = output_grads
        w = self(x)
        return [gz * 1/(2.718281828459045**w + x)]

# initiate a LambertW class and have it set to work with tensors
lambert_w_scalar = LambertW(aes.upgrade_to_float, name='lambert_w_scalar')
lambert_w = Elemwise(lambert_w_scalar)

# run checks
x = at.dscalar('x')
rng = np.random.default_rng(1337)
assert lambertw(1.5).real == lambert_w_scalar(x).eval({x:1.5})
aesara.gradient.verify_grad(lambert_w_scalar, pt=[1.5], rng=rng)

# initiate a LambertW class and have it set to work with tensors
lambert_w_scalar = LambertW(aes.upgrade_to_float, name='lambert_w_scalar')
lambert_w = Elemwise(lambert_w_scalar)

# test inside of PyMC model
with pm.Model():
    x = pm.Uniform('x', 1 / pm.math.exp(1), 1000)
    z = pm.Deterministic('z', lambert_w(x))

    idata = pm.sample_prior_predictive(var_names=['z'])

Now, I would like to try and run this model to ensure everything is alright, but… I get the following error:

Traceback (most recent call last):
  File "/home/undercover/ze/research/msc/misc/pymc-tests/lambertw.py", line 37, in <module>
    with pm.Model():
         ^^^^^^^^
AttributeError: module 'pymc' has no attribute 'Model'

This is about as weird as it gets, because it was working yesterday and I did nothing.
Additionally, I’m using a virtual environment. To ensure it wasn’t some magic installation problem, I’ve created a brand new virtual environment and only installed PyMC onto to it…

I followed the math here for the derivative: calculus - Derivative of Lambert W function. - Mathematics Stack Exchange. This thread mentions that your form is equivalent to what I tried to write, but you’re right that I had a typo (should be w / (x * (1 + w))

For e ** x you can use at.exp(x), no need to write out the constant yourself.

For PyMC installation, I always follow the installation instructions on the wiki when setting up a new environment (link to windows instructions, you can find mac/linux on the sidebar of the same page)

Alright, I am very dumb and had a file named pymc in the same folder as the script that I was running.

Anyways… Thank you very much for your help, I will now see if I can get the ODE to work, as it was presented in the Github page I have shared previously, and hopefully figure out a way to deploy an integration routine in PyMC as well.

Just a quick reply on the ODE part of your question:

I would recommend sunode. It works with PyMC 4.x unless something broke (the CI pipeline didn’t run since July).
Soon we’ll also update it to work with PyTensor instead of Aesara.

For the numerical integration you could certainly write a custom Op that uses scipy.integrate.quad or something similar in it’s perform method. You’ll probably loose out on gradients unless you can also implement a grad method.
Alternatively you could write you own integrator with pure PyTensor/Aesara Ops, but a naiive implementation might be very slow.

The thing is that before you try it you won’t know whether it’s worth it to do the extra steps of implementing the model in a differentiable way. Sometimes it’s more efficient to use a gradient-free sampler for more iterations compared to a more sample-efficient step method with a slow gradient computation.

So simply wrapping scipy.integrate.quad as done for the Lambert W function before would work?
I thought that the gradient was required in order for the function to work with PyMC, that’s why I thought this was more of a complicated business. Then, if you don’t mind explaining, why are we required to write the functions in a specific math backend? Is it simply for performance reasons?
Also, is there any go-to reference which explains how to write a generic function to said backend?

That was the library I mentioned in the blog post, I’m glad to know it’s the recommended one!
I haven’t had the time to try it out, but I will definitely do so as soon as I can.

Hello :slight_smile:

I don’t think you need a separate integrator, if you already solve an ode, why not just add an additional state to do the integration?

So u'(t) = f(t, u) and y'(t) = u(t) with y(0) = 0. That way the ode solver will do the integration.

Unfortunately the aesara wrappers in sunode currently don’t support non-constant integration boundaries.

I added that in a WIP PR here: Support computing derivatives with respect to the ode evaluation time by aseyboldt · Pull Request #38 · pymc-devs/sunode · GitHub
If you want to use this, you have to either wait until this is merged and we create a new release, or you can install the branch from github, install the dependencies manually (sundials<6.0) and install it using pip install -e . (in the dir of the repo).

3 Likes

Many thanks for both the pull request and the rather obvious mathematical statement which I completely missed.
I will give manual installation a shot!

Happens :slight_smile:
Let me know if you run into any problems.

Apologies for the late reply, but I would just like to let you know that the manual installation worked out perfectly and that I was able to use the ODE in PyMC!
I will marked this as solved, considering that I can turn my integral into an ODE and use the ODE solver provided, and the Lambda function has already been implemented above.
Many thanks!

2 Likes

Just a heads-up: We released Release v0.4.0 · pymc-devs/sunode (github.com) which now includes the changes from that PR.

1 Like