Why the (automatic) nested CompoundStep Fails?

Hi all,

I’m trying to implement a simple model based on two-classes Gaussian clustering with particular priors on the ratios of the clases. However, it seems that I’m running on an issue when pymc automatically group variables within a Metropolis compound step. What am I doing wrong, or is this an expected behaviour?

Here a simple example that already fails (mixture part remove for simplicity).

with pm.Model() as modeWithErros:

    a   = pm.Poisson("a",mu=10)

    b   = pm.Binomial("b", n=a, p=0.8)

    c   = pm.Poisson("c",mu=11)

    d   = pm.Dirichlet("d",a=pt.stack([c,b]))

    pm.sample(draws=1000,tune=1000,chains=4)

Console output:

Multiprocess sampling (4 chains in 4 jobs)
CompoundStep
>CompoundStep
>>Metropolis: [a]
>>Metropolis: [b]
>>Metropolis: [c]
>NUTS: [d]

Error:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[250], line 11
      7 c   = pm.Poisson("c",mu=11)
      9 d   = pm.Dirichlet("d",a=pt.stack([c,b]))
---> 11 pm.sample(draws=1000,tune=1000,chains=4)

File c:\Users\<user>\venvs\pymc\Lib\site-packages\pymc\sampling\mcmc.py:935, in sample(draws, tune, chains, cores, random_seed, progressbar, progressbar_theme, step, var_names, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, blas_cores, model, compile_kwargs, **kwargs)
    933 _print_step_hierarchy(step)
    934 try:
--> 935     _mp_sample(**sample_args, **parallel_args)
    936 except pickle.PickleError:
    937     _log.warning("Could not pickle model, sampling singlethreaded.")

File c:\Users\<user>\venvs\pymc\Lib\site-packages\pymc\sampling\mcmc.py:1411, in _mp_sample(draws, tune, step, chains, cores, rngs, start, progressbar, progressbar_theme, traces, model, callback, blas_cores, mp_ctx, **kwargs)
   1409 try:
   1410     with sampler:
-> 1411         for draw in sampler:
   1412             strace = traces[draw.chain]
   1413             if not zarr_recording:
   1414                 # Zarr recording happens in each process

File c:\Users\<user>\venvs\pymc\Lib\site-packages\pymc\sampling\parallel.py:513, in ParallelSampler.__iter__(self)
    510 draw = ProcessAdapter.recv_draw(self._active)
    511 proc, is_last, draw, tuning, stats = draw
--> 513 self._progress.update(
    514     chain_idx=proc.chain, is_last=is_last, draw=draw, tuning=tuning, stats=stats
    515 )
    517 if is_last:
    518     proc.join()

File c:\Users\<user>\venvs\pymc\Lib\site-packages\pymc\util.py:886, in ProgressBarManager.update(self, chain_idx, is_last, draw, tuning, stats)
    883 if not tuning and stats and stats[0].get("diverging"):
    884     self.divergences += 1
--> 886 self.progress_stats = self.update_stats(self.progress_stats, stats, chain_idx)
    887 more_updates = (
    888     {stat: value[chain_idx] for stat, value in self.progress_stats.items()}
    889     if self.full_stats
    890     else {}
    891 )
    893 self._progress.update(
    894     self.tasks[chain_idx],
    895     completed=draw,
   (...)
    899     **more_updates,
    900 )

File c:\Users\<user>\venvs\pymc\Lib\site-packages\pymc\step_methods\compound.py:340, in CompoundStep._make_update_stats_function.<locals>.update_stats(stats, step_stats, chain_idx)
    338 def update_stats(stats, step_stats, chain_idx):
    339     for step_stat, update_fn in zip(step_stats, update_fns):
--> 340         stats = update_fn(stats, step_stat, chain_idx)
    342     return stats

File c:\Users\<user>\venvs\pymc\Lib\site-packages\pymc\step_methods\compound.py:340, in CompoundStep._make_update_stats_function.<locals>.update_stats(stats, step_stats, chain_idx)
    338 def update_stats(stats, step_stats, chain_idx):
    339     for step_stat, update_fn in zip(step_stats, update_fns):
--> 340         stats = update_fn(stats, step_stat, chain_idx)
    342     return stats

File c:\Users\<user>\venvs\pymc\Lib\site-packages\pymc\step_methods\metropolis.py:354, in Metropolis._make_update_stats_function.<locals>.update_stats(stats, step_stats, chain_idx)
    351 if isinstance(step_stats, list):
    352     step_stats = step_stats[0]
--> 354 stats["tune"][chain_idx] = step_stats["tune"]
    355 stats["accept_rate"][chain_idx] = step_stats["accept"]
    356 stats["scaling"][chain_idx] = step_stats["scaling"]

TypeError: string indices must be integers, not 'str'

I have seen this behaviour in other times where also the compoundstep has a nested compound step with Metropolis sampled variables. I’m quite confused.

Thanks!

This is a bug in the progress bar. As a work-around you can set progressbar=False, but could you open an issue on the github repo and copy that model as a minimum example of the bug?

Thanks Jesse, indeed the work-around works. I’ll be raising an issue on the github repo. Cheers.