I am working on a JAXified version of the Hierarchical Gaussian Filter for reinforcement learning that I try to sample with PyMC v4.0.1, after wrapping the function into an Aesara Op following this example.
The sampling is working nicely using this approach, I was even surprised to see that PyMC outperforms Numpyro here. However, using a different response (logp) functions, NUTS get stuck and does not sample at all, and it is hard to figure out what is happening.
The model is defined as follows:
# Generate the Aesara Op corresponding to the model parameters and load the data
# as defaults arguments (partial function) for modularity (this should allow to chose
# arbitrary response (logp) functions accepting additional parameters) without changing
# the underlying Aesara Ops for the Hierarchical Gaussian Filter
hgf_logp_op = HGFDistribution(
data=input_data,
n_levels=2,
model_type="continuous",
response_function=hrd_behaviors,
response_function_parameters=response_function_parameters
)
with pm.Model() as hgf_model:
omega_1 = pm.Normal("omega_1", -3.0, 2)
omega_2 = pm.Normal("omega_2", -3.0, 2)
mu_2 = pm.Normal("mu_2", -3.0, 2)
bias = pm.Normal("bias", 0.0, 0.01)
pm.Potential(
"hgf_loglike",
hgf_logp_op(
omega_1=omega_1,
omega_2=omega_2,
omega_input=parameters["omega_input"],
rho_1=parameters["rho_1"],
rho_2=parameters["rho_2"],
pi_1=parameters["pi_1"],
pi_2=parameters["pi_2"],
mu_1=parameters["mu_1"],
mu_2=parameters["pi_2"],
kappa_1=parameters["kappa_1"],
bias=bias,
),
)
From what I can see, the model is not raising errors during initialization
initial_point = hgf_model.initial_point()
initial_point
#{'omega_1': array(-3.),
#'omega_2': array(-3.),
#'mu_2': array(-3.),
#'bias': array(0.)}
pointslogs = hgf_model.point_logps(initial_point)
pointslogs
#{'omega_1': -1.61,
#'omega_2': -1.61,
#'mu_2': -1.61,
#'bias': 3.69,
#'hhgf_loglike': -145.64}
The underlying logp JAX functions can be jit()-ed and grad()-ed without errors, and the Aesara Ops wrapping this function (logp and grad_logp) are working correctly (as far as I can see). Also, a pure Numpyro approach can sample the JAX logp function.
Because this use case is very specific and I cannot provide a fully reproducible example, my question is then:
How to find out why NUTS is not sampling while my model is apparently working well?