Yes, it’s a bit custom. Basically, you create an HSGP for all channels then you mask the channels you want. e.g:
import numpy as npimport xarray as xr
import pymc as pm
import pytensor.tensor as pt
from typing import Self
from pymc_marketing.mmm import SoftPlusHSGP
from pymc_marketing.mmm.hsgp import HSGPBase
class MaskedHSGP(HSGPBase):
"""
HSGP with channel masking - masked channels get constant 1.0.Parameters
----------
hsgp : SoftPlusHSGP
The underlying HSGP instance. Must have 'channel' in dims.
mask : xr.DataArray
Boolean mask with dims matching the non-time dims of hsgp.
True = use HSGP, False = constant 1.0 (no time-varying effect).
Examples
--------
>>> from pymc_marketing.mmm import SoftPlusHSGP, MMM
>>> import xarray as xr
>>>
>>> channels = ["TV", "Radio", "Digital", "Print"]
>>> # Only TV and Radio have time-varying effects
>>> mask = xr.DataArray(
... [True, True, False, False],
... dims=["channel"],
... coords={"channel": channels},
... )
>>>
>>> hsgp = SoftPlusHSGP.parameterize_from_data(
... X=np.arange(52),
... dims=("date", "channel"),
... )
>>>
>>> masked_hsgp = MaskedHSGP(hsgp=hsgp, mask=mask)
>>>
>>> mmm = MMM(
... date_column="date",
... channel_columns=channels,
... adstock=...,
... saturation=...,
... time_varying_media=masked_hsgp,
... )
"""
def __init__(self, hsgp: SoftPlusHSGP, mask: xr.DataArray):
self.hsgp = hsgp
self.mask = mask
self._validate()
def _validate(self):
# Ensure 'channel' is in hsgp dims
if "channel" not in self.hsgp.dims:
raise ValueError(
"HSGP must have 'channel' in dims for per-channel masking. "
f"Got dims: {self.hsgp.dims}"
)
# Ensure mask dims align with hsgp non-time dims
hsgp_non_time_dims = tuple(d for d in self.hsgp.dims if d != "date")
if tuple(self.mask.dims) != hsgp_non_time_dims:
raise ValueError(
f"mask dims {self.mask.dims} must match HSGP non-time dims {hsgp_non_time_dims}"
)
# Delegate HSGPBase properties
@property
def m(self):
return self.hsgp.m
@property
def dims(self):
return self.hsgp.dims
@property
def X(self):
return self.hsgp.X
@property
def X_mid(self):
return self.hsgp.X_mid
@property
def transform(self):
return self.hsgp.transform
@property
def demeaned_basis(self):
return self.hsgp.demeaned_basis
def register_data(self, X) -> Self:
"""Register time data - delegates to underlying HSGP."""
self.hsgp.register_data(X)
return self
@staticmethod
def deterministics_to_replace(name: str) -> list[str]:
"""Deterministics to replace for out-of-sample predictions."""
return SoftPlusHSGP.deterministics_to_replace(f"{name}_raw")
def create_variable(self, name: str) -> pt.TensorVariable:
"""Create the masked HSGP variable.
Returns HSGP values for unmasked channels, 1.0 for masked channels.
"""
# Create raw HSGP for all channels
raw_hsgp = self.hsgp.create_variable(f"{name}_raw")
# Build mask tensor with proper broadcasting
# The mask should broadcast against all dims except 'date'
mask_tensor = pt.as_tensor_variable(self.mask.values)
# pt.where: where mask is True, use HSGP; where False, use 1.0
masked_result = pt.where(mask_tensor, raw_hsgp, 1.0)
return pm.Deterministic(name, masked_result, dims=self.dims)
def to_dict(self) -> dict:
"""Serialize to dictionary for model persistence."""
return {
"class": "MaskedHSGP",
"hsgp": {
**self.hsgp.to_dict(),
"hsgp_class": self.hsgp.__class__.__name__,
},
"mask": self.mask.values.tolist(),
"mask_dims": list(self.mask.dims),
"mask_coords": {
dim: list(self.mask.coords[dim].values)
for dim in self.mask.dims
},
}
@classmethod
def from_dict(cls, data: dict) -> "MaskedHSGP":
"""Deserialize from dictionary."""
from pymc_marketing.mmm.hsgp import hsgp_from_dict
hsgp = hsgp_from_dict(data["hsgp"])
mask = xr.DataArray(
data["mask"],
dims=data["mask_dims"],
coords=data.get("mask_coords"),
)
return cls(hsgp=hsgp, mask=mask)
Then you can use with the MMM class in multidimensional.
import numpy as np
import pandas as pd
import xarray as xr
from pymc_marketing.mmm import SoftPlusHSGP, GeometricAdstock, LogisticSaturation
from pymc_marketing.mmm.multidimensional import MMM
# Your data setup
channels = ["TV", "Radio", "Digital", "Print"]
n_dates = 104 # 2 years of weekly data
# Define which channels should have time-varying effects
# True = HSGP time-varying, False = constant (no time variation)
mask = xr.DataArray(
[True, True, False, False], # Only TV and Radio vary over time
dims=["channel"],
coords={"channel": channels},
)
# Create the base HSGP with channel dimension
base_hsgp = SoftPlusHSGP.parameterize_from_data(
X=np.arange(n_dates),
dims=("date", "channel"), # Must include 'channel' for per-channel effects
)
# Wrap with masking
masked_hsgp = MaskedHSGP(hsgp=base_hsgp, mask=mask)
# Create MMM with masked time-varying media
mmm = MMM(
date_column="date",
target_column="target",
channel_columns=channels,
adstock=GeometricAdstock(l_max=8),
saturation=LogisticSaturation(),
time_varying_media=masked_hsgp, # Pass the masked HSGP
)
# Fit as usual
# mmm.fit(X, y)
By setting the multiplier to 1.0 for channels you don’t want to vary, their contribution remains unchanged over time. For your Search channel with the HSGP, the multiplier can fluctuate, capturing how the effectiveness of Search changes throughout the year (e.g., seasonal patterns, market changes).
Note: I didn’t test the code, but should be around those lines!
Hope this helps! 