Example of sampling from DensityDist in pymc3 source code does not work

I am trying to understand how to use pm.DensityDist. From the source code there are example and I followed it with just a little bit change: Here is my code:

import pymc3 as pm
import numpy as np
import scipy.stats as ss


with pm.Model():
    mu = pm.Normal('mu', 0 , 1)
    normal_dist = pm.Normal.dist(mu, 1)
    dens = pm.DensityDist(
        'density_dist',
        normal_dist.logp,
        observed=np.random.randn(100),
        random=ss.norm.rvs,
    )
    prior = pm.sample_prior_predictive(10)['density_dist']

However I got an error “TypeError: _parse_args_rvs() got an unexpected keyword argument ‘point’”. Here is the detail:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
~/miniconda3/envs/ppl/lib/python3.7/site-packages/pymc3/distributions/distribution.py in _draw_value(param, point, givens, size)
    807                 try:
--> 808                     return dist_tmp.random(point=point, size=size)
    809                 except (ValueError, TypeError):

~/miniconda3/envs/ppl/lib/python3.7/site-packages/pymc3/distributions/distribution.py in random(self, point, size, **kwargs)
    392                         size=None,
--> 393                         not_broadcast_kwargs=not_broadcast_kwargs,
    394                     )

~/miniconda3/envs/ppl/lib/python3.7/site-packages/pymc3/distributions/distribution.py in generate_samples(generator, *args, **kwargs)
    966     if dist_bcast_shape[:len(size_tup)] == size_tup:
--> 967         samples = generator(size=dist_bcast_shape, *args, **kwargs)
    968     else:

~/miniconda3/envs/ppl/lib/python3.7/site-packages/scipy/stats/_distn_infrastructure.py in rvs(self, *args, **kwds)
    958         rndm = kwds.pop('random_state', None)
--> 959         args, loc, scale, size = self._parse_args_rvs(*args, **kwds)
    960         cond = logical_and(self._argcheck(*args), (scale >= 0))

TypeError: _parse_args_rvs() got an unexpected keyword argument 'point'

During handling of the above exception, another exception occurred:

TypeError                                 Traceback (most recent call last)
<ipython-input-1-ee9edf2b0ada> in <module>
     13         random=ss.norm.rvs,
     14     )
---> 15     prior = pm.sample_prior_predictive(10)['density_dist']

~/miniconda3/envs/ppl/lib/python3.7/site-packages/pymc3/sampling.py in sample_prior_predictive(samples, model, vars, var_names, random_seed)
   1384     names = get_default_varnames(vars_, include_transformed=False)
   1385     # draw_values fails with auto-transformed variables. transform them later!
-> 1386     values = draw_values([model[name] for name in names], size=samples)
   1387 
   1388     data = {k: v for k, v in zip(names, values)}

~/miniconda3/envs/ppl/lib/python3.7/site-packages/pymc3/distributions/distribution.py in draw_values(params, point, size)
    625                                         point=point,
    626                                         givens=temp_givens,
--> 627                                         size=size)
    628                     givens[next_.name] = (next_, value)
    629                     drawn[(next_, size)] = value

~/miniconda3/envs/ppl/lib/python3.7/site-packages/pymc3/distributions/distribution.py in _draw_value(param, point, givens, size)
    815                     with _DrawValuesContextBlocker():
    816                         val = np.atleast_1d(dist_tmp.random(point=point,
--> 817                                                             size=None))
    818                     # Sometimes point may change the size of val but not the
    819                     # distribution's shape

~/miniconda3/envs/ppl/lib/python3.7/site-packages/pymc3/distributions/distribution.py in random(self, point, size, **kwargs)
    391                         self.rand,
    392                         size=None,
--> 393                         not_broadcast_kwargs=not_broadcast_kwargs,
    394                     )
    395                     test_shape = test_draw.shape

~/miniconda3/envs/ppl/lib/python3.7/site-packages/pymc3/distributions/distribution.py in generate_samples(generator, *args, **kwargs)
    965         )
    966     if dist_bcast_shape[:len(size_tup)] == size_tup:
--> 967         samples = generator(size=dist_bcast_shape, *args, **kwargs)
    968     else:
    969         samples = generator(size=size_tup + dist_bcast_shape, *args, **kwargs)

~/miniconda3/envs/ppl/lib/python3.7/site-packages/scipy/stats/_distn_infrastructure.py in rvs(self, *args, **kwds)
    957         discrete = kwds.pop('discrete', None)
    958         rndm = kwds.pop('random_state', None)
--> 959         args, loc, scale, size = self._parse_args_rvs(*args, **kwds)
    960         cond = logical_and(self._argcheck(*args), (scale >= 0))
    961         if not np.all(cond):

TypeError: _parse_args_rvs() got an unexpected keyword argument 'point'.

I am not so sure is it bug in pymc3, scipy, or my code is just wrong. Is there anyone who can help me?

I think we might have changed the API to do this, @lucianopaz?

The random method needs to accept the keyword parameter point. This is because the samples it has to generate depend on the parameter of the distribution: mu, sigma and tau but these parameters can be random values as well. So you need to draw fixed parameter values before you can generate a random sample of the distribution. E.g., you could borrow from the implementation of pymc3.distributions.Normal:

import pymc3 as pm
import numpy as np
import scipy.stats as ss
from pymc3.distributions import draw_values, generate_samples

with pm.Model():
    mu = pm.Normal('mu', 0 , 1)
    normal_dist = pm.Normal.dist(mu, 1)
    def my_random_method(point=None, size=None):
        mu, tau, _ = draw_values([normal_dist.mu, normal_dist.tau, normal_dist.sigma],
                                 point=point, size=size)
        return generate_samples(ss.norm.rvs, loc=mu, scale=tau**-0.5,
                                dist_shape=normal_dist.shape,
                                size=size)
    dens = pm.DensityDist(
        'density_dist',
        normal_dist.logp,
        observed=np.random.randn(100),
        random=my_random_method,
    )
    prior = pm.sample_prior_predictive(10)['density_dist']

@yusri-dh, are you looking at pymc3.7 or at the github’s master branch? The documentation should have been changed on github. We’ve made some changes to DensityDist but now I’ve noticed there’s something wrong with the docstring. The current docs read:

random: None or callable (Optional)
If None, no random method is attached to the DensityDist
instance.
If a callable, it is used as the distribution’s random method.
The behvaior of this callable can be altered with the
wrap_random_with_dist_shape parameter.
The supplied callable must have the following signature:
random(size=None, **kwargs), where size is the number of
IID draws to take from the distribution. Any extra keyword
argument can be added as required.

The signature is wrong, it should read random(point=None, size=None, **kwargs). Could you open an issue on github pointing out this error in the docstring?

We require this signature to make the DensityDist.random compatible with the other pymc3 distributions’ random signature. Your model should be written like this to work:

import pymc3 as pm
import numpy as np
import scipy.stats as ss


def rand(point=None, size=None, **kwargs):
    return ss.norm.rvs(size=size, **kwargs)


with pm.Model():
    mu = pm.Normal('mu', 0 , 1)
    normal_dist = pm.Normal.dist(mu, 1)
    dens = pm.DensityDist(
        'density_dist',
        normal_dist.logp,
        observed=np.random.randn(100),
        random=rand,
    )
    prior = pm.sample_prior_predictive(10)['density_dist']

As @Dominik said, I recommend you look at the implementation of random from some other pymc3 distributions to get an idea of how to structure the random method.

I am looking at github’s master branch. Thank you so much @lucianopaz and @Dominik. The code is working now. I will also open an issue on github about this.

1 Like