Arviz plot_trace runs forever/doesn't complete

Hi all, I am running the following model to estimate a time varying poisson distribution for each of 5 time points

with pm.Model() as model:
                                 
    beta = pm.Normal('beta',mu=0,sigma=np.sqrt(2))
    ha = pm.Normal('Ha',mu=0,sigma=np.sqrt(2))
    
    alpha_zero = pm.Normal.dist(0,np.sqrt(2), shape=30) 

    alpha = pm.GaussianRandomWalk("alpha", init_dist=alpha_zero, sigma=np.sqrt(2),shape=(30,10,5))   

    theta = pm.Deterministic("theta", pm.invlogit(pm.math.dot(df_pilot.pred_score, alpha[g,p,t])+beta+ha))
    
    # Likelihood:     

    Y_obs =pm.Poisson("Y_obs", mu=theta, observed= y_pilot_home.y_score)


with prior data for each o the 5 time points looking like this:

Untitled

Whilst this doesn’t sample too well

with model:
    trace = pm.sample(1000, random_seed=rng, target_accept=.95)


100.00% [8000/8000 02:14<00:00 Sampling 4 chains, 3,991 divergences]


Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 166 seconds.
There were 999 divergences after tuning. Increase `target_accept` or reparameterize.
There were 998 divergences after tuning. Increase `target_accept` or reparameterize.
There were 999 divergences after tuning. Increase `target_accept` or reparameterize.
There were 995 divergences after tuning. Increase `target_accept` or reparameterize.

I would still expect to be able to plot the trace and posterior, but both of the below functions run for hours without result.

with model:
    az.plot_trace(trace)
with model:
    pm.sample_posterior_predictive(trace, extend_inferencedata=True, random_seed=rng)

I’d appreciate any help/thoughts on:

  • what I can do about the mass-divergence / sampling issues
  • why arviz seems to fail

Jan

You might have too many variables and computing the KDEs for all of them takes a prohibitive amount of time. Try using compact=False so that each variable has its own row, ArviZ has a feature to limit the amount of axes to plot at once to prevent such issues which I belive will be triggered then (but it is only axes related, not related to too many lines inside an axes).

Try using plot_forest for vector-valued variables. plot_trace is really just for inspecting individual parameters for convergence, etc. You can restrict the variables plotted with any of the plotting functions using the var_names argument.

1 Like