This seems to be doing better. I changed the model to the following
coords = {'cann':cann}
with pm.Model(coords = coords) as cannibal_model:
cannibal = pm.Data('cannibal', cann_idx, mutable = True)
obs = pm.Data('obs', obs_array, mutable = True)
# beta = pm.Normal('beta', mu=0, sigma = .1)
alpha = pm.Normal('alpha', mu=.99, sigma = .07, dims = ['cann'])
sigma = pm.HalfNormal('sigma', 1)
y_latent = pm.Normal.dist(mu=alpha[cannibal], sigma = sigma)
eaches = pm.Censored('predicted_eaches',
dist=y_latent,
lower = 0,
upper = 1,
observed=obs)
idata = pm.sampling_jax.sample_numpyro_nuts(draws = 1000, tune=2000, target_accept = .95)


I think this is a good base to expand out from. Thank you for the help.
