Model segmental linear regression

Here’s the code for piecewise linear - but your data needs piecewise constant. See if you can give that a go.

import pymc as pm
import numpy as np
import arviz as az
import matplotlib.pyplot as plt

# Simulate data
t = np.arange(0, 300)
y = np.random.randn(300)
y[0:100] += 10  # breakpoint one at 100
y[200:299] += 20  # breakpoint two at 200

# Compute the piecewise linear function based on knots
def piecewise_linear(x, beta_0, beta_1, beta_knots, knots):
    # Start with the intercept and slope for the initial region
    y = beta_0 + beta_1 * x
    # Add terms for each knot
    for i, knot in enumerate(knots):
        y += beta_knots[i] * T(x, knot)
    return y

# Helper function T(x, knot) to determine if x is beyond a given knot
def T(x, knot):
    return np.maximum(0, x - knot)
        
# Define the PyMC model
with pm.Model() as model:
    # Define breakpoints (knots)
    knots = [100, 200]
    
    # Define parameters
    beta_0 = pm.Normal("beta_0", mu=10, sigma=5)  # intercept
    beta_1 = pm.Normal("beta_1", mu=0, sigma=1)   # slope for the first segment
    
    # Coefficients for regions beyond the initial one
    beta_knots = pm.Normal("beta_knots", mu=0, sigma=1, shape=len(knots))

    # Expected value of y based on the piecewise function
    mu = piecewise_linear(t, beta_0, beta_1, beta_knots, knots)

    # Likelihood (sampling distribution) of observations
    sigma = pm.HalfNormal("sigma", sigma=1)
    y_obs = pm.Normal("y_obs", mu=mu, sigma=sigma, observed=y)

    # Inference
    idata = pm.sample()
    idata.extend(pm.sample_posterior_predictive(idata))

# Plotting
plt.figure(figsize=(10, 6))
plt.plot(t, y, "o", markersize=3, label="Observed Data")

# Plot posterior predictive mean
posterior_predictive_mean = idata.posterior_predictive["y_obs"].mean(dim=["chain", "draw"])
plt.plot(t, posterior_predictive_mean, label="Posterior Predictive Mean", color="orange")

# Plot HDI using arviz
az.plot_hdi(t, idata.posterior_predictive["y_obs"], hdi_prob=0.94, color="lightblue", smooth=False)
plt.xlabel("Time")
plt.ylabel("y")
plt.legend()
plt.title("Piecewise Linear Regression with Posterior Predictive and 94% HDI")
plt.show()