Recreating convoys in pymc

I’m interested in recreating the algorithm from convoys for generalized gamma survival analysis (that converges to some conversion rate, c). The loss function is posted here

Does anyone have any pointers on how to get started for this?
I understand how the generalized_gamma_loss function calculcates the log probability, but I have no idea how to make this compatible with pymc3

I found an old PR from the convoys package that actually attempted using a Weibull distribution with pymc3 and adapted it slightly to have a parameterization that more closely mimics pymc3’s. Here’s the code. It shouldnt be too difficult to adapt this to the generalized gamma

And if anyone has suggestions for priors to reduce some of the divergences and improve regularization, I’d be very happy to hear them.

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scipy.stats as stats
import pymc3 as pm
import arviz as az
from scipy.special import expit
from pymc3.math import dot, sigmoid, log, exp
logistic = lambda x: np.exp(x)/ (1+np.exp(x))


N = 10000
days = 100
t = np.arange(0, days,1)
T_true = pm.Weibull.dist(2.5, 70).random(size=N) # true time to event (independent of conversion itself)

X = np.random.normal(size=N)[:,None]
b = -0.2
CVR = logistic( b*X + -0.5 )

# only call it a conversion if they wouldve converted by now AND they convert
p = np.random.binomial(1, CVR)
B = ((T_true < t.max()) & (p.ravel() == 1))*1

# adjust T for non-converts
ttc = (t >= np.where(B==1, T_true, 10000)[:,None]*1)
T = np.where((p.ravel()==0) | (T_true > t.max()), t.max(),T_true)

plt.plot(ttc.mean(axis=0))
plt.title("Time to Conversion ")
plt.ylabel("CVR")
plt.xlabel("Time")


n, ncoefs = X.shape

with pm.Model() as m:
    alpha_c = pm.Normal('alpha_c', -5, 1) # intercept for conversion rate
    
    beta_sd = pm.Exponential('beta_sd', 1.0)  # Weak prior for the regression coefficients
    beta = pm.Normal('beta', mu=0, sd=beta_sd, shape=(ncoefs,))  # Regression coefficients
    
    c = sigmoid(dot(X, beta) + alpha_c)  # Conversion rates for each example
    k = pm.Lognormal('k', mu=0.5, sd=1.0)  # Weak prior around k=1
    lambd = pm.Exponential('lambd', 0.01)  # Weak prior

    # PDF of Weibull: k / lambda * (x / lambda)^(k-1) * exp(-(t / lambda)^k)
    LL_observed = log(c) + log(k) - log(lambd) + (k-1)*(log(T) - log(lambd)) - (T/lambd)**k
    # CDF of Weibull: 1 - exp(-(t / lambda)^k)
    LL_censored = log((1-c) + c * exp(-(T/lambd)**k))

    # We need to implement the likelihood using pm.Potential (custom likelihood)
    logp = B * LL_observed + (1 - B) * LL_censored
    logpvar = pm.Potential('logpvar', logp.sum())

    trace = pm.sample(init="advi+adapt_diag", return_inferencedata=True)