Nutpie panel data events "moving"

pymc version: 5.21.1
nutpie version 0.14.3

When I try to incorporate event data into a model with multiple locations, and allow the coefficient to vary by location, I am noticing that the “date” of the event is shifting in the posterior – only with nutpie. This disappears with pymc and numpyro. The below replicates and creates a plot showing what the transformed data sum(event_vectors * coefficients) looks like prior to fitting, and after fitting.

import matplotlib.pyplot as plt
import pandas as pd
import arviz as az
import pymc as pm


# create data that represents ten time periods,
# in which three events occur
data = pd.DataFrame(index=range(10), columns=[
                    'feat_1', 'feat_2', 'feat_3']).fillna(0)

data.loc[1, 'feat_1'] = 1

data.loc[4, 'feat_2'] = 1

data.loc[7, 'feat_3'] = 1

# the coordinates are:
# three events, two locations, ten time periods
coords = {
    'events': ['feat_1', 'feat_2', 'feat_3'],
    'locs': ['A', 'B'],
    'times': range(10)}

with pm.Model(coords=coords) as model:

    data = pm.Data('data',
                   data.to_numpy(),
                   dims=('times', 'events')
                   )
    # specify a prior effect size of each of the three
    # events, but allow each location to be impacted
    # differently
    coeffs = pm.Normal('coeffs',
                       mu=[1, 2, 3],
                       sigma=[0.1, 0.1, 0.1],
                       dims=('locs', 'events'))

    # the output is the sum of event impacts at each
    # point in time
    output = pm.Deterministic('output',
                              (coeffs[None, :, :] *
                               data[:, None, :]).sum(axis=2),
                              dims=('times', 'locs')
                              )

with model:
    idata = pm.sample_prior_predictive()
    # USE NUTPIE AS THE SAMPLER
    idata.extend(pm.sample(
        nuts_sampler='nutpie'))

# collect "output" from the prior and the posterior
op_pre = idata.prior.output.mean(["chain", "draw"]).to_numpy()
op_post = idata.posterior.output.mean(["chain", "draw"]).to_numpy()

# plot "output" as given by the prior and the posterior to
# observe that the time of events 1 and 3 have changed
fig, [ax, ax2] = plt.subplots(2)
ax.plot(range(10), op_pre)
ax.set_title("transformed data prior to fitting")
ax2.plot(range(10), op_post)
ax2.set_title("transformed data post fitting")
fig.tight_layout()

The top figure shows expected behavior.
The bottom figure shows that the “time” of the first and third events has shifted.


This is the same plot with the sampler changed to pymc (or numpyro)

I can reproduce, it’s odd.

If I set the nutpie backend to “jax” the results are correct. If I set the backend of PyMC to “numba” the results are also correct. If I manually call pm.compute_determinstics with the “numba” backend it’s also correct.

@aseyboldt any idea?

I think I have a rough idea of what’s going on:

nutpie has some special handling for shared variables (so for instance the data variable above) that appear in the logp and expand functions with the numba backend. To get rid of call overhead, we store pointers to shared variables in a data structure that is passed to the compiled numba functions. But it seems that code silently assumes that those arrays are in row major (C) order.

In this example, data happens to have column major (F) order though, so the entries of the array are in effect shuffled around. The problem goes away, if we explicitly use a c-continugous array:

data = pm.Data(
    'data',
    np.ascontiguousarray(data.to_numpy()),
    dims=('times', 'events')
)

The jax backend works fine, because it doesn’t have this data structure for the shared variables.

This is a pretty nasty bug, as it produces silently incorrect results…
I can only hope it didn’t appear in too many models before (it needs somewhat special circumstances to show up?).

I’ll try to make a bugfix release quickly.

1 Like

WIP fix is here: fix(numba): non-contiguous shared variable by aseyboldt · Pull Request #217 · pymc-devs/nutpie · GitHub

1 Like

A bugfix release 0.15.1 is now available on pypi, the conda-forge package will still need a little time to finish.

@KL_Bean Thanks for reporting this!
If you have the time, it would be great if you could double check that you also don’t see this anymore with the new version.