Time Varying Coefficients for a Single Channel

Is it possible to set up time-varying coefficients for a single channel when using the pymc_marketing.MMM class? For example, I just want to see if my coefficient for a search channel changes over time while keeping all other media channels the same.

Any examples and help is very appreciated!

Yes, it’s a bit custom. Basically, you create an HSGP for all channels then you mask the channels you want. e.g:

import numpy as npimport xarray as xr
import pymc as pm
import pytensor.tensor as pt
from typing import Self

from pymc_marketing.mmm import SoftPlusHSGP
from pymc_marketing.mmm.hsgp import HSGPBase

class MaskedHSGP(HSGPBase):
"""
HSGP with channel masking - masked channels get constant 1.0.Parameters
----------
hsgp : SoftPlusHSGP
    The underlying HSGP instance. Must have 'channel' in dims.
mask : xr.DataArray
    Boolean mask with dims matching the non-time dims of hsgp.
    True = use HSGP, False = constant 1.0 (no time-varying effect).

Examples
--------
>>> from pymc_marketing.mmm import SoftPlusHSGP, MMM
>>> import xarray as xr
>>> 
>>> channels = ["TV", "Radio", "Digital", "Print"]
>>> # Only TV and Radio have time-varying effects
>>> mask = xr.DataArray(
...     [True, True, False, False],
...     dims=["channel"],
...     coords={"channel": channels},
... )
>>> 
>>> hsgp = SoftPlusHSGP.parameterize_from_data(
...     X=np.arange(52),
...     dims=("date", "channel"),
... )
>>> 
>>> masked_hsgp = MaskedHSGP(hsgp=hsgp, mask=mask)
>>> 
>>> mmm = MMM(
...     date_column="date",
...     channel_columns=channels,
...     adstock=...,
...     saturation=...,
...     time_varying_media=masked_hsgp,
... )
"""

def __init__(self, hsgp: SoftPlusHSGP, mask: xr.DataArray):
    self.hsgp = hsgp
    self.mask = mask
    self._validate()
    
def _validate(self):
    # Ensure 'channel' is in hsgp dims
    if "channel" not in self.hsgp.dims:
        raise ValueError(
            "HSGP must have 'channel' in dims for per-channel masking. "
            f"Got dims: {self.hsgp.dims}"
        )
    # Ensure mask dims align with hsgp non-time dims
    hsgp_non_time_dims = tuple(d for d in self.hsgp.dims if d != "date")
    if tuple(self.mask.dims) != hsgp_non_time_dims:
        raise ValueError(
            f"mask dims {self.mask.dims} must match HSGP non-time dims {hsgp_non_time_dims}"
        )

# Delegate HSGPBase properties
@property
def m(self):
    return self.hsgp.m

@property
def dims(self):
    return self.hsgp.dims

@property
def X(self):
    return self.hsgp.X

@property
def X_mid(self):
    return self.hsgp.X_mid

@property
def transform(self):
    return self.hsgp.transform

@property
def demeaned_basis(self):
    return self.hsgp.demeaned_basis

def register_data(self, X) -> Self:
    """Register time data - delegates to underlying HSGP."""
    self.hsgp.register_data(X)
    return self

@staticmethod
def deterministics_to_replace(name: str) -> list[str]:
    """Deterministics to replace for out-of-sample predictions."""
    return SoftPlusHSGP.deterministics_to_replace(f"{name}_raw")

def create_variable(self, name: str) -> pt.TensorVariable:
    """Create the masked HSGP variable.
    
    Returns HSGP values for unmasked channels, 1.0 for masked channels.
    """
    # Create raw HSGP for all channels
    raw_hsgp = self.hsgp.create_variable(f"{name}_raw")
    
    # Build mask tensor with proper broadcasting
    # The mask should broadcast against all dims except 'date'
    mask_tensor = pt.as_tensor_variable(self.mask.values)
    
    # pt.where: where mask is True, use HSGP; where False, use 1.0
    masked_result = pt.where(mask_tensor, raw_hsgp, 1.0)
    
    return pm.Deterministic(name, masked_result, dims=self.dims)

def to_dict(self) -> dict:
    """Serialize to dictionary for model persistence."""
    return {
        "class": "MaskedHSGP",
        "hsgp": {
            **self.hsgp.to_dict(),
            "hsgp_class": self.hsgp.__class__.__name__,
        },
        "mask": self.mask.values.tolist(),
        "mask_dims": list(self.mask.dims),
        "mask_coords": {
            dim: list(self.mask.coords[dim].values) 
            for dim in self.mask.dims
        },
    }

@classmethod
def from_dict(cls, data: dict) -> "MaskedHSGP":
    """Deserialize from dictionary."""
    from pymc_marketing.mmm.hsgp import hsgp_from_dict
    
    hsgp = hsgp_from_dict(data["hsgp"])
    mask = xr.DataArray(
        data["mask"],
        dims=data["mask_dims"],
        coords=data.get("mask_coords"),
    )
    return cls(hsgp=hsgp, mask=mask)

Then you can use with the MMM class in multidimensional.

import numpy as np
import pandas as pd
import xarray as xr
from pymc_marketing.mmm import SoftPlusHSGP, GeometricAdstock, LogisticSaturation
from pymc_marketing.mmm.multidimensional import MMM

# Your data setup
channels = ["TV", "Radio", "Digital", "Print"]
n_dates = 104  # 2 years of weekly data

# Define which channels should have time-varying effects
# True = HSGP time-varying, False = constant (no time variation)
mask = xr.DataArray(
    [True, True, False, False],  # Only TV and Radio vary over time
    dims=["channel"],
    coords={"channel": channels},
)

# Create the base HSGP with channel dimension
base_hsgp = SoftPlusHSGP.parameterize_from_data(
    X=np.arange(n_dates),
    dims=("date", "channel"),  # Must include 'channel' for per-channel effects
)

# Wrap with masking
masked_hsgp = MaskedHSGP(hsgp=base_hsgp, mask=mask)

# Create MMM with masked time-varying media
mmm = MMM(
    date_column="date",
    target_column="target",
    channel_columns=channels,
    adstock=GeometricAdstock(l_max=8),
    saturation=LogisticSaturation(),
    time_varying_media=masked_hsgp,  # Pass the masked HSGP
)

# Fit as usual
# mmm.fit(X, y)

By setting the multiplier to 1.0 for channels you don’t want to vary, their contribution remains unchanged over time. For your Search channel with the HSGP, the multiplier can fluctuate, capturing how the effectiveness of Search changes throughout the year (e.g., seasonal patterns, market changes).

Note: I didn’t test the code, but should be around those lines!

Hope this helps! :fire:

1 Like

Thank you so much, this is super helpful! Unfortunately, when I try to pass masked_hsgp for time_varying_media I get an error informing me that the parameter will only take “bool” arguments. Do you know why?

You need to import from multidimensional the MMM class!

from pymc_marketing.mmm.multidimensional import MMM

Thank you! I was getting some errors triggered by pydantic - would this work instead? Otherwise if it’s a version issue (ex. ValueError: “MaskedHSGP” object has no field “hsgp”) I can make sure I update accordingly!

import xarray as xrimport pymc as pmimport pytensor.tensor as pt

from typing import Self from pydantic import ConfigDict, model_validator

from pymc_marketing.mmm import SoftPlusHSGPfrom pymc_marketing.mmm.hsgp import HSGPBase

class MaskedHSGP(HSGPBase):
model_config = ConfigDict(arbitrary_types_allowed=True)
hsgp: SoftPlusHSGP
mask: xr.DataArray

@model_validator(mode="before")
@classmethod
def \_populate_parent_fields(cls, data):
    """
    HSGPBase expected fields
    """
    if not isinstance(data, dict):
        return data

    h = data.get("hsgp")
    if h is None:
        return data

    # Populate potentially missing fields HSGP
    data.setdefault("m", h.m)
    data.setdefault("X", h.X)
    data.setdefault("X_mid", h.X_mid)
    data.setdefault("dims", h.dims)
    data.setdefault("transform", h.transform)
    data.setdefault("demeaned_basis", h.demeaned_basis)

    return data

@model_validator(mode="after")
def \_validate(self):
    if "channel" not in self.hsgp.dims:
        raise ValueError(
            "HSGP must have 'channel' in dims for per-channel masking. "
            f"Got dims: {self.hsgp.dims}"
        )

    hsgp_non_time_dims = tuple(d for d in self.hsgp.dims if d != "date")
    if tuple(self.mask.dims) != hsgp_non_time_dims:
        raise ValueError(
            f"mask dims {self.mask.dims} must match HSGP non-time dims {hsgp_non_time_dims}"
        )

    return self

def register_data(self, X) -> Self:
    self.hsgp.register_data(X)
    self.X = self.hsgp.X
    self.X_mid = self.hsgp.X_mid
    return self

@staticmethod
def deterministics_to_replace(name: str) -> list\[str\]:
    return SoftPlusHSGP.deterministics_to_replace(f"{name}\_raw")

def create_variable(self, name: str) -> pt.TensorVariable:
    raw_hsgp = self.hsgp.create_variable(f"{name}\_raw")
    mask_tensor = pt.shape_padleft(pt.as_tensor_variable(self.mask.values), 1)
    masked_result = pt.where(mask_tensor, raw_hsgp, 1.0)
    return pm.Deterministic(name, masked_result, dims=self.dims)

@classmethod
def from_dict(cls, data: dict) -> "MaskedHSGP":
    """Deserialize from dictionary."""
    from pymc_marketing.mmm.hsgp import hsgp_from_dict

    hsgp = hsgp_from_dict(data\["hsgp"\])
    mask = xr.DataArray(
        data\["mask"\],
        dims=data\["mask_dims"\],
        coords=data.get("mask_coords"),
    )
    return cls(hsgp=hsgp, mask=mask)

If helpful, I’m adding the package versions and full error traceback:

Package versions:
Numpy: 2.2.6
Xarray: 2025.9.0
pymc: 5.25.1
pytensor: 2.31.7
pymc_merketing; 0.17.1

The full error:
ValueError Traceback (most recent call last)
Cell In[40], line 1
----> 1 masked_hsgp = MaskedHSGP(hsgp=base_hsgp, mask=mask)

Cell In[35], line 50, in MaskedHSGP._init_(self, hsgp, mask)
49 def _init_(self, hsgp: SoftPlusHSGP, mask: xr.DataArray):
—> 50 self.hsgp = hsgp
51 self.mask = mask
52 self._validate()

File ~/PycharmProjects/MMMBuild_SD/Legacy Model Run Through/.venv/lib/python3.13/site-packages/pydantic/main.py:1032, in BaseModel._setattr_(self, name, value)
1030 setattr_handler(self, name, value)
1031 # if None is returned from _setattr_handler, the attribute was set directly
→ 1032 elif (setattr_handler := self._setattr_handler(name, value)) is not None:
1033 setattr_handler(self, name, value) # call here to not memo on possibly unknown fields
1034 self._pydantic_setattr_handlers_[name] = setattr_handler

File ~/PycharmProjects/MMMBuild_SD/Legacy Model Run Through/.venv/lib/python3.13/site-packages/pydantic/main.py:1079, in BaseModel._setattr_handler(self, name, value)
1076 elif name not in cls._pydantic_fields_:
1077 if cls.model_config.get(‘extra’) != ‘allow’:
1078 # TODO - matching error
→ 1079 raise ValueError(f’“{cls._name_}” object has no field “{name}”')
1080 elif attr is None:
1081 # attribute does not exist, so put it in extra
1082 self._pydantic_extra_[name] = value

ValueError: “MaskedHSGP” object has no field “hsgp”