Combining models using switch

@mgilbert you shouldn’t need that approach in your case. You are observing a bug that still exists in PyTensor today, related to the gradient of a discrete switch.

The gradient is being called just to figure out what variables can be called by Nuts or must be given a non-gradient sampler. It’s a bit silly but you can overcome the problem by rewriting the switch like this:

import scipy
import pymc as pm
import pytensor
import pytensor.tensor as at

n = 200
dates = np.arange(n)
rates = scipy.stats.poisson(mu=3).rvs(n)
sign = scipy.stats.bernoulli(p=0.3).rvs(n) * 2 - 1
obs = rates * sign

with pm.Model() as switch_model:
    switch_model.add_coord("date", dates, mutable=True)

    rate_lambda = pm.Normal("rate_lambda", sigma=1)
    obs_sigma = pm.HalfNormal("obs_sigma", sigma=1)

    arrival_intensity = pm.Poisson("arrival_intensity", pm.math.abs(rate_lambda), dims="date")

    arrivals = pm.Deterministic(
        "arrivals",
        arrival_intensity * ((rate_lambda >= 0) * 1 + (rate_lambda < 0) * -1),
        dims="date"
    )
    arrivals_obs = pm.Normal("arrivals_obs", mu=arrivals, sigma=obs_sigma, dims="date", observed=obs)

with switch_model:
    pm.sample()        
1 Like