I have created a custom distribution that I use for the observed variable in my model. After I pickle my model (using pickle or cloudpickle), restart the script and try to sample a trace from the model, it fails. This is my code:
class GenNormRV(RandomVariable):
name: str = "GenNorm"
ndim_supp: int = 0
ndims_params: List[int] = [0, 0, 0]
dtype: str = "floatX"
_print_name: Tuple[str, str] = ("GenNorm", "GGD")
@classmethod
def rng_fn(
cls,
rng: np.random.RandomState,
beta: np.ndarray,
loc: np.ndarray,
scale: np.ndarray,
size: Tuple[int, ...],
) -> np.ndarray:
return ss.gennorm.rvs(beta, loc, scale, random_state=rng, size=size)
gennormrv = GenNormRV()
class GenNorm(Continuous):
rv_op = gennormrv
@classmethod
def dist(cls, beta, loc, scale, *args, **kwargs):
beta = at.as_tensor_variable(floatX(beta))
loc = at.as_tensor_variable(floatX(loc))
scale = at.as_tensor_variable(floatX(scale))
return super().dist([beta, loc, scale], *args, **kwargs)
def moment(rv, size, beta, loc, scale):
moment, _ = at.broadcast_arrays(beta, loc, scale)
if not rv_size_is_none(size):
moment = at.full(size, moment)
return moment
def logp(value, beta, loc, scale):
return check_parameters(
at.log(beta / (2 * scale)) - at.gammaln(1.0 / beta) -
(at.abs_(value - loc) / scale)**beta, beta >= 0, scale >= 0)
def logcdf(value, beta, loc, scale):
b = value - loc
c = 0.5 * b / at.abs_(b)
return (0.5 + c) - c * at.gammaincc(1.0 / beta,
at.abs_(b / scale)**beta)
model = pm.Model()
with model:
beta = pm.TruncatedNormal("beta", mu=0, sigma=1, lower=0)
loc = pm.Normal("loc", mu=0, sigma=1)
scale = pm.HalfNormal("scale", sigma=1)
obs = GenNorm("obs", beta=beta, loc=loc, scale=scale, observed=data_sample)
write_model("test.pkl", model)
# I am restarting my script after this
model = read_model("test.pkl")
with model:
trace = pm.sample(draws=1000, step=pm.Metropolis(), chains=4, cores=4)
This is my error:
NotImplementedError Traceback (most recent call last)
Input In [13], in <cell line: 5>()
3 GenNorm.rv_op = gennormrv
5 with model:
----> 6 trace = pm.sample(draws=1000,step=pm.Metropolis(),chains=4,cores=4,progressbar=False)
File ~/repos/magisterska/.venv/lib/python3.10/site-packages/pymc/step_methods/arraystep.py:89, in BlockedStep.__new__(cls, *args, **kwargs)
86 step = super().__new__(cls)
87 # If we don't return the instance we have to manually
88 # call __init__
---> 89 step.__init__([var], *args, **kwargs)
90 # Hack for creating the class correctly when unpickling.
91 step.__newargs = ([var],) + args, kwargs
File ~/repos/magisterska/.venv/lib/python3.10/site-packages/pymc/step_methods/metropolis.py:229, in Metropolis.__init__(self, vars, S, proposal_dist, scaling, tune, tune_interval, model, mode, **kwargs)
226 self.mode = mode
228 shared = pm.make_shared_replacements(initial_values, vars, model)
--> 229 self.delta_logp = delta_logp(initial_values, model.logpt(), vars, shared)
230 super().__init__(vars, shared)
File ~/repos/magisterska/.venv/lib/python3.10/site-packages/pymc/model.py:745, in Model.logpt(self, vars, jacobian, sum)
743 rv_logps: List[TensorVariable] = []
744 if rv_values:
--> 745 rv_logps = joint_logpt(list(rv_values.keys()), rv_values, sum=False, jacobian=jacobian)
746 assert isinstance(rv_logps, list)
748 # Replace random variables by their value variables in potential terms
File ~/repos/magisterska/.venv/lib/python3.10/site-packages/pymc/distributions/logprob.py:226, in joint_logpt(var, rv_values, jacobian, scaling, transformed, sum, **kwargs)
223 transform_map[value_var] = original_value_var.tag.transform
225 transform_opt = TransformValuesOpt(transform_map)
--> 226 temp_logp_var_dict = factorized_joint_logprob(
227 tmp_rvs_to_values,
228 extra_rewrites=transform_opt,
229 use_jacobian=jacobian,
230 warn_missing_rvs=False,
231 **kwargs,
232 )
234 # Raise if there are unexpected RandomVariables in the logp graph
235 # Only SimulatorRVs are allowed
236 from pymc.distributions.simulator import SimulatorRV
File ~/repos/magisterska/.venv/lib/python3.10/site-packages/aeppl/joint_logprob.py:147, in factorized_joint_logprob(rv_values, warn_missing_rvs, extra_rewrites, **kwargs)
144 q_value_vars = remapped_vars[: len(q_value_vars)]
145 q_rv_inputs = remapped_vars[len(q_value_vars) :]
--> 147 q_logprob_vars = _logprob(
148 node.op,
149 q_value_vars,
150 *q_rv_inputs,
151 **kwargs,
152 )
154 if not isinstance(q_logprob_vars, (list, tuple)):
155 q_logprob_vars = [q_logprob_vars]
File ~/miniconda3/envs/py310/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw)
885 if not args:
886 raise TypeError(f'{funcname} requires at least '
887 '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)
File ~/repos/magisterska/.venv/lib/python3.10/site-packages/aeppl/logprob.py:85, in _logprob(op, values, *inputs, **kwargs)
71 @singledispatch
72 def _logprob(
73 op: Op,
(...)
76 **kwargs,
77 ):
78 """Create a graph for the log-density/mass of a ``RandomVariable``.
79
80 This function dispatches on the type of ``op``, which should be a subclass
(...)
83
84 """
---> 85 raise NotImplementedError(f"Logprob method not implemented for {op}")
NotImplementedError: Logprob method not implemented for GenNorm_rv{0, (0, 0, 0), floatX, False}
Does anyone have any idea why this is happening?