Bug Minibatching with CustomDist?

Hey All! Is it possible to use a CustomDist with arithmetic on Minibatched data? Am I missing a syntax error or is this a bug? The following code runs without Minibatching, but throws a NotImplementedError with it. It also does fine if the custom dist doesn’t involve arithmetic.

import pymc as pm
import numpy as np

length = 100

rng = np.random.default_rng(1)
x = np.linspace(0, 1, num=length)
true_mu =  4
y = true_mu + rng.normal(0, 1, size=length)
batched_y, batched_x = pm.Minibatch(y, x, batch_size=10)
total_size = len(x)

#custom distribution with addition (code runs fine without the +-/* operators)
def custom_dist(mu, a, b, size=None):
    return  mu + pm.Normal.dist(a, b, size=size)

#def custom_dist_that_works(mu, a, b, size=None):
#    return pm.Normal.dist(a, b, size=size)

with pm.Model() as model:
    mu = pm.Normal('mu', sigma=10)
    alpha = pm.HalfNormal('alpha', 1)
    beta = pm.HalfNormal('beta', 1)
    
    y_obs = pm.CustomDist('y_obs', mu, alpha, beta, dist=custom_dist, observed=batched_y, total_size=total_size)
    idata = pm.fit()

This gives a NotImplementedError and the following traceback:

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[44], line 22
     19 beta = pm.HalfNormal('beta', 1)
     21 y_obs = pm.CustomDist('y_obs', mu, alpha, beta, dist=custom_dist, observed=batched_y, total_size=total_size)
---> 22 idata = pm.fit()
     23 #idata = pm.sample_posterior_predictive(idata, extend_inferencedata=True)

File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pymc/variational/inference.py:766, in fit(n, method, model, random_seed, start, start_sigma, inf_kwargs, **kwargs)
    764 else:
    765     raise TypeError(f"method should be one of {set(_select.keys())} or Inference instance")
--> 766 return inference.fit(n, **kwargs)

File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pymc/variational/inference.py:149, in Inference.fit(self, n, score, callbacks, progressbar, progressbar_theme, **kwargs)
    147     callbacks = []
    148 score = self._maybe_score(score)
--> 149 step_func = self.objective.step_function(score=score, **kwargs)
    151 if score:
    152     state = self._iterate_with_loss(
    153         0, n, step_func, progressbar, progressbar_theme, callbacks
    154     )

File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pytensor/configparser.py:44, in _ChangeFlagsDecorator.__call__.<locals>.res(*args, **kwargs)
     41 @wraps(f)
     42 def res(*args, **kwargs):
     43     with self:
---> 44         return f(*args, **kwargs)

File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pymc/variational/opvi.py:379, in ObjectiveFunction.step_function(self, obj_n_mc, tf_n_mc, obj_optimizer, test_optimizer, more_obj_params, more_tf_params, more_updates, more_replacements, total_grad_norm_constraint, score, fn_kwargs)
    377 if score and not self.op.returns_loss:
    378     raise NotImplementedError(f"{self.op} does not have loss")
--> 379 updates = self.updates(
    380     obj_n_mc=obj_n_mc,
    381     tf_n_mc=tf_n_mc,
    382     obj_optimizer=obj_optimizer,
    383     test_optimizer=test_optimizer,
    384     more_obj_params=more_obj_params,
    385     more_tf_params=more_tf_params,
    386     more_updates=more_updates,
    387     more_replacements=more_replacements,
    388     total_grad_norm_constraint=total_grad_norm_constraint,
    389 )
    390 seed = self.approx.rng.randint(2**30, dtype=np.int64)
    391 if score:

File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pymc/variational/opvi.py:268, in ObjectiveFunction.updates(self, obj_n_mc, tf_n_mc, obj_optimizer, test_optimizer, more_obj_params, more_tf_params, more_updates, more_replacements, total_grad_norm_constraint)
    266     if more_tf_params:
    267         _warn_not_used("more_tf_params", self.op)
--> 268 self.add_obj_updates(
    269     resulting_updates,
    270     obj_n_mc=obj_n_mc,
    271     obj_optimizer=obj_optimizer,
    272     more_obj_params=more_obj_params,
    273     more_replacements=more_replacements,
    274     total_grad_norm_constraint=total_grad_norm_constraint,
    275 )
    276 resulting_updates.update(more_updates)
    277 return resulting_updates

File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pymc/variational/opvi.py:313, in ObjectiveFunction.add_obj_updates(self, updates, obj_n_mc, obj_optimizer, more_obj_params, more_replacements, total_grad_norm_constraint)
    311 if more_replacements is None:
    312     more_replacements = dict()
--> 313 obj_target = self(
    314     obj_n_mc, more_obj_params=more_obj_params, more_replacements=more_replacements
    315 )
    316 grads = pm.updates.get_or_compute_grads(obj_target, self.obj_params + more_obj_params)
    317 if total_grad_norm_constraint is not None:

File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pytensor/configparser.py:44, in _ChangeFlagsDecorator.__call__.<locals>.res(*args, **kwargs)
     41 @wraps(f)
     42 def res(*args, **kwargs):
     43     with self:
---> 44         return f(*args, **kwargs)

File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pymc/variational/opvi.py:432, in ObjectiveFunction.__call__(self, nmc, **kwargs)
    430 else:
    431     m = 1.0
--> 432 a = self.op.apply(self.tf)
    433 a = self.approx.set_size_and_deterministic(a, nmc, 0, kwargs.get("more_replacements"))
    434 return m * self.op.T(a)

File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pymc/variational/operators.py:63, in KL.apply(self, f)
     62 def apply(self, f):
---> 63     return -self.datalogp_norm + self.beta * (self.logq_norm - self.varlogp_norm)

File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pymc/variational/opvi.py:472, in Operator.<lambda>(self)
    470 logp_norm = property(lambda self: self.approx.logp_norm)
    471 varlogp_norm = property(lambda self: self.approx.varlogp_norm)
--> 472 datalogp_norm = property(lambda self: self.approx.datalogp_norm)
    473 logq_norm = property(lambda self: self.approx.logq_norm)
    474 model = property(lambda self: self.approx.model)

File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/cachetools/__init__.py:814, in cachedmethod.<locals>.decorator.<locals>.wrapper(self, *args, **kwargs)
    812 except KeyError:
    813     pass  # key not found
--> 814 v = method(self, *args, **kwargs)
    815 try:
    816     c[k] = v

File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pytensor/configparser.py:44, in _ChangeFlagsDecorator.__call__.<locals>.res(*args, **kwargs)
     41 @wraps(f)
     42 def res(*args, **kwargs):
     43     with self:
---> 44         return f(*args, **kwargs)

File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pymc/variational/opvi.py:1358, in Approximation.datalogp_norm(self)
   1355 @node_property
   1356 def datalogp_norm(self):
   1357     """*Dev* - normalized :math:`E_{q}(data term)`"""
-> 1358     return self.datalogp / self.symbolic_normalizing_constant

File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/cachetools/__init__.py:814, in cachedmethod.<locals>.decorator.<locals>.wrapper(self, *args, **kwargs)
    812 except KeyError:
    813     pass  # key not found
--> 814 v = method(self, *args, **kwargs)
    815 try:
    816     c[k] = v

File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pytensor/configparser.py:44, in _ChangeFlagsDecorator.__call__.<locals>.res(*args, **kwargs)
     41 @wraps(f)
     42 def res(*args, **kwargs):
     43     with self:
---> 44         return f(*args, **kwargs)

File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pymc/variational/opvi.py:1319, in Approximation.datalogp(self)
   1316 @node_property
   1317 def datalogp(self):
   1318     """*Dev* - computes :math:`E_{q}(data term)` from model via `pytensor.scan` that can be optimized later"""
-> 1319     return self.sized_symbolic_datalogp.mean(0)

File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/cachetools/__init__.py:814, in cachedmethod.<locals>.decorator.<locals>.wrapper(self, *args, **kwargs)
    812 except KeyError:
    813     pass  # key not found
--> 814 v = method(self, *args, **kwargs)
    815 try:
    816     c[k] = v

File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pytensor/configparser.py:44, in _ChangeFlagsDecorator.__call__.<locals>.res(*args, **kwargs)
     41 @wraps(f)
     42 def res(*args, **kwargs):
     43     with self:
---> 44         return f(*args, **kwargs)

File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pymc/variational/opvi.py:1299, in Approximation.sized_symbolic_datalogp(self)
   1296 @node_property
   1297 def sized_symbolic_datalogp(self):
   1298     """*Dev* - computes sampled data term from model via `pytensor.scan`"""
-> 1299     return self._sized_symbolic_varlogp_and_datalogp[1]

File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/cachetools/__init__.py:814, in cachedmethod.<locals>.decorator.<locals>.wrapper(self, *args, **kwargs)
    812 except KeyError:
    813     pass  # key not found
--> 814 v = method(self, *args, **kwargs)
    815 try:
    816     c[k] = v

File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pytensor/configparser.py:44, in _ChangeFlagsDecorator.__call__.<locals>.res(*args, **kwargs)
     41 @wraps(f)
     42 def res(*args, **kwargs):
     43     with self:
---> 44         return f(*args, **kwargs)

File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pymc/variational/opvi.py:1287, in Approximation._sized_symbolic_varlogp_and_datalogp(self)
   1283 @node_property
   1284 def _sized_symbolic_varlogp_and_datalogp(self):
   1285     """*Dev* - computes sampled prior term from model via `pytensor.scan`"""
   1286     varlogp_s, datalogp_s = self.symbolic_sample_over_posterior(
-> 1287         [self.model.varlogp, self.model.datalogp]
   1288     )
   1289     return varlogp_s, datalogp_s

File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pymc/model/core.py:858, in Model.varlogp(self)
    854 @property
    855 def varlogp(self) -> Variable:
    856     """PyTensor scalar of log-probability of the unobserved random variables
    857     (excluding deterministic)."""
--> 858     return self.logp(vars=self.free_RVs)

File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pymc/model/core.py:742, in Model.logp(self, vars, jacobian, sum)
    740 rv_logps: list[TensorVariable] = []
    741 if rvs:
--> 742     rv_logps = transformed_conditional_logp(
    743         rvs=rvs,
    744         rvs_to_values=self.rvs_to_values,
    745         rvs_to_transforms=self.rvs_to_transforms,
    746         jacobian=jacobian,
    747     )
    748     assert isinstance(rv_logps, list)
    750 # Replace random variables by their value variables in potential terms

File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pymc/logprob/basic.py:611, in transformed_conditional_logp(rvs, rvs_to_values, rvs_to_transforms, jacobian, **kwargs)
    608     transform_rewrite = TransformValuesRewrite(values_to_transforms)  # type: ignore
    610 kwargs.setdefault("warn_rvs", False)
--> 611 temp_logp_terms = conditional_logp(
    612     rvs_to_values,
    613     extra_rewrites=transform_rewrite,
    614     use_jacobian=jacobian,
    615     **kwargs,
    616 )
    618 # The function returns the logp for every single value term we provided to it.
    619 # This includes the extra values we plugged in above, so we filter those we
    620 # actually wanted in the same order they were given in.
    621 logp_terms = {}

File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pymc/logprob/basic.py:541, in conditional_logp(rv_values, warn_rvs, ir_rewriter, extra_rewrites, **kwargs)
    538 q_values = remapped_vars[: len(q_values)]
    539 q_rv_inputs = remapped_vars[len(q_values) :]
--> 541 q_logprob_vars = _logprob(
    542     node.op,
    543     q_values,
    544     *q_rv_inputs,
    545     **kwargs,
    546 )
    548 if not isinstance(q_logprob_vars, list | tuple):
    549     q_logprob_vars = [q_logprob_vars]

File ~/anaconda3/envs/pipemodel_env/lib/python3.12/functools.py:907, in singledispatch.<locals>.wrapper(*args, **kw)
    903 if not args:
    904     raise TypeError(f'{funcname} requires at least '
    905                     '1 positional argument')
--> 907 return dispatch(args[0].__class__)(*args, **kw)

File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pymc/variational/minibatch_rv.py:105, in minibatch_rv_logprob(op, values, *inputs, **kwargs)
    103 [value] = values
    104 rv, *total_size = inputs
--> 105 return _logprob_helper(rv, value, **kwargs) * get_scaling(total_size, value.shape)

File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pymc/logprob/abstract.py:68, in _logprob_helper(rv, *values, **kwargs)
     66 def _logprob_helper(rv, *values, **kwargs):
     67     """Helper that calls `_logprob` dispatcher."""
---> 68     logprob = _logprob(rv.owner.op, values, *rv.owner.inputs, **kwargs)
     70     name = rv.name
     71     if (not name) and (len(values) == 1):

File ~/anaconda3/envs/pipemodel_env/lib/python3.12/functools.py:907, in singledispatch.<locals>.wrapper(*args, **kw)
    903 if not args:
    904     raise TypeError(f'{funcname} requires at least '
    905                     '1 positional argument')
--> 907 return dispatch(args[0].__class__)(*args, **kw)

File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pymc/logprob/abstract.py:63, in _logprob(op, values, *inputs, **kwargs)
     49 @singledispatch
     50 def _logprob(
     51     op: Op,
   (...)
     54     **kwargs,
     55 ):
     56     """Create a graph for the log-density/mass of a ``RandomVariable``.
     57 
     58     This function dispatches on the type of ``op``, which should be a subclass
   (...)
     61 
     62     """
---> 63     raise NotImplementedError(f"Logprob method not implemented for {op}")

NotImplementedError: Logprob method not implemented for Add

Here is the relevant conda environment info:

# packages in environment at /home/ubuntu/anaconda3/envs/model_env:
#
# Name                    Version                   Build  Channel

pymc                      5.16.2               hd8ed1ab_0    conda-forge
pymc-base                 5.16.2             pyhd8ed1ab_0    conda-forge

pytensor                  2.25.4          py312h97902ae_0    conda-forge
pytensor-base             2.25.4          py312h25a0e75_0    conda-forge
python                    3.12.5          h2ad013b_0_cpython    conda-forge

Should be allowed in the next PyMC release: Allow Minibatch of derived RVs and deprecate generators as data by ricardoV94 · Pull Request #7480 · pymc-devs/pymc · GitHub

Thanks @ricardoV94! Embarassingly, as it turns out I was wrong about Minibatching being directly involved in this error. This won’t run using pm.sample either. The arithmetic within the custom distribution is the unsupported thing here for some reason.

It shouldn’t be, but maybe it is. How did you arrive at that conclusion?

No, you’re right. I thought I had narrowed the issue down but I was wrong. I didn’t realize that providing the total_size argument in the likelihood would affect how the variables were treated, even if using a NUTS sampler instead of ADVI. For example, in the code below the problemetic CustomDist will sample fine if I don’t supply a total size.

I’ll readily admit I don’t understand the codebase, but it seems like this is because total_size is causing things to be interpreted as Minibatch Random Variables even when there is no minibatching? For my own education, what is happening here? Why is the logprob for these seemingly simple arithmetic ops not implemented for minibatch variables in this situation?

#Custom distribution without the  arithimetic operation:
def custom_dist_no_arith(mu, a, size=None): #works in all situations
    return pm.Normal.dist(mu, a, size=size)

#dist using multiplication of random variables:
def custom_dist(mu, a, size=None): #NotImplementedError: Logprob method not implemented for Mul
    return a * pm.Normal(mu, 1, size=size)

# NUTS, Total_length supplied, and CustomDist with no arithmetic: 
#.          Runs fine
with pm.Model() as model:
    mu = pm.Normal('mu', sigma=10)
    alpha = pm.HalfNormal('alpha', 1)
    y_obs = pm.CustomDist('y_obs', mu, alpha, dist=custom_dist_no_addition, observed=y, total_size=length)
    idata = pm.sample()

# NUTS, Total_length supplied, CustomDist with arithmetic:
#         NotImplementedError
with pm.Model() as m2:
    mu = pm.Normal('mu', sigma=10)
    alpha = pm.HalfNormal('alpha', 1)
    y_obs = pm.CustomDist('y_obs', mu, alpha, dist=custom_dist, observed=y, total_size=length)
    idata = pm.sample()

#NUTS, no Total_length, with arithmetic: 
#                    Runs fine
with pm.Model() as m3:
    mu = pm.Normal('mu', sigma=10)
    alpha = pm.HalfNormal('alpha', 1)
    y_obs = pm.CustomDist('y_obs', mu, alpha, dist=custom_dist, observed=y) #total length arg not supplied
    idata = pm.sample()

EDIT: one last wrinkle which has me even more baffled- separating out the arithmetic operation and storing it in a temporary variable resolves the issue:

def custom_dist_two_liner(mu, a, size=None): #works with total_size arg supplied
    mu2 = mu + a
    return pm.Normal.dist(mu2, 1, size=size)

def custom_dist_one_liner(mu, a, size=None): #equivalent but throws error
    return a + pm.Normal.dist(mu, 1, size=size)

total_size rescales the probability of a variable, usually used in conjunction with minibatch, where you give 20% of the datapoints and then rescale it as if 100% had been given.

As I wrote above, support of rescaling (total_size) with CustomDist was implemented recently and will only be available in the next release