This is something we are trying to improve, and would be done via a CustomDist, like this:
import numpy as np
import pymc as pm
import pytensor.tensor as pt
def switch_dist(treatment, rsp1_mu, rsp2_mu, size):
rsp1 = -1 * pm.HalfNormal.dist(rsp1_mu, size=size)
rsp2 = pm.TruncatedNormal.dist(rsp2_mu, lower=0, size=size)
return pm.math.switch(pt.eq(treatment, 0), rsp1, rsp2)
with pm.Model() as m:
treatment = np.array([0, 0, 0, 1, 1, 1])
obs_data = np.abs(np.random.normal(size=6))
obs_data[:3] *= -1
rsp1_mu = pm.HalfNormal("rsp1_mu")
rsp2_mu = pm.HalfNormal('rsp2_mu')
pm.CustomDist("llike", treatment, rsp1_mu, rsp2_mu, dist=switch_dist, observed=obs_data)
m.point_logps() # Fails
Using PyMC ability to derive simple logprob expressions. The switch case is still very limited, so it is failing in this case (I am planning to improve it this week).
For now, the simplest option is to:
- Implement the
CustomDistlogp method directly, which should look something like:
import numpy as np
import pymc as pm
import pytensor.tensor as pt
def switch_dist(treatment, rsp1_mu, rsp2_mu, size):
rsp1 = -1 * pm.HalfNormal.dist(rsp1_mu, size=size)
rsp2 = pm.TruncatedNormal.dist(rsp2_mu, lower=0, size=size)
return pm.math.switch(pt.eq(treatment, 0), rsp1, rsp2)
def switch_logp(value, treatment, rsp1_mu, rsp2_mu):
rsp1 = -1 * pm.HalfNormal.dist(rsp1_mu)
rsp2 = pm.TruncatedNormal.dist(rsp2_mu, lower=0)
logp = pt.switch(
pt.eq(treatment, 0),
pm.logp(rsp1, value),
pm.logp(rsp2, value),
)
return logp
with pm.Model() as m:
treatment = np.array([0, 0, 0, 1, 1, 1])
obs_data = np.abs(np.random.normal(size=6))
obs_data[:3] *= -1
rsp1_mu = pm.HalfNormal("rsp1_mu")
rsp2_mu = pm.HalfNormal('rsp2_mu')
pm.CustomDist("llike", treatment, rsp1_mu, rsp2_mu, dist=switch_dist, logp=switch_logp, observed=obs_data)
m.point_logps()
- Split your observations into two CustomDists/Likelihoods:
import numpy as np
import pymc as pm
def halfnormal_dist(rsp1_mu, size):
return -1 * pm.HalfNormal.dist(rsp1_mu, size=size)
with pm.Model() as m:
treatment = np.array([0, 0, 0, 1, 1, 1])
obs_data = np.abs(np.random.normal(size=6))
obs_data[:3] *= -1
rsp1_mu = pm.HalfNormal("rsp1_mu")
rsp2_mu = pm.HalfNormal('rsp2_mu')
pm.CustomDist("llike1", rsp1_mu, dist=halfnormal_dist, observed=obs_data[treatment == 0])
pm.TruncatedNormal("llike2", rsp2_mu, lower=0, observed=obs_data[treatment == 1])
m.point_logps()
At which point you can also avoid the HalfNormal CustomDist, by simply negating the observed data and using a vanilla HalfNormal directly.