KeyError: 'tune' in _get_sampler_stats with custom pdf

Hi, I was trying to setup a custom probability. Unfortunately I couldn’t find an up-to-date version of the guide, so I approximately followed the source code of the continuous distributions.
Here it is the model:

class VarianceGammaRV(RandomVariable):
    # https://arxiv.org/pdf/2303.05615.pdf eq 2.20
    name: str = "variance_gamma"

    ndim_supp: int = 0

    ndims_params: List[int] = [0, 0, 0, 0]

    dtype: str = "floatX"
    @classmethod
    def rng_fn(
        cls,
        rng: np.random.RandomState,
        r: np.ndarray,
        theta: np.ndarray,
        sigma: np.ndarray,
        mu: np.ndarray,
        size: Tuple[int, ...],
    ) -> np.ndarray:
        v0 = np.sqrt(theta**2+sigma**2)
        v1 = 1/(v0 + theta)
        v2 = 1/(v0 - theta)
        s1 = scipy.stats.gamma.rvs(r/2, v1, random_state=rng, size=size)
        s2 = scipy.stats.gamma.rvs(r/2, v2, random_state=rng, size=size)
        return (mu + s1 - s2)

class VarianceGamma(Continuous):
    rv_op = variance_gamma
    
    @classmethod
    def dist(cls, r, theta, sigma, mu, *args, **kwargs):
        r = pt.as_tensor_variable(floatX(r))
        theta = pt.as_tensor_variable(floatX(theta))
        sigma = pt.as_tensor_variable(floatX(sigma))
        mu = pt.as_tensor_variable(floatX(mu))

        return super().dist([r, theta, sigma, mu], *args, **kwargs)

    def logp(value, r, theta, sigma, mu):
        x = value - mu
        sigma2 = sigma**2
        abs_x = pt.abs(x)
        s = pt.sqrt(theta**2 + sigma2)
        v1 = -pt.log(sigma*pt.sqrt(np.pi))-gammaln(r/2)
        v2 = theta*x/sigma2
        v3 = (r-1)/2*pt.log(abs_x/(2*s))
        alpha = (r-1.0)/2.0
        v40 = - s*abs_x/sigma2 - 1.0/2.0*pt.log(np.pi/(2.0*s*abs_x/sigma2))
        v4 = pt.switch(pt.abs(alpha)>0.0,pt.log(np.pi/2.0*(pt.iv(-alpha, s*abs_x/sigma2) - pt.iv(alpha, s*abs_x/sigma2))/pt.sin(alpha*np.pi)), v40 )
        # v4 = scipy.special.kv(alpha, s*abs_x/sigma2)
        res = v1 + v2 + v3 + v4
        return check_parameters(
            res,
            r > 0,
            sigma > 0,
            msg="r > 0, sigma > 0",
        )

The distribution seems to properly work, I both get reasonable random variables and pdf. However when I try and compute the trace I get

File ~/.local/lib/python3.11/site-packages/pymc/backends/ndarray.py:125, in NDArray._get_sampler_stats(self, varname, sampler_idx, burn, thin)
    122 def _get_sampler_stats(
    123     self, varname: str, sampler_idx: int, burn: int, thin: int
    124 ) -> np.ndarray:
--> 125     return self._stats[sampler_idx][varname][burn::thin]

KeyError: 'tune'

I guess it has something to do with RandomVariable (maybe I should use ScipyRandomVariable, but it’s not here to me how I am supposed to handle loc/scale parameters).
Thanks!

Hi, I wasn’t able to reproduce your problem locally. I think your RV might be well-defined and the problem could rest elsewhere.

I built the RV with this code

from pytensor.tensor.random.op import RandomVariable
from typing import List, Tuple
from pymc.pytensorf import floatX
from pymc.distributions.distribution import Continuous
import scipy
from pymc.distributions.dist_math import (check_parameters)

class VarianceGammaRV(RandomVariable):
    # https://arxiv.org/pdf/2303.05615.pdf eq 2.20
    name: str = "variance_gamma"

    ndim_supp: int = 0

    ndims_params: List[int] = [0, 0, 0, 0]

    dtype: str = "floatX"
    @classmethod
    def rng_fn(
        cls,
        rng: np.random.RandomState,
        r: np.ndarray,
        theta: np.ndarray,
        sigma: np.ndarray,
        mu: np.ndarray,
        size: Tuple[int, ...],
    ) -> np.ndarray:
        v0 = np.sqrt(theta**2+sigma**2)
        v1 = 1/(v0 + theta)
        v2 = 1/(v0 - theta)
        s1 = scipy.stats.gamma.rvs(r/2, v1, random_state=rng, size=size)
        s2 = scipy.stats.gamma.rvs(r/2, v2, random_state=rng, size=size)
        return (mu + s1 - s2)

class VarianceGamma(Continuous):
    #rv_op = variance_gamma
    rv_op = VarianceGammaRV()
    
    @classmethod
    def dist(cls, r, theta, sigma, mu, *args, **kwargs):
        r = pt.as_tensor_variable(floatX(r))
        theta = pt.as_tensor_variable(floatX(theta))
        sigma = pt.as_tensor_variable(floatX(sigma))
        mu = pt.as_tensor_variable(floatX(mu))

        return super().dist([r, theta, sigma, mu], *args, **kwargs)

    def logp(value, r, theta, sigma, mu):
        x = value - mu
        sigma2 = sigma**2
        abs_x = pt.abs(x)
        s = pt.sqrt(theta**2 + sigma2)
        v1 = -pt.log(sigma*pt.sqrt(np.pi))# - gammaln(r/2)
        v2 = theta*x/sigma2
        v3 = (r-1)/2*pt.log(abs_x/(2*s))
        alpha = (r-1.0)/2.0
        v40 = - s*abs_x/sigma2 - 1.0/2.0*pt.log(np.pi/(2.0*s*abs_x/sigma2))
        v4 = pt.switch(pt.abs(alpha)>0.0,pt.log(np.pi/2.0*(pt.iv(-alpha, s*abs_x/sigma2) - pt.iv(alpha, s*abs_x/sigma2))/pt.sin(alpha*np.pi)), v40 )
        # v4 = scipy.special.kv(alpha, s*abs_x/sigma2)
        res = v1 + v2 + v3 + v4
        return check_parameters(
            res,
            r > 0,
            sigma > 0,
            msg="r > 0, sigma > 0",
        )

I generated a simulated dataset with:

VG = VarianceGamma.dist(0.5,1,1,0)
data = pm.draw(VG,draws=100)

Tested the logp with:

pm.logp(VG,data).eval()

And then sampled the model with:

with pm.Model() as m0:
    r = 0.5
    theta = pm.Normal('theta',0,1) 
    sigma = pm.Normal('sigma',0,1)
    mu = pm.Normal('mu',0,1)
    
    y = VarianceGamma('y',r,theta,sigma,mu,observed=data)
    trace = pm.sample()

The code runs as is. Idk if this clarifies anything but the sampler didn’t perform particularly well, my trace plots show a couple of the chains barely moved during sampling.

Also, could you clear up what gammaln(r/2) is in:

In my version, I commented out the subtract gammaln for the time being.

1 Like

Hi Daniel,

you are right, now the error disappeared.
Thanks for the immediate reply!

1 Like

Hello again.
The problem actually re-appeared once I re-introduced the r prior:

import pytensor.tensor as pt
from pytensor.tensor.random.op import RandomVariable
from typing import List, Tuple
from pymc.pytensorf import floatX
from pymc.distributions.distribution import Continuous
import scipy
import numpy as np
import pymc as pm
import arviz as az
from pytensor.tensor import gammaln
from jax import numpy as jnp
from pymc.distributions.dist_math import (check_parameters)


class VarianceGammaRV(RandomVariable):
    # https://arxiv.org/pdf/2303.05615.pdf eq 2.20
    name: str = "variance_gamma"

    ndim_supp: int = 0

    ndims_params: List[int] = [0, 0, 0, 0]

    dtype: str = "floatX"
    @classmethod
    def rng_fn(
        cls,
        rng: np.random.RandomState,
        r: np.ndarray,
        theta: np.ndarray,
        sigma: np.ndarray,
        mu: np.ndarray,
        size: Tuple[int, ...],
    ) -> np.ndarray:
        v0 = np.sqrt(theta**2+sigma**2)
        v1 = 1/(v0 + theta)
        v2 = 1/(v0 - theta)
        s1 = scipy.stats.gamma.rvs(r/2, v1, random_state=rng, size=size)
        s2 = scipy.stats.gamma.rvs(r/2, v2, random_state=rng, size=size)
        return (mu + s1 - s2)

class VarianceGamma(Continuous):
    #rv_op = variance_gamma
    rv_op = VarianceGammaRV()
    
    @classmethod
    def dist(cls, r, theta, sigma, mu, *args, **kwargs):
        r = pt.as_tensor_variable(floatX(r))
        theta = pt.as_tensor_variable(floatX(theta))
        sigma = pt.as_tensor_variable(floatX(sigma))
        mu = pt.as_tensor_variable(floatX(mu))

        return super().dist([r, theta, sigma, mu], *args, **kwargs)

    def logp(value, r, theta, sigma, mu):
        x = value - mu
        sigma2 = sigma**2
        abs_x = pt.abs(x)
        s = pt.sqrt(theta**2 + sigma2)
        v1 = -pt.log(sigma*pt.sqrt(np.pi)) - gammaln(r/2)
        v2 = theta*x/sigma2
        v3 = (r-1)/2*pt.log(abs_x/(2*s))
        alpha = (r-1.0)/2.0
        v40 = - s*abs_x/sigma2 - 1.0/2.0*pt.log(np.pi/(2.0*s*abs_x/sigma2))
        v4 = pt.log(np.pi/2.0*(pt.iv(-alpha, s*abs_x/sigma2) - pt.iv(alpha, s*abs_x/sigma2))/pt.sin(alpha*np.pi))
        res = v1 + v2 + v3 + v4
        return check_parameters(
            res,
            r > 0,
            sigma > 0,
            msg="r > 0, sigma > 0",
        )
    
VG = VarianceGamma.dist(0.5,1,1,0)
data = pm.draw(VG,draws=100)

with pm.Model() as model:
    mu = 0
    v = pm.Gamma('v', 2, 1)
    theta = pm.Normal('theta',0,1) 
    sigma = pm.Gamma('sigma',2,1)
    
    y = VarianceGamma('y',r=v, theta=theta,sigma=sigma,mu=mu,observed=data)
    trace_model = pm.sample()

Here I am using
pymc==5.3.1
pytensor==2.11.3
numpy==1.24.3

I also report the full log:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[3], line 8
      5 sigma = pm.Gamma('sigma',2,1)
      7 y = VarianceGamma('y',r=v, theta=theta,sigma=sigma,mu=mu,observed=data)
----> 8 trace_model = pm.sample()

File ~/.local/lib/python3.11/site-packages/pymc/sampling/mcmc.py:702, in sample(draws, tune, chains, cores, random_seed, progressbar, step, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, model, **kwargs)
    698 t_sampling = time.time() - t_start
    700 # Packaging, validating and returning the result was extracted
    701 # into a function to make it easier to test and refactor.
--> 702 return _sample_return(
    703     run=run,
    704     traces=traces,
    705     tune=tune,
    706     t_sampling=t_sampling,
    707     discard_tuned_samples=discard_tuned_samples,
    708     compute_convergence_checks=compute_convergence_checks,
    709     return_inferencedata=return_inferencedata,
    710     keep_warning_stat=keep_warning_stat,
    711     idata_kwargs=idata_kwargs or {},
    712     model=model,
    713 )

File ~/.local/lib/python3.11/site-packages/pymc/sampling/mcmc.py:742, in _sample_return(run, traces, tune, t_sampling, discard_tuned_samples, compute_convergence_checks, return_inferencedata, keep_warning_stat, idata_kwargs, model)
    738 # count the number of tune/draw iterations that happened
    739 # ideally via the "tune" statistic, but not all samplers record it!
    740 if "tune" in mtrace.stat_names:
    741     # Get the tune stat directly from chain 0, sampler 0
--> 742     stat = mtrace._straces[0].get_sampler_stats("tune", sampler_idx=0)
    743     stat = tuple(stat)
    744     n_tune = stat.count(True)

File ~/.local/lib/python3.11/site-packages/pymc/backends/base.py:254, in BaseTrace.get_sampler_stats(self, stat_name, sampler_idx, burn, thin)
    230 """Get sampler statistics from the trace.
    231 
    232 Note: This implementation attempts to squeeze object arrays into a consistent dtype,
   (...)
    251     Otherwise, the shape should be `(draws, samplers)`.
    252 """
    253 if sampler_idx is not None:
--> 254     return self._get_sampler_stats(stat_name, sampler_idx, burn, thin)
    256 sampler_idxs = [i for i, s in enumerate(self.sampler_vars) if stat_name in s]
    257 if not sampler_idxs:

File ~/.local/lib/python3.11/site-packages/pymc/backends/ndarray.py:125, in NDArray._get_sampler_stats(self, varname, sampler_idx, burn, thin)
    122 def _get_sampler_stats(
    123     self, varname: str, sampler_idx: int, burn: int, thin: int
    124 ) -> np.ndarray:
--> 125     return self._stats[sampler_idx][varname][burn::thin]

KeyError: 'tune'

and another possibly relevant message is the trace log:

...
Multiprocess sampling (4 chains in 4 jobs)
CompoundStep
>Slice: [v]
>NUTS: [theta, sigma]

This is just a silly bug that the slice sampler doesn’t return information about tuning. It was fixed in a recent release where we also speed up the Slice sampler. If you can update it should fix it

Still lingering issue reported here: Faulty logic to retrieve `tune` from mixed samplers · Issue #6710 · pymc-devs/pymc · GitHub

But it’s fixed for the Slice sampler

1 Like

Done, now it works.
Thanks!