Multivariate Normal Imputation Error

Hello,

I am currently working through converting the code in the pymc-resources/Rethinking_2 directory to pymc v4. I have forked the repo and have been working on updating the code as I read the book. In code block 15.22 I get this error message:

Here are the package versions I am running:
image

It appears that this was a possibility on pymc3 according to the issue/discussion on this discourse link. Is there a different way that imputing multivariate normal random variables should be handled in v4?

Thanks

I had the same question you had and ended up getting a response in here that worked, so I just wanted to share

d = pd.read_csv('https://raw.githubusercontent.com/rmcelreath/rethinking/master/data/milk.csv', sep=";")

d["neocortex.prop"] = d["neocortex.perc"] / 100
d["logmass"] = np.log(d["mass"])

K = stats.zscore(d["kcal.per.g"],  nan_policy = 'omit')
B = stats.zscore(d["neocortex.prop"], nan_policy = 'omit')
M = stats.zscore(d["logmass"],  nan_policy = 'omit')

MB_vals = np.stack([M,B], axis=1)

with pm.Model() as m15_7:
    sigma = pm.Exponential("sigma", 1)
    muM = pm.Normal("muM", 0, 0.5)
    muB = pm.Normal("muB", 0, 0.5)
    bM = pm.Normal("bM", 0, 0.5)
    bB = pm.Normal("bB", 0, 0.5)
    a = pm.Normal("a", 0, 0.5)

    chol, _, _ = pm.LKJCholeskyCov(
        "chol_cov", n=2, eta=2, sd_dist=pm.Exponential.dist(1), compute_corr=True
    )


    # Create a vector of flat variables for the unobserved components of the MvNormal
    MB_impute = pm.Flat("MB_impute", shape=(np.isnan(MB_vals).sum(), ))



    # Create the symbolic value of MB, combining observed data and unobserved variables
    MB = at.as_tensor(MB_vals)

    MB = pm.Deterministic("MB", at.set_subtensor(MB[np.isnan(MB_vals)], MB_impute))

    # Add a Potential with the logp of the variable conditioned on `MB`

    pm.Potential("MB_logp",  pm.logp(value = MB, rv=pm.MvNormal.dist(mu = at.stack([muM, muB]),chol = chol))) 
    mu = a + bB * MB[:, 1] + bM * M
    Ki = pm.Normal("Ki", mu, sigma, observed=K)
    idata_m15_7 = pm.sample()
2 Likes