Plotting survival curve with a Bayesian inference?

import pymc as pm
import numpy as np
import pytensor
import pytensor.tensor as pt
import arviz as az

# Set values for the prior parameters
gamma_mode = 0.5
gamma_variance = 1.0
t_star_mode = 15.0
t_star_variance = 100.0
s_star = 0.88
lambda_ = -pt.log(s_star)

# Definition of prior for gamma
aux1 = 2.0 + (gamma_mode**2.0) / gamma_variance
a = 0.5 * (aux1 + np.sqrt(aux1**2.0 - 4.0))
b = gamma_mode / (a - 1.0)
b1 = 1.0 / b

# Definition of prior for t_star
aux2 = 2.0 + (t_star_mode**2.0) / t_star_variance
ap = 0.5 * (aux2 + np.sqrt(aux2**2.0 - 4.0))
bp = t_star_mode / (ap - 1.0)
b2 = 1.0 / bp

# Define the model
model = pm.Model()

# Observed data (censored)
t = np.array([25.62191781, 25.35616438, 25.34246575, 25.29589041, 24.61643836,             24.34794521, 24.30410959,
      24.06027397, 23.81643836, 23.50410959, 23.2109589, 23.14520548, 22.97808219, 22.70410959,
      21.90684932, 21.81917808, 21.36986301, 20, 19.5260274])

# Function to calculate log complementary cdf of Weibull distribution
def weibull_lccdf(x, t_star, gamma):
    return -(lambda_ * (x / t_star) ** gamma)

with model:
    # Definition of priors for gamma and t_star
    gamma = pm.Gamma('gamma', alpha=a, beta=1/b)
    t_star = pm.Gamma('t_star', alpha=ap, beta=1/bp)

    # Bus survivor calculation
    surv = pm.Deterministic('surv', pt.exp(-lambda_ * (pt.arange(0, 26) / t_star) ** gamma))

    # Define likelihood
    y_cens = pm.Potential("y_cens", weibull_lccdf(t, t_star, gamma))

    # Sample from the model
    trace = pm.sample(draws=10000, tune=1000)

summary = pm.summary(trace)
print(summary)

I wanted to plot the survival fonction “surv” in my code but it is not working. Does anyone can help on this task ?

I tried fixing the formatting of your code, but please let me know if I didn’t get it quite right.

I would suggest checking out the survival analysis notebooks for examples of who to visualize what is going on. If you have tried plotting and it is not working, you can share that code as well.

# Assuming you have already calculated the "surv" function from your first code
# Replace the following lines with your calculated "surv" function and interval bounds
# For example, assuming you have `surv` and `interval_bounds` from your first code
surv_function = surv  # Replace with your actual "surv" function
interval_bounds = np.arange(0, 26)  # Replace with your actual interval bounds

# Plot survival function
fig, surv_ax = plt.subplots(figsize=(8, 6))

# Plot credible intervals of survival function for your data
az.plot_hdi(interval_bounds[:-1], surv_function, ax=surv_ax, smooth=False, color="C0")

# Plot mean survival function
surv_ax.plot(interval_bounds[:-1], get_mean(surv_function), color="darkblue")

# Customize the plot for survival function
surv_ax.set_xlim(0, interval_bounds.max())
surv_ax.set_xlabel("Months since mastectomy")
surv_ax.set_ylabel("Survival function $S(t)$")

# Add a title to the figure
fig.suptitle("Survival Function Visualization")

Thanks, so I add the following code above but it still didn’t work:

UFuncTypeError                            Traceback (most recent call last)
<ipython-input-9-d5ec0b9e0bd3> in <module>
     68 
     69 # Plot credible intervals of survival function for your data
---> 70 az.plot_hdi(interval_bounds[:-1], surv_function, ax=surv_ax, smooth=False, color="C0")
     71 
     72 # Plot mean survival function

~\AppData\Roaming\Python\Python38\site-packages\arviz\plots\hdiplot.py in plot_hdi(x, y, hdi_prob, hdi_data, color, circular, smooth, smooth_kwargs, figsize, fill_kwargs, plot_kwargs, hdi_kwargs, ax, backend, backend_kwargs, show)
    154         elif not 1 >= hdi_prob > 0:
    155             raise ValueError("The value of hdi_prob should be in the interval (0, 1]")
--> 156         hdi_data = hdi(y, hdi_prob=hdi_prob, circular=circular, multimodal=False, **hdi_kwargs)
    157 
    158     hdi_shape = hdi_data.shape

~\AppData\Roaming\Python\Python38\site-packages\arviz\stats\stats.py in hdi(ary, hdi_prob, circular, multimodal, skipna, group, var_names, filter_vars, coords, max_modes, dask_kwargs, **kwargs)
    588     if isarray and ary.ndim <= 1:
    589         func_kwargs.pop("out_shape")
--> 590         hdi_data = func(ary, **func_kwargs)  # pylint: disable=unexpected-keyword-arg
    591         return hdi_data[~np.isnan(hdi_data).all(axis=1), :] if multimodal else hdi_data
    592 

~\AppData\Roaming\Python\Python38\site-packages\arviz\stats\stats.py in _hdi(ary, hdi_prob, circular, skipna)
    631     interval_idx_inc = int(np.floor(hdi_prob * n))
    632     n_intervals = n - interval_idx_inc
--> 633     interval_width = np.subtract(ary[interval_idx_inc:], ary[:n_intervals], dtype=np.float_)
    634 
    635     if len(interval_width) == 0:

UFuncTypeError: Cannot cast ufunc 'subtract' input 0 from dtype('O') to dtype('float64') with casting rule 'same_kind'