Wrapping a scipy distribution with CustomDist

Using a scipy distribution with CustomDist

Hi all, I’ve been trying to figure out a way to use scipy.stats.gengamma with CustomDist, but I’m having some serious trouble figuring out the documentation, and the errors I’ve been getting.

I found this previous post: Wrapping a scipy distribution with DensityDist

but it seems to be referring to an older version of pymc, and maybe doesn’t apply any more?

Either way, from my understanding, CustomDist should work with any function so long as it has a way to compute the logPDF and draw samples from it, right? I guess my hangup is understanding how arguments are actually passed to my custom logP and random functions, and how they should be crafted to work with pymc and aesara.

Here’s what I’ve tried:

import numpy as np
from scipy import stats

import pymc as pm

q = np.array([0.137, 0.146, 0.143, 0.14, 0.136, 0.28, 0.223, 0.167, 0.158, 1.007, 0.477, 0.286, 
              0.209, 0.194, 0.19, 0.185, 0.181, 0.434, 0.24, 0.209, 0.197, 0.403, 0.235, 0.211,
              0.205, 0.184, 0.173, 0.714, 0.314, 0.256])
L = np.array([57.423, 56.73, 56.423, 55.273, 54.581, 72.969, 59.053, 59.927, 60.152, 72.497, 
              67.81, 64.752, 64.355, 64.503, 63.982, 63.033, 62.93, 66.293, 65.472, 64.952, 
              64.23, 66.867, 64.954, 65.373, 65.065, 61.19, 62.883, 67.027, 67.089, 70.58])

with pm.Model() as m:
    a = pm.Uniform("a", lower=0, upper=500)
    b = pm.Uniform("b", lower=0, upper=5)
    ak = pm.Uniform("ak", lower=0, upper=10)
    lk = pm.Uniform("lk", lower=0, upper=100)

    def random(lk, ak, a, b, size=None, rng=None):
        scale = (a*(ak)**b).eval()
        a = lk.eval()
        c = (a/b).eval()
        out = stats.gengamma(scale=scale, a=a, c=c).rvs(size=size.data, random_state=rng)
        return out
    
    def logp(value, lk, ak, a, b):
        scale = (a*(ak)**b).eval()
        a = lk.eval()
        c = (a/b).eval()
        out = stats.gengamma(scale=scale, a=a, c=c).logpdf(value)
        return out
    
    def logcdf(value, lk, ak, a, b):
        scale = (a*(ak)**b).eval()
        a = lk.eval()
        c = (a/b).eval()
        out = stats.gengamma(scale=scale, a=a, c=c).logcdf(value)
        return out
    
    L_obs = pm.CustomDist(
        "GenGamma", 
        lk, ak, a, b,
        random=random,
        logp=logp,
        logcdf=logcdf,
        dtype="float64",
        observed=L
    )

    idata = pm.sample(3000)
And the error message I get:
Traceback (most recent call last):
  File "/Users/alex/Documents/Work/ORNL/Projects/streamflow-active-length-regimes/pymc_testing.py", line 84, in <module>
    idata = pm.sample(3000)
            ^^^^^^^^^^^^^^^
  File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/site-packages/pymc/sampling/mcmc.py", line 564, in sample
    step = assign_step_methods(model, step, methods=pm.STEP_METHODS, step_kwargs=kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/site-packages/pymc/sampling/mcmc.py", line 181, in assign_step_methods
    model_logp = model.logp()
                 ^^^^^^^^^^^^
  File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/site-packages/pymc/model.py", line 764, in logp
    rv_logps = joint_logp(
               ^^^^^^^^^^^
  File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/site-packages/pymc/logprob/basic.py", line 359, in joint_logp
    temp_logp_terms = factorized_joint_logprob(
                      ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/site-packages/pymc/logprob/basic.py", line 277, in factorized_joint_logprob
    q_logprob_vars = _logprob(
                     ^^^^^^^^^
  File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/functools.py", line 909, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/site-packages/pymc/distributions/distribution.py", line 568, in custom_dist_logp
    return logp(values[0], *dist_params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/alex/Documents/Work/ORNL/Projects/streamflow-active-length-regimes/pymc_testing.py", line 61, in logp
    scale = (a*(ak)**b).eval()
            ^^^^^^^^^^^^^^^^^^
  File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/site-packages/pytensor/graph/basic.py", line 619, in eval
    self._fn_cache[inputs] = function(inputs, self)
                             ^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/site-packages/pytensor/compile/function/__init__.py", line 315, in function
    fn = pfunc(
         ^^^^^^
  File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/site-packages/pytensor/compile/function/pfunc.py", line 367, in pfunc
    return orig_function(
           ^^^^^^^^^^^^^^
  File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/site-packages/pytensor/compile/function/types.py", line 1744, in orig_function
    m = Maker(
        ^^^^^^
  File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/site-packages/pytensor/compile/function/types.py", line 1508, in __init__
    fgraph, found_updates = std_fgraph(
                            ^^^^^^^^^^^
  File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/site-packages/pytensor/compile/function/types.py", line 228, in std_fgraph
    fgraph = FunctionGraph(
             ^^^^^^^^^^^^^^
  File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/site-packages/pytensor/graph/fg.py", line 157, in __init__
    self.add_output(output, reason="init")
  File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/site-packages/pytensor/graph/fg.py", line 167, in add_output
    self.import_var(var, reason=reason, import_missing=import_missing)
  File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/site-packages/pytensor/graph/fg.py", line 308, in import_var
    self.import_node(var.owner, reason=reason, import_missing=import_missing)
  File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/site-packages/pytensor/graph/fg.py", line 373, in import_node
    raise MissingInputError(error_msg, variable=var)
pytensor.graph.utils.MissingInputError: Input 0 (b_interval__) of the graph (indices start from 0), used to compute Elemwise{sigmoid,no_inplace}(b_interval__), was not provided and not given a value. Use the PyTensor flag exception_verbosity='high', for more information on this error.
 
Backtrace when that variable is created:

  File "/Users/alex/Documents/Work/ORNL/Projects/streamflow-active-length-regimes/pymc_testing.py", line 49, in <module>
    b = pm.Uniform("b", lower=0, upper=5)
  File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/site-packages/pymc/distributions/distribution.py", line 312, in __new__
    rv_out = model.register_rv(
  File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/site-packages/pymc/model.py", line 1333, in register_rv
    self.create_value_var(rv_var, transform)
  File "/Users/alex/mambaforge/envs/pymc5/lib/python3.11/site-packages/pymc/model.py", line 1526, in create_value_var
    value_var = transform.forward(rv_var, *rv_var.owner.inputs).type()

I wish I had more to share, but I haven’t been able to get any version of this running without errors. Any suggestions are appreciated!

Thanks,
Alex

Have you seen the code examples in the documentation? pymc.CustomDist — PyMC 5.15.1 documentation

Any time you call .eval(), you are terminating the construction of a computational graph and evaluating whatever leads up to the symbol. logp and logcdf expect symbolic inputs and outputs, but scipy expects numerical inputs and outputs. You can check this article on wrapping arbitrary python code into a pytensor Op for use in a case like this.

For the record, you can implement a generalized gamma distribution using a generative model, which eliminates the need to do any of this. If you want to deeply understand how PyMC does this, these probability puzzles are a great resource. In particular you can use this fact from wikipeia:

  • More generally, if X \sim \text{Gamma}(k,\theta), then X^q for q > 0 follows a generalized gamma distribution with parameters p = 1/q, d = k/q, and a = θ^q