Hello! Here steps to reproduce
step 1
import pytensor
import scipy.stats as st
import pytensor.tensor as pt
import pymc as pm
import arviz as az
import numpy as np
import matplotlib.pyplot as plt
print(pm.__version__)
print(az.__version__)
print(pytensor.__version__)
np.random.seed(0)
y = np.concatenate([
st.fisk.rvs(c=6.00, loc=0.5, scale=1.0, size=30, random_state=100),
st.fisk.rvs(c=2.50, loc=2.0, scale=1.0, size=90, random_state=100)
])
t = np.arange(y.shape[0])
_, ax = plt.subplots(figsize=(7, 2.5))
ax.plot(y)
plt.show()
gives:
5.10.2
0.16.1
2.18.4
step 2
logp with random for fisk (log-logistic dist)
def random(c: Union[np.ndarray, float],
loc: Union[np.ndarray, float],
scale: Union[np.ndarray, float],
rng: Optional[np.random.Generator]=None,
size : Optional[Tuple[int]]=None) -> Union[np.ndarray, float]:
return st.fisk(c=c, loc=loc, scale=1.0)
def logp(x: TensorVariable,
c: TensorVariable,
loc: TensorVariable) -> TensorVariable:
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.fisk.html
# Specifically, fisk.pdf(x, c, loc, scale) is identically equivalent
# to fisk.pdf(y, c) / scale with y = (x - loc) / scale.
scale = 1.0
x = (x - loc) / scale
up = c * pt.pow(x, c - 1)
dw = pt.pow(1 + x ** c, 2)
pdf = (up / dw) / scale
check = pt.all(x >= 0) & pt.all(c > 0)
res = pm.math.switch(check, pt.log(pdf), 0)
return res
model = pm.Model()
with model:
tau = pm.DiscreteUniform('switchpoint', lower=0, upper=y.shape[0])
c = pm.Uniform('c', size=2, lower=0.1, upper=20.0)
loc = pm.Uniform('loc', size=2, lower=0.1, upper=20.0)
c_obs = pm.math.switch(t < tau, c[0], c[1])
loc_obs = pm.math.switch(t < tau, loc[0], loc[1])
obs = pm.CustomDist('obs', c_obs, loc_obs,
logp=logp,
random=random,
observed=y,
size=y.shape)
# check logp with scipy.stats.fisk.logpdf
for x in np.linspace(0.5, 10, 5):
r1 = st.fisk.logpdf(x=x, c=2.0, loc=0.5, scale=1.0)
r2 = logp(x=x,
c=pytensor.shared(2.0),
loc=pytensor.shared(0.5)).eval()
# print(f'x={x}: r1: {r1:.5f}, r2: {r2:.5f}')
assert np.allclose(r1, r1), 'not equal'
model.debug(verbose=True)
gives:
point={‘switchpoint’: array(60), ‘c_interval__’: array([0., 0.]), ‘loc_interval__’: array([0., 0.])}
No problems found
step 3
for any draws and tune params it always detects switch point in the middle
with model:
trace = pm.sample(draws=500, tune=500, progressbar=True)
mean = trace.posterior['switchpoint'].mean()
print('mean checkpoint:\n', mean.values)
print(az.summary(trace))
# plot post
az.plot_posterior(trace, figsize=(8, 5))
# plot trace
strace = trace.posterior.stack(draws=('chain', 'draw'))
az.plot_trace(trace, figsize=(8, 5))
# plot obs
for k in ['c', 'loc']:
_, ax = plt.subplots(figsize=(5, 2))
ax.plot(y, alpha=0.6)
ax.vlines(strace['switchpoint'].mean(), y.min(), y.max(), color='C1')
avg_trace = np.zeros_like(y)
for i, year in enumerate(np.arange(y.shape[0])):
idx = year < strace['switchpoint']
avg_trace[i] = np.mean(np.where(idx, strace[k][0], strace[k][1]))
sp_hpd = az.hdi(trace, var_names=['switchpoint'])['switchpoint'].values
ax.fill_betweenx(
y=[y.min(), y.max()],
x1=sp_hpd[0],
x2=sp_hpd[1],
alpha=0.5,
color='C1',
)
ax.grid(False)
ax2 = ax.twinx()
ax2.plot(avg_trace, 'k--', lw=1.5, label=k)
ax2.legend(loc='best')
plt.show()
output:
Multiprocess sampling (2 chains in 2 jobs)
[distfit] INFO Multiprocess sampling (2 chains in 2 jobs)
CompoundStep
[distfit] INFO CompoundStep
Metropolis: [switchpoint]
[distfit] INFO Metropolis: [switchpoint]
NUTS: [c, loc]
[distfit] INFO NUTS: [c, loc]
WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.
WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.
Sampling 2 chains for 500 tune and 500 draw iterations (1_000 + 1_000 draws total) took 18 seconds.
[distfit] INFO Sampling 2 chains for 500 tune and 500 draw iterations (1_000 + 1_000 draws total) took 18 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics
[distfit] INFO We recommend running at least 4 chains for robust computation of convergence diagnostics
mean checkpoint:
58.235
mean sd hdi_3% hdi_97% mcse_mean mcse_sd ess_bulk
switchpoint 58.235 35.124 4.000 115.000 2.194 1.576 265.0 \
c[0] 10.048 5.820 1.335 19.897 0.176 0.126 980.0
c[1] 10.002 5.591 0.881 19.319 0.188 0.137 719.0
loc[0] 10.160 5.851 1.413 19.785 0.232 0.164 494.0
loc[1] 10.164 5.718 1.020 19.444 0.186 0.132 916.0
ess_tail r_hat
switchpoint 193.0 1.01
c[0] 694.0 1.00
c[1] 548.0 1.00
loc[0] 443.0 1.00
loc[1] 633.0 1.00
Thanks for reading