Update: I could not figure out how to extend the model to include extra variables, so I just wrote my own sampler, that takes values from the trace and does the whole sampling procedure, the using Altair for plotting:
def random(point=None, size=None, N=N, initial_obs = -np.math.inf):
# Retrive parameters
mu = point['mu']
sigma = point['sigma']
# Generate sample
obss = []
obs = initial_obs
for t in range(N):
aux = mu + sigma*np.random.randn()
obs = np.max([obs, aux])
obss.append(obs)
obss = np.array(obss)
return obss
#@title Extrapolate to future timesteps
M = 100
import pandas as pd
# Generate samples from the posterior distribution of future data
mu_chains = trace['posterior']['mu'].data
sigma_chains = trace['posterior']['sigma'].data
extrapolation_data = []
for chain_number, (mu_chain, sigma_chain) in enumerate(zip(mu_chains, sigma_chains)):
for mu, sigma in zip(mu_chain, sigma_chain):
extrapolated_obss = random({'mu':mu, 'sigma':sigma}, N=M, initial_obs=obss[-1])
extrapolation_data += [extrapolated_obss]
extrapolation_data = np.array(extrapolation_data)
assert extrapolation_data.shape[1] == M
# Preprocess extrapolated posterior
lower_bound = np.quantile(extrapolation_data, q=0.05, axis=0)
median = np.quantile(extrapolation_data, q=0.5, axis=0)
upper_bound = np.quantile(extrapolation_data, q=0.95, axis=0)
assert median.shape == (M,)
# Print guesses for last observation
print(f"The median guess for the record at time {N+M} is {median[-1]:.2f}")
print(f"The 90% confidence interval for the record at time {N+M} is {lower_bound[-1]:.2f} to {upper_bound[-1]:.2f}")
# Encode as pandas DataFrame
import pandas as pd
ex_ts = np.array(range(N, N+M))
ext_df = pd.DataFrame({
't': ex_ts,
'lower': lower_bound,
'median' : median,
'upper': upper_bound,
})
# Plot the extrapolated data
# https://stackoverflow.com/questions/60649486/line-chart-with-custom-confidence-interval-in-altair
import altair as alt
line_chart = alt.Chart(ext_df).mark_line().encode(
x=alt.X('t:Q', title='Timestep'),
y=alt.Y('median:Q', title='Cumulative maximum'),
)
band_chart = alt.Chart(ext_df).mark_area(
opacity=0.5
).encode(
x='t:Q',
y='lower:Q',
y2='upper:Q'
)
(line_chart + band_chart)\
.properties(
width=800,
height=400
)\
.configure_axis(
labelFontSize=20,
titleFontSize=30
)
There might be some bugs but I think the approach should work!
