Different posterior predictive results after loading saved model

I am using the pymc-experimental ModelBuilder Class to save my model. After saving my model, I use the fitted unsaved model to generate a posterior predictive sample and then I load the saved model to a new object and using the same data I generate a posterior predictive sample. The predictions from the loaded model are way off. I am really not sure where to begin with debugging this issue. I am just wondering if anyone else has experienced something similar?

I am adding a repro here:

First the ModelBuilder file “multilevel_model.py”:

import json
from pathlib import Path
from typing import Dict, Tuple, Union

import arviz as az
import numpy as np
import polars as pl
import pymc as pm
from pymc_experimental.model_builder import ModelBuilder


class MultilevelModel(ModelBuilder):
    """
    Defines the Multilevel risk model
    """

    # Give the model a name
    _model_type = "MiltilevelModel"

    # And a version
    version = "0.1"

    def build_model(self, X: pl.DataFrame, y: pl.Series, **kwargs):
        """
        build_model creates the PyMC model

        Parameters:
        model_config: dictionary
            it is a dictionary with all the parameters that we need in our model example:  a_loc, a_scale, b_loc
        X : pd.DataFrame
            The input data that is going to be used in the model. This should be a DataFrame
            containing the features (predictors) for the model. For efficiency reasons, it should
            only contain the necessary data columns, not the entire available dataset, as this
            will be encoded into the data used to recreate the model.

        y : pd.Series
            The target data for the model. This should be a Series representing the output
            or dependent variable for the model.

        kwargs : dict
            Additional keyword arguments that may be used for model configuration.
        """
        # Check the type of X and y and adjust access accordingly

        self._generate_and_preprocess_model_data(X, y)

        with pm.Model() as self.model:
            # Data coords
            self.model.add_coord("patients", self.patients)
            self.model.add_coord("gender", self.gender)
            self.model.add_coord("unit", self.unit)
            self.model.add_coord("care", self.care)
            self.model.add_coord("obs", range(self.X.shape[0]))

            # Data containers
            patient_index = pm.Data("patient_index", self.patient_idx, dims="obs")
            gender_index = pm.Data("gender_index", self.gender_idx, dims="obs")
            unit_index = pm.Data("unit_index", self.unit_idx, dims="obs")
            care_index = pm.Data("care_index", self.care_idx, dims="obs")
            age_ = pm.Data("age", self.age, dims="obs")
            time_ = pm.Data("time", self.time, dims="obs")
            hematocrit_ = pm.Data("hematocrit", self.hematocrit, dims="obs")
            troponin_ = pm.Data("troponin", self.troponin, dims="obs")
            uramy_ = pm.Data("uramy", self.uramy, dims="obs")
            myelo_ = pm.Data("myelo", self.myelo, dims="obs")
            ururo_ = pm.Data("uroro", self.ururo, dims="obs")
            ursod_ = pm.Data("ursod", self.ursod, dims="obs")
            creact_ = pm.Data("creact", self.creact, dims="obs")
            abscd8_ = pm.Data("abscd8", self.abscd8, dims="obs")
            abscd4_ = pm.Data("abscd4", self.abscd4, dims="obs")
            abseos_ = pm.Data("abseos", self.abseos, dims="obs")
            abslymp_ = pm.Data("abslymp", self.abslymp, dims="obs")
            absneu_ = pm.Data("absneu", self.absneu, dims="obs")
            absc8_ = pm.Data("absc8", self.absc8, dims="obs")
            absc4_ = pm.Data("absc4", self.absc4, dims="obs")
            min_bp_ = pm.Data("min_bp", self.min_BP, dims="obs")
            max_bp_ = pm.Data("max_bp", self.max_BP, dims="obs")
            min_temp_ = pm.Data("min_temp", self.min_temp, dims="obs")
            max_temp_ = pm.Data("max_temp", self.max_temp, dims="obs")
            min_rsp_ = pm.Data("min_rsp", self.min_rsp, dims="obs")
            max_rsp_ = pm.Data("max_rsp", self.max_rsp, dims="obs")
            min_pls_ = pm.Data("min_pls", self.min_pls, dims="obs")
            max_pls_ = pm.Data("max_pls", self.max_pls, dims="obs")
            min_sgr_ = pm.Data("min_sgr", self.min_sgr, dims="obs")
            max_sgr_ = pm.Data("max_sgr", self.max_sgr, dims="obs")
            min_pn_ = pm.Data("min_pn", self.min_pn, dims="obs")
            max_pn_ = pm.Data("max_pn", self.max_pn, dims="obs")
            transferred_ = pm.Data("transferred", self.y, dims="obs")

            # Prior
            sigma_a = pm.HalfCauchy("sigma_a", self.model_config.get("cauchy_sigma"))

            # Patient Baseline
            gamma = pm.Normal(
                "gamma_0",
                self.model_config.get("normal_mu"),
                self.model_config.get("normal_sigma"),
            )
            gamma_gender = pm.Normal(
                "gamma_gender",
                self.model_config.get("normal_mu"),
                self.model_config.get("normal_sigma"),
                dims="gender",
            )
            gamma_unit = pm.Normal(
                "gamma_unit",
                self.model_config.get("normal_mu"),
                self.model_config.get("normal_sigma"),
                dims="unit",
            )
            gamma_care = pm.Normal(
                "gamma_care",
                self.model_config.get("normal_mu"),
                self.model_config.get("normal_sigma"),
                dims="care",
            )
            gamma_age = pm.Normal(
                "gamma_age",
                self.model_config.get("normal_mu"),
                self.model_config.get("normal_sigma"),
            )

            mu_a = pm.Deterministic(
                "mu_a",
                var=gamma
                + gamma_gender[gender_index]
                + gamma_unit[unit_index]
                + gamma_care[care_index]
                + gamma_age * age_,
            )

            # Unexplained Patient Variation
            epsilon_a = pm.Normal(
                "epsilon_a",
                self.model_config.get("normal_mu"),
                self.model_config.get("normal_sigma"),
                dims="patients",
            )
            a = pm.Deterministic("a", var=mu_a + sigma_a * epsilon_a[patient_index])

            # Common slopes priors
            b_time = pm.Normal(
                "b_time",
                self.model_config.get("normal_mu"),
                self.model_config.get("normal_sigma"),
            )
            b_hematocrit = pm.Normal(
                "b_hematocrit",
                self.model_config.get("normal_mu"),
                self.model_config.get("normal_sigma"),
            )
            b_troponin = pm.Normal(
                "b_troponin",
                self.model_config.get("normal_mu"),
                self.model_config.get("normal_sigma"),
            )
            b_uramy = pm.Normal(
                "b_uramy",
                self.model_config.get("normal_mu"),
                self.model_config.get("normal_sigma"),
            )
            b_myelo = pm.Normal(
                "b_myelo",
                self.model_config.get("normal_mu"),
                self.model_config.get("normal_sigma"),
            )
            b_ururo = pm.Normal(
                "b_ururo",
                self.model_config.get("normal_mu"),
                self.model_config.get("normal_sigma"),
            )
            b_ursod = pm.Normal(
                "b_ursod",
                self.model_config.get("normal_mu"),
                self.model_config.get("normal_sigma"),
            )
            b_creact = pm.Normal(
                "b_creact",
                self.model_config.get("normal_mu"),
                self.model_config.get("normal_sigma"),
            )
            b_abscd8 = pm.Normal(
                "b_abscd8",
                self.model_config.get("normal_mu"),
                self.model_config.get("normal_sigma"),
            )
            b_abscd4 = pm.Normal(
                "b_abscd4",
                self.model_config.get("normal_mu"),
                self.model_config.get("normal_sigma"),
            )
            b_abseos = pm.Normal(
                "b_abseos",
                self.model_config.get("normal_mu"),
                self.model_config.get("normal_sigma"),
            )
            b_abslymp = pm.Normal(
                "b_abslymp",
                self.model_config.get("normal_mu"),
                self.model_config.get("normal_sigma"),
            )
            b_absneu = pm.Normal(
                "b_absneu",
                self.model_config.get("normal_mu"),
                self.model_config.get("normal_sigma"),
            )
            b_absc8 = pm.Normal(
                "b_absc8",
                self.model_config.get("normal_mu"),
                self.model_config.get("normal_sigma"),
            )
            b_absc4 = pm.Normal(
                "b_absc4",
                self.model_config.get("normal_mu"),
                self.model_config.get("normal_sigma"),
            )
            b_min_bp = pm.Normal(
                "b_min_bp",
                self.model_config.get("normal_mu"),
                self.model_config.get("normal_sigma"),
            )
            b_max_bp = pm.Normal(
                "b_max_bp",
                self.model_config.get("normal_mu"),
                self.model_config.get("normal_sigma"),
            )
            b_min_temp = pm.Normal(
                "b_min_temp",
                self.model_config.get("normal_mu"),
                self.model_config.get("normal_sigma"),
            )
            b_max_temp = pm.Normal(
                "b_max_temp",
                self.model_config.get("normal_mu"),
                self.model_config.get("normal_sigma"),
            )
            b_min_rsp = pm.Normal(
                "b_min_rsp",
                self.model_config.get("normal_mu"),
                self.model_config.get("normal_sigma"),
            )
            b_max_rsp = pm.Normal(
                "b_max_rsp",
                self.model_config.get("normal_mu"),
                self.model_config.get("normal_sigma"),
            )
            b_min_pls = pm.Normal(
                "b_min_pls",
                self.model_config.get("normal_mu"),
                self.model_config.get("normal_sigma"),
            )
            b_max_pls = pm.Normal(
                "b_max_pls",
                self.model_config.get("normal_mu"),
                self.model_config.get("normal_sigma"),
            )
            b_min_sgr = pm.Normal(
                "b_min_sgr",
                self.model_config.get("normal_mu"),
                self.model_config.get("normal_sigma"),
            )
            b_max_sgr = pm.Normal(
                "b_max_sgr",
                self.model_config.get("normal_mu"),
                self.model_config.get("normal_sigma"),
            )
            b_min_pn = pm.Normal(
                "b_min_pn",
                self.model_config.get("normal_mu"),
                self.model_config.get("normal_sigma"),
            )
            b_max_pn = pm.Normal(
                "b_max_pn",
                self.model_config.get("normal_mu"),
                self.model_config.get("normal_sigma"),
            )

            # Expectation
            p = pm.Deterministic(
                "p",
                pm.math.invlogit(
                    a[patient_index]
                    + b_time * time_
                    + b_hematocrit * hematocrit_
                    + b_troponin * troponin_
                    + b_uramy * uramy_
                    + b_myelo * myelo_
                    + b_ururo * ururo_
                    + b_ursod * ursod_
                    + b_creact * creact_
                    + b_abscd4 * abscd4_
                    + b_abscd8 * abscd8_
                    + b_absc4 * absc4_
                    + b_absc8 * absc8_
                    + b_abseos * abseos_
                    + b_abslymp * abslymp_
                    + b_absneu * absneu_
                    + b_min_bp * min_bp_
                    + b_max_bp * max_bp_
                    + b_min_temp * min_temp_
                    + b_max_temp * max_temp_
                    + b_min_rsp * min_rsp_
                    + b_max_rsp * max_rsp_
                    + b_min_pls * min_pls_
                    + b_max_pls * max_pls_
                    + b_min_sgr * min_sgr_
                    + b_max_sgr * max_sgr_
                    + b_min_pn * min_pn_
                    + b_max_pn * max_pn_
                ),
                dims="obs",
            )
            pm.Bernoulli("likelihood", p=p, observed=transferred_, dims="obs")

    def _data_setter(
        self, X: Union[pl.DataFrame, np.ndarray], y: Union[pl.Series, np.ndarray] = None
    ):
        """
        Sets the data for predictions
        """
        patient_idx, patients = self.factorize(X["ClientID"])
        gender_idx, gender = self.factorize(X["PatientGender"])
        unit_idx, unit = self.factorize(X["UnitDescription"])
        care_idx, care = self.factorize(X["AlternateCareLevel"])
        time = X["Time"].to_numpy()
        age = X["age"].to_numpy()
        hematocrit = X["hematocrit-istat"].to_numpy()
        troponin = X["troponin i"].to_numpy()
        uramy = X["urine amylase"].to_numpy()
        myelo = X["absolute myelocytes"].to_numpy()
        ururo = X["urine urobilinogin"].to_numpy()
        ursod = X["urine sodium"].to_numpy()
        creact = X["c reactive protein"].to_numpy()
        abscd8 = X["absolute cd8+ cells"].to_numpy()
        abscd4 = X["absolute cd4+ cells"].to_numpy()
        abseos = X["absolute eosinophils"].to_numpy()
        abslymp = X["absolute lymphocytes"].to_numpy()
        absneu = X["absolute neutrophils"].to_numpy()
        absc8 = X["cd8, absolute"].to_numpy()
        absc4 = X["cd4, absolute"].to_numpy()
        min_BP = X["min_VitalsDescription_BP - Systolic"].to_numpy()
        max_BP = X["max_VitalsDescription_BP - Systolic"].to_numpy()
        min_temp = X["min_VitalsDescription_Temperature"].to_numpy()
        max_temp = X["max_VitalsDescription_Temperature"].to_numpy()
        min_rsp = X["min_VitalsDescription_Respiration"].to_numpy()
        max_rsp = X["max_VitalsDescription_Respiration"].to_numpy()
        min_pls = X["min_VitalsDescription_Pulse"].to_numpy()
        max_pls = X["max_VitalsDescription_Pulse"].to_numpy()
        min_sgr = X["min_VitalsDescription_Blood Sugar"].to_numpy()
        max_sgr = X["max_VitalsDescription_Blood Sugar"].to_numpy()
        min_pn = X["min_VitalsDescription_Pain Level"].to_numpy()
        max_pn = X["max_VitalsDescription_Pain Level"].to_numpy()

        with self.model:
            pm.set_data(
                new_data={
                    "patient_index": patient_idx,
                    "gender_index": gender_idx,
                    "unit_index": unit_idx,
                    "care_index": care_idx,
                    "age": age,
                    "time": time,
                    "hematocrit": hematocrit,
                    "troponin": troponin,
                    "uramy": uramy,
                    "myelo": myelo,
                    "uroro": ururo,
                    "ursod": ursod,
                    "creact": creact,
                    "abscd8": abscd8,
                    "abscd4": abscd4,
                    "abseos": abseos,
                    "abslymp": abslymp,
                    "absneu": absneu,
                    "absc8": absc8,
                    "absc4": absc4,
                    "min_bp": min_BP,
                    "max_bp": max_BP,
                    "min_temp": min_temp,
                    "max_temp": max_temp,
                    "min_rsp": min_rsp,
                    "max_rsp": max_rsp,
                    "min_pls": min_pls,
                    "max_pls": max_pls,
                    "min_sgr": min_sgr,
                    "max_sgr": max_sgr,
                    "min_pn": min_pn,
                    "max_pn": max_pn,
                    "transferred": np.zeros_like(max_pn, dtype=np.int32),
                },
                coords=dict(
                    patients=patients,
                    gender=gender,
                    unit=unit,
                    care=care,
                    obs=range(X.shape[0]),
                ),
            )

    @staticmethod
    def get_default_model_config() -> Dict:
        """
        Returns a class default config dict for model builder if no model_config is provided on class initialization.
        The model config dict is generally used to specify the prior values we want to build the model with.
        It supports more complex data structures like lists, dictionaries, etc.
        It will be passed to the class instance on initialization, in case the user doesn't provide any model_config of their own.
        """
        model_config: Dict = {"normal_mu": 0, "normal_sigma": 1, "cauchy_sigma": 2}
        return model_config

    @staticmethod
    def get_default_sampler_config() -> Dict:
        """
        Returns a class default sampler dict for model builder if no sampler_config is provided on class initialization.
        The sampler config dict is used to send parameters to the sampler .
        It will be used during fitting in case the user doesn't provide any sampler_config of their own.
        """
        sampler_config: Dict = {
            "draws": 1_000,
            "tune": 1_000,
            "chains": 3,
            "target_accept": 0.95,
            "nuts_sampler": "nutpie",
            "progressbar": False,
        }
        return sampler_config

    @property
    def output_var(self):
        """
        Ensures variable of interest in predictions
        """
        return "likelihood"

    @property
    def _serializable_model_config(self) -> Dict[str, Union[int, float, Dict]]:
        """
        _serializable_model_config is a property that returns a dictionary with all the model parameters that we want to save.
        as some of the data structures are not json serializable, we need to convert them to json serializable objects.
        Some models will need them, others can just define them to return the model_config.
        """
        return self.model_config

    def _generate_and_preprocess_model_data(
        self, X: Union[pl.DataFrame, pl.Series], y: Union[pl.Series, np.ndarray]
    ) -> None:
        """
        Process data
        """
        self.X = X
        self.y = y
        self.patient_idx, self.patients = self.factorize(X["ClientID"])
        self.gender_idx, self.gender = self.factorize(X["PatientGender"])
        self.unit_idx, self.unit = self.factorize(X["UnitDescription"])
        self.care_idx, self.care = self.factorize(X["AlternateCareLevel"])
        self.time = X["Time"].to_numpy()
        self.age = X["age"].to_numpy()
        self.hematocrit = X["hematocrit-istat"].to_numpy()
        self.troponin = X["troponin i"].to_numpy()
        self.uramy = X["urine amylase"].to_numpy()
        self.myelo = X["absolute myelocytes"].to_numpy()
        self.ururo = X["urine urobilinogin"].to_numpy()
        self.ursod = X["urine sodium"].to_numpy()
        self.creact = X["c reactive protein"].to_numpy()
        self.abscd8 = X["absolute cd8+ cells"].to_numpy()
        self.abscd4 = X["absolute cd4+ cells"].to_numpy()
        self.abseos = X["absolute eosinophils"].to_numpy()
        self.abslymp = X["absolute lymphocytes"].to_numpy()
        self.absneu = X["absolute neutrophils"].to_numpy()
        self.absc8 = X["cd8, absolute"].to_numpy()
        self.absc4 = X["cd4, absolute"].to_numpy()
        self.min_BP = X["min_VitalsDescription_BP - Systolic"].to_numpy()
        self.max_BP = X["max_VitalsDescription_BP - Systolic"].to_numpy()
        self.min_temp = X["min_VitalsDescription_Temperature"].to_numpy()
        self.max_temp = X["max_VitalsDescription_Temperature"].to_numpy()
        self.min_rsp = X["min_VitalsDescription_Respiration"].to_numpy()
        self.max_rsp = X["max_VitalsDescription_Respiration"].to_numpy()
        self.min_pls = X["min_VitalsDescription_Pulse"].to_numpy()
        self.max_pls = X["max_VitalsDescription_Pulse"].to_numpy()
        self.min_sgr = X["min_VitalsDescription_Blood Sugar"].to_numpy()
        self.max_sgr = X["max_VitalsDescription_Blood Sugar"].to_numpy()
        self.min_pn = X["min_VitalsDescription_Pain Level"].to_numpy()
        self.max_pn = X["max_VitalsDescription_Pain Level"].to_numpy()

    def factorize(self, series: pl.Series) -> Tuple[np.ndarray, np.ndarray]:
        """
        Factorize polars series
        """
        name = series.name
        df = series.to_frame()
        df = df.fill_null("<NA>")
        df_ranked = df.unique(maintain_order=True).with_row_index(name=f"{name}_index")
        uniques = df_ranked[name].to_numpy()
        codes = df.join(df_ranked, how="left", on=name)[f"{name}_index"].to_numpy()
        return codes, uniques

    def _validate_data(self, X, y=None):
        if y is not None:
            return X, y
        else:
            return X

    @classmethod
    def load(cls, fname: str):
        """
        Creates a ModelBuilder instance from a file,
        Loads inference data for the model.

        Parameters
        ----------
        fname : string
            This denotes the name with path from where idata should be loaded from.

        Returns
        -------
        Returns an instance of ModelBuilder.

        Raises
        ------
        ValueError
            If the inference data that is loaded doesn't match with the model.
        Examples
        --------
        >>> class MyModel(ModelBuilder):
        >>>     ...
        >>> name = './mymodel.nc'
        >>> imported_model = MyModel.load(name)
        """
        filepath = Path(str(fname))
        idata = az.from_netcdf(filepath)
        # needs to be converted, because json.loads was changing tuple to list
        model_config = cls._model_config_formatting(
            json.loads(idata.attrs["model_config"])
        )
        model = cls(
            model_config=model_config,
            sampler_config=json.loads(idata.attrs["sampler_config"]),
        )
        model.idata = idata
        dataset = idata.fit_data.to_dataframe()
        X = dataset.drop(columns=[model.output_var])
        y = dataset[model.output_var]
        X = pl.from_pandas(X)
        y = pl.from_pandas(y)
        model.build_model(X, y)
        # All previously used data is in idata.

        if model.id != idata.attrs["id"]:
            raise ValueError(
                f"The file '{fname}' does not contain an inference data of the same model or configuration as '{cls._model_type}'"
            )

        return model

Then the reproducible script:

import polars as pl
import polars.selectors as cs
import pymc as pm
import arviz as az
import numpy as np
from multilevel_model import MultilevelModel

df = pl.DataFrame(
    {
        "ClientID": np.arange(0, 100),
        "PatientGender": np.random.choice(['M', 'F'], size=100),
        "UnitDescription": np.random.choice(['a', 'b'], size=100),
        "AlternateCareLevel": np.random.choice(['a', 'b'], size=100),
        "age": np.random.normal(size=100),
        "Time": np.random.normal(size=100),
        "hematocrit-istat": np.random.normal(size=100),
        "troponin i": np.random.normal(size=100),
        "urine amylase": np.random.normal(size=100),
        "absolute myelocytes": np.random.normal(size=100),
        "urine urobilinogin": np.random.normal(size=100),
        "urine sodium": np.random.normal(size=100),
        "c reactive protein": np.random.normal(size=100),
        "absolute cd8+ cells": np.random.normal(size=100),
        "absolute cd4+ cells": np.random.normal(size=100),
        "absolute eosinophils": np.random.normal(size=100),
        "absolute lymphocytes": np.random.normal(size=100),
        "absolute neutrophils": np.random.normal(size=100),
        "cd8, absolute": np.random.normal(size=100),
        "cd4, absolute": np.random.normal(size=100),
        "min_VitalsDescription_BP - Systolic": np.random.normal(size=100),
        "max_VitalsDescription_BP - Systolic": np.random.normal(size=100),
        "min_VitalsDescription_Temperature": np.random.normal(size=100),
        "max_VitalsDescription_Temperature": np.random.normal(size=100),
        "min_VitalsDescription_Respiration": np.random.normal(size=100),
        "max_VitalsDescription_Respiration": np.random.normal(size=100),
        "min_VitalsDescription_Pulse": np.random.normal(size=100),
        "max_VitalsDescription_Pulse": np.random.normal(size=100),
        "min_VitalsDescription_Blood Sugar": np.random.normal(size=100),
        "max_VitalsDescription_Blood Sugar": np.random.normal(size=100),
        "min_VitalsDescription_Pain Level": np.random.normal(size=100),
        "max_VitalsDescription_Pain Level": np.random.normal(size=100),
        "transferred": np.random.binomial(1, 0.5, size=100)
    }
)

model = MultilevelModel()
model.fit(
    X = df.drop("transferred"),
    y = df['transferred']
)

model.save("./multilevel_model.nc")
model.predict_posterior(
    X_pred =  df.drop("Transferred"),
    var_names = ['p', 'likelihood']
)
az.hdi(model.idata.posterior_predictive.p, hdi_prob=0.8)

model2 = MultilevelModel.load("./multilevel_model.nc")
model2.predict_posterior(
    X_pred=df.drop("Transferred"),
    var_names = ['p', 'likelihood']
)
az.hdi(model2.idata.posterior_predictive.p, hdi_prob=0.8)

A couple of things that I discovered while putting this example together. On MacOS this works as expected, however, on windows and linux the posterior predictive samples are way off. I noticed that when sampling the ppc from the loaded model in MacOS it samples only [likelihood] however on windows and linux it samples [gamma_care, gamma_unit, likelihood] and the ppc is way different than the unloaded fitted model

EDIT:
Actually, there is no difference between the 3 OS. The reproducible example works as expected where the loaded model samples only [likelihood]. However, with the real data the loaded model samples [gamma_care, gamma_unit, likelihood] while the fitted (not loaded) model also only samples [likelihood]. I am not sure why this is happening. There are no missing data in my real data just as it is in the repro example.

Can you provide a minimal reproducible example?

@ricardoV94 Thank you for you response and sorry for the late reply. It took me a bit of time to produce the example. I added it to my original question above.

I figured out the problem. The data for the variables gamma_unit and gamma_care had null values that were being imputed in the _generate_and_preprocess_model_data() method in the ModelBuilder class. However, the data was being saved in the class prior to the imputation. This was leaking the nulls into the fit_data group in the idata object. Imputing the data before passing it to the ModelBuilder class fixes the issue.

2 Likes