Saving intermediate results using MCMC in pyMC4

Thanks, Thats a great suggestion. I think I got it to work, i’m just not sure if I am properly supplying the starting point for the iter_sample() method. The documentation on this is not great.

start dict
Starting point in parameter space (or partial point).

I have the following code, that is actually saving every 10th iteration. It’s not the actual likelihood I want to use, but just a self-sufficient example taken from tutorials and previous examples.

import aesara.tensor as at
import numpy as np
import pymc as pm
import cloudpickle
from pymc import iter_sample

n = 4

mu1 = np.ones(n) * (1.0 / 2)
mu2 = -mu1

stdev = 0.1
sigma = np.power(stdev, 2) * np.eye(n)
isigma = np.linalg.inv(sigma)
dsigma = np.linalg.det(sigma)

w1 = 0.1  # one mode with 0.1 of the mass
w2 = 1 - w1  # the other mode with 0.9 of the mass


def two_gaussians(x):
    log_like1 = (
            -0.5 * n * at.log(2 * np.pi)
            - 0.5 * at.log(dsigma)
            - 0.5 * (x - mu1).T.dot(isigma).dot(x - mu1)
    )
    log_like2 = (
            -0.5 * n * at.log(2 * np.pi)
            - 0.5 * at.log(dsigma)
            - 0.5 * (x - mu2).T.dot(isigma).dot(x - mu2)
    )
    return pm.math.logsumexp([at.log(w1) + log_like1, at.log(w2) + log_like2])

with pm.Model() as model:
    X = pm.Uniform(
        "X",
        shape=n,
        lower=-2.0 * np.ones_like(mu1),
        upper=2.0 * np.ones_like(mu1),
        initval=-1.0 * np.ones_like(mu1),
    )
    llk = pm.Potential("llk", two_gaussians(X))
    step = pm.Metropolis()
    iter_counter = 0
    for trace in iter_sample(draws=100,
                             start=model.initial_point(),
                             step=pm.Metropolis(),
                             model=model):
        iter_counter += 1
        if iter_counter % 10 == 0:
            print('Saving trace for iteration ' + str(iter_counter))
            with open('iteration' + str(iter_counter) + '.pkl', mode='wb') as file:
                cloudpickle.dump(trace, file)
1 Like