What were you doing in v3 to make plots? The arviz API hasn’t changed much (at all?), so my guess is you just need some pointers related to working with xarrays. It would be good to see your starting point.
You have only partial code, and the plotting code you posted will not make the plot you posted (there is nothing to get the HDI, draw sample trajectories, etc).
Here is a simple model and a plot of a conditional distribution. Perhaps it will be helpful?
# Generate fake data
rng = np.random.default_rng(123)
true_lengthscale = 0.2
true_eta = 2.0
true_cov = true_eta**2 * pm.gp.cov.ExpQuad(1, true_lengthscale)
# Add white noise to stabilise
true_cov += pm.gp.cov.WhiteNoise(1)
X = np.linspace(0, 2, 200)[:, None]
true_K = true_cov(X).eval()
y = pm.draw(
pm.MvNormal.dist(mu=np.zeros(len(true_K)), cov=true_K, shape=true_K.shape[0]), draws=1, random_seed=rng)
# Fit GP
with pm.Model() as m:
ls = pm.Gamma('ls', 2, 1)
eta = pm.Exponential('eta', 1)
cov = eta ** 2 * pm.gp.cov.ExpQuad(1, ls)
gp = pm.gp.Marginal(cov_func=cov)
sigma = pm.Exponential('sigma', 1)
y_hat = gp.marginal_likelihood('y', X=X, y=y, sigma=sigma)
idata = pm.sample(nuts_sampler='numpyro')
# Sample conditional distribution from new data
Xnew = np.linspace(0, 2, 100)[:, None]
with m:
f_star = gp.conditional("f_star", Xnew=Xnew)
idata = pm.sample_posterior_predictive(idata, var_names=['f_star'], predictions=True)
# Plot the results
fig, ax = plt.subplots(figsize=(14,4), dpi=144)
ax.plot(Xnew.ravel(),
idata.predictions.stack(sample=['chain', 'draw'])['f_star'].values,
alpha=0.1,
c='0.5')
ax.plot(X.ravel(), y, c='k')
plt.show()
Thanks very much, it’s helpful. Besides, when I use gaussian process, it takes a very long time to finish the process, if there any method can make it faster?