Hi everyone!
I am trying to follow this tutorial to implement a custom Distribution
that uses a wrapped Jax function to compute the log likelihoods. From what I understand, Distribution.logp()
should return a vector of element-wise log likelihoods. However, in this tutorial, the vectorized jax function is summed to return the sum of log likelihoods. A jax function needs return a scalar for it to be used with grad
, and I think that’s why the summation was used. Therefore, if I want to implement the method in the tutorial, my logp
can only return the sum of log-likelihoods.
My question is: are there any consequences if the logp
method of my custom Distribution
returns the sum of log-likelihoods, instead of element-wise likelihoods? It shouldn’t affect the gradients, but are the element-wise log-likelihoods used elsewhere in PyMC?
Thank you!