Modelling a timeseries of cumulative maximum

Update: I could not figure out how to extend the model to include extra variables, so I just wrote my own sampler, that takes values from the trace and does the whole sampling procedure, the using Altair for plotting:

def random(point=None, size=None, N=N, initial_obs = -np.math.inf):

  # Retrive parameters
  mu = point['mu']
  sigma = point['sigma']

  # Generate sample
  obss = []
  obs = initial_obs
  for t in range(N):
    aux = mu + sigma*np.random.randn()
    obs = np.max([obs, aux])
    obss.append(obs)
  obss = np.array(obss)

  return obss

#@title Extrapolate to future timesteps

M = 100

import pandas as pd

# Generate samples from the posterior distribution of future data

mu_chains = trace['posterior']['mu'].data
sigma_chains = trace['posterior']['sigma'].data

extrapolation_data = []
for chain_number, (mu_chain, sigma_chain) in enumerate(zip(mu_chains, sigma_chains)):
  for mu, sigma in zip(mu_chain, sigma_chain):
    extrapolated_obss = random({'mu':mu, 'sigma':sigma}, N=M, initial_obs=obss[-1])
    extrapolation_data += [extrapolated_obss]
extrapolation_data = np.array(extrapolation_data)

assert extrapolation_data.shape[1] == M

# Preprocess extrapolated posterior
lower_bound = np.quantile(extrapolation_data, q=0.05, axis=0)
median = np.quantile(extrapolation_data, q=0.5, axis=0)
upper_bound = np.quantile(extrapolation_data, q=0.95, axis=0)

assert median.shape == (M,)

# Print guesses for last observation
print(f"The median guess for the record at time {N+M} is {median[-1]:.2f}")
print(f"The 90% confidence interval for the record at time {N+M} is {lower_bound[-1]:.2f} to {upper_bound[-1]:.2f}")

# Encode as pandas DataFrame
import pandas as pd
ex_ts = np.array(range(N, N+M))
ext_df = pd.DataFrame({
    't': ex_ts,
    'lower': lower_bound,
    'median' : median,
    'upper': upper_bound,
    })

# Plot the extrapolated data
# https://stackoverflow.com/questions/60649486/line-chart-with-custom-confidence-interval-in-altair
import altair as alt
line_chart = alt.Chart(ext_df).mark_line().encode(
    x=alt.X('t:Q', title='Timestep'),
    y=alt.Y('median:Q', title='Cumulative maximum'),
)

band_chart = alt.Chart(ext_df).mark_area(
    opacity=0.5
).encode(
    x='t:Q',
    y='lower:Q',
    y2='upper:Q'
)

(line_chart + band_chart)\
.properties(
  width=800,
  height=400
)\
.configure_axis(
  labelFontSize=20,
  titleFontSize=30
  )

There might be some bugs but I think the approach should work!