I am trying to follow this tutorial to implement a custom
Distribution that uses a wrapped Jax function to compute the log likelihoods. From what I understand,
Distribution.logp() should return a vector of element-wise log likelihoods. However, in this tutorial, the vectorized jax function is summed to return the sum of log likelihoods. A jax function needs return a scalar for it to be used with
grad, and I think that’s why the summation was used. Therefore, if I want to implement the method in the tutorial, my
logp can only return the sum of log-likelihoods.
My question is: are there any consequences if the
logp method of my custom
Distribution returns the sum of log-likelihoods, instead of element-wise likelihoods? It shouldn’t affect the gradients, but are the element-wise log-likelihoods used elsewhere in PyMC?