With the new version I was finally able to get something that works. I am posting it here so that others may benefit from it as well. Thank you @ricardoV94 for your help, and please feel free to correct my code below if you think there are any mistakes.
def mvlognormal_dist(mu, cov, size):
return pt.exp(pm.MvNormal.dist(mu, cov, size=size))
def mvlognormal_moment(rv, size, mu, cov, *args, **kwargs):
return pt.exp(pm.MvNormal.moment(rv, size, mu, cov))
def mvlognormal_logp(value, mu, cov, *args, **kwargs):
res = pt.switch(pt.gt(value, 0.0), pt.log(value), 1) # 1 is a dummy value
product = pt.prod(value, axis=-1)
product = pt.switch(pt.neq(product, 0.0), product, 1)
logp_addition = pt.switch(pt.all(pt.gt(value, 0.0), axis=-1), 0, -np.inf)
return pm.logp(pm.MvNormal.dist(mu, cov), res) / product + logp_addition
class MvLogNormal:
def __new__(cls, name, mu, cov, **kwargs):
return pm.CustomDist(
name,
mu,
cov,
dist=mvlognormal_dist,
moment=mvlognormal_moment,
logp=mvlognormal_logp,
ndim_supp=1,
**kwargs
)
@classmethod
def dist(cls, mu, cov, **kwargs):
return pm.CustomDist.dist(
mu,
cov,
class_name="MvLogNormal",
dist=mvlognormal_dist,
moment=mvlognormal_moment,
logp=mvlognormal_logp,
ndim_supp=1,
**kwargs
)