Hi, Iβm trying to get a toy model working for fitting piecewise functions. The background is that for time series data, the process might be nonstationary, and at some moment it switches into a new regime (the parameters describing the process undergo a step change). It might look something like this:
To that end, I generated some fake data with a really obvious inflection point to see if I could get a minimal sensible example working with PyMC3, and then add more complications to it later. Hereβs my code to generate the data:
import numpy as np
import pymc3 as pm
import matplotlib.pyplot as plt
print('Running on PyMC3 v{}'.format(pm.__version__))
Running on PyMC3 v3.5
times_1 = np.arange(0, 10, 0.1)
times_2 = np.arange(10, 15, 0.1)
times_all = np.append(times_1, times_2)
slope_1 = 2
slope_2 = -0.5
sigma = 0.1
prices_1 = times_1 * slope_1 + np.random.randn(len(times_1)) * sigma
prices_2 = times_2 * slope_2 + np.random.randn(len(times_2)) * sigma
prices_2 += (times_1.max() * slope_1) - (times_2.min() * slope_2)
prices_all = np.append(prices_1, prices_2)
plt.plot(times_all, prices_all)
We can imagine the horizontal axis representing time, and the vertical axis representing the price of some commodity. As can be clearly seen, its behavior is two different linear functions spliced together at t=10, with a small amount of added noise. Here is the model I want to fit:
Let w be the βswitchpointβ.
When t < w, Y = m_1 * t + b_1 + epsilon,
When t >= w, Y = m_2 * t + b_2 + epsilon
The true values from the above code are w = 10, m_1 = 2, m_2 = -0.5, b_1 = 0, b_2 = 25 (implicitly), epsilon ~ N(0, 0.1)
To infer the parameters from the data (in particular, the location of the switchpoint) I tried using the pm.math.switch
function in the same way as is used in the coal mining disasters example in the documentation. Here is my code:
switch_model = pm.Model()
with switch_model:
rv_slope_1 = pm.Normal("rv_slope_1", mu=0, sd=10) ## gradient for the first region
rv_slope_2 = pm.Normal("rv_slope_2", mu=0, sd=10) ## gradient for the second region
rv_intercept_1 = pm.Normal("rv_intercept_1", mu=0, sd=100) ## intercept for first region
rv_intercept_2 = pm.Normal("rv_intercept_2", mu=0, sd=100) ## intercept for second region
rv_sigma = pm.HalfNormal("rv_sigma", sd=10) ## noise term
## where does the trend switch to a different regime?
rv_switchpoint = pm.Uniform("rv_switchpoint", lower=times_all.min(), upper=times_all.max())
## if-statement to decide which parameters to use for the different regions
slope_either = pm.math.switch(rv_switchpoint > prices_all, rv_slope_1, rv_slope_2)
intercept_either = pm.math.switch(rv_switchpoint > prices_all, rv_intercept_1, rv_intercept_2)
mu_either = slope_either*prices_all + intercept_either ## piecewise y=mx+b
## add noise term and declare what the observed values are
trendline = pm.Normal("trendline", mu=mu_either, sd=rv_sigma, observed=prices_all)
trace = pm.sample(samples=1000, tune=2000, cores=1, chains=4)
pm.traceplot(trace)
print(pm.summary(trace))
However, this fails. It gets to a few hundred samples and then crashes with this error readout:
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Sequential sampling (4 chains in 1 job)
NUTS: [rv_switchpoint, rv_sigma, rv_intercept_2, rv_intercept_1, rv_slope_2, rv_slope_1]
32%|ββββ | 810/2500 [00:44<01:33, 18.05it/s]
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-7-b97e0927e6ae> in <module>()
18 trendline = pm.Normal("trendline", mu=mu_either, sd=rv_sigma, observed=prices_all)
19
---> 20 trace = pm.sample(samples=1000, tune=2000, cores=1, chains=4)
21 pm.traceplot(trace)
22 print(pm.summary(trace))
D:\Anaconda3\lib\site-packages\pymc3\sampling.py in sample(draws, step, init, n_init, start, trace, chain_idx, chains, cores, tune, nuts_kwargs, step_kwargs, progressbar, model, random_seed, live_plot, discard_tuned_samples, live_plot_kwargs, compute_convergence_checks, use_mmap, **kwargs)
467 _log.info('Sequential sampling ({} chains in 1 job)'.format(chains))
468 _print_step_hierarchy(step)
--> 469 trace = _sample_many(**sample_args)
470
471 discard = tune if discard_tuned_samples else 0
D:\Anaconda3\lib\site-packages\pymc3\sampling.py in _sample_many(draws, chain, chains, start, random_seed, step, **kwargs)
513 for i in range(chains):
514 trace = _sample(draws=draws, chain=chain + i, start=start[i],
--> 515 step=step, random_seed=random_seed[i], **kwargs)
516 if trace is None:
517 if len(traces) == 0:
D:\Anaconda3\lib\site-packages\pymc3\sampling.py in _sample(chain, progressbar, random_seed, start, draws, step, trace, tune, model, live_plot, live_plot_kwargs, **kwargs)
557 try:
558 strace = None
--> 559 for it, strace in enumerate(sampling):
560 if live_plot:
561 if live_plot_kwargs is None:
D:\Anaconda3\lib\site-packages\tqdm\_tqdm.py in __iter__(self)
931 """, fp_write=getattr(self.fp, 'write', sys.stderr.write))
932
--> 933 for obj in iterable:
934 yield obj
935 # Update and possibly print the progressbar.
D:\Anaconda3\lib\site-packages\pymc3\sampling.py in _iter_sample(draws, step, start, trace, chain, tune, model, random_seed)
653 step = stop_tuning(step)
654 if step.generates_stats:
--> 655 point, states = step.step(point)
656 if strace.supports_sampler_stats:
657 strace.record(point, states)
D:\Anaconda3\lib\site-packages\pymc3\step_methods\arraystep.py in step(self, point)
245
246 if self.generates_stats:
--> 247 apoint, stats = self.astep(array)
248 point = self._logp_dlogp_func.array_to_full_dict(apoint)
249 return point, stats
D:\Anaconda3\lib\site-packages\pymc3\step_methods\hmc\base_hmc.py in astep(self, q0)
113
114 if not np.isfinite(start.energy):
--> 115 self.potential.raise_ok(self._logp_dlogp_func._ordering.vmap)
116 raise ValueError('Bad initial energy: %s. The model '
117 'might be misspecified.' % start.energy)
D:\Anaconda3\lib\site-packages\pymc3\step_methods\hmc\quadpotential.py in raise_ok(self, vmap)
199 errmsg.append('The derivative of RV `{}`.ravel()[{}]'
200 ' is zero.'.format(*name_slc[ii]))
--> 201 raise ValueError('\n'.join(errmsg))
202
203 if np.any(~np.isfinite(self._stds)):
ValueError: Mass matrix contains zeros on the diagonal.
The derivative of RV `rv_slope_2`.ravel()[0] is zero.
The derivative of RV `rv_slope_1`.ravel()[0] is zero.
What am I doing wrong?
EDIT:
I fixed all the slope and intercept variables to be the known constants to see whether it would still work.
rv_slope_1 = 2 #pm.Normal("rv_slope_1", mu=0, sd=10) ## gradient for the first region
rv_slope_2 = -0.5 #pm.Normal("rv_slope_2", mu=0, sd=10) ## gradient for the second region
rv_intercept_1 = 0 #pm.Normal("rv_intercept_1", mu=0, sd=100) ## intercept for first region
rv_intercept_2 = 25 #pm.Normal("rv_intercept_2", mu=0, sd=100) ## intercept for second region
The inference finished but the sampling was very slow and n_eff was small (I had to reduce the number of samples and chains because otherwise I would get bored waiting):
Auto-assigning NUTS sampler...
Initializing NUTS using jitter+adapt_diag...
Sequential sampling (2 chains in 1 job)
NUTS: [rv_switchpoint, rv_sigma]
100%|ββββββββββ| 2000/2000 [04:43<00:00, 7.04it/s]
100%|ββββββββββ| 2000/2000 [04:55<00:00, 6.77it/s]
D:\Anaconda3\lib\site-packages\mkl_fft\_numpy_fft.py:1044: FutureWarning: Using a non-tuple sequence for multidimensional indexing is deprecated; use `arr[tuple(seq)]` instead of `arr[seq]`. In the future this will be interpreted as an array index, `arr[np.array(seq)]`, which will result either in an error or a different result.
output = mkl_fft.rfftn_numpy(a, s, axes)
The chain reached the maximum tree depth. Increase max_treedepth, increase target_accept or reparameterize.
The chain reached the maximum tree depth. Increase max_treedepth, increase target_accept or reparameterize.
The estimated number of effective samples is smaller than 200 for some parameters.
The estimate for the location of the switchpoint is right, but the noise estimate is inflated by an order of magnitude (possibly because of the errors in the switchpoint?)
mean sd mc_error hpd_2.5 hpd_97.5 n_eff \
rv_sigma 4.787512 0.279653 0.017103 4.285377 5.374192 222.934947
rv_switchpoint 9.938956 0.435651 0.029998 9.077874 10.792307 114.701224
Rhat
rv_sigma 1.000893
rv_switchpoint 1.012629
Any idea why this is making NUTS unhappy? Is it the if-statement screwing up the gradient calculations?