How to debug a model that is not sampling?

I am working on a JAXified version of the Hierarchical Gaussian Filter for reinforcement learning that I try to sample with PyMC v4.0.1, after wrapping the function into an Aesara Op following this example.

The sampling is working nicely using this approach, I was even surprised to see that PyMC outperforms Numpyro here. However, using a different response (logp) functions, NUTS get stuck and does not sample at all, and it is hard to figure out what is happening.

The model is defined as follows:

# Generate the Aesara Op corresponding to the model parameters and load the data
# as defaults arguments (partial function) for modularity (this should allow to chose
# arbitrary response (logp) functions accepting additional parameters) without changing
# the underlying Aesara Ops for the Hierarchical Gaussian Filter
hgf_logp_op = HGFDistribution(
    data=input_data,
    n_levels=2,
    model_type="continuous",
    response_function=hrd_behaviors,
    response_function_parameters=response_function_parameters
    )

with pm.Model() as hgf_model:

    omega_1 = pm.Normal("omega_1", -3.0, 2)
    omega_2 = pm.Normal("omega_2", -3.0, 2)
    mu_2 = pm.Normal("mu_2", -3.0, 2)
    bias = pm.Normal("bias", 0.0, 0.01)

    pm.Potential(
        "hgf_loglike",
        hgf_logp_op(
            omega_1=omega_1,
            omega_2=omega_2,
            omega_input=parameters["omega_input"],
            rho_1=parameters["rho_1"],
            rho_2=parameters["rho_2"],
            pi_1=parameters["pi_1"],
            pi_2=parameters["pi_2"],
            mu_1=parameters["mu_1"],
            mu_2=parameters["pi_2"],
            kappa_1=parameters["kappa_1"],
            bias=bias,
        ),
    )

From what I can see, the model is not raising errors during initialization

initial_point = hgf_model.initial_point()
initial_point
#{'omega_1': array(-3.),
#'omega_2': array(-3.),
#'mu_2': array(-3.),
#'bias': array(0.)}

pointslogs = hgf_model.point_logps(initial_point)
pointslogs
#{'omega_1': -1.61,
#'omega_2': -1.61,
#'mu_2': -1.61,
#'bias': 3.69,
#'hhgf_loglike': -145.64}

The underlying logp JAX functions can be jit()-ed and grad()-ed without errors, and the Aesara Ops wrapping this function (logp and grad_logp) are working correctly (as far as I can see). Also, a pure Numpyro approach can sample the JAX logp function.

Because this use case is very specific and I cannot provide a fully reproducible example, my question is then:

How to find out why NUTS is not sampling while my model is apparently working well?