Pickle a model containing a custom distribution

I have created a custom distribution that I use for the observed variable in my model. After I pickle my model (using pickle or cloudpickle), restart the script and try to sample a trace from the model, it fails. This is my code:

class GenNormRV(RandomVariable):
    name: str = "GenNorm"
    ndim_supp: int = 0
    ndims_params: List[int] = [0, 0, 0]
    dtype: str = "floatX"
    _print_name: Tuple[str, str] = ("GenNorm", "GGD")

    @classmethod
    def rng_fn(
        cls,
        rng: np.random.RandomState,
        beta: np.ndarray,
        loc: np.ndarray,
        scale: np.ndarray,
        size: Tuple[int, ...],
    ) -> np.ndarray:
        return ss.gennorm.rvs(beta, loc, scale, random_state=rng, size=size)
    

gennormrv = GenNormRV()
    
class GenNorm(Continuous):
    rv_op = gennormrv

    @classmethod
    def dist(cls, beta, loc, scale, *args, **kwargs):
        beta = at.as_tensor_variable(floatX(beta))
        loc = at.as_tensor_variable(floatX(loc))
        scale = at.as_tensor_variable(floatX(scale))
        return super().dist([beta, loc, scale], *args, **kwargs)

    def moment(rv, size, beta, loc, scale):
        moment, _ = at.broadcast_arrays(beta, loc, scale)
        if not rv_size_is_none(size):
            moment = at.full(size, moment)
        return moment

    def logp(value, beta, loc, scale):
        return check_parameters(
            at.log(beta / (2 * scale)) - at.gammaln(1.0 / beta) -
            (at.abs_(value - loc) / scale)**beta, beta >= 0, scale >= 0)

    def logcdf(value, beta, loc, scale):
        b = value - loc
        c = 0.5 * b / at.abs_(b)
        return (0.5 + c) - c * at.gammaincc(1.0 / beta,
                                            at.abs_(b / scale)**beta)

model = pm.Model()
with model:
    beta = pm.TruncatedNormal("beta", mu=0, sigma=1, lower=0)
    loc = pm.Normal("loc", mu=0, sigma=1)
    scale = pm.HalfNormal("scale", sigma=1)
    obs = GenNorm("obs", beta=beta, loc=loc, scale=scale, observed=data_sample)
    
write_model("test.pkl", model)
# I am restarting my script after this
model = read_model("test.pkl")

with model:
    trace = pm.sample(draws=1000, step=pm.Metropolis(), chains=4, cores=4)

This is my error:

NotImplementedError                       Traceback (most recent call last)
Input In [13], in <cell line: 5>()
      3 GenNorm.rv_op = gennormrv
      5 with model:
----> 6     trace = pm.sample(draws=1000,step=pm.Metropolis(),chains=4,cores=4,progressbar=False)

File ~/repos/magisterska/.venv/lib/python3.10/site-packages/pymc/step_methods/arraystep.py:89, in BlockedStep.__new__(cls, *args, **kwargs)
     86 step = super().__new__(cls)
     87 # If we don't return the instance we have to manually
     88 # call __init__
---> 89 step.__init__([var], *args, **kwargs)
     90 # Hack for creating the class correctly when unpickling.
     91 step.__newargs = ([var],) + args, kwargs

File ~/repos/magisterska/.venv/lib/python3.10/site-packages/pymc/step_methods/metropolis.py:229, in Metropolis.__init__(self, vars, S, proposal_dist, scaling, tune, tune_interval, model, mode, **kwargs)
    226 self.mode = mode
    228 shared = pm.make_shared_replacements(initial_values, vars, model)
--> 229 self.delta_logp = delta_logp(initial_values, model.logpt(), vars, shared)
    230 super().__init__(vars, shared)

File ~/repos/magisterska/.venv/lib/python3.10/site-packages/pymc/model.py:745, in Model.logpt(self, vars, jacobian, sum)
    743 rv_logps: List[TensorVariable] = []
    744 if rv_values:
--> 745     rv_logps = joint_logpt(list(rv_values.keys()), rv_values, sum=False, jacobian=jacobian)
    746     assert isinstance(rv_logps, list)
    748 # Replace random variables by their value variables in potential terms

File ~/repos/magisterska/.venv/lib/python3.10/site-packages/pymc/distributions/logprob.py:226, in joint_logpt(var, rv_values, jacobian, scaling, transformed, sum, **kwargs)
    223                 transform_map[value_var] = original_value_var.tag.transform
    225 transform_opt = TransformValuesOpt(transform_map)
--> 226 temp_logp_var_dict = factorized_joint_logprob(
    227     tmp_rvs_to_values,
    228     extra_rewrites=transform_opt,
    229     use_jacobian=jacobian,
    230     warn_missing_rvs=False,
    231     **kwargs,
    232 )
    234 # Raise if there are unexpected RandomVariables in the logp graph
    235 # Only SimulatorRVs are allowed
    236 from pymc.distributions.simulator import SimulatorRV

File ~/repos/magisterska/.venv/lib/python3.10/site-packages/aeppl/joint_logprob.py:147, in factorized_joint_logprob(rv_values, warn_missing_rvs, extra_rewrites, **kwargs)
    144 q_value_vars = remapped_vars[: len(q_value_vars)]
    145 q_rv_inputs = remapped_vars[len(q_value_vars) :]
--> 147 q_logprob_vars = _logprob(
    148     node.op,
    149     q_value_vars,
    150     *q_rv_inputs,
    151     **kwargs,
    152 )
    154 if not isinstance(q_logprob_vars, (list, tuple)):
    155     q_logprob_vars = [q_logprob_vars]

File ~/miniconda3/envs/py310/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw)
    885 if not args:
    886     raise TypeError(f'{funcname} requires at least '
    887                     '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)

File ~/repos/magisterska/.venv/lib/python3.10/site-packages/aeppl/logprob.py:85, in _logprob(op, values, *inputs, **kwargs)
     71 @singledispatch
     72 def _logprob(
     73     op: Op,
   (...)
     76     **kwargs,
     77 ):
     78     """Create a graph for the log-density/mass of a ``RandomVariable``.
     79 
     80     This function dispatches on the type of ``op``, which should be a subclass
   (...)
     83 
     84     """
---> 85     raise NotImplementedError(f"Logprob method not implemented for {op}")

NotImplementedError: Logprob method not implemented for GenNorm_rv{0, (0, 0, 0), floatX, False}

Does anyone have any idea why this is happening?

Does this happen if you define the distribution in a separate script that is just imported (instead of in the same script where you sample)?

1 Like

No. Not sure why, but importing the custom distribution from a different script seems to solve the problem. Thank you!

That’s quite common with pickling issues actually. Glad that I could help.

1 Like