Using a scipy distribution with CustomDist
Hi all, I’ve been trying to figure out a way to use scipy.stats.gengamma with CustomDist, but I’m having some serious trouble figuring out the documentation, and the errors I’ve been getting.
I found this previous post: Wrapping a scipy distribution with DensityDist
but it seems to be referring to an older version of pymc, and maybe doesn’t apply any more?
Either way, from my understanding, CustomDist
should work with any function so long as it has a way to compute the logPDF and draw samples from it, right? I guess my hangup is understanding how arguments are actually passed to my custom logP and random functions, and how they should be crafted to work with pymc and aesara.
Here’s what I’ve tried:
import numpy as np
from scipy import stats
import pymc as pm
q = np.array([0.137, 0.146, 0.143, 0.14, 0.136, 0.28, 0.223, 0.167, 0.158, 1.007, 0.477, 0.286,
0.209, 0.194, 0.19, 0.185, 0.181, 0.434, 0.24, 0.209, 0.197, 0.403, 0.235, 0.211,
0.205, 0.184, 0.173, 0.714, 0.314, 0.256])
L = np.array([57.423, 56.73, 56.423, 55.273, 54.581, 72.969, 59.053, 59.927, 60.152, 72.497,
67.81, 64.752, 64.355, 64.503, 63.982, 63.033, 62.93, 66.293, 65.472, 64.952,
64.23, 66.867, 64.954, 65.373, 65.065, 61.19, 62.883, 67.027, 67.089, 70.58])
with pm.Model() as m:
a = pm.Uniform("a", lower=0, upper=500)
b = pm.Uniform("b", lower=0, upper=5)
ak = pm.Uniform("ak", lower=0, upper=10)
lk = pm.Uniform("lk", lower=0, upper=100)
def random(lk, ak, a, b, size=None, rng=None):
scale = (a*(ak)**b).eval()
a = lk.eval()
c = (a/b).eval()
out = stats.gengamma(scale=scale, a=a, c=c).rvs(size=size.data, random_state=rng)
return out
def logp(value, lk, ak, a, b):
scale = (a*(ak)**b).eval()
a = lk.eval()
c = (a/b).eval()
out = stats.gengamma(scale=scale, a=a, c=c).logpdf(value)
return out
def logcdf(value, lk, ak, a, b):
scale = (a*(ak)**b).eval()
a = lk.eval()
c = (a/b).eval()
out = stats.gengamma(scale=scale, a=a, c=c).logcdf(value)
return out
L_obs = pm.CustomDist(
"GenGamma",
lk, ak, a, b,
random=random,
logp=logp,
logcdf=logcdf,
dtype="float64",
observed=L
)
idata = pm.sample(3000)
And the error message I get:
Traceback (most recent call last):
File "/Users/alex/Documents/Work/ORNL/Projects/streamflow-active-length-regimes/pymc_testing.py", line 84, in <module>
idata = pm.sample(3000)
^^^^^^^^^^^^^^^
File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/site-packages/pymc/sampling/mcmc.py", line 564, in sample
step = assign_step_methods(model, step, methods=pm.STEP_METHODS, step_kwargs=kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/site-packages/pymc/sampling/mcmc.py", line 181, in assign_step_methods
model_logp = model.logp()
^^^^^^^^^^^^
File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/site-packages/pymc/model.py", line 764, in logp
rv_logps = joint_logp(
^^^^^^^^^^^
File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/site-packages/pymc/logprob/basic.py", line 359, in joint_logp
temp_logp_terms = factorized_joint_logprob(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/site-packages/pymc/logprob/basic.py", line 277, in factorized_joint_logprob
q_logprob_vars = _logprob(
^^^^^^^^^
File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/functools.py", line 909, in wrapper
return dispatch(args[0].__class__)(*args, **kw)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/site-packages/pymc/distributions/distribution.py", line 568, in custom_dist_logp
return logp(values[0], *dist_params)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/alex/Documents/Work/ORNL/Projects/streamflow-active-length-regimes/pymc_testing.py", line 61, in logp
scale = (a*(ak)**b).eval()
^^^^^^^^^^^^^^^^^^
File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/site-packages/pytensor/graph/basic.py", line 619, in eval
self._fn_cache[inputs] = function(inputs, self)
^^^^^^^^^^^^^^^^^^^^^^
File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/site-packages/pytensor/compile/function/__init__.py", line 315, in function
fn = pfunc(
^^^^^^
File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/site-packages/pytensor/compile/function/pfunc.py", line 367, in pfunc
return orig_function(
^^^^^^^^^^^^^^
File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/site-packages/pytensor/compile/function/types.py", line 1744, in orig_function
m = Maker(
^^^^^^
File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/site-packages/pytensor/compile/function/types.py", line 1508, in __init__
fgraph, found_updates = std_fgraph(
^^^^^^^^^^^
File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/site-packages/pytensor/compile/function/types.py", line 228, in std_fgraph
fgraph = FunctionGraph(
^^^^^^^^^^^^^^
File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/site-packages/pytensor/graph/fg.py", line 157, in __init__
self.add_output(output, reason="init")
File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/site-packages/pytensor/graph/fg.py", line 167, in add_output
self.import_var(var, reason=reason, import_missing=import_missing)
File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/site-packages/pytensor/graph/fg.py", line 308, in import_var
self.import_node(var.owner, reason=reason, import_missing=import_missing)
File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/site-packages/pytensor/graph/fg.py", line 373, in import_node
raise MissingInputError(error_msg, variable=var)
pytensor.graph.utils.MissingInputError: Input 0 (b_interval__) of the graph (indices start from 0), used to compute Elemwise{sigmoid,no_inplace}(b_interval__), was not provided and not given a value. Use the PyTensor flag exception_verbosity='high', for more information on this error.
Backtrace when that variable is created:
File "/Users/alex/Documents/Work/ORNL/Projects/streamflow-active-length-regimes/pymc_testing.py", line 49, in <module>
b = pm.Uniform("b", lower=0, upper=5)
File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/site-packages/pymc/distributions/distribution.py", line 312, in __new__
rv_out = model.register_rv(
File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/site-packages/pymc/model.py", line 1333, in register_rv
self.create_value_var(rv_var, transform)
File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/site-packages/pymc/model.py", line 1526, in create_value_var
value_var = transform.forward(rv_var, *rv_var.owner.inputs).type()
I wish I had more to share, but I haven’t been able to get any version of this running without errors. Any suggestions are appreciated!
Thanks,
Alex