Indexing pm.Deterministic prior to inference

I would like to know how I can index explanatory variables wrapped in deterministic containers, as well as observed values, prior to inference. Let’s assume I want to perform inference on y_scaled < 0.2 in the code below, along with all the corresponding entries in our observations. While it would be straightforward to handle this before setting up the model, I want to do it afterward, since certain parameters are based on timesteps. The parametrization can be quite complex, so I am interested in understanding if it’s possible to index variables after the model has been set up.

Consider the code below i found in this repo: juanitorduz (Juan Orduz) · GitHub

import pytensor.tensor as pt
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import pymc.sampling_jax
import seaborn as sns
from scipy.stats import pearsonr
from sklearn.preprocessing import MaxAbsScaler
import xarray as xr
import pymc.sampling_jax as jax
from pymc import HalfCauchy, Model, Normal, sample
def main():
    data_path = "https://raw.githubusercontent.com/juanitorduz/website_projects/master/data/ktr_data.csv"

    data_df = pd.read_csv(data_path, parse_dates=["date"])
    columns_to_keep = ["index", "date", "year", "month", "dayofyear", "z", "y"]

    df = data_df[columns_to_keep].copy()
    t = (df.index - df.index.min()) / (df.index.max() - df.index.min())

    n_order = 7
    periods = df["dayofyear"] / 365.25
    fourier_features = pd.DataFrame(
        {
            f"{func}_order_{order}": getattr(np, func)(2 * np.pi * periods * order)
            for order in range(1, n_order + 1)
            for func in ("sin", "cos")
        }
    )

    date = df["date"].to_numpy()
    date_index = df.index
    y = df["y"].to_numpy()
    z = df["z"].to_numpy()
    t = t.values
    n_obs = y.size

    endog_scaler = MaxAbsScaler()
    endog_scaler.fit(y.reshape(-1, 1))
    y_scaled = endog_scaler.transform(y.reshape(-1, 1)).flatten()

    channel_scaler = MaxAbsScaler()
    channel_scaler.fit(z.reshape(-1, 1))
    z_scaled = channel_scaler.transform(z.reshape(-1, 1)).flatten()

    coords = {"date": date, "fourier_mode": np.arange(2 * n_order)}

    with pm.Model(coords=coords) as base_model:
        # --- coords ---
        base_model.add_coord(name="dat", values=date, mutable=True)
        base_model.add_coord(name="fourier_mode", values=np.arange(2 * n_order), mutable=False)

        # --- data containers ---
        z_scaled_ = pm.MutableData(name="z_scaled", value=z_scaled, dims="date")
        y_scaled_ = pm.MutableData(name="y_scaled_", value=y_scaled, dims="date")

        # --- priors ---
        ## intercept
        a = pm.Normal(name="a", mu=0, sigma=4)
        ## trend
        b_trend = pm.Normal(name="b_trend", mu=0, sigma=2)
        ## seasonality
        b_fourier = pm.Laplace(name="b_fourier", mu=0, b=2, dims="fourier_mode")
        ## regressor
        b_z = pm.HalfNormal(name="b_z", sigma=2)
        ## standard deviation of the normal likelihood
        sigma = pm.HalfNormal(name="sigma", sigma=0.5)
        # degrees of freedom of the t distribution
        nu = pm.Gamma(name="nu", alpha=25, beta=2)

        # --- model parametrization ---
        trend = pm.Deterministic(name="trend", var=a + b_trend * t, dims="date")
        seasonality = pm.Deterministic(
            name="seasonality", var=pm.math.dot(fourier_features, b_fourier), dims="date"
        )
        z_effect = pm.Deterministic(name="z_effect", var=b_z * z_scaled_, dims="date")
        mu = pm.Deterministic(name="mu", var=trend + seasonality + z_effect, dims="date")

        ### here i would like to index observed and mu

        # --- likelihood ---
        pm.StudentT(name="likelihood", nu=nu, mu=mu, sigma=sigma, observed=y_scaled_, dims="date")

        # --- prior samples ---
        base_model_prior_predictive = pm.sample_prior_predictive()


    with base_model:
        base_model_trace = jax.sample_blackjax_nuts(
            nuts_sampler="",
            draws=6_000,
            chains=4,
            idata_kwargs={"log_likelihood": True},
        )
        base_model_posterior_predictive = pm.sample_posterior_predictive(
            trace=base_model_trace
        )

    print(az.summary(
    data=base_model_trace,
    var_names=["a", "b_trend", "b_z", "sigma", "nu"],
))
    print("-.")

if __name__ == '__main__':
    main()

Lets say i solely want to conduct the inference on y_scaled < 0.2 and the corresponding observations BUT i want to do this filtering after i have set up the model with all interactions between variables and parameters, how can i achieve this?

You can index PyMC variables just like numpy arrays, so something like mu[y_scaled < 0.2] will work fine.

Where you will hit problems is in the observed data, which must be root notes (raw data with no operations). You can avoid this by dumping the pm.Data wrapper, and directly passing observed = y_scaled[y_scaled < 0.2].

Can you say more about why you want to do all this computation then throw it away?

1 Like

As for your question: mainly due to stupid coding and saving time… :), thanks for answering, i noticed that this does not work with certain samplers though.