Thanks for the observation about the X array.
By the way, I had to modify the code because I was getting an error in Theano (indexing with bool mask).
import pymc3 as pm
from theano import shared, tensor as tt
y = np.log(time)
y_std = (y - y.mean()) / y.std()
event_ = shared(event)
X = np.ones((len(time), 1))
X_ = shared(X)
def gumbel_sf(y, mu, sigma):
return 1.0 - tt.exp(-tt.exp(-(y - mu) / sigma))
# Build Bayesian model
with pm.Model() as model:
# http://personal.psu.edu/drh20/525/weekly/weibull.pdf
# log(T) = sigmaW + gamma^\top X
# where gamma = 1 (intercept only) and sigmaW ~ Gumbel(0, s) with s ~ HalfNormal(5)
# Hyperprior
s = pm.HalfNormal("s", tau=5.0)
# Priors
gamma = pm.Normal("gamma", 0., 5.0, shape=X.shape[1])
gammaX = gamma.dot(X_.T)
# Likelihood for uncensored and censored survival times
y_obs = pm.Gumbel("y_obs", mu=gammaX[(event == 1).nonzero()], beta=s, observed=y_std[~cens])
y_cens = pm.Bernoulli("y_cens", p=gumbel_sf(y_std[cens], mu=gammaX[(event == 0).nonzero()], sigma=s), observed=np.ones(cens.sum()))
# Initialization
start = pm.find_MAP()
# Perform MC sampling
with model:
trace = pm.sample(draws=2000, start=start, tune=1000)
Regarding the computation of \beta, s is the hyperprior of one of the hyperparameters of a Gumbel distribution. So I don’t think \beta = 1/s is correct. It looks like \sigma \sim Gumbel(0, s) and then \beta = 1/\sigma. Any comments?
Also, I’m getting very low values for \eta. If I fit the model in R with survreg, I get very different values. Here’s the R code:
time <-c(59, 115, 156, 421, 431, 448, 464, 475, 477, 563, 638, 744, 769, 770, 803, 855, 1040, 1106, 1129, 1206, 1227, 268, 329, 353, 365, 377)
event <- c(1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0)
library(survival)
r <- survreg(Surv(time, event) ~ 1, dist="weibull")
beta <- 1/r$scale
eta <- exp(r$coefficients[1])
And I get \beta = 1.10806 and \eta = 1225.419 . The PyMC3 code above gives me an average gamma of -0.185689, and thus eta = np.exp(trace["gamma"]) gives values nowhere near 1225.419.