Hey All! Is it possible to use a CustomDist with arithmetic on Minibatched data? Am I missing a syntax error or is this a bug? The following code runs without Minibatching, but throws a NotImplementedError with it. It also does fine if the custom dist doesn’t involve arithmetic.
import pymc as pm
import numpy as np
length = 100
rng = np.random.default_rng(1)
x = np.linspace(0, 1, num=length)
true_mu = 4
y = true_mu + rng.normal(0, 1, size=length)
batched_y, batched_x = pm.Minibatch(y, x, batch_size=10)
total_size = len(x)
#custom distribution with addition (code runs fine without the +-/* operators)
def custom_dist(mu, a, b, size=None):
return mu + pm.Normal.dist(a, b, size=size)
#def custom_dist_that_works(mu, a, b, size=None):
# return pm.Normal.dist(a, b, size=size)
with pm.Model() as model:
mu = pm.Normal('mu', sigma=10)
alpha = pm.HalfNormal('alpha', 1)
beta = pm.HalfNormal('beta', 1)
y_obs = pm.CustomDist('y_obs', mu, alpha, beta, dist=custom_dist, observed=batched_y, total_size=total_size)
idata = pm.fit()
This gives a NotImplementedError and the following traceback:
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
Cell In[44], line 22
19 beta = pm.HalfNormal('beta', 1)
21 y_obs = pm.CustomDist('y_obs', mu, alpha, beta, dist=custom_dist, observed=batched_y, total_size=total_size)
---> 22 idata = pm.fit()
23 #idata = pm.sample_posterior_predictive(idata, extend_inferencedata=True)
File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pymc/variational/inference.py:766, in fit(n, method, model, random_seed, start, start_sigma, inf_kwargs, **kwargs)
764 else:
765 raise TypeError(f"method should be one of {set(_select.keys())} or Inference instance")
--> 766 return inference.fit(n, **kwargs)
File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pymc/variational/inference.py:149, in Inference.fit(self, n, score, callbacks, progressbar, progressbar_theme, **kwargs)
147 callbacks = []
148 score = self._maybe_score(score)
--> 149 step_func = self.objective.step_function(score=score, **kwargs)
151 if score:
152 state = self._iterate_with_loss(
153 0, n, step_func, progressbar, progressbar_theme, callbacks
154 )
File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pytensor/configparser.py:44, in _ChangeFlagsDecorator.__call__.<locals>.res(*args, **kwargs)
41 @wraps(f)
42 def res(*args, **kwargs):
43 with self:
---> 44 return f(*args, **kwargs)
File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pymc/variational/opvi.py:379, in ObjectiveFunction.step_function(self, obj_n_mc, tf_n_mc, obj_optimizer, test_optimizer, more_obj_params, more_tf_params, more_updates, more_replacements, total_grad_norm_constraint, score, fn_kwargs)
377 if score and not self.op.returns_loss:
378 raise NotImplementedError(f"{self.op} does not have loss")
--> 379 updates = self.updates(
380 obj_n_mc=obj_n_mc,
381 tf_n_mc=tf_n_mc,
382 obj_optimizer=obj_optimizer,
383 test_optimizer=test_optimizer,
384 more_obj_params=more_obj_params,
385 more_tf_params=more_tf_params,
386 more_updates=more_updates,
387 more_replacements=more_replacements,
388 total_grad_norm_constraint=total_grad_norm_constraint,
389 )
390 seed = self.approx.rng.randint(2**30, dtype=np.int64)
391 if score:
File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pymc/variational/opvi.py:268, in ObjectiveFunction.updates(self, obj_n_mc, tf_n_mc, obj_optimizer, test_optimizer, more_obj_params, more_tf_params, more_updates, more_replacements, total_grad_norm_constraint)
266 if more_tf_params:
267 _warn_not_used("more_tf_params", self.op)
--> 268 self.add_obj_updates(
269 resulting_updates,
270 obj_n_mc=obj_n_mc,
271 obj_optimizer=obj_optimizer,
272 more_obj_params=more_obj_params,
273 more_replacements=more_replacements,
274 total_grad_norm_constraint=total_grad_norm_constraint,
275 )
276 resulting_updates.update(more_updates)
277 return resulting_updates
File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pymc/variational/opvi.py:313, in ObjectiveFunction.add_obj_updates(self, updates, obj_n_mc, obj_optimizer, more_obj_params, more_replacements, total_grad_norm_constraint)
311 if more_replacements is None:
312 more_replacements = dict()
--> 313 obj_target = self(
314 obj_n_mc, more_obj_params=more_obj_params, more_replacements=more_replacements
315 )
316 grads = pm.updates.get_or_compute_grads(obj_target, self.obj_params + more_obj_params)
317 if total_grad_norm_constraint is not None:
File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pytensor/configparser.py:44, in _ChangeFlagsDecorator.__call__.<locals>.res(*args, **kwargs)
41 @wraps(f)
42 def res(*args, **kwargs):
43 with self:
---> 44 return f(*args, **kwargs)
File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pymc/variational/opvi.py:432, in ObjectiveFunction.__call__(self, nmc, **kwargs)
430 else:
431 m = 1.0
--> 432 a = self.op.apply(self.tf)
433 a = self.approx.set_size_and_deterministic(a, nmc, 0, kwargs.get("more_replacements"))
434 return m * self.op.T(a)
File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pymc/variational/operators.py:63, in KL.apply(self, f)
62 def apply(self, f):
---> 63 return -self.datalogp_norm + self.beta * (self.logq_norm - self.varlogp_norm)
File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pymc/variational/opvi.py:472, in Operator.<lambda>(self)
470 logp_norm = property(lambda self: self.approx.logp_norm)
471 varlogp_norm = property(lambda self: self.approx.varlogp_norm)
--> 472 datalogp_norm = property(lambda self: self.approx.datalogp_norm)
473 logq_norm = property(lambda self: self.approx.logq_norm)
474 model = property(lambda self: self.approx.model)
File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/cachetools/__init__.py:814, in cachedmethod.<locals>.decorator.<locals>.wrapper(self, *args, **kwargs)
812 except KeyError:
813 pass # key not found
--> 814 v = method(self, *args, **kwargs)
815 try:
816 c[k] = v
File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pytensor/configparser.py:44, in _ChangeFlagsDecorator.__call__.<locals>.res(*args, **kwargs)
41 @wraps(f)
42 def res(*args, **kwargs):
43 with self:
---> 44 return f(*args, **kwargs)
File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pymc/variational/opvi.py:1358, in Approximation.datalogp_norm(self)
1355 @node_property
1356 def datalogp_norm(self):
1357 """*Dev* - normalized :math:`E_{q}(data term)`"""
-> 1358 return self.datalogp / self.symbolic_normalizing_constant
File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/cachetools/__init__.py:814, in cachedmethod.<locals>.decorator.<locals>.wrapper(self, *args, **kwargs)
812 except KeyError:
813 pass # key not found
--> 814 v = method(self, *args, **kwargs)
815 try:
816 c[k] = v
File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pytensor/configparser.py:44, in _ChangeFlagsDecorator.__call__.<locals>.res(*args, **kwargs)
41 @wraps(f)
42 def res(*args, **kwargs):
43 with self:
---> 44 return f(*args, **kwargs)
File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pymc/variational/opvi.py:1319, in Approximation.datalogp(self)
1316 @node_property
1317 def datalogp(self):
1318 """*Dev* - computes :math:`E_{q}(data term)` from model via `pytensor.scan` that can be optimized later"""
-> 1319 return self.sized_symbolic_datalogp.mean(0)
File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/cachetools/__init__.py:814, in cachedmethod.<locals>.decorator.<locals>.wrapper(self, *args, **kwargs)
812 except KeyError:
813 pass # key not found
--> 814 v = method(self, *args, **kwargs)
815 try:
816 c[k] = v
File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pytensor/configparser.py:44, in _ChangeFlagsDecorator.__call__.<locals>.res(*args, **kwargs)
41 @wraps(f)
42 def res(*args, **kwargs):
43 with self:
---> 44 return f(*args, **kwargs)
File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pymc/variational/opvi.py:1299, in Approximation.sized_symbolic_datalogp(self)
1296 @node_property
1297 def sized_symbolic_datalogp(self):
1298 """*Dev* - computes sampled data term from model via `pytensor.scan`"""
-> 1299 return self._sized_symbolic_varlogp_and_datalogp[1]
File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/cachetools/__init__.py:814, in cachedmethod.<locals>.decorator.<locals>.wrapper(self, *args, **kwargs)
812 except KeyError:
813 pass # key not found
--> 814 v = method(self, *args, **kwargs)
815 try:
816 c[k] = v
File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pytensor/configparser.py:44, in _ChangeFlagsDecorator.__call__.<locals>.res(*args, **kwargs)
41 @wraps(f)
42 def res(*args, **kwargs):
43 with self:
---> 44 return f(*args, **kwargs)
File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pymc/variational/opvi.py:1287, in Approximation._sized_symbolic_varlogp_and_datalogp(self)
1283 @node_property
1284 def _sized_symbolic_varlogp_and_datalogp(self):
1285 """*Dev* - computes sampled prior term from model via `pytensor.scan`"""
1286 varlogp_s, datalogp_s = self.symbolic_sample_over_posterior(
-> 1287 [self.model.varlogp, self.model.datalogp]
1288 )
1289 return varlogp_s, datalogp_s
File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pymc/model/core.py:858, in Model.varlogp(self)
854 @property
855 def varlogp(self) -> Variable:
856 """PyTensor scalar of log-probability of the unobserved random variables
857 (excluding deterministic)."""
--> 858 return self.logp(vars=self.free_RVs)
File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pymc/model/core.py:742, in Model.logp(self, vars, jacobian, sum)
740 rv_logps: list[TensorVariable] = []
741 if rvs:
--> 742 rv_logps = transformed_conditional_logp(
743 rvs=rvs,
744 rvs_to_values=self.rvs_to_values,
745 rvs_to_transforms=self.rvs_to_transforms,
746 jacobian=jacobian,
747 )
748 assert isinstance(rv_logps, list)
750 # Replace random variables by their value variables in potential terms
File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pymc/logprob/basic.py:611, in transformed_conditional_logp(rvs, rvs_to_values, rvs_to_transforms, jacobian, **kwargs)
608 transform_rewrite = TransformValuesRewrite(values_to_transforms) # type: ignore
610 kwargs.setdefault("warn_rvs", False)
--> 611 temp_logp_terms = conditional_logp(
612 rvs_to_values,
613 extra_rewrites=transform_rewrite,
614 use_jacobian=jacobian,
615 **kwargs,
616 )
618 # The function returns the logp for every single value term we provided to it.
619 # This includes the extra values we plugged in above, so we filter those we
620 # actually wanted in the same order they were given in.
621 logp_terms = {}
File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pymc/logprob/basic.py:541, in conditional_logp(rv_values, warn_rvs, ir_rewriter, extra_rewrites, **kwargs)
538 q_values = remapped_vars[: len(q_values)]
539 q_rv_inputs = remapped_vars[len(q_values) :]
--> 541 q_logprob_vars = _logprob(
542 node.op,
543 q_values,
544 *q_rv_inputs,
545 **kwargs,
546 )
548 if not isinstance(q_logprob_vars, list | tuple):
549 q_logprob_vars = [q_logprob_vars]
File ~/anaconda3/envs/pipemodel_env/lib/python3.12/functools.py:907, in singledispatch.<locals>.wrapper(*args, **kw)
903 if not args:
904 raise TypeError(f'{funcname} requires at least '
905 '1 positional argument')
--> 907 return dispatch(args[0].__class__)(*args, **kw)
File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pymc/variational/minibatch_rv.py:105, in minibatch_rv_logprob(op, values, *inputs, **kwargs)
103 [value] = values
104 rv, *total_size = inputs
--> 105 return _logprob_helper(rv, value, **kwargs) * get_scaling(total_size, value.shape)
File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pymc/logprob/abstract.py:68, in _logprob_helper(rv, *values, **kwargs)
66 def _logprob_helper(rv, *values, **kwargs):
67 """Helper that calls `_logprob` dispatcher."""
---> 68 logprob = _logprob(rv.owner.op, values, *rv.owner.inputs, **kwargs)
70 name = rv.name
71 if (not name) and (len(values) == 1):
File ~/anaconda3/envs/pipemodel_env/lib/python3.12/functools.py:907, in singledispatch.<locals>.wrapper(*args, **kw)
903 if not args:
904 raise TypeError(f'{funcname} requires at least '
905 '1 positional argument')
--> 907 return dispatch(args[0].__class__)(*args, **kw)
File ~/anaconda3/envs/pipemodel_env/lib/python3.12/site-packages/pymc/logprob/abstract.py:63, in _logprob(op, values, *inputs, **kwargs)
49 @singledispatch
50 def _logprob(
51 op: Op,
(...)
54 **kwargs,
55 ):
56 """Create a graph for the log-density/mass of a ``RandomVariable``.
57
58 This function dispatches on the type of ``op``, which should be a subclass
(...)
61
62 """
---> 63 raise NotImplementedError(f"Logprob method not implemented for {op}")
NotImplementedError: Logprob method not implemented for Add
Here is the relevant conda environment info:
# packages in environment at /home/ubuntu/anaconda3/envs/model_env:
#
# Name Version Build Channel
pymc 5.16.2 hd8ed1ab_0 conda-forge
pymc-base 5.16.2 pyhd8ed1ab_0 conda-forge
pytensor 2.25.4 py312h97902ae_0 conda-forge
pytensor-base 2.25.4 py312h25a0e75_0 conda-forge
python 3.12.5 h2ad013b_0_cpython conda-forge