Combining models using switch

This is something we are trying to improve, and would be done via a CustomDist, like this:

import numpy as np
import pymc as pm
import pytensor.tensor as pt

def switch_dist(treatment, rsp1_mu, rsp2_mu, size):
    rsp1 = -1 * pm.HalfNormal.dist(rsp1_mu, size=size)
    rsp2 = pm.TruncatedNormal.dist(rsp2_mu, lower=0, size=size)
    return pm.math.switch(pt.eq(treatment, 0), rsp1, rsp2)
   
with pm.Model() as m:
    treatment = np.array([0, 0, 0, 1, 1, 1])
    obs_data = np.abs(np.random.normal(size=6))
    obs_data[:3] *= -1

    rsp1_mu = pm.HalfNormal("rsp1_mu")
    rsp2_mu = pm.HalfNormal('rsp2_mu')
    pm.CustomDist("llike", treatment, rsp1_mu, rsp2_mu, dist=switch_dist, observed=obs_data)

m.point_logps()  # Fails

Using PyMC ability to derive simple logprob expressions. The switch case is still very limited, so it is failing in this case (I am planning to improve it this week).

For now, the simplest option is to:

  1. Implement the CustomDist logp method directly, which should look something like:
import numpy as np
import pymc as pm
import pytensor.tensor as pt

def switch_dist(treatment, rsp1_mu, rsp2_mu, size):
    rsp1 = -1 * pm.HalfNormal.dist(rsp1_mu, size=size)
    rsp2 = pm.TruncatedNormal.dist(rsp2_mu, lower=0, size=size)
    return pm.math.switch(pt.eq(treatment, 0), rsp1, rsp2)

def switch_logp(value, treatment, rsp1_mu, rsp2_mu):
    rsp1 = -1 * pm.HalfNormal.dist(rsp1_mu)
    rsp2 = pm.TruncatedNormal.dist(rsp2_mu, lower=0)
    logp = pt.switch(
        pt.eq(treatment, 0),
        pm.logp(rsp1, value),
        pm.logp(rsp2, value),
    )
    return logp
    
with pm.Model() as m:
    treatment = np.array([0, 0, 0, 1, 1, 1])
    obs_data = np.abs(np.random.normal(size=6))
    obs_data[:3] *= -1

    rsp1_mu = pm.HalfNormal("rsp1_mu")
    rsp2_mu = pm.HalfNormal('rsp2_mu')
    pm.CustomDist("llike", treatment, rsp1_mu, rsp2_mu, dist=switch_dist, logp=switch_logp, observed=obs_data)

m.point_logps()
  1. Split your observations into two CustomDists/Likelihoods:
import numpy as np
import pymc as pm

def halfnormal_dist(rsp1_mu, size):
    return -1 * pm.HalfNormal.dist(rsp1_mu, size=size)
   
with pm.Model() as m:
    treatment = np.array([0, 0, 0, 1, 1, 1])
    obs_data = np.abs(np.random.normal(size=6))
    obs_data[:3] *= -1

    rsp1_mu = pm.HalfNormal("rsp1_mu")
    rsp2_mu = pm.HalfNormal('rsp2_mu')
    pm.CustomDist("llike1", rsp1_mu, dist=halfnormal_dist, observed=obs_data[treatment == 0])
    pm.TruncatedNormal("llike2", rsp2_mu, lower=0, observed=obs_data[treatment == 1])

m.point_logps()

At which point you can also avoid the HalfNormal CustomDist, by simply negating the observed data and using a vanilla HalfNormal directly.

2 Likes