How best to use a posterior sample as a prior for a future integration in another model?

EDIT: use chol @ ... instead of MvNormal

FYI, here is an implementation that I tested and which seems to work. It is designed to deal with Normal and LogNormal distributions, which should already cover many practical cases.

import numpy as np
from scipy.stats import norm, lognorm
import pymc as pm
import arviz

def define_params_from_trace(trace, names, label, log_params=[]):
    """Define prior parameters from existing trace

    (must be called within a pymc.Model context)

    Parameters
    ----------
    trace: pymc.InferenceData from a previous sampling
    names: parameter names to be redefined in a correlated manner
    label: used to add a "{label}_params" dimension to the model, and name a new "{label}_iid" for an i.i.d Normal distribution used to generate the prior
    log_params: optional, list of parameter names (subset of `names`)
        this parameters are assumed to have a long tail and will be normalized with a log-normal assumption
    
    Returns
    -------
    list of named tensors defined in the function

    Aim
    ---
    Conserve both the marginal distribution and the covariance across parameters

    Method
    ------
    The specified parameters are extracted from the trace and are fitted 
    with a scipy.stats.norm or scipy.stats.lognorm distribution (as specified by `log_params`).
    They are then normalized to that their marginal posterior distribution follows ~N(0,1).
    Their joint posterior is approximated as MvNormal, and the parameters are subsequently 
    transformed back to their original mean and standard deviation 
    (and the exponent is taken in case of a log-normal parameter)
    """

    # normalize the parameters so that their marginal distribution follows N(0, 1)
    params = arviz.extract(trace.posterior[names])
    fits = []
    for name in names:
        # find distribution parameters and normalize the variable
        x = params[name].values
        if name in log_params:
            sigma, loc, exp_mu = lognorm.fit(x)
            mu = np.log(exp_mu)
            x[:] = (np.log(x - loc) - mu) / sigma
            fits.append((mu, sigma, loc))
        else:
            fitted = mu, sigma = norm.fit(x)
            x[:] = (x - mu) / sigma
            fits.append((mu, sigma))

    # add a model dimension 
    model = pm.modelcontext(None)
    dim = f'{label}_params'
    if dim not in model.coords:
        model.add_coord(dim, names)

    # Define a MvNormal
    # mu = params.mean(axis=0) # zero mean
    cov = np.cov([params[name] for name in names], rowvar=True)
    chol = np.linalg.cholesky(cov)
    # MvNormal may lead to convergence issues
    # mv = pm.MvNormal(label+'_mv', mu=np.zeros(len(names)), chol=chol, dims=dim)
    # Better to sample from an i.i.d R.V.
    coparams = pm.Normal(label+'_iid', dims=dim)
    # ... and introduce correlations by multiplication with the cholesky factor
    mv = chol @ coparams

    # Transform back to parameters with proper mu, scale and possibly log-normal
    named_tensors = []
    for i,(name, fitted) in enumerate(zip(names, fits)):
        if name in log_params:
            mu, sigma, loc = fitted
            tensor = pm.math.exp(mv[i] * sigma + mu) + loc
        else:
            mu, sigma = fitted
            tensor = mv[i] * sigma + mu
        named_tensors.append(pm.Deterministic(name, tensor))

    return named_tensors
2 Likes