Hello everyone,
I am currently working on a model using NormalMixture() as Likelihood. I encounter some (in my opinion) strange behaviour of sample_prior_predictive(). I expect to see a bunch of normal mixture distributions, as in my second plot (green lines). However, neither does plot_ppc() from arviz produce those (first plot), nor are the prior predictive samples of the observed variable distributed that way (red lines in second plot). Is this a bug, or did I misunderstand anything?
I installed pymc3 according to the explanation (for MacOS) on GitHub.
I’m running:
pymc: 3.11.4 (pip)
theano-pymc: 1.1.2 (pip)
arviz: 0.11.2 (pip)
python: 3.9.7
OS: macOS Big Sur (M1 chip)
A short example of my problem:
import pymc3 as pm
import arviz as az
import numpy as np
import matplotlib.pyplot as plt
random_seed = np.random.seed(123)
# create some multimodal data
data = pm.NormalMixture.dist(w=[0.75,0.25],mu=[212,213],sigma=0.05).random(size=1000)
with pm.Model() as model:
# priors
# positions of normal distributions
pos = pm.Normal("pos",mu=data.mean(),sigma=1,shape=2)
# weights sum to 1
w1 = pm.Uniform("w1",lower=0,upper=1)
w2 = 1-w1
weights = [w1,w2]
# standard deviations of normal distributions
sigma = pm.Uniform("sigma",lower=0,upper=1,shape=2)
# Likelihood
obs = pm.NormalMixture("obs",w=weights,mu=pos,sigma=sigma,observed=data)
# strange behaviour?
priorpc = pm.sample_prior_predictive(random_seed=random_seed)
az.plot_ppc(az.from_pymc3(prior=priorpc),group="prior")
plt.show()
# what I would expect
rand_indices = np.random.choice(np.arange(500),size=50,replace=False)
for r in rand_indices:
w1 = priorpc["w1"][r]
pos = priorpc["pos"][r]
sigma = priorpc["sigma"][r]
mixture_sample = pm.NormalMixture.dist(w=[w1,1-w1],mu = pos,sigma = sigma).random(size=1000)
# in green: sampled mixture distributions (that look like mixture distributions)
az.plot_kde(mixture_sample,plot_kwargs={"color":"green"})
# in red: corresponding arrays of observed variable in prior predicitve dict
az.plot_kde(priorpc["obs"][r],plot_kwargs={"color":"red"})
az.plot_kde(data,plot_kwargs={"color":"black"})
plt.show()
my output:
Thank you very much in advance for any help!