No, you’re right. I thought I had narrowed the issue down but I was wrong. I didn’t realize that providing the total_size argument in the likelihood would affect how the variables were treated, even if using a NUTS sampler instead of ADVI. For example, in the code below the problemetic CustomDist will sample fine if I don’t supply a total size.
I’ll readily admit I don’t understand the codebase, but it seems like this is because total_size is causing things to be interpreted as Minibatch Random Variables even when there is no minibatching? For my own education, what is happening here? Why is the logprob for these seemingly simple arithmetic ops not implemented for minibatch variables in this situation?
#Custom distribution without the arithimetic operation:
def custom_dist_no_arith(mu, a, size=None): #works in all situations
return pm.Normal.dist(mu, a, size=size)
#dist using multiplication of random variables:
def custom_dist(mu, a, size=None): #NotImplementedError: Logprob method not implemented for Mul
return a * pm.Normal(mu, 1, size=size)
# NUTS, Total_length supplied, and CustomDist with no arithmetic:
#. Runs fine
with pm.Model() as model:
mu = pm.Normal('mu', sigma=10)
alpha = pm.HalfNormal('alpha', 1)
y_obs = pm.CustomDist('y_obs', mu, alpha, dist=custom_dist_no_addition, observed=y, total_size=length)
idata = pm.sample()
# NUTS, Total_length supplied, CustomDist with arithmetic:
# NotImplementedError
with pm.Model() as m2:
mu = pm.Normal('mu', sigma=10)
alpha = pm.HalfNormal('alpha', 1)
y_obs = pm.CustomDist('y_obs', mu, alpha, dist=custom_dist, observed=y, total_size=length)
idata = pm.sample()
#NUTS, no Total_length, with arithmetic:
# Runs fine
with pm.Model() as m3:
mu = pm.Normal('mu', sigma=10)
alpha = pm.HalfNormal('alpha', 1)
y_obs = pm.CustomDist('y_obs', mu, alpha, dist=custom_dist, observed=y) #total length arg not supplied
idata = pm.sample()
EDIT: one last wrinkle which has me even more baffled- separating out the arithmetic operation and storing it in a temporary variable resolves the issue:
def custom_dist_two_liner(mu, a, size=None): #works with total_size arg supplied
mu2 = mu + a
return pm.Normal.dist(mu2, 1, size=size)
def custom_dist_one_liner(mu, a, size=None): #equivalent but throws error
return a + pm.Normal.dist(mu, 1, size=size)