Hi!
I had some success by inheriting from DelayedSaturatedMMM and overriding the build_model method.
Based on the original method I essentially added this in the prior specification to use delayed adstock if theta is specified in the model config:
if self.model_config.get("theta"):
self.theta_dist = self._get_distribution(dist=self.model_config["theta"])
self._adstock_function = 'delayed_adstock'
else:
self.theta_dist = None
self._adstock_function = 'geometric_adstock'
And this for specifying the actual transform function:
if self._adstock_function == 'geometric_adstock':
adstock_transform = geometric_adstock(
x=channel_data_,
alpha=alpha,
l_max=self.adstock_max_lag,
normalize=True,
axis=0,
)
elif self._adstock_function == 'delayed_adstock':
adstock_transform = delayed_adstock(
x=channel_data_,
alpha=alpha,
theta=theta,
l_max=self.adstock_max_lag,
normalize=True,
axis=0,
)
Finally, the adstock_transform must be used to replace the direct specification of the geometric_adstock in the deterministic channel_adstock:
channel_adstock = pm.Deterministic(
name="channel_adstock",
var=adstock_transform,
dims=("date", "channel"),
)
This appears to work, as indicated by the model graph (and it runs), but I have not tested in depth yet. I believe the class (or rather its parents) may have some other methods for which it must be ensured that the correct adstock transform is used, such as those for out of sample plotting, but I have not checked that yet.
Best Regards
Jonas
