Calling `pm.set_data` multiple times

In the predicting on hold out data section (4.1) in the quickstart, how come I can’t change the data again, like this (if I’ve already changed the data once as in the tutorial with "x_obs": [-1, 0, 1.0] and coords={"idx": [1001, 1002, 1003]})?

with model:
    # change the value and shape of the data
    pm.set_data(
        {
            "x_obs": [.0, 0, .0],
            # use dummy values with the same shape:
            "y_obs": [.0, 0, .0]
        },
        coords={"idx": [1004, 1005, 1006]},
    )
    idata.extend(pm.sample_posterior_predictive(idata))
idata.posterior_predictive["obs"].mean(dim=["draw", "chain"])

It just returns the same thing as the original code and hasn’t updated the out of sample data at all.

Are you creating the data via pm.MutableData()?

Yes, I believe so. When the following code is run the second attempt to adjust the mutable data has no effect.

x = rng.standard_normal(100)
y = x > 0

coords = {"idx": np.arange(100)}
with pm.Model() as model:
    # create shared variables that can be changed later on
    x_obs = pm.MutableData("x_obs", x, dims="idx")
    y_obs = pm.MutableData("y_obs", y, dims="idx")

    coeff = pm.Normal("x", mu=0, sigma=1)
    logistic = pm.math.sigmoid(coeff * x_obs)
    pm.Bernoulli("obs", p=logistic, observed=y_obs, dims="idx")
    idata = pm.sample()

with model:
    # change the value and shape of the data
    pm.set_data(
        {
            "x_obs": [-1, 0, 1.0],
            # use dummy values with the same shape:
            "y_obs": [0, 0, 0],
        },
        coords={"idx": [1001, 1002, 1003]},
    )

    idata.extend(pm.sample_posterior_predictive(idata))

print(idata.posterior_predictive["obs"].mean(dim=["draw", "chain"]))

with model:
    # change the value and shape of the data
    pm.set_data(
        {
            "x_obs": [.0, 0, .0],
            # use dummy values with the same shape:
            "y_obs": [.0, 0, .0]
        },
        coords={"idx": [1004, 1005, 1006]},
    )
    idata.extend(pm.sample_posterior_predictive(idata))
idata.posterior_predictive["obs"].mean(dim=["draw", "chain"])

I’m not sure what happens when you try to extend an inferenceData objects with multiple groups, each of which shares a name (e.g., posterior_predictive).

This seems to work for me?

import pymc as pm
import numpy as np

rng = np.random.default_rng(12345)

x = rng.standard_normal(100)
y = x > 0

coords = {"idx": np.arange(100)}
with pm.Model() as model:
    # create shared variables that can be changed later on
    x_obs = pm.MutableData("x_obs", x, dims="idx")
    y_obs = pm.MutableData("y_obs", y, dims="idx")

    coeff = pm.Normal("x", mu=0, sigma=1)
    logistic = pm.math.sigmoid(coeff * x_obs)
    pm.Bernoulli("obs", p=logistic, observed=y_obs, dims="idx")
    idata = pm.sample()

with model:
    # change the value and shape of the data
    pm.set_data(
        {
            "x_obs": [-1, 0, 1.0],
            # use dummy values with the same shape:
            "y_obs": [0, 0, 0],
        },
        coords={"idx": [1001, 1002, 1003]},
    )

    pp1 = pm.sample_posterior_predictive(idata)


with model:
    # change the value and shape of the data
    pm.set_data(
        {
            "x_obs": [.0, 0, .0],
            # use dummy values with the same shape:
            "y_obs": [.0, 0, .0]
        },
        coords={"idx": [1004, 1005, 1006]},
    )
    pp2 = pm.sample_posterior_predictive(idata)

print(pp1.posterior_predictive["obs"].mean(dim=["draw", "chain"]))

#<xarray.DataArray 'obs' (idx: 3)>
#array([0.027 , 0.5025, 0.977 ])
#Coordinates:
#  * idx      (idx) int64 1001 1002 1003

print(pp2.posterior_predictive["obs"].mean(dim=["draw", "chain"]))

#<xarray.DataArray 'obs' (idx: 3)>
#array([0.51625, 0.5025 , 0.501  ])
#Coordinates:
#  * idx      (idx) int64 1004 1005 1006

1 Like

The issue is (was) with the usage of extend, it is not a merge or concatenation function.

As explained in its docstring, extend has by default join="left" which means that groups that are both present in idata and in other are kept from idata without modifying it, join="right" would replace the repeated groups in idata in order to keep the ones from other.

1 Like