B-splines regression, out-of-sample predictions

First of all, make sure to use the development version of Bambi. We changed many internals recently and it affects how you build families.
Second, I’m going to create the custom family twice. The first will be simpler, but it doesn’t give us posterior predictive sampling. The second is a little more complicated and you need to know a little more about the internals, but it’s implemented exactly the same way than a built-in family. This means that if you want, you can use this code to open a PR to add a new skewnormal family.

import arviz as az
import bambi as bmb
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy as sp
from scipy import stats

# simulate data
knots_n = 12
size = 100
rng = np.random.default_rng(1234)

knots = np.linspace(-0.5, 1.5, knots_n)
knots_coeff = rng.normal(size=knots_n)

spline = sp.interpolate.BSpline(knots, knots_coeff, 3, extrapolate=False)

x = rng.uniform(0, 1, size)
y = spline(x) + stats.skewnorm(a=0.9, loc=0, scale=0.5).rvs(size=size, random_state=rng)

data = pd.DataFrame({"x": x, "y": y})

fig, ax = plt.subplots(figsize=(8, 6))

x_plot = np.linspace(0, 1, 200)
ax.plot(x_plot, spline(x_plot), c='k')
ax.scatter(x, y, alpha=0.75, zorder=5)
ax.set_xlim(0, 1);

Now we can create the family and the model. See priors are not specified in the family anymore.

# Define the custom family
likelihood = bmb.Likelihood("SkewNormal", params=["mu", "sigma", "alpha"], parent="mu")
link = bmb.Link("identity")
family = bmb.Family("skewnormal", likelihood, link)

# Define the priors for the auxiliary parameters (all the non-parent params)
priors = {
    "sigma": bmb.Prior("HalfStudentT", nu=4, sigma=1),
    "alpha": bmb.Prior("Normal", mu=0, sigma=5),
# Use them in the model
model = bmb.Model("y ~ 1 + bs(x, df=9)", data, family=family, priors=priors, dropna=True)
idata = model.fit()

(see we get mean predictions without having to do anything)

new_x = np.linspace(0, 1, num=100)
model.predict(idata, kind="mean", data=pd.DataFrame({"x": new_x}))

y_pred_mean = idata.posterior["y_mean"].mean(("chain", "draw")).to_numpy()
hdi_data = idata.posterior["y_mean"].quantile([0.025, 0.975], ("chain", "draw")).to_numpy()

fig, ax = plt.subplots(figsize=(8, 6))
ax.plot(x_plot, spline(x_plot), color="black")
ax.scatter(x, y, alpha=0.75, zorder=5)
ax.set_xlim(0, 1);
ax.plot(new_x, y_pred_mean, color="C3")
ax.fill_between(new_x, hdi_data[0], hdi_data[1], alpha=0.4, color="C1");

But notice that if you want draws from the posterior predictive distribution, it will fail. This is because it’s not implemented. Bambi does not use PyMC capabilities to get draws from the posterior predictive distribution. It implements them manually (this is another debate).

model.predict(idata, kind="pps")
AttributeError: 'Family' object has no attribute 'posterior_predictive'

So if we really want it, we need to implement it. See that it works

from bambi.utils import get_aliased_name
from bambi.families.univariate import UnivariateFamily
import xarray as xr

class SkewNormalFamily(UnivariateFamily):
    SUPPORTED_LINKS = {"mu": ["identity"], "sigma": ["log"], "alpha": ["log"]}

    def posterior_predictive(self, model, posterior, **kwargs):
        response_name = get_aliased_name(model.response_component.response_term)
        mean = posterior[response_name + "_mean"]
        sigma = posterior[response_name + "_sigma"]
        alpha = posterior[response_name + "_alpha"]
        return xr.apply_ufunc(stats.skewnorm.rvs, alpha, mean, sigma)

likelihood = bmb.Likelihood("SkewNormal", params=["mu", "sigma", "alpha"], parent="mu")
link = bmb.Link("identity")
family = SkewNormalFamily("skewnormal", likelihood, link)

# Define the priors for the auxiliary parameters (all the non-parent params)
priors = {
    "sigma": bmb.Prior("HalfStudentT", nu=4, sigma=1),
    "alpha": bmb.Prior("Normal", mu=0, sigma=5),
# Use them in the model
model = bmb.Model("y ~ 1 + bs(x, df=9)", data, family=family, priors=priors, dropna=True)
idata = model.fit()

model.predict(idata, kind="pps")