I am running the PyMC example code for SMC and there seems to be an issue with the number of draws for sample_smc
. I ran the code from the example (both in a notebook and as a script) with 5 chains.
import arviz as az
import numpy as np
import pymc as pm
import pytensor.tensor as pt
import matplotlib.pyplot as plt
n = 4
mu1 = np.ones(n) * (1.0 / 2)
mu2 = -mu1
stdev = 0.1
sigma = np.power(stdev, 2) * np.eye(n)
isigma = np.linalg.inv(sigma)
dsigma = np.linalg.det(sigma)
w1 = 0.1 # one mode with 0.1 of the mass
w2 = 1 - w1 # the other mode with 0.9 of the mass
def two_gaussians(x):
log_like1 = (
-0.5 * n * pt.log(2 * np.pi)
- 0.5 * pt.log(dsigma)
- 0.5 * (x - mu1).T.dot(isigma).dot(x - mu1)
)
log_like2 = (
-0.5 * n * pt.log(2 * np.pi)
- 0.5 * pt.log(dsigma)
- 0.5 * (x - mu2).T.dot(isigma).dot(x - mu2)
)
return pm.math.logsumexp([pt.log(w1) + log_like1, pt.log(w2) + log_like2])
def main():
with pm.Model() as model:
X = pm.Uniform(
"X",
shape=n,
lower=-2.0 * np.ones_like(mu1),
upper=2.0 * np.ones_like(mu1),
initval=-1.0 * np.ones_like(mu1),
)
llk = pm.Potential("llk", two_gaussians(X))
idata_04 = pm.sample_smc(2000)
ax = az.plot_trace(idata_04, compact=True, kind="rank_vlines")
ax[0, 0].axvline(-0.5, 0, 0.9, color="k")
ax[0, 0].axvline(0.5, 0, 0.1, color="k")
plt.show()
print(f'Estimated w1 = {np.mean(idata_04.posterior["X"] < 0).item():.3f}')
if __name__=="__main__":
main()
In the example code, the number of draws is set to 2000, but when I get the following warning, and the plots do not look as expected.
/opt/miniconda3/envs/pymc/lib/python3.13/site-packages/arviz/data/base.py:272: UserWarning: More chains (5) than draws (1). Passed array should have shape (chains, draws, *shape) warnings.warn(
Does anyone know what might be going wrong here?