How to track a 'nan energy'?

The energy problem, if it is not from invalid start value (i.e., model.test_point) causing non-finite logp, it is usually due to gradient being non-finite. It could be difficult to diagnose, so here would be all the possible step to identify the problem:

with pm.Model() as model:
    # your model definition

# make sure all test_value are finite
print(model.test_point)

# make sure all logp are finite
model.check_test_point()

with model:
    step = pm.HamiltonianMC()

q0 = step._logp_dlogp_func.dict_to_array(model.test_point)
p0 = step.potential.random()
# make sure the potentials are all finite
print(p0)

start = step.integrator.compute_state(q0, p0)
print(start.energy)

# make sure model logp and its gradients are finite
logp, dlogp = step.integrator._logp_dlogp_func(q0)
print(logp)
print(dlogp)

# make sure velocity is finite
v = step.integrator._potential.velocity(p0)
print(v)
kinetic = step.integrator._potential.energy(p0, velocity=v)
print(kinetic)

Any time you see an array containing non-finite element, you can map it back into a dict to see which RV is causing the problem. For example, say the dlogp contain non-finite value:

step._logp_dlogp_func.array_to_dict(dlogp)

And adjust the prior for that RV accordingly.

Hope this is clear!

4 Likes