Combining models using switch

Hi all,
I have a model in which in which subjects can get one of two treatments. The distribution of the response depends on the treatment; treatment A results in a negative HalfNormal distributed response, and treatment B results in a positive TruncatedNormal response. In my model, I created a switch statement, and applied the correct model based on the treatment. The problem I have is that I cannot treat the output of the switch step as an observed value. In the model below, is there a way I can use the output of the switch statement as the observed data?

model = pm.Model()
with model:

trt1_idx = (df['treatment']==1).astype(int).values
trt2_idx  = (df['treatment']==2).astype(int).values

obs_data=df['response'].values

#Treatment 1
mu_age  = pm.Normal('mu_BLd', 20, 3) 
Bage      = pm.Normal('Bage', mu_age)

mu_wt  = pm.Normal('mu_wt', 150, 10) 
Bwt      = pm.Normal('Bwt'    , mu_wt)

rsp1_mu  = pm.Deterministic('rsp1_mu', (df['age'].values**mu_age ) * (df['weight'].values**Bwt) )  
rsp1         = -1*pm.HalfNormal('rsp1', rsp1_mu) 

# Treatment 2
mu_ht  = pm.Normal('mu_ht', 80, 5) 
Bht      = pm.Normal('Bht'    , mu_ht)

rsp2_mu  = pm.Deterministic('rsp2_mu', (df['height'].values**Bht) )  
rsp2         = pm.TruncatedNormal('rsp2', rsp2_mu, lower=0) 

#Combined Response
response=pm.math.switch(df['treatment']==1, rsp1[trt1_idx], rsp2[trt2_idx], shape = df.shape[0], observed=obs_data)

This is something we are trying to improve, and would be done via a CustomDist, like this:

import numpy as np
import pymc as pm
import pytensor.tensor as pt

def switch_dist(treatment, rsp1_mu, rsp2_mu, size):
    rsp1 = -1 * pm.HalfNormal.dist(rsp1_mu, size=size)
    rsp2 = pm.TruncatedNormal.dist(rsp2_mu, lower=0, size=size)
    return pm.math.switch(pt.eq(treatment, 0), rsp1, rsp2)
   
with pm.Model() as m:
    treatment = np.array([0, 0, 0, 1, 1, 1])
    obs_data = np.abs(np.random.normal(size=6))
    obs_data[:3] *= -1

    rsp1_mu = pm.HalfNormal("rsp1_mu")
    rsp2_mu = pm.HalfNormal('rsp2_mu')
    pm.CustomDist("llike", treatment, rsp1_mu, rsp2_mu, dist=switch_dist, observed=obs_data)

m.point_logps()  # Fails

Using PyMC ability to derive simple logprob expressions. The switch case is still very limited, so it is failing in this case (I am planning to improve it this week).

For now, the simplest option is to:

  1. Implement the CustomDist logp method directly, which should look something like:
import numpy as np
import pymc as pm
import pytensor.tensor as pt

def switch_dist(treatment, rsp1_mu, rsp2_mu, size):
    rsp1 = -1 * pm.HalfNormal.dist(rsp1_mu, size=size)
    rsp2 = pm.TruncatedNormal.dist(rsp2_mu, lower=0, size=size)
    return pm.math.switch(pt.eq(treatment, 0), rsp1, rsp2)

def switch_logp(value, treatment, rsp1_mu, rsp2_mu):
    rsp1 = -1 * pm.HalfNormal.dist(rsp1_mu)
    rsp2 = pm.TruncatedNormal.dist(rsp2_mu, lower=0)
    logp = pt.switch(
        pt.eq(treatment, 0),
        pm.logp(rsp1, value),
        pm.logp(rsp2, value),
    )
    return logp
    
with pm.Model() as m:
    treatment = np.array([0, 0, 0, 1, 1, 1])
    obs_data = np.abs(np.random.normal(size=6))
    obs_data[:3] *= -1

    rsp1_mu = pm.HalfNormal("rsp1_mu")
    rsp2_mu = pm.HalfNormal('rsp2_mu')
    pm.CustomDist("llike", treatment, rsp1_mu, rsp2_mu, dist=switch_dist, logp=switch_logp, observed=obs_data)

m.point_logps()
  1. Split your observations into two CustomDists/Likelihoods:
import numpy as np
import pymc as pm

def halfnormal_dist(rsp1_mu, size):
    return -1 * pm.HalfNormal.dist(rsp1_mu, size=size)
   
with pm.Model() as m:
    treatment = np.array([0, 0, 0, 1, 1, 1])
    obs_data = np.abs(np.random.normal(size=6))
    obs_data[:3] *= -1

    rsp1_mu = pm.HalfNormal("rsp1_mu")
    rsp2_mu = pm.HalfNormal('rsp2_mu')
    pm.CustomDist("llike1", rsp1_mu, dist=halfnormal_dist, observed=obs_data[treatment == 0])
    pm.TruncatedNormal("llike2", rsp2_mu, lower=0, observed=obs_data[treatment == 1])

m.point_logps()

At which point you can also avoid the HalfNormal CustomDist, by simply negating the observed data and using a vanilla HalfNormal directly.

2 Likes

Thank you for the quick reply. The first approach worked perfectly for me.

Not to hijack this thread, but I’m also running into a similar issue. I have an arrival/departure rate that can be positive or negative and I am attempting to model this as a Poisson process. The toy example below isn’t exactly what I am doing, but captures the spirit. This leads to an attribute error when I attempt to sample.

import numpy as np
import scipy
import pymc as pm
import aesara
import aesara.tensor as at

n = 200
dates = np.arange(n)
rates = scipy.stats.poisson(mu=3).rvs(n)
sign = scipy.stats.bernoulli(p=0.3).rvs(n)*2 - 1
obs = rates * sign


with pm.Model() as switch_model:
    switch_model.add_coord("date", dates, mutable=True)    
    
    rate_lambda = pm.Normal("rate_lambda", sigma=1)
    obs_sigma = pm.HalfNormal("obs_sigma", sigma=1)
    
    arrival_intensity = pm.Poisson("arrival_intensity", pm.math.abs(rate_lambda), dims="date")
    
    arrivals = pm.Deterministic(
        "arrivals",
         at.switch(
             rate_lambda >= 0,
             arrival_intensity,
             -arrival_intensity,
         ),
        dims="date"
    )
    arrivals_obs = pm.Normal("arrivals_obs", mu=arrivals, sigma=obs_sigma, dims="date", observed=obs)

with switch_model:
    pm.sample()

Which is giving the following error

python3.8/site-packages/aesara/tensor/elemwise.py in transform(r)
    619         def transform(r):
    620             # From a graph of ScalarOps, make a graph of Broadcast ops.
--> 621             if isinstance(r.type, (NullType, DisconnectedType)):
    622                 return r
    623             if r in scalar_inputs:

AttributeError: 'float' object has no attribute 'type'

I’m using the following versions and have some restrictions on being able to bump my versions. I’m wondering if it is possible to do something similar to what is outlined above using CustomDensity with my version restrictions?

pymc version: 4.2.1
aesara version: 2.8.6

@mgilbert you shouldn’t need that approach in your case. You are observing a bug that still exists in PyTensor today, related to the gradient of a discrete switch.

The gradient is being called just to figure out what variables can be called by Nuts or must be given a non-gradient sampler. It’s a bit silly but you can overcome the problem by rewriting the switch like this:

import scipy
import pymc as pm
import pytensor
import pytensor.tensor as at

n = 200
dates = np.arange(n)
rates = scipy.stats.poisson(mu=3).rvs(n)
sign = scipy.stats.bernoulli(p=0.3).rvs(n) * 2 - 1
obs = rates * sign

with pm.Model() as switch_model:
    switch_model.add_coord("date", dates, mutable=True)

    rate_lambda = pm.Normal("rate_lambda", sigma=1)
    obs_sigma = pm.HalfNormal("obs_sigma", sigma=1)

    arrival_intensity = pm.Poisson("arrival_intensity", pm.math.abs(rate_lambda), dims="date")

    arrivals = pm.Deterministic(
        "arrivals",
        arrival_intensity * ((rate_lambda >= 0) * 1 + (rate_lambda < 0) * -1),
        dims="date"
    )
    arrivals_obs = pm.Normal("arrivals_obs", mu=arrivals, sigma=obs_sigma, dims="date", observed=obs)

with switch_model:
    pm.sample()        
1 Like

I opened an issue in our repo: Gradient of discrete switch returns wrong types · Issue #331 · pymc-devs/pytensor · GitHub