I am trying to recreate the Facebook Prophet model from scratch in PyMC. Unfortunately, I am already running into performance issues at the very beginning.
I am using the peyton_manning dataset (a uni-variate time series of 2905 rows) that is used in the FB Prophet docs. Fetching and processing that dataset is shown in this code snippet:
df = pd.read_csv('https://raw.githubusercontent.com/facebook/prophet/main/examples/example_wp_log_peyton_manning.csv')
df["ds"] = pd.to_datetime(df["ds"])
y_max = df["y"].max()
df["y_scaled"] = df["y"] / y_max
df["t"] = (df["ds"] - df["ds"].min()) / (df["ds"].max() - df["ds"].min())
My model looks like this:
def trend_model(y, t, n_changepoints=25, changepoints_prior_scale=0.05,
growth_prior_scale=5, changepoint_range=0.8):
"""
The piecewise linear trend with changepoint implementation in PyMC.
:param y: (np.array) MinMax scaled observations.
:param t: (np.array) MinMax scaled time.
:param n_changepoints: (int) The number of changepoints to model.
:param changepoint_prior_scale: (flt) The scale of the Laplace prior on the delta vector.
:param growth_prior_scale: (flt) The standard deviation of the prior on the growth.
:param changepoint_range: (flt) Proportion of history in which trend changepoints will be estimated.
:return model
"""
model = pm.Model()
s = np.linspace(0, changepoint_range * np.max(t), n_changepoints + 1)[1:]
# * 1 casts the boolean to integers
A = (t[:, None] > s) * 1
with model:
# initial growth
k = pm.Normal('k', 0 , growth_prior_scale)
# rate of change
delta = pm.Laplace('delta', 0, changepoints_prior_scale, shape=n_changepoints)
# offset
m = pm.Normal('m', 0, 5)
gamma = -s * delta
trend = pm.Deterministic("trend", (k + pyt.tensor.dot(A, delta)) * t + (m + pyt.tensor.dot(A, gamma)))
sigma = pm.HalfCauchy('sigma', 0.5, initval=1)
pm.Normal('obs', mu=trend, sigma=sigma, observed=y)
return model
model = trend_model(df["y_scaled"], np.array(df["t"]))
Finally, I am sampling:
with model:
linear_trace = pm.sample(return_inferencedata=True)
I stopped the execution of this piece of code after 20 mins of sampling (it was 40% done). Is it just me, or is this sampling time ridiculously slow?