Avoiding non-NUTS samplers when broadcasting a multiplication operation

I’m trying to revive an old model I built a couple of years ago using pymc3 that was able to run in ~2-4 hours. When running this same model using pymc5, it takes several days to run. Under pymc3, the step method used for all parameters are NUTS. Under pymc5, it uses a mix of NUTS and Metropolis. This prevents me from trying to use some of the newer, faster samplers (e.g., numpyro). After playing a bit with the code, I was able to narrow down the exact operation that is causing it to switch from using NUTS to Metropolis. In the following toy code, the x[:, np.newaxis] * np.arange(4) is causing the switch to Metropolis.

import pymc as pm
import numpy as np

with pm.Model() as model:
    x = pm.Normal('test', shape=74)
    mu = x[:, np.newaxis] * np.arange(4)
    mu = mu.max(axis=-1)
    pm.Normal('obs', mu, observed=np.random.uniform(size=74))
    pm.sample()

Is there a way I can modify this so that it will use the NUTS sampler?

What message do you get when you call model.dlogp()?

Yes, I get a NotImplementedError. It seems that adding a dimension (for broadcasting) then multiplying causes this. Any suggestions for a work-around?

---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[17], line 1
----> 1 model.dlogp()

File ~\anaconda3\envs\R01-kujawa-NUTS\lib\site-packages\pymc\model\core.py:789, in Model.dlogp(self, vars, jacobian)
    787 cost = self.logp(jacobian=jacobian)
    788 cost = rewrite_pregrad(cost)
--> 789 return gradient(cost, value_vars)

File ~\anaconda3\envs\R01-kujawa-NUTS\lib\site-packages\pymc\pytensorf.py:317, in gradient(f, vars)
    314     vars = cont_inputs(f)
    316 if vars:
--> 317     return pt.concatenate([gradient1(f, v) for v in vars], axis=0)
    318 else:
    319     return empty_gradient

File ~\anaconda3\envs\R01-kujawa-NUTS\lib\site-packages\pymc\pytensorf.py:317, in <listcomp>(.0)
    314     vars = cont_inputs(f)
    316 if vars:
--> 317     return pt.concatenate([gradient1(f, v) for v in vars], axis=0)
    318 else:
    319     return empty_gradient

File ~\anaconda3\envs\R01-kujawa-NUTS\lib\site-packages\pymc\pytensorf.py:306, in gradient1(f, v)
    304 def gradient1(f, v):
    305     """flat gradient of f wrt v"""
--> 306     return pt.flatten(grad(f, v, disconnected_inputs="warn"))

File ~\anaconda3\envs\R01-kujawa-NUTS\lib\site-packages\pytensor\gradient.py:607, in grad(cost, wrt, consider_constant, disconnected_inputs, add_names, known_grads, return_disconnected, null_gradients)
    604     if hasattr(g.type, "dtype"):
    605         assert g.type.dtype in pytensor.tensor.type.float_dtypes
--> 607 _rval: Sequence[Variable] = _populate_grad_dict(
    608     var_to_app_to_idx, grad_dict, _wrt, cost_name
    609 )
    611 rval: MutableSequence[Variable | None] = list(_rval)
    613 for i in range(len(_rval)):

File ~\anaconda3\envs\R01-kujawa-NUTS\lib\site-packages\pytensor\gradient.py:1407, in _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name)
   1404     # end if cache miss
   1405     return grad_dict[var]
-> 1407 rval = [access_grad_cache(elem) for elem in wrt]
   1409 return rval

File ~\anaconda3\envs\R01-kujawa-NUTS\lib\site-packages\pytensor\gradient.py:1407, in <listcomp>(.0)
   1404     # end if cache miss
   1405     return grad_dict[var]
-> 1407 rval = [access_grad_cache(elem) for elem in wrt]
   1409 return rval

File ~\anaconda3\envs\R01-kujawa-NUTS\lib\site-packages\pytensor\gradient.py:1362, in _populate_grad_dict.<locals>.access_grad_cache(var)
   1360 for node in node_to_idx:
   1361     for idx in node_to_idx[node]:
-> 1362         term = access_term_cache(node)[idx]
   1364         if not isinstance(term, Variable):
   1365             raise TypeError(
   1366                 f"{node.op}.grad returned {type(term)}, expected"
   1367                 " Variable instance."
   1368             )

File ~\anaconda3\envs\R01-kujawa-NUTS\lib\site-packages\pytensor\gradient.py:1037, in _populate_grad_dict.<locals>.access_term_cache(node)
   1034 if node not in term_dict:
   1035     inputs = node.inputs
-> 1037     output_grads = [access_grad_cache(var) for var in node.outputs]
   1039     # list of bools indicating if each output is connected to the cost
   1040     outputs_connected = [
   1041         not isinstance(g.type, DisconnectedType) for g in output_grads
   1042     ]

File ~\anaconda3\envs\R01-kujawa-NUTS\lib\site-packages\pytensor\gradient.py:1037, in <listcomp>(.0)
   1034 if node not in term_dict:
   1035     inputs = node.inputs
-> 1037     output_grads = [access_grad_cache(var) for var in node.outputs]
   1039     # list of bools indicating if each output is connected to the cost
   1040     outputs_connected = [
   1041         not isinstance(g.type, DisconnectedType) for g in output_grads
   1042     ]

File ~\anaconda3\envs\R01-kujawa-NUTS\lib\site-packages\pytensor\gradient.py:1362, in _populate_grad_dict.<locals>.access_grad_cache(var)
   1360 for node in node_to_idx:
   1361     for idx in node_to_idx[node]:
-> 1362         term = access_term_cache(node)[idx]
   1364         if not isinstance(term, Variable):
   1365             raise TypeError(
   1366                 f"{node.op}.grad returned {type(term)}, expected"
   1367                 " Variable instance."
   1368             )

File ~\anaconda3\envs\R01-kujawa-NUTS\lib\site-packages\pytensor\gradient.py:1037, in _populate_grad_dict.<locals>.access_term_cache(node)
   1034 if node not in term_dict:
   1035     inputs = node.inputs
-> 1037     output_grads = [access_grad_cache(var) for var in node.outputs]
   1039     # list of bools indicating if each output is connected to the cost
   1040     outputs_connected = [
   1041         not isinstance(g.type, DisconnectedType) for g in output_grads
   1042     ]

File ~\anaconda3\envs\R01-kujawa-NUTS\lib\site-packages\pytensor\gradient.py:1037, in <listcomp>(.0)
   1034 if node not in term_dict:
   1035     inputs = node.inputs
-> 1037     output_grads = [access_grad_cache(var) for var in node.outputs]
   1039     # list of bools indicating if each output is connected to the cost
   1040     outputs_connected = [
   1041         not isinstance(g.type, DisconnectedType) for g in output_grads
   1042     ]

File ~\anaconda3\envs\R01-kujawa-NUTS\lib\site-packages\pytensor\gradient.py:1362, in _populate_grad_dict.<locals>.access_grad_cache(var)
   1360 for node in node_to_idx:
   1361     for idx in node_to_idx[node]:
-> 1362         term = access_term_cache(node)[idx]
   1364         if not isinstance(term, Variable):
   1365             raise TypeError(
   1366                 f"{node.op}.grad returned {type(term)}, expected"
   1367                 " Variable instance."
   1368             )

File ~\anaconda3\envs\R01-kujawa-NUTS\lib\site-packages\pytensor\gradient.py:1192, in _populate_grad_dict.<locals>.access_term_cache(node)
   1184         if o_shape != g_shape:
   1185             raise ValueError(
   1186                 "Got a gradient of shape "
   1187                 + str(o_shape)
   1188                 + " on an output of shape "
   1189                 + str(g_shape)
   1190             )
-> 1192 input_grads = node.op.L_op(inputs, node.outputs, new_output_grads)
   1194 if input_grads is None:
   1195     raise TypeError(
   1196         f"{node.op}.grad returned NoneType, expected iterable."
   1197     )

File ~\anaconda3\envs\R01-kujawa-NUTS\lib\site-packages\pytensor\graph\op.py:389, in Op.L_op(self, inputs, outputs, output_grads)
    362 def L_op(
    363     self,
    364     inputs: Sequence[Variable],
    365     outputs: Sequence[Variable],
    366     output_grads: Sequence[Variable],
    367 ) -> list[Variable]:
    368     r"""Construct a graph for the L-operator.
    369 
    370     The L-operator computes a row vector times the Jacobian.
   (...)
    387 
    388     """
--> 389     return self.grad(inputs, output_grads)

File ~\anaconda3\envs\R01-kujawa-NUTS\lib\site-packages\pytensor\graph\op.py:360, in Op.grad(self, inputs, output_grads)
    317 def grad(
    318     self, inputs: Sequence[Variable], output_grads: Sequence[Variable]
    319 ) -> list[Variable]:
    320     r"""Construct a graph for the gradient with respect to each input variable.
    321 
    322     Each returned `Variable` represents the gradient with respect to that
   (...)
    358 
    359     """
--> 360     raise NotImplementedError()

NotImplementedError: 

Broadcasting should be fine, the error message would be more useful if it said which Op is raising that NotImplementedError

1 Like

I think it’s due to the max(), because this also raises:

with pm.Model() as model:
    x = pm.Normal('test', shape=(74, 2))
    mu = x.max(axis=-1)
    pm.Normal('obs', mu, observed=np.random.uniform(size=74))

Strange, isn’t max differentiable (something argmax?)?

Is MaxAndArgmax being rewritten into Max too soon (in the pregrad rewrites?)

This works, so could it be related to MeasurableMax?

x = pt.tensor('x', shape=(None, 2))
y = pt.max(x, axis=-1)
pytensor.grad(y.sum(), x).eval({x:np.random.normal(size=(10, 2))})

Don’t know, either logprob rewrites or pregrad rewrites, one of those is likely introducing Max (which should be the default anyway -.-) but doesn’t seem to have grad implemented.

The logic should be the other way around, first there are maxes and argmaxes, and then a specialized Op can be introduced during rewrites to merge them. Anyway, we have to fix it

Can you open in issue on PyMC?

1 Like