Hi,

I’m following the tutorial on how to use JAX functions into pymc and I have a question on how to obtain pointwise log likelihoods for each observation.

I modified the code in this section as follows:

```
def logp(emission_observed, emission_signal,
emission_noise, logp_initial_state, logp_transition):
return hmm_logp_op(
emission_observed,
emission_signal,
emission_noise,
logp_initial_state,
logp_transition,
)
with pm.Model() as model:
emission_signal = pm.Normal("emission_signal", 0, 1)
emission_noise = pm.HalfNormal("emission_noise", 1)
p_initial_state = pm.Dirichlet("p_initial_state", np.ones(3))
logp_initial_state = pt.log(p_initial_state)
p_transition = pm.Dirichlet("p_transition", np.ones(3), size=3)
logp_transition = pt.log(p_transition)
# use DensityDist instead of Potential
hmm_logp_dist = pm.DensityDist('hmm_logp_dist', emission_signal, emission_noise, logp_initial_state, logp_transition,
logp=logp, observed=emission_observed)
with model:
idata = pm.sample(chains=2, cores=1, idata_kwargs={"log_likelihood": True})
```

Sampling works correctly and the results match the ones in the tutorial.

But, when I run `az.loo(idata)`

I get `UserWarning: The point-wise LOO is the same with the sum LOO, please double check the Observed RV in your model to make sure it returns element-wise logp.`

even though the `emission_observed`

array has 70 observations.

I tried to change the following function so that it sums across observations rather than summing everything together, but it still results in the point-wise LOO being the same as the sum LOO.

```
def vec_hmm_logp(*args):
vmap = jax.vmap(
hmm_logp,
# Only the first argument, needs to be vectorized
in_axes=(0, None, None, None, None),
)
# For simplicity we sum across observations
return jnp.sum(vmap(*args), axis=0)
```

Can you provide any help on this? What needs to change in the tutorial to get the element-wise loo values?

Thanks!