Hi,
I’m following the tutorial on how to use JAX functions into pymc and I have a question on how to obtain pointwise log likelihoods for each observation.
I modified the code in this section as follows:
def logp(emission_observed, emission_signal,
emission_noise, logp_initial_state, logp_transition):
return hmm_logp_op(
emission_observed,
emission_signal,
emission_noise,
logp_initial_state,
logp_transition,
)
with pm.Model() as model:
emission_signal = pm.Normal("emission_signal", 0, 1)
emission_noise = pm.HalfNormal("emission_noise", 1)
p_initial_state = pm.Dirichlet("p_initial_state", np.ones(3))
logp_initial_state = pt.log(p_initial_state)
p_transition = pm.Dirichlet("p_transition", np.ones(3), size=3)
logp_transition = pt.log(p_transition)
# use DensityDist instead of Potential
hmm_logp_dist = pm.DensityDist('hmm_logp_dist', emission_signal, emission_noise, logp_initial_state, logp_transition,
logp=logp, observed=emission_observed)
with model:
idata = pm.sample(chains=2, cores=1, idata_kwargs={"log_likelihood": True})
Sampling works correctly and the results match the ones in the tutorial.
But, when I run az.loo(idata)
I get UserWarning: The point-wise LOO is the same with the sum LOO, please double check the Observed RV in your model to make sure it returns element-wise logp.
even though the emission_observed
array has 70 observations.
I tried to change the following function so that it sums across observations rather than summing everything together, but it still results in the point-wise LOO being the same as the sum LOO.
def vec_hmm_logp(*args):
vmap = jax.vmap(
hmm_logp,
# Only the first argument, needs to be vectorized
in_axes=(0, None, None, None, None),
)
# For simplicity we sum across observations
return jnp.sum(vmap(*args), axis=0)
Can you provide any help on this? What needs to change in the tutorial to get the element-wise loo values?
Thanks!