Is it OK to return the sum of log likelihoods in `logp()` of a `Distribution`?

Hi everyone!

I am trying to follow this tutorial to implement a custom Distribution that uses a wrapped Jax function to compute the log likelihoods. From what I understand, Distribution.logp() should return a vector of element-wise log likelihoods. However, in this tutorial, the vectorized jax function is summed to return the sum of log likelihoods. A jax function needs return a scalar for it to be used with grad, and I think that’s why the summation was used. Therefore, if I want to implement the method in the tutorial, my logp can only return the sum of log-likelihoods.

My question is: are there any consequences if the logp method of my custom Distribution returns the sum of log-likelihoods, instead of element-wise likelihoods? It shouldn’t affect the gradients, but are the element-wise log-likelihoods used elsewhere in PyMC?

Thank you!

Note that in that notebook we don’t create any distribution, we add the logps directly as a Potential.

To answer your question it depends… If you are using your distribution inside something like a Mixture or if you want to do model comparison you need the logps to have the right shape. Otherwise, it doesn’t really matter.

There is a guide on the PyMC Labs on how to wrap JAX Ops more easily, and it includes some non-scalar examples as well: How to use JAX ODEs and Neural Networks in PyMC - PyMC Labs

Thank you so much, @ricardoV94! I will take a look at the non-scalar examples.