Hi All,
I am relatively new to pymc and pytensor but I am having runtime issues when implementing the pytensor function tanh into my model, below is a simplifed version of my model which outlines the issue. It takes roughly ~50 seconds to run but when implemented in the larger model this runtime becomes a problem
def tanh_saturation(x, b, c):
return b * pt.tanh(x / (b * c))
# generating fake data
avg_spend = 100000
spend = np.random.normal(avg_spend, avg_spend/10, size = (15,8))
date_range = pd.date_range(start='2022-01-01', periods=15, freq='MS')
states = ['ACT', 'NSW', 'NT', 'QLD', 'SA', 'TAS', 'VIC', 'WA']
spend_array = xr.DataArray(spend, coords=[date_range, states], dims=['date', 'state'])
ceiling = np.random.normal(avg_spend*2, avg_spend/5, size = 8)
cpa = 0.34
sales_actual = ceiling * xr.apply_ufunc(np.tanh, spend_array / (ceiling*cpa))
coords = {'date': date_range, 'state': states}
with pm.Model(coords=coords) as model1:
spend_arr = pm.MutableData("spend_arr", spend_array, dims = ['date', 'state'])
sales_arr = pm.MutableData("sales_arr", sales_actual, dims = ['date', 'state'])
ceiling = pm.Exponential('ceiling', scale = avg_spend, dims = 'state')
cpa = pm.Exponential('cpa', scale = 1)
expected_sales = pm.Deterministic('expected_sales',
var = tanh_saturation(spend_arr, ceiling.dimshuffle("x",0), cpa),
dims = ['date', 'state'])
sales_noise = pm.Exponential('sales_noise', scale = avg_spend/8)
sales = pm.Normal('sales', mu = expected_sales, sigma = sales_noise, observed = sales_arr)
trace = pm.sample(draws = 1000,
tune = 5000,
chains = 1)
If anyone could provide a fix or workaround or any other tips to decrease runtime that would be greatly appreciated.
Thanks a lot