Sampling the Banana-shaped distribution

Hi,
I want to implement a new sampler and need the banana-shaped distribution as a test case.

Here I found an implementation for pymc, but I can’t transfer it to pymc3.

This is how I can transform samples from a multivariate normal into the banana-shape:

def banana(X, b=0.03):
    """Twist the second column of X into a banana."""
    X = [x for x in X.copy()]
    X[1] += b * X[0]**2 - 100 * b
    return X

C = [
    [100,1],
    [1,1]
]
bananasamples = banana(numpy.random.multivariate_normal([0, 0], C, size=200).T)

Then it comes to the pymc3 model, I have a fundamental misunderstanding, but I can’t figure out what it is.

def run_stepper(stepper_cls):
    with pymc3.Model() as pmodel:
        x = pymc3.Uniform('x', lower=-50, upper=50)
        y = pymc3.Uniform('y', lower=-50, upper=50)

        G = pymc3.MvNormal('G', mu=[0,0], cov=C)
        # ???????
        pymc3.Potential('X', G.logp([x,y]))

        step = stepper_cls()
        mt = pymc3.sample(step=step, tune=1000)
    return mt


if __name__ == '__main__':
    methods = [pymc3.Metropolis, pymc3.NUTS]
    fig, hosts = visualization.fig_hosts(1, len(methods), 5, 5)
    for c in range(len(methods)):
        mt = run_stepper(methods[c])
        hosts[0,c].scatter(S[0], S[1], s=0.2)
        hosts[0,c].scatter(mt['x'], mt['y'], s=0.3)
        accept_stat = [s for s in mt.stat_names if 'accept' in s][0]
        accept_rate = mt.get_sampler_stats(accept_stat).mean()
        title = '{} ({:.1f}%)'.format(methods[c].name, accept_rate * 100)
        hosts[0,c].set_title(title)
    pyplot.show()

I am sure there is a one- or two-line way to do it right and would appreciate any help.
The expectation then is that Metropolis will of course have a much worse acceptance rate compared to NUTS.

thanks
michael

You can define a banana-shaped function and wrap it using pm.Potential. For example, defining a Rosenbrock function:

import numpy as np
import pymc3 as pm
import theano.tensor as tt
import theano

def pot1(z):
    z = z.T
    a = 1.
    b = 100.
    return (a-z[0])**2 + b*(z[1] - z[0]**2)**2

z = tt.matrix('z')
z.tag.test_value = pm.floatX([[0., 0.]])
pot1f = theano.function([z], pot1(z))

Visualize the function by evaluating it on a grid:

import matplotlib.pylab as plt
from mpl_toolkits import mplot3d
xlim=(-2,2)
ylim=(-1,3)
grid = pm.floatX(np.mgrid[xlim[0]:xlim[1]:100j,ylim[0]:ylim[1]:100j])
grid_2d = grid.reshape(2, -1).T
Z = pot1f(grid_2d)
fig = plt.figure()
ax = fig.gca(projection='3d')
surf = ax.plot_surface(grid[0], grid[1], Z.reshape(100,100),cmap='viridis',
                       linewidth=0, antialiased=False)
plt.show()

Sample from the above function in PyMC3, and visualized the samples:

def cust_logp(z):
    return -pot1(z)

with pm.Model() as pot1m:
    pm.DensityDist('pot1', logp=cust_logp, shape=(2,))
    trace1 = pm.sample(1000, step=pm.NUTS())
    trace2 = pm.sample(1000, step=pm.Metropolis())

_, ax = plt.subplots(1,2,figsize=(10,5))
tr1 = trace1['pot1']
ax[0].plot(tr1[:,0], tr1[:,1], 'ro--',alpha=.1)
tr2 = trace2['pot1']
ax[1].plot(tr2[:,0], tr2[:,1], 'bo--',alpha=.1)
plt.tight_layout()

There are also similar example in the doc of Normalizing flow: http://docs.pymc.io/notebooks/normalizing_flows_overview.html#Simulated-data-example

3 Likes

Sorry to bother you… I’m using Jupyter 4.4.0, Python 3.6.4 on Ubuntu 14.04 and i got this error when running your solution in PyMC3:

BrokenProcessPool: A process in the executor was terminated abruptly while the future was running or pending.

It seems that the problem is in this line:

trace1 = pm.sample(1000, step=pm.NUTS())

I am wondering if there is an easy fix to this … thanks!

What version of PyMC3 are you on? Could you please try to upgrade to the newest release?

Thank you for your prompt reply! It is PyMC3 v3.4.1 and upgrade has not resolved the issue. Still the same error (for what it’s worth in full):

BrokenProcessPool                         Traceback (most recent call last)
<ipython-input-21-fb1a0d873163> in <module>()
  4 with pm.Model() as pot1m:
  5     pm.DensityDist('pot1', logp=cust_logp, shape=(2,))
----> 6     trace1 = pm.sample(1000, step=pm.NUTS())
  7     trace2 = pm.sample(1000, step=pm.Metropolis())
  8 

~/anaconda3/lib/python3.6/site-packages/pymc3/sampling.py in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, nuts_kwargs, step_kwargs, progressbar, model, random_seed, live_plot, discard_tuned_samples, live_plot_kwargs, compute_convergence_checks, use_mmap, **kwargs)
440 
441         has_population_samplers = np.any([ isinstance(m, arraystep.PopulationArrayStepShared)
--> 442             for m in (step.methods if isinstance(step, CompoundStep) else [step])])
443 
444         parallel = cores > 1 and chains > 1 and not has_population_samplers

~/anaconda3/lib/python3.6/site-packages/pymc3/sampling.py in _mp_sample(**kwargs)
980         for idx in range(chain, chain + chains):
981             if trace is not None:
--> 982                 strace = _choose_backend(copy(trace), idx, model=model)
983             else:
984                 strace = _choose_backend(None, idx, model=model)

~/anaconda3/lib/python3.6/site-packages/joblib/parallel.py in __call__(self, iterable)
960 
961             with self._backend.retrieval_context():
--> 962                 self.retrieve()
963             # Make sure that we get a last message telling us we are done
964             elapsed_time = time.time() - self._start_time

~/anaconda3/lib/python3.6/site-packages/joblib/parallel.py in retrieve(self)
863             try:
864                 if getattr(self._backend, 'supports_timeout', False):
--> 865                     self._output.extend(job.get(timeout=self.timeout))
866                 else:
867                     self._output.extend(job.get())

~/anaconda3/lib/python3.6/site-packages/joblib/_parallel_backends.py in wrap_future_result(future, timeout)
513         AsyncResults.get from multiprocessing."""
514         try:
--> 515             return future.result(timeout=timeout)
516         except LokyTimeoutError:
517             raise TimeoutError()

~/anaconda3/lib/python3.6/site-packages/joblib/externals/loky/_base.py in result(self, timeout)
429                 raise CancelledError()
430             elif self._state == FINISHED:
--> 431                 return self.__get_result()
432             else:
433                 raise TimeoutError()

~/anaconda3/lib/python3.6/site-packages/joblib/externals/loky/_base.py in __get_result(self)
380     def __get_result(self):
381         if self._exception:
--> 382             raise self._exception
383         else:
384             return self._result

BrokenProcessPool: A process in the executor was terminated abruptly while the future was running or pending.

If you upgrade to 3.5 (we just released it yesterday), it no longer use joblib and the error you are seeing here (which is joblib related) should be resolved.

I have upgraded to PYMC3 v3.5 and the issue seems to be resolved (thanks!). However, the problem now is that the NUTS sampler does not initialize and it seems to get stuck at

Multiprocess sampling (4 chains in 4 jobs) 
NUTS: [pot1]

Any help would be very much appreciated! Would you recommend switching to Python 2.7?

HMM definatly not, we are sun-setting py2.7 soon, it would be better if you stay on py3 and try to resolve this.

Do you see any problem trying to sample a simple model?

I have tried another model and now I am getting the following ``[Errno 32] broken pipe’’ error:

BrokenPipeError                           Traceback (most recent call last)
<ipython-input-4-4bee32470d49> in <module>()
      4 with pm.Model() as pot1m:
      5     pm.DensityDist('pot1', logp=cust_logp, shape=(2,))
----> 6     trace1 = pm.sample(draws=10, step=pm.NUTS())
      7     #trace2 = pm.sample(10, step=pm.Metropolis())
      8 

~\Anaconda2\envs\python3.6\lib\site-packages\pymc3\sampling.py in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, nuts_kwargs, step_kwargs, progressbar, model, random_seed, live_plot, discard_tuned_samples, live_plot_kwargs, compute_convergence_checks, use_mmap, **kwargs)
    447             _print_step_hierarchy(step)
    448             try:
--> 449                 trace = _mp_sample(**sample_args)
    450             except pickle.PickleError:
    451                 _log.warning("Could not pickle model, sampling singlethreaded.")

~\Anaconda2\envs\python3.6\lib\site-packages\pymc3\sampling.py in _mp_sample(draws, tune, step, chains, cores, chain, random_seed, start, progressbar, trace, model, use_mmap, **kwargs)
    994         sampler = ps.ParallelSampler(
    995             draws, tune, chains, cores, random_seed, start, step,
--> 996             chain, progressbar)
    997         try:
    998             with sampler:

~\Anaconda2\envs\python3.6\lib\site-packages\pymc3\parallel_sampling.py in __init__(self, draws, tune, chains, cores, seeds, start_points, step_method, start_chain_num, progressbar)
    273             ProcessAdapter(draws, tune, step_method,
    274                            chain + start_chain_num, seed, start)
--> 275             for chain, seed, start in zip(range(chains), seeds, start_points)
    276         ]
    277 

~\Anaconda2\envs\python3.6\lib\site-packages\pymc3\parallel_sampling.py in <listcomp>(.0)
    273             ProcessAdapter(draws, tune, step_method,
    274                            chain + start_chain_num, seed, start)
--> 275             for chain, seed, start in zip(range(chains), seeds, start_points)
    276         ]
    277 

~\Anaconda2\envs\python3.6\lib\site-packages\pymc3\parallel_sampling.py in __init__(self, draws, tune, step_method, chain, seed, start)
    180             draws, tune, seed)
    181         # We fork right away, so that the main process can start tqdm threads
--> 182         self._process.start()
    183 
    184     @property

~\Anaconda2\envs\python3.6\lib\multiprocessing\process.py in start(self)
    103                'daemonic processes are not allowed to have children'
    104         _cleanup()
--> 105         self._popen = self._Popen(self)
    106         self._sentinel = self._popen.sentinel
    107         # Avoid a refcycle if the target function holds an indirect

~\Anaconda2\envs\python3.6\lib\multiprocessing\context.py in _Popen(process_obj)
    221     @staticmethod
    222     def _Popen(process_obj):
--> 223         return _default_context.get_context().Process._Popen(process_obj)
    224 
    225 class DefaultContext(BaseContext):

~\Anaconda2\envs\python3.6\lib\multiprocessing\context.py in _Popen(process_obj)
    320         def _Popen(process_obj):
    321             from .popen_spawn_win32 import Popen
--> 322             return Popen(process_obj)
    323 
    324     class SpawnContext(BaseContext):

~\Anaconda2\envs\python3.6\lib\multiprocessing\popen_spawn_win32.py in __init__(self, process_obj)
     63             try:
     64                 reduction.dump(prep_data, to_child)
---> 65                 reduction.dump(process_obj, to_child)
     66             finally:
     67                 set_spawning_popen(None)

~\Anaconda2\envs\python3.6\lib\multiprocessing\reduction.py in dump(obj, file, protocol)
     58 def dump(obj, file, protocol=None):
     59     '''Replacement for pickle.dump() using ForkingPickler.'''
---> 60     ForkingPickler(file, protocol).dump(obj)
     61 
     62 #

BrokenPipeError: [Errno 32] Broken pipe

Thank you for your help!

What about below:

with pm.Model() as m:
    mu = pm.Normal('mu', 0., 10.)
    x = pm.Normal('x', mu, 1., observed=np.array([2., 1.4]))
    trace = pm.sample()

That works:

Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (4 chains in 4 jobs)
NUTS: [mu]
Sampling 4 chains: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 4000/4000 [00:02<00:00, 1213.70draws/s]
The acceptance probability does not match the target. It is 0.8865605470216715, but should be close to 0.8. Try to increase the number of tuning steps.
The acceptance probability does not match the target. It is 0.8836544226015702, but should be close to 0.8. Try to increase the number of tuning steps.

Possibly related: I was trying to compare the NUTS and HamiltonianMC sampling from here with the following code:

def jointplot(ary):
    """Helper to plot everything consistently"""
    sns.jointplot(*ary.T, alpha=0.1, stat_func=None, xlim=(-1.2, 1.2), ylim=(-1.2, 1.2))

def tt_donut_pdf(scale):
    """Compare to `donut_pdf`"""
    def logp(x):
         return -tt.square((1 - x.norm(2)) / scale)
    return logp

@sampled
def donut(scale=0.1, **observed):
    """Gets samples from the donut pdf, and allows adjusting the scale of the donut at sample time."""
    pm.DensityDist('donut', logp=tt_donut_pdf(scale), shape=2, testval=[0, 1])

with donut(scale=0.1):
    hamiltonianmc_sample1 = pm.sample(draws=100, init=None, step=pm.HamiltonianMC())
    nuts_sample1 = pm.sample(draws=100, init=None, step=pm.NUTS())
    
jointplot(hamiltonianmc_sample1.get_values('donut'))
jointplot(nuts_sample1.get_values('donut'))

But I get this error:

AttributeError: Can't pickle local object 'tt_donut_pdf.<locals>.logp'

Are these two problems possibly related? Thanks!

Right, lambda function sometimes is non pickle-able, in these cases, sample with cores=1 should resolved the issue.

1 Like

Thank you very much! That actually solved both of my problems! :grinning:

1 Like

In order to be able to compare the coverages of these methods after a certain number of iterations, I was wondering if there is an easy way to add the contour plot of the Rosenbrock function to your original diagrams? Thanks!

Sure, there is a contour plot function in matplotlib.