Hi Chris,
I have revised my code based on your suggestion, here is my main code now:
with pm.Model() as model:
### prior for unknown parameters
# prior for lambda
BoundedNormal_lam = pm.Bound(pm.Normal, lower=lam_lbd, upper=lam_ubd)
lam = BoundedNormal_lam('lam', mu=lam_mean, sd=lam_sd, shape = 8)
# prior for gamma
BoundedNormal_gamma = pm.Bound(pm.Normal, lower=0.2, upper=0.8)
gamma = BoundedNormal_gamma('gamma', mu=0.5, sd=0.1)
# prior for beta
BoundedNormal_beta = pm.Bound(pm.Normal, lower=b_lbd, upper=b_ubd)
beta = BoundedNormal_beta('beta', mu=b_mean, sd=b_sd, shape = 8)
# prior for v
v = pm.Uniform('v', lower = 1e-2, upper =1e2)
# decayed promotion: pde * lam_j^{t-l}
xxdat['ad'] = (xxdat['pde_norm'].values) * (lam[xxdat['chnl_cd'].values-1]**(xxdat['gap_nrx_promo'].values))
# Sum up of decayed promotion for each NRx + each channel: \sum_t pde * lam_j^{t-l}
yydat = xxdat.groupby(['spcl_st_cd','nrx_wk', 'chnl_cd', 'nrx_norm'], as_index= False).agg({'ad':test_sum})
yydat.rename(columns = {'ad': 'ad_sum'}, inplace = True)
yydat['ad_sum_r'] = (yydat['ad_sum'] ** gamma) # Add adstock diminishing factor (\sum_t pde * lam_j^{t-l})^r
# normalize each channel's adstock
yydat['ad_norm'] = (yydat['ad_sum_r'].values)/(yydat.groupby('chnl_cd', as_index=False).agg({'ad_sum_r':test_mean})['ad_sum_r'][yydat['chnl_cd'].values-1]).values
# sum up adstock multiplying impact factor over channels
yydat['bx'] = (beta[yydat['chnl_cd'].values -1]*(yydat['ad_norm']))
zzdat = yydat.groupby(['spcl_st_cd', 'nrx_wk', 'nrx_norm'], as_index = False).agg({'bx':test_sum})
# Get the final mu
mu = test_mean(zzdat['bx'])
### likelihood (sampling distribution) of observations
Y_obs = pm.Normal('Y_obs', mu=mu, sd = v, observed = zzdat['nrx_norm'])
And here are these two test functions:
def test_sum(x):
return tt.sum(x.to_list())
def test_mean(x):
return tt.mean(x.to_list())
It can run successfully with input data having row number < 256. But once it is larger than it, i got this error:
fatal error: bracket nesting level exceeded maximum of 256.
I have tried changed the maximum number to some larger number but it give some other errors. So i think still i need to revise my code. Do you have any more idea about my case? Thank you!