Tanh saturation severely increasing sampling runtime

Hi All,

I am relatively new to pymc and pytensor but I am having runtime issues when implementing the pytensor function tanh into my model, below is a simplifed version of my model which outlines the issue. It takes roughly ~50 seconds to run but when implemented in the larger model this runtime becomes a problem

def tanh_saturation(x, b, c):
    return b * pt.tanh(x / (b * c))

# generating fake data
avg_spend = 100000
spend = np.random.normal(avg_spend, avg_spend/10, size = (15,8))
date_range = pd.date_range(start='2022-01-01', periods=15, freq='MS')
states = ['ACT', 'NSW', 'NT', 'QLD', 'SA', 'TAS', 'VIC', 'WA']
spend_array = xr.DataArray(spend, coords=[date_range, states], dims=['date', 'state'])
ceiling = np.random.normal(avg_spend*2, avg_spend/5, size = 8)
cpa = 0.34
sales_actual = ceiling * xr.apply_ufunc(np.tanh, spend_array / (ceiling*cpa))

coords = {'date': date_range, 'state': states}
with pm.Model(coords=coords) as model1:
    spend_arr = pm.MutableData("spend_arr", spend_array, dims = ['date', 'state'])
    sales_arr = pm.MutableData("sales_arr", sales_actual, dims = ['date', 'state'])
    ceiling = pm.Exponential('ceiling', scale = avg_spend, dims = 'state')
    cpa = pm.Exponential('cpa', scale = 1)
    expected_sales = pm.Deterministic('expected_sales', 
                                      var = tanh_saturation(spend_arr, ceiling.dimshuffle("x",0), cpa),
                                      dims = ['date', 'state'])
    sales_noise = pm.Exponential('sales_noise', scale = avg_spend/8)
    sales = pm.Normal('sales', mu = expected_sales, sigma = sales_noise, observed = sales_arr)
    trace = pm.sample(draws = 1000,
                      tune = 5000,
                      chains = 1)

If anyone could provide a fix or workaround or any other tips to decrease runtime that would be greatly appreciated.

Thanks a lot