Code Review - 1 out of 4 chains stuck

HI,

I have a custom likelihood function and I am sharing my code below, can somebody please review and let me know if there is a scope of improvement that will help me speed up the sampling process?

from pymc.math import where
import pandas as pd
import pymc as pm
from pymc.math import log,exp
from pytensor import tensor as pt
import numpy as np
from numpy import pi
from sklearn.preprocessing import StandardScaler
import pymc.sampling_jax
import numpyro
import jax

numpyro.set_host_device_count(4)


# Implementing failure rate function - arguments expected = (Covars of mortgage i, t_d of mortgage i)
# A=((2 * pi * σ_d^2) ^ -1/2) * ((t_d) ^ -1)
# B = exp(-0.5 *(log(t_d) - μ_d)^2/σ_d^2)
# C = 1 - CDF((log(t_d)-μ_d)/σ_d)
# D = exp(θD * X_d(t))
# ƛ_d(t|X_d(t)) = (A * B * D)/C
# Theta is a vector of parameters
# data is also a vector of covariates

def standardize(x, mu, sigma):
    return (x - mu) / sigma


def standardNormCdf(x):
    return 0.5 + 0.5 * pm.math.erf(x / pm.math.sqrt(2))


def getContributionFromInterval(interval, mu, sigma):
    logT = log(interval)
    a = standardize(logT, mu, sigma)
    return log(1 - standardNormCdf(a))


def computeFailureRate(sigma, mu, t, theta, data):
    logT = log(t)
    a1 = standardize(log(t), mu, sigma)
    a2 = 0.5 * log(2 * pi * pow(sigma, 2))
    B = 0.5 * pow(a1, 2)
    C = log(1 - standardNormCdf(a1))
    D = (theta * data).sum(axis=-1).reshape((-1, 1))
    failureRate = ((-1) * logT) - a2 - B - C + D
    return failureRate


def getSurvival(time_to_event, mu, sigma, theta, features):
    contributionFromFinalInterval = (-1) * getContributionFromInterval(time_to_event, mu, sigma)
    theta_features_vector = (theta * features).sum(axis=-1)
    theta_features_vector = theta_features_vector.reshape((-1, 1))
    survival = (contributionFromFinalInterval * exp(theta_features_vector))
    survival = survival.sum(axis=1).reshape((-1, 1))
    return survival


def runModel():
    bcphm_model = pm.Model()
    with bcphm_model:
        data = pd.read_csv('data/precovid_SF.csv')

        data.sort_values(by=['event'], inplace=True)
        data.loc[(data.FirstTimeHomeBuyer.isna()) & data.LoanPurpose.isin(
            ['Refinance', 'CashOutRefi']), 'FirstTimeHomeBuyer'] = 'Not Applicable'
        data['PMI'].fillna(0, inplace=True)
        data = data[data.time_to_event != 0]

        data.dropna(inplace=True)
        data['ClosingDt'] = pd.to_datetime(data['ClosingDt']).dt.year

        data = data.groupby('ClosingDt', group_keys=False).apply(lambda x: x.sample(frac=0.5))

        events = np.where(data.event.values == "default", 0, np.where(data.event.values == "prepayment", 1, 2)).reshape(
            (-1, 1))

        time_to_event = data.time_to_event.values

        time_to_event_shape = time_to_event.shape
        time_to_event = time_to_event.reshape(time_to_event_shape[0], 1)
        data['isSingleBorrower'] = data['isSingleBorrower'].map({0: 'No', 1: 'Yes'})
        features = data.drop(
            columns=['LoanNumber', 'time_to_default', 'time_to_prepayment', 'State', 'ClosingDt', 'event',
                     'time_to_event'], axis=1)
        categorical = [col for col in features.columns if features[col].dtype == "O"]
        quantitative = set(features.columns) - set(categorical)
        # features = np.repeat(features[:, np.newaxis, :], lifetime.shape[1], axis=1)

        categorical_dummies = pd.get_dummies(features.loc[:, list(categorical)], columns=categorical, drop_first=True)

        sc = StandardScaler()
        standardized_quantitative = pd.DataFrame(sc.fit_transform(features.loc[:, list(quantitative)]),
                                                 columns=list(quantitative))
        model_input = pd.concat(
            [categorical_dummies.reset_index(drop=True), standardized_quantitative.reset_index(drop=True)],
            axis=1)

        model_input.replace({False: 0, True: 1}, inplace=True)

        features_shape = model_input.shape

        features_num = model_input.values
        features_num = pm.MutableData('features_num', features_num)
        events = pm.MutableData('events', events)

        theta_D = pm.Normal('theta_D', mu=0, sigma=100, shape=features_shape[1])
        theta_P = pm.Normal('theta_P', mu=0, sigma=100, shape=features_shape[1])
        mu_D = pm.Normal('mu_D', mu=0, sigma=10)
        mu_P = pm.Normal('mu_P', mu=0, sigma=10)
        sigma_D = pm.Exponential('sigma_D', .01)
        sigma_P = pm.Exponential('sigma_P', .01)

        def logp(time_to_event, mu_P, mu_D, sigma_P,
                 sigma_D, theta_D, theta_P, event,
                 features):
            failureRate = where(
                pt.eq(event, 0),
                computeFailureRate(sigma_D, mu_D, time_to_event, theta_D, features),
                where(pt.eq(event, 1),
                      computeFailureRate(sigma_P, mu_P, time_to_event, theta_P, features),
                      0)
            )
            defaultSurvival = getSurvival(time_to_event, mu_D, sigma_D,
                                          theta_D, features)
            prepaymentSurival = getSurvival(time_to_event, mu_P,
                                            sigma_P, theta_P, features)
            return failureRate - defaultSurvival - prepaymentSurival

#         l = logp(time_to_event, mu_P, mu_D,
#                  sigma_P, sigma_D, theta_D, theta_P,
#                  events, features_num)
        likelihood = pm.CustomDist('LL', mu_P, mu_D,
                                   sigma_P, sigma_D, theta_D, theta_P,
                                   events, features_num,
                                   logp=logp,
                                   observed=time_to_event)

        trace = pm.sampling_jax.sample_numpyro_nuts(draws=1000,tune=1000, chains=4,chain_method='parallel')
#         trace = pm.sample(draws=1000, tune=1000, chains=4)
    return trace

One chain seems to be stuck, while others are sampling fine, how do I diagnose the issue?
@ricardoV94 @twiecki