Dimension mismatch in `NUTS` and `metropolis`

Dear all,

I’d like to undestand why I get the dimension mismatch error below.

I’m considering the paradimatic Disaster Model as in the documentation:

disasters_data = array([4, 5, 4, 0, 1, 4, 3, 4, 0, 6, 3, 3, 4, 0, 2, 6,
                    3, 3, 5, 4, 5, 3, 1, 4, 4, 1, 5, 5, 3, 4, 2, 5,
                    2, 2, 3, 4, 2, 1, 3, 2, 2, 1, 1, 1, 1, 3, 0, 0,
                    1, 0, 1, 1, 0, 0, 3, 1, 0, 3, 2, 2, 0, 1, 1, 1,
                    0, 1, 0, 1, 0, 0, 0, 2, 1, 0, 0, 0, 1, 1, 0, 2,
                    3, 3, 1, 1, 2, 1, 1, 1, 1, 2, 4, 2, 0, 0, 1, 4,
                    0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1])
year = arange(1851, 1962)
with pm.Model() as model:
  switchpoint = pm.DiscreteUniform('switchpoint', lower=year.min(), upper=year.max())
  early_mean = pm.Exponential('early_mean', lam=1.)
  late_mean = pm.Exponential('late_mean', lam=1.)
  rate = tt.switch(switchpoint >= year, early_mean, late_mean) 
  disasters = pm.Poisson('disasters', rate, observed=disasters_data)

Where disasters is a discrete random variable and needs a metropolis sampler:

with model:
  step1 = pm.Metropolis(switchpoint)
  step2 = pm.NUTS([early_mean, late_mean])
  tr = pm.sample(200, tune=100, step=[step1, step2])

So far so good, that is the defaul choice, indeed. However, if I try:

with model:
  step1 = pm.Metropolis(switchpoint)
  step2 = pm.Metropolis(late_mean)
  step3 = pm.NUTS(early_mean)
  tr = pm.sample(200, tune=100, step=[step1, step2, step3])

I get a long error, which culminates in:

ValueError: Dimension mismatch; shapes are (2), (1)

Is that something inherent to the NUTS sampler? Or some bug?

Cheers,

M.

Could you update to master and try again? I cannot reproduce your problem.

Thanks. I couldn’t install the current branch immediately, but now I can confirm that the master branch (as installed from the GitHub repo) doesn’t reproduce the problem, and the snippet above only displays a warning (as metropolis needs a longer adaptation interval)

PS: The bug is still present in the version of pymc3 version 3.1 installed via conda-forge.

Thanks for the update @mcavallaro!