Hi everyone!

I am trying to follow this tutorial to implement a custom `Distribution`

that uses a wrapped Jax function to compute the log likelihoods. From what I understand, `Distribution.logp()`

should return a vector of element-wise log likelihoods. However, in this tutorial, the vectorized jax function is summed to return the sum of log likelihoods. A jax function needs return a scalar for it to be used with `grad`

, and I think that’s why the summation was used. Therefore, if I want to implement the method in the tutorial, my `logp`

can only return the sum of log-likelihoods.

My question is: are there any consequences if the `logp`

method of my custom `Distribution`

returns the sum of log-likelihoods, instead of element-wise likelihoods? It shouldn’t affect the gradients, but are the element-wise log-likelihoods used elsewhere in PyMC?

Thank you!