I’ve updated @dovgalec 's code to work with current versions, and incorporated some of the suggestions above. Hopefully it’s useful to someone else who finds themselves here at the end of a Google search.
I’m using pm.Data instead of theano.shared; I couldn’t get the latter to work.
import arviz as az
import logging
import matplotlib.pyplot as plt
import numpy as np
import pymc3 as pm
from scipy.stats import bernoulli
logger = logging.getLogger('pymc3')
logger.setLevel(logging.ERROR)
p_true = 0.60 # hidden success rate
with pm.Model() as model:
# initialized with bootstrap values
n = pm.Data('n', 100)
data = pm.Data('data', bernoulli.rvs(p_true, size=100).sum())
p = pm.Beta('p', alpha=2, beta=2)
y_obs = pm.Binomial('y_obs', p=p, n=n, observed=data)
# Bootstrap initial p_hat
with model:
trace = pm.sample(10000, progressbar=False, return_inferencedata=True)
interval = az.hdi(trace, 0.05)
interval_width = interval['p'][1] - interval['p'][0]
summary = az.summary(trace)
p_hat = summary["mean"]["p"]
sample_size = [10, 100, 1000, 10000]
observed_data =[bernoulli.rvs(p_hat, size=s).sum() for s in sample_size]
# Iterate over samples
interval_widths = list()
for i in range(len(sample_size)):
with model:
pm.set_data({'data': observed_data[i], 'n': sample_size[i]})
trace = pm.sample(progressbar=False, return_inferencedata=True)
interval = az.hdi(trace, 0.05)
interval_width = interval['p'][1] - interval['p'][0]
interval_widths.append(float(interval_width))
fig = plt.figure()
plt.plot( np.asarray(sample_size), np.asarray(interval_widths) )
plt.yscale("log")
plt.title("Credible Interval Precision Curve")
plt.ylabel("HDI Width")
plt.xlabel("Sample Size")
plt.show()