Multivariate LogNormal distribution

With the new version I was finally able to get something that works. I am posting it here so that others may benefit from it as well. Thank you @ricardoV94 for your help, and please feel free to correct my code below if you think there are any mistakes.

def mvlognormal_dist(mu, cov, size):
    return pt.exp(pm.MvNormal.dist(mu, cov, size=size))


def mvlognormal_moment(rv, size, mu, cov, *args, **kwargs):
    return pt.exp(pm.MvNormal.moment(rv, size, mu, cov))


def mvlognormal_logp(value, mu, cov, *args, **kwargs):
    res = pt.switch(pt.gt(value, 0.0), pt.log(value), 1) # 1 is a dummy value
    product = pt.prod(value, axis=-1)
    product = pt.switch(pt.neq(product, 0.0), product, 1)
    logp_addition = pt.switch(pt.all(pt.gt(value, 0.0), axis=-1), 0, -np.inf)
    return pm.logp(pm.MvNormal.dist(mu, cov), res) / product + logp_addition


class MvLogNormal:
    def __new__(cls, name, mu, cov, **kwargs):
        return pm.CustomDist(
            name,
            mu,
            cov,
            dist=mvlognormal_dist,
            moment=mvlognormal_moment,
            logp=mvlognormal_logp,
            ndim_supp=1,
            **kwargs
        )

    @classmethod
    def dist(cls, mu, cov, **kwargs):
        return pm.CustomDist.dist(
            mu,
            cov,
            class_name="MvLogNormal",
            dist=mvlognormal_dist,
            moment=mvlognormal_moment,
            logp=mvlognormal_logp,
            ndim_supp=1,
            **kwargs
        )
2 Likes