Using a lookup table in pymc model

In a geophysical application, there is a complex relationship between two variables. One variable, I’ll call it y here, is a function of x (plus uncertainty), but the functional relationship is not easily parameterized, and so I would like to use a lookup table to model their relationship. In the end, I want to add more variables and add data to fit, but here I am just interested in how to implement a lookup table or something equivalent.

Here is a simple NumPy example of what I would like to implement in PyMC:

import numpy as np

# values for x in lookup table
lut_x = np.linspace(0.0, 1.0, 11)
# corresponding values for y (using a simple relationship here)
lut_y = 0.1*np.sin(2.0 * np.pi * lut_x)

x = np.random.uniform()

ilut = 1
while x > lut_x[ilut]:
    ilut += 1

# linear interpolation
weight = (lut_x[ilut]-x)/(lut_x[ilut]-lut_x[ilut-1])
mu = weight * lut_y[ilut-1] + (1.0-weight) * lut_y[ilut]

y = np.random.normal(mu, 0.1)

Now, I have tried something similar in PyMC, but it either does not run or does not converge properly.
For example (leaving out the linear interpolation):

import numpy as np
import pymc as pm
import arviz
import aesara.tensor as at

coords = {'lut': np.linspace(0.0, 1.0, 11)}
model = pm.Model(coords=coords)

with model:
    # values for x in lookup table
    lut_x = pm.Data('lut_x', coords['lut'], dims=('lut',), mutable=False)
    # corresponding values for y (using a simple relationship here)
    lut_y = pm.Data('lut_y', 0.1*np.sin(2.0 * np.pi * coords['lut']), dims=('lut',), mutable=False)

    x = pm.Uniform('x', lower=0.0, upper=1.0)

    ilut = 1
    while at.gt(x, lut_x[ilut]):
        ilut += 1

    y = pm.Normal('y', mu=lut_y[ilut], sigma=0.01)

    idata = pm.sample(1000)
    print(arviz.summary(idata))

does not want to start or gets stuck early. A solution using

ilut = pm.Deterministic('ilut', at.argmin(at.abs(lut_x - x)))

is a bit slow, does not converge, or doesn’t sample the full space.

I realize that a lookup table is probably not ideal for the sampler, but it works well in Stan, which I am trying to move away from. So probably, there’s a technique that I am not yet aware of – I looked at

but that’s not quite what I’d like to achieve.

For the record (and those interested), here is a pystan (v2) version of this simple model which converges and produces the desired output:

import pystan

stan_code = '''
data {
int nlut;
real lut_x[nlut];
real lut_y[nlut];
}
parameters {
real<lower=0.0, upper=1.0> x;
real y;
}
model {
x ~ uniform(0.0, 1.0);
{
    int ilut;
    real weight;
    real mu;

    ilut = 2;
    while(x > lut_x[ilut]){
        ilut += 1;
    }
    weight = (lut_x[ilut]-x)/(lut_x[ilut]-lut_x[ilut-1]);
    mu = weight * lut_y[ilut-1] + (1.0-weight) * lut_y[ilut];
    y ~ normal(mu, 0.1);
}
}
'''

# values for x in lookup table
lut_x = np.linspace(0.0, 1.0, 11)
# corresponding values for y (using a simple relationship here)
lut_y = 0.1*np.sin(2.0 * np.pi * lut_x)

stan_data = {
    'nlut': len(lut_x),
    'lut_x': lut_x,
    'lut_y': lut_y,
}

model = pystan.StanModel(model_code=stan_code)
fit = model.sampling(data=stan_data, iter=4000, chains=4)
results = fit.extract()
print(fit)

You can’t use while with PyMC, you would need to use a scan, which is the equivalent symbolic loop operator.

https://aesara.readthedocs.io/en/latest/library/scan.html

However, in your simple case would this suffice?

ilut = (x > lut).sum()

I am not sure it is differentiable though, which would prevent using NUTS

Might also need to cast to an int for use in indexing ilut = ilut.astype(”int32")

Thanks for the helpful tips. Unfortunately,

ilut = (x > lut_x).sum()

is performing similarly slow as the

ilut = at.argmin(at.abs(lut_x - x))

solution.

Now, a lookup table is pretty simple. In the Stan code, I am using a piecewise linear interpolation to interpolate between values. By using piecewise polynomials of a higher degree, I could compute the gradient of the function efficiently. From my reading of various topics here, it looks like the definition of a custom aesara Op may be a useful approach.

Writing your own Op would be slower. I don’t see anything that should be slow about the model (without the while which is invalid).

I see you’re using a different sigma for your data (0.1) and the model likelihood (0.01), which could be the problem.

If your likelihood is mispecified NUTS could struggle/ have to take tiny slow steps.

Ok, I’ll try to use a different solution before going down the custom Op route.

And sorry for my variable naming, y is not the data here. I should have called it x2 perhaps. In the full model, there are several other random variables with different prior distributions (and switchable by the user). Everything goes into a mechanistic model and the output of that model is fit to data. All that works already in my PyMC implementation using NUTS, the only ingredient missing is the relationship between x and y here (again, y is not data in this example), where y = f(x) + \epsilon, \epsilon \sim N(0, \sigma) and f is given by the lookup table. There must be other PyMC examples using a form of interpolation, I’ll look into that.

I have had some partial success in implementing the lookup table using an Aesara function, but some things are not quite working.

Here is a model that runs fast and converges, but it produces the wrong output. A scatter plot of x and y shows scatter across a horizontal line (no sign of the sine), as if just one of the samples of x determined the index i, see code:

import numpy as np
import pymc as pm
import arviz
import aesara
import aesara.tensor as at

coords = {'lut': np.linspace(0.0, 1.0, 11)}
model = pm.Model(coords=coords)

xi = at.vector('xi')
yi = at.vector('yi')
x = at.random.uniform(0.0, 1.0, size=None, name='x')
model.register_rv(x, name='x')
i = at.searchsorted(xi, x)
weight = (x - xi[i-1])/(xi[i] - xi[i-1])
res = weight * yi[i] + (1.0 - weight) * yi[i-1]

interpolate = pm.compile_pymc(inputs=[xi, yi], outputs=res)

with model:
    # values for x in lookup table
    lut_x = np.array(coords['lut'])
    # corresponding values for y (using a simple relationship here)
    lut_y = np.array(0.1*np.sin(2.0 * np.pi * coords['lut']))

    mu = interpolate(lut_x, lut_y)

    y = pm.Normal('y', mu=mu, sigma=0.01)

    idata = pm.sample(1000)
    print(arviz.summary(idata))

So there is likely an issue with the way I am setting up the model or Aesara function. Ditching the lookup table for a moment, I get very similar results (no sine wave) when using at.sin directly in the Aesara function:

model = pm.Model()

x = at.random.uniform(0.0, 1.0, size=None, name='x')
model.register_rv(x, name='x')
res = 0.1*at.sin(2.0*np.pi*x)

interpolate = pm.compile_pymc(inputs=[], outputs=[res])

with model:
    mu = interpolate()
    y = pm.Normal('y', mu=mu, sigma=0.01)

    idata = pm.sample(1000)
    print(arviz.summary(idata))

Extending the Aesara function to include y works better, produces a sine, but suffers from bad convergence (r_hat > 1.1, only part of the sine wave sampled).

model = pm.Model()

x = at.random.uniform(0.0, 1.0, size=None, name='x')
model.register_rv(x, name='x')
res = 0.1*at.sin(2.0*np.pi*x)
y = at.random.normal(res, 0.01, size=None, name='y')
model.register_rv(y, name='y')

interpolate = pm.compile_pymc(inputs=[], outputs=[x, y])

with model:
    x, y = interpolate()

    idata = pm.sample(1000)
    print(arviz.summary(idata))

Does anyone have advice on how to restructure the code or change the function to make PyMC and Aesara work better together?

Without going through your code in detail, you shouldn’t use a compiled Aesara function inside a PyMC model. You should specify the relationship between variables using Aesara operators and PyMC will itself compile whatever functions it needs for sampling.

Also I see you’re registering variables manually in a model. You shouldn’t have to do this unless you’re doing something very very specific. Calling pm.Uniform will do that for you as well as make sure you passed the right inputs and your variables are properly sized.

Thank you for your help, that is good to know. Most Aesara introductions start with functions and compiling them, so I thought I could directly use those here.

Here is a new piece of code with the linear interpolation and the lookup table directly included in the model. To use the index returned by at.searchsorted, I am using eval() here, as suggested in its help text:

coords = {'lut': np.linspace(0.0, 1.0, 11)}
model = pm.Model(coords=coords)

with model:
    # values for x in lookup table
    lut_x = np.array(coords['lut'])
    # corresponding values for y (using a simple relationship here)
    lut_y = np.array(0.1*np.sin(2.0 * np.pi * coords['lut']))

    x = pm.Uniform('x', 0.0, 1.0)
    i = at.searchsorted(lut_x, x).eval()

    weight = (x - lut_x[i-1])/(lut_x[i] - lut_x[i-1])
    mu = weight * lut_y[i] + (1.0 - weight) * lut_y[i-1]

    y = pm.Normal('y', mu=mu, sigma=0.001)

    idata = pm.sample(1000)
    print(arviz.summary(idata))

The outcome is similar to some of the examples I tried before. The model converges, samples from x nicely, but somehow appears to precompute the gradient at a small range of values (it’s not always the same). The result is a tangential line (plus the expected noise) and not a sine wave:

I admit I’m not 100% sure if this is your intent, but if you just want to:

  1. Draw a value between 0-1
  2. Quantize it into a one of 10 buckets of equal length
  3. Use the quantized value as an index to access a lookup table
  4. Do some computation with the values from the table

It seems like a much easier way to accomplish (2) is to just multiply by 10 and convert to an integer:

x_vals = np.linspace(0, 1, 11)
y_vals = 0.1 * np.sin(2 * np.pi * x_vals)

with pm.Model() as model:
    x = pm.Uniform('x', 0, 1)
    i = (x * 10).astype(int)
    
    x_at = at.as_tensor_variable(x_vals)
    y_at = at.as_tensor_variable(y_vals)
    
    weight = (x - x_at[i-1])/(x_at[i] - x_at[i-1])
    res = pm.Deterministic('result', weight * y_at[i] + (1.0 - weight) * y_at[i-1])
    
    idata = pm.sample_prior_predictive()

Here’s the resulting plot:

image

Note that there’s no ground truth data, so it doesn’t make sense to use pm.sample.

Right, I see where you are coming from, but it’s more like PyMC is using Aesara, not you directly. You won’t see any pymc examples with direct Aesara function compilation.

99% of users don’t need to know anything about Aesara other than calling at.foo instead of np.foo.

.eval() is just a helper which does function compilation and evaluation for debugging purposes. For the same reason you can’t use a compiled function in PyMC, you can’t use an evaled variable (otherwise it will just be a constant with whatever eval returns the first time it’s compiled).

Sometimes a little knowledge can be a dangerous thing. It’s better to ignore what you know about Aesara until you read more about how exactly PyMC uses Aesara. If you still want to understand better, this might be a good start:

https://www.pymc.io/projects/docs/en/stable/learn/core_notebooks/pymc_aesara.html

Again, 99% of users don’t need to understand much about Aesara, so depending on your goals that could be a waste of time.

I admit I’m not 100% sure if this is your intent, but if you just want to:

Draw a value between 0-1
Quantize it into a one of 10 buckets of equal length
Use the quantized value as an index to access a lookup table
Do some computation with the values from the table

It seems like a much easier way to accomplish (2) is to just multiply by 10 and convert to an integer:

That solution is pretty much what I want to achieve, thank you. In the full model, the buckets/bins are not evenly spaced, but I am sure I’ll get it to work given the last posts here.

And just to clarify, the intent is to perform a linear interpolation for a function f(x). f is a difficult to compute function (I am using the sine as a simple stand-in), so f has been pre-computed for many values of x and put into a lookup table. Instead of computing f(x) for new values of x, the lookup table is used to approximate f(x) from f(x_i) and f(x_{i+1}) that are contained in the lookup table with x_i \leq x < x_{i+1}.

Note that there’s no ground truth data, so it doesn’t make sense to use pm.sample.

Indeed, I forgot to change that when I eliminated the data from the example code.

.eval() is just a helper which does function compilation and evaluation for debugging purposes. For the same reason you can’t use a compiled function in PyMC, you can’t use an evaled variable (otherwise it will just be a constant with whatever eval returns the first time it’s compiled).

Sometimes a little knowledge can be a dangerous thing. It’s better to ignore what you know about Aesara until you read more about how exactly PyMC uses Aesara. If you still want to understand better, this might be a good start:

https://www.pymc.io/projects/docs/en/stable/learn/core_notebooks/pymc_aesara.html

Again, 99% of users don’t need to understand much about Aesara, so depending on your goals that could be a waste of time.

Okay, more things to learn (I did not quite realize that the at.searchsorted(...).eval() was the equivalent of aesara.function’s eval), thanks for being patient with me. I’ll go through the PyMC and Aesara tutorial more carefully now and will hopefully be able to find a solution.

1 Like

If the buckets aren’t of equal length, it will probably be easier to draw the indices from a categorical distribution. Come to think of it, that’s probably true even if the buckets are of equal length.

It was really all about not accidentally compiling the Aesara function and using tensor variables throughout. This is working as I would expect it, and using different bucket sizes for illustration purposes.

coords = {'lut': np.concatenate([np.linspace(0.0, 0.45, 3), np.linspace(0.5, 1.0, 101)])}
model = pm.Model(coords=coords)

with model:
    # values for x in lookup table
    lut_x = pm.Data('lut_x', coords['lut'], dims=('lut',), mutable=False)
    # or: lut_x = at.as_tensor_variable(coords['lut'])
    # corresponding values for y (using a simple relationship here)
    lut_y = pm.Data('lut_y', 0.1*np.sin(2.0 * np.pi * coords['lut']), dims=('lut',), mutable=False)
    # or: lut_y = at.as_tensor_variable(0.1*np.sin(2.0 * np.pi * coords['lut']))

    x = pm.Uniform('x', 0.0, 1.0)
    i = at.searchsorted(lut_x, x)

    weight = (x - lut_x[i-1])/(lut_x[i] - lut_x[i-1])
    mu = weight * lut_y[i] + (1.0 - weight) * lut_y[i-1]

    y = pm.Normal('y', mu=mu, sigma=0.01)

    idata = pm.sample_prior_predictive(1000)
    print(arviz.summary(idata))

Thank you both for your helpful input!

2 Likes