I would like to know how I can index explanatory variables wrapped in deterministic containers, as well as observed values, prior to inference. Let’s assume I want to perform inference on y_scaled < 0.2
in the code below, along with all the corresponding entries in our observations. While it would be straightforward to handle this before setting up the model, I want to do it afterward, since certain parameters are based on timesteps. The parametrization can be quite complex, so I am interested in understanding if it’s possible to index variables after the model has been set up.
Consider the code below i found in this repo: juanitorduz (Juan Orduz) · GitHub
import pytensor.tensor as pt
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import pymc.sampling_jax
import seaborn as sns
from scipy.stats import pearsonr
from sklearn.preprocessing import MaxAbsScaler
import xarray as xr
import pymc.sampling_jax as jax
from pymc import HalfCauchy, Model, Normal, sample
def main():
data_path = "https://raw.githubusercontent.com/juanitorduz/website_projects/master/data/ktr_data.csv"
data_df = pd.read_csv(data_path, parse_dates=["date"])
columns_to_keep = ["index", "date", "year", "month", "dayofyear", "z", "y"]
df = data_df[columns_to_keep].copy()
t = (df.index - df.index.min()) / (df.index.max() - df.index.min())
n_order = 7
periods = df["dayofyear"] / 365.25
fourier_features = pd.DataFrame(
{
f"{func}_order_{order}": getattr(np, func)(2 * np.pi * periods * order)
for order in range(1, n_order + 1)
for func in ("sin", "cos")
}
)
date = df["date"].to_numpy()
date_index = df.index
y = df["y"].to_numpy()
z = df["z"].to_numpy()
t = t.values
n_obs = y.size
endog_scaler = MaxAbsScaler()
endog_scaler.fit(y.reshape(-1, 1))
y_scaled = endog_scaler.transform(y.reshape(-1, 1)).flatten()
channel_scaler = MaxAbsScaler()
channel_scaler.fit(z.reshape(-1, 1))
z_scaled = channel_scaler.transform(z.reshape(-1, 1)).flatten()
coords = {"date": date, "fourier_mode": np.arange(2 * n_order)}
with pm.Model(coords=coords) as base_model:
# --- coords ---
base_model.add_coord(name="dat", values=date, mutable=True)
base_model.add_coord(name="fourier_mode", values=np.arange(2 * n_order), mutable=False)
# --- data containers ---
z_scaled_ = pm.MutableData(name="z_scaled", value=z_scaled, dims="date")
y_scaled_ = pm.MutableData(name="y_scaled_", value=y_scaled, dims="date")
# --- priors ---
## intercept
a = pm.Normal(name="a", mu=0, sigma=4)
## trend
b_trend = pm.Normal(name="b_trend", mu=0, sigma=2)
## seasonality
b_fourier = pm.Laplace(name="b_fourier", mu=0, b=2, dims="fourier_mode")
## regressor
b_z = pm.HalfNormal(name="b_z", sigma=2)
## standard deviation of the normal likelihood
sigma = pm.HalfNormal(name="sigma", sigma=0.5)
# degrees of freedom of the t distribution
nu = pm.Gamma(name="nu", alpha=25, beta=2)
# --- model parametrization ---
trend = pm.Deterministic(name="trend", var=a + b_trend * t, dims="date")
seasonality = pm.Deterministic(
name="seasonality", var=pm.math.dot(fourier_features, b_fourier), dims="date"
)
z_effect = pm.Deterministic(name="z_effect", var=b_z * z_scaled_, dims="date")
mu = pm.Deterministic(name="mu", var=trend + seasonality + z_effect, dims="date")
### here i would like to index observed and mu
# --- likelihood ---
pm.StudentT(name="likelihood", nu=nu, mu=mu, sigma=sigma, observed=y_scaled_, dims="date")
# --- prior samples ---
base_model_prior_predictive = pm.sample_prior_predictive()
with base_model:
base_model_trace = jax.sample_blackjax_nuts(
nuts_sampler="",
draws=6_000,
chains=4,
idata_kwargs={"log_likelihood": True},
)
base_model_posterior_predictive = pm.sample_posterior_predictive(
trace=base_model_trace
)
print(az.summary(
data=base_model_trace,
var_names=["a", "b_trend", "b_z", "sigma", "nu"],
))
print("-.")
if __name__ == '__main__':
main()
Lets say i solely want to conduct the inference on y_scaled < 0.2 and the corresponding observations BUT i want to do this filtering after i have set up the model with all interactions between variables and parameters, how can i achieve this?