# CustomDist doesnt work as well?

Hello! Here steps to reproduce

step 1

``````import pytensor
import scipy.stats as st
import pytensor.tensor as pt
import pymc as pm
import arviz as az
import numpy as np
import matplotlib.pyplot as plt

print(pm.__version__)
print(az.__version__)
print(pytensor.__version__)

np.random.seed(0)
y = np.concatenate([
st.fisk.rvs(c=6.00, loc=0.5, scale=1.0, size=30, random_state=100),
st.fisk.rvs(c=2.50, loc=2.0, scale=1.0, size=90, random_state=100)
])
t = np.arange(y.shape[0])

_, ax = plt.subplots(figsize=(7, 2.5))
ax.plot(y)
plt.show()
``````

gives:

5.10.2
0.16.1
2.18.4

step 2
logp with random for fisk (log-logistic dist)

``````def random(c: Union[np.ndarray, float],
loc: Union[np.ndarray, float],
scale: Union[np.ndarray, float],
rng: Optional[np.random.Generator]=None,
size : Optional[Tuple[int]]=None) -> Union[np.ndarray, float]:
return st.fisk(c=c, loc=loc, scale=1.0)

def logp(x: TensorVariable,
c: TensorVariable,
loc: TensorVariable) -> TensorVariable:
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.fisk.html
# Specifically, fisk.pdf(x, c, loc, scale) is identically equivalent
# to fisk.pdf(y, c) / scale with y = (x - loc) / scale.
scale = 1.0
x = (x - loc) / scale
up = c * pt.pow(x, c - 1)
dw = pt.pow(1 + x ** c, 2)
pdf = (up / dw) / scale

check = pt.all(x >= 0) & pt.all(c > 0)
res = pm.math.switch(check, pt.log(pdf), 0)
return res

model = pm.Model()
with model:
tau = pm.DiscreteUniform('switchpoint', lower=0, upper=y.shape[0])

c = pm.Uniform('c', size=2, lower=0.1, upper=20.0)
loc = pm.Uniform('loc', size=2, lower=0.1, upper=20.0)

c_obs = pm.math.switch(t < tau, c[0], c[1])
loc_obs = pm.math.switch(t < tau, loc[0], loc[1])

obs = pm.CustomDist('obs', c_obs, loc_obs,
logp=logp,
random=random,
observed=y,
size=y.shape)

# check logp with scipy.stats.fisk.logpdf
for x in  np.linspace(0.5, 10, 5):
r1 = st.fisk.logpdf(x=x, c=2.0, loc=0.5, scale=1.0)
r2 = logp(x=x,
c=pytensor.shared(2.0),
loc=pytensor.shared(0.5)).eval()
# print(f'x={x}: r1: {r1:.5f}, r2: {r2:.5f}')
assert np.allclose(r1, r1), 'not equal'

model.debug(verbose=True)
``````

gives:

point={‘switchpoint’: array(60), ‘c_interval__’: array([0., 0.]), ‘loc_interval__’: array([0., 0.])}
No problems found

step 3
for any draws and tune params it always detects switch point in the middle

``````with model:
trace = pm.sample(draws=500, tune=500, progressbar=True)
mean = trace.posterior['switchpoint'].mean()
print('mean checkpoint:\n', mean.values)
print(az.summary(trace))

# plot post
az.plot_posterior(trace, figsize=(8, 5))
# plot trace
strace = trace.posterior.stack(draws=('chain', 'draw'))
az.plot_trace(trace, figsize=(8, 5))

# plot obs
for k in ['c', 'loc']:
_, ax = plt.subplots(figsize=(5, 2))
ax.plot(y, alpha=0.6)
ax.vlines(strace['switchpoint'].mean(), y.min(), y.max(), color='C1')
avg_trace = np.zeros_like(y)
for i, year in enumerate(np.arange(y.shape[0])):
idx = year < strace['switchpoint']
avg_trace[i] = np.mean(np.where(idx, strace[k][0], strace[k][1]))

sp_hpd = az.hdi(trace, var_names=['switchpoint'])['switchpoint'].values
ax.fill_betweenx(
y=[y.min(), y.max()],
x1=sp_hpd[0],
x2=sp_hpd[1],
alpha=0.5,
color='C1',
)
ax.grid(False)
ax2 = ax.twinx()
ax2.plot(avg_trace, 'k--', lw=1.5, label=k)
ax2.legend(loc='best')
plt.show()
``````

output:

``````Multiprocess sampling (2 chains in 2 jobs)
[distfit] INFO Multiprocess sampling (2 chains in 2 jobs)
CompoundStep
[distfit] INFO CompoundStep
Metropolis: [switchpoint]
[distfit] INFO Metropolis: [switchpoint]
NUTS: [c, loc]
[distfit] INFO NUTS: [c, loc]
WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.
WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.

Sampling 2 chains for 500 tune and 500 draw iterations (1_000 + 1_000 draws total) took 18 seconds.
[distfit] INFO Sampling 2 chains for 500 tune and 500 draw iterations (1_000 + 1_000 draws total) took 18 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics
[distfit] INFO We recommend running at least 4 chains for robust computation of convergence diagnostics
mean checkpoint:
58.235
mean      sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk
switchpoint  58.235  35.124   4.000  115.000      2.194    1.576     265.0  \
c[0]         10.048   5.820   1.335   19.897      0.176    0.126     980.0
c[1]         10.002   5.591   0.881   19.319      0.188    0.137     719.0
loc[0]       10.160   5.851   1.413   19.785      0.232    0.164     494.0
loc[1]       10.164   5.718   1.020   19.444      0.186    0.132     916.0

ess_tail  r_hat
switchpoint     193.0   1.01
c[0]            694.0   1.00
c[1]            548.0   1.00
loc[0]          443.0   1.00
loc[1]          633.0   1.00
``````