JAX wrapping and element-wise LOO

I’m following the tutorial on how to use JAX functions into pymc and I have a question on how to obtain pointwise log likelihoods for each observation.

I modified the code in this section as follows:

def logp(emission_observed, emission_signal,
         emission_noise, logp_initial_state, logp_transition):
    return hmm_logp_op(

with pm.Model() as model:
    emission_signal = pm.Normal("emission_signal", 0, 1)
    emission_noise = pm.HalfNormal("emission_noise", 1)

    p_initial_state = pm.Dirichlet("p_initial_state", np.ones(3))
    logp_initial_state = pt.log(p_initial_state)

    p_transition = pm.Dirichlet("p_transition", np.ones(3), size=3)
    logp_transition = pt.log(p_transition)

    # use DensityDist instead of Potential
    hmm_logp_dist = pm.DensityDist('hmm_logp_dist', emission_signal, emission_noise, logp_initial_state, logp_transition, 
                           logp=logp, observed=emission_observed)

with model:
    idata = pm.sample(chains=2, cores=1, idata_kwargs={"log_likelihood": True})

Sampling works correctly and the results match the ones in the tutorial.

But, when I run az.loo(idata) I get UserWarning: The point-wise LOO is the same with the sum LOO, please double check the Observed RV in your model to make sure it returns element-wise logp. even though the emission_observed array has 70 observations.

I tried to change the following function so that it sums across observations rather than summing everything together, but it still results in the point-wise LOO being the same as the sum LOO.

def vec_hmm_logp(*args):
    vmap = jax.vmap(
        # Only the first argument, needs to be vectorized
        in_axes=(0, None, None, None, None),
    # For simplicity we sum across observations
    return jnp.sum(vmap(*args), axis=0)

Can you provide any help on this? What needs to change in the tutorial to get the element-wise loo values?

I was wondering if anybody can provide some help with this?

Do you have a single chain with 70 observation or 70 chains with x observations? Technically speaking 70 observations of a single chain would belong to a single multivariate distribution and hence have a single non-decomposable logp.

You can think of them as conditional probabilities. I am not sure whether LOO cares/handles that distinction. @OriolAbril may be able to chime in

My understanding was that, when you have n observations, elpd (y_1,...y_n) is different to elpd(y_1) + ... + elpd(y_n)

I guess that in the tutorial, since you are trying to estimate the parameters of a single HMM process it makes to have elpd (y_1,...y_n) when computing loo. I may be wrong though?