EDIT: use chol @ ...
instead of MvNormal
FYI, here is an implementation that I tested and which seems to work. It is designed to deal with Normal and LogNormal distributions, which should already cover many practical cases.
import numpy as np
from scipy.stats import norm, lognorm
import pymc as pm
import arviz
def define_params_from_trace(trace, names, label, log_params=[]):
"""Define prior parameters from existing trace
(must be called within a pymc.Model context)
Parameters
----------
trace: pymc.InferenceData from a previous sampling
names: parameter names to be redefined in a correlated manner
label: used to add a "{label}_params" dimension to the model, and name a new "{label}_iid" for an i.i.d Normal distribution used to generate the prior
log_params: optional, list of parameter names (subset of `names`)
this parameters are assumed to have a long tail and will be normalized with a log-normal assumption
Returns
-------
list of named tensors defined in the function
Aim
---
Conserve both the marginal distribution and the covariance across parameters
Method
------
The specified parameters are extracted from the trace and are fitted
with a scipy.stats.norm or scipy.stats.lognorm distribution (as specified by `log_params`).
They are then normalized to that their marginal posterior distribution follows ~N(0,1).
Their joint posterior is approximated as MvNormal, and the parameters are subsequently
transformed back to their original mean and standard deviation
(and the exponent is taken in case of a log-normal parameter)
"""
# normalize the parameters so that their marginal distribution follows N(0, 1)
params = arviz.extract(trace.posterior[names])
fits = []
for name in names:
# find distribution parameters and normalize the variable
x = params[name].values
if name in log_params:
sigma, loc, exp_mu = lognorm.fit(x)
mu = np.log(exp_mu)
x[:] = (np.log(x - loc) - mu) / sigma
fits.append((mu, sigma, loc))
else:
fitted = mu, sigma = norm.fit(x)
x[:] = (x - mu) / sigma
fits.append((mu, sigma))
# add a model dimension
model = pm.modelcontext(None)
dim = f'{label}_params'
if dim not in model.coords:
model.add_coord(dim, names)
# Define a MvNormal
# mu = params.mean(axis=0) # zero mean
cov = np.cov([params[name] for name in names], rowvar=True)
chol = np.linalg.cholesky(cov)
# MvNormal may lead to convergence issues
# mv = pm.MvNormal(label+'_mv', mu=np.zeros(len(names)), chol=chol, dims=dim)
# Better to sample from an i.i.d R.V.
coparams = pm.Normal(label+'_iid', dims=dim)
# ... and introduce correlations by multiplication with the cholesky factor
mv = chol @ coparams
# Transform back to parameters with proper mu, scale and possibly log-normal
named_tensors = []
for i,(name, fitted) in enumerate(zip(names, fits)):
if name in log_params:
mu, sigma, loc = fitted
tensor = pm.math.exp(mv[i] * sigma + mu) + loc
else:
mu, sigma = fitted
tensor = mv[i] * sigma + mu
named_tensors.append(pm.Deterministic(name, tensor))
return named_tensors