Hi there, I’ve been developing a model for user based repetition wear and restoration effects in online advertising. I’m having some confusion on the dimensions mismatch error within pytensor.scan:
Here’s some of the code excerpt:
coords = {"user_ids": imps_df.index, "weeks": imps_df.columns}
with pm.Model(coords=coords) as model1:
## OBSERVED IMPRESSIONS
imps = pm.MutableData("imps", imps_df, dims={"user_ids", "weeks"})
wks_since_last_imp = pm.MutableData("wks_since_last_imp", wks_since_last_imp_df, dims={"user_ids", "weeks"})
...
###### REPETITION WEAR AND RESTORATION
#repetition wear
delta_1 = pm.Uniform("delta_1", lower=0, upper=1)
#restoration effects - Universal
rho_1 = pm.Uniform("rho_1", lower=0, upper=1)
#geometric decay rate
adstock_decay = pm.Uniform("adstock_decay", lower=.4, upper=.8)
## Add restoration and wear to adstock transformation
R1 = pm.Deterministic("r1", (rho_1*wks_since_last_imp)/(1+rho_1*wks_since_last_imp))
wearout = pm.Deterministic("wearout", 1 - delta_1**imps)
eit = pm.Deterministic("eit", advertising_effect*(1 - wearout + R1 * wearout))
# Recursively Compute Adstock Efficiently via pytensor trace
adstock0 = pt.zeros(nusers)
# def update_adstoc(eit, prior_result, adstock_decay):
adstock, __ = pytensor.scan(fn=lambda eit, prior_result, adstock_decay: prior_result * adstock_decay + eit,
outputs_info=adstock0, #initialize results adstock is equal to eit at wk 1
sequences = [eit],
non_sequences=[adstock_decay],
n_steps=nweeks)
My model graph lines up with what I’m envisioning:
But am getting the error:
ValueError: Input dimension mismatch: (input[0].shape[0] = 10000, input[2].shape[0] = 45
I believe this is coming from the prior_value supplied through the scan function. I want to make sure the function is looking back to the adstock from previous weeks not of previous users. Input dataframe was indexed by user_id with columns for successive weeks.
Any thoughts? Do I just need to transpose the input features to the scan function? Appreciate the active community and help moving further along.