pymc version: 5.21.1
nutpie version 0.14.3
When I try to incorporate event data into a model with multiple locations, and allow the coefficient to vary by location, I am noticing that the “date” of the event is shifting in the posterior – only with nutpie. This disappears with pymc and numpyro. The below replicates and creates a plot showing what the transformed data sum(event_vectors * coefficients) looks like prior to fitting, and after fitting.
import matplotlib.pyplot as plt
import pandas as pd
import arviz as az
import pymc as pm
# create data that represents ten time periods,
# in which three events occur
data = pd.DataFrame(index=range(10), columns=[
'feat_1', 'feat_2', 'feat_3']).fillna(0)
data.loc[1, 'feat_1'] = 1
data.loc[4, 'feat_2'] = 1
data.loc[7, 'feat_3'] = 1
# the coordinates are:
# three events, two locations, ten time periods
coords = {
'events': ['feat_1', 'feat_2', 'feat_3'],
'locs': ['A', 'B'],
'times': range(10)}
with pm.Model(coords=coords) as model:
data = pm.Data('data',
data.to_numpy(),
dims=('times', 'events')
)
# specify a prior effect size of each of the three
# events, but allow each location to be impacted
# differently
coeffs = pm.Normal('coeffs',
mu=[1, 2, 3],
sigma=[0.1, 0.1, 0.1],
dims=('locs', 'events'))
# the output is the sum of event impacts at each
# point in time
output = pm.Deterministic('output',
(coeffs[None, :, :] *
data[:, None, :]).sum(axis=2),
dims=('times', 'locs')
)
with model:
idata = pm.sample_prior_predictive()
# USE NUTPIE AS THE SAMPLER
idata.extend(pm.sample(
nuts_sampler='nutpie'))
# collect "output" from the prior and the posterior
op_pre = idata.prior.output.mean(["chain", "draw"]).to_numpy()
op_post = idata.posterior.output.mean(["chain", "draw"]).to_numpy()
# plot "output" as given by the prior and the posterior to
# observe that the time of events 1 and 3 have changed
fig, [ax, ax2] = plt.subplots(2)
ax.plot(range(10), op_pre)
ax.set_title("transformed data prior to fitting")
ax2.plot(range(10), op_post)
ax2.set_title("transformed data post fitting")
fig.tight_layout()