I am playing around with a simple multi-level model and finding that when I non-center the parameters, the number of effective samples is much lower than when I use a centered parameterization. This is the opposite of what I was expecting and I am wondering if I am making a mistake or if this something that happens sometimes and I should not worry about it.
The data that I am using is from Richard McElreath’s Statistical Rethinking course and can be found here: https://github.com/rmcelreath/rethinking/blob/Experimental/data/reedfrogs.csv .
The code for the model is basically just taken from the Statistical Rethinking Course (see week video from 15-Feb 11 here https://github.com/rmcelreath/statrethinking_winter2019), but I non-centered the parameters. It models the probability of survival of tadpoles in different ponds (also referred to as tanks). Each row of the dataset corresponds to a different pond. There are small and large ponds and they contain different numbers of tadpoles which is taken into account by the variable n
below. The fraction of tadpoles that survive in each pond is modeled by a Binomial distribution.
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pymc3 as pm
import arviz as az
RANDOM_SEED = 95714
df1 = pd.read_csv('reedfrogs.csv', sep=';')
N_TANKS = len(df1)
with pm.Model() as model_1_1:
pond = pm.Data('pond', df1.index.values.astype('int'))
n = pm.Data('n', df1.density.values.astype('int'))
survival_obs = pm.Data('survival_obs', df1.surv)
a_bar = pm.Normal('a_bar', mu=0, sigma=1.5)
sigma = pm.Exponential('sigma', lam=1)
# This parametrization does not give a warning
# a = pm.Normal('a', mu=a_bar, sigma=sigma, shape=N_TANKS)
# use non-centered variables for better sampling
# this gives me a warning, but I don't undertsand why:
# "The number of effective samples is smaller than 25% for some parameters"
z = pm.Normal('z', mu=0, sigma=1, shape=N_TANKS)
a = pm.Deterministic('a', a_bar + z*sigma)
p = pm.invlogit(a[pond])
survival = pm.Binomial('survival', n=n[pond], p=p, observed=survival_obs)
trace_1_1 = pm.sample(2000, tune=2000, return_inferencedata=True, random_seed=RANDOM_SEED, chains=4)
If anyone can shed some light on this, I’d be very grateful.