I am sampling using the Metropolis sampler, so the transform should not be a problem. I did not understand very well from the documentation what the moment function should be, but I tried implementing it using
def mvlognormal_moment(rv, size, mu, cov):
return pt.exp(pm.MvNormal.moment(rv, size, mu, cov))
Still, the following script gives me the error mvlognormal_moment() takes 4 positional arguments but 5 were given. Why does it happen, and what should the moment function do?
import pymc as pm
import numpy as np
import arviz as az
import pytensor.tensor as pt
def mvlognormal_dist(mu, cov, size):
return pt.exp(pm.MvNormal.dist(mu, cov, size=size))
def mvlognormal_moment(rv, size, mu, cov):
return pt.exp(pm.MvNormal.moment(rv, size, mu, cov))
class MvLogNormal:
def __new__(cls, name, mu, cov, **kwargs):
return pm.CustomDist(
name,
mu,
cov,
dist=mvlognormal_dist,
moment=mvlognormal_moment,
ndim_supp=1,
**kwargs
)
@classmethod
def dist(cls, mu, cov, **kwargs):
return pm.CustomDist.dist(
mu,
cov,
class_name="MvLogNormal",
dist=mvlognormal_dist,
moment=mvlognormal_moment,
ndim_supp=1,
**kwargs
)
with pm.Model() as m:
x = MvLogNormal("x", mu=np.zeros(3), cov=np.eye(3)*1e-3,)
y = pm.MvNormal("y", mu=x, cov=np.eye(3), observed=np.exp([1, 2, 3]))
trace = pm.sample(chains=1, step=pm.Metropolis())
az.summary(trace)