@mgilbert you shouldn’t need that approach in your case. You are observing a bug that still exists in PyTensor today, related to the gradient of a discrete switch.
The gradient is being called just to figure out what variables can be called by Nuts or must be given a non-gradient sampler. It’s a bit silly but you can overcome the problem by rewriting the switch like this:
import scipy
import pymc as pm
import pytensor
import pytensor.tensor as at
n = 200
dates = np.arange(n)
rates = scipy.stats.poisson(mu=3).rvs(n)
sign = scipy.stats.bernoulli(p=0.3).rvs(n) * 2 - 1
obs = rates * sign
with pm.Model() as switch_model:
switch_model.add_coord("date", dates, mutable=True)
rate_lambda = pm.Normal("rate_lambda", sigma=1)
obs_sigma = pm.HalfNormal("obs_sigma", sigma=1)
arrival_intensity = pm.Poisson("arrival_intensity", pm.math.abs(rate_lambda), dims="date")
arrivals = pm.Deterministic(
"arrivals",
arrival_intensity * ((rate_lambda >= 0) * 1 + (rate_lambda < 0) * -1),
dims="date"
)
arrivals_obs = pm.Normal("arrivals_obs", mu=arrivals, sigma=obs_sigma, dims="date", observed=obs)
with switch_model:
pm.sample()