We just released pymc 5.2.0, which can infer logp of transformed multivariate distributions.
I am able to run the following snippet:
import pymc as pm
import numpy as np
import arviz as az
def mvlognormal_dist(mu, cov, size):
return pm.math.exp(pm.MvNormal.dist(mu, cov, size=size))
class MvLogNormal:
def __new__(cls, name, mu, cov, **kwargs):
return pm.CustomDist(name, mu, cov, dist=mvlognormal_dist, ndim_supp=1, **kwargs)
@classmethod
def dist(cls, mu, cov, **kwargs):
return pm.CustomDist.dist(mu, cov, class_name="MvLogNormal", dist=mvlognormal_dist, ndim_supp=1, **kwargs)
with pm.Model() as m:
mu = pm.Normal("mu", shape=(3,))
x = MvLogNormal("x", mu=mu, cov=np.eye(3)*1e-3, observed=np.exp([1, 2, 3]))
trace = pm.sample()
az.summary(trace)
mean sd hdi_3% hdi_97% ... mcse_sd ess_bulk ess_tail r_hat
mu[0] 0.999 0.031 0.940 1.055 ... 0.0 7103.0 3448.0 1.0
mu[1] 1.999 0.032 1.935 2.054 ... 0.0 6510.0 3054.0 1.0
mu[2] 2.997 0.032 2.937 3.055 ... 0.0 5575.0 2972.0 1.0