CustomDist doesnt work as well?

Hello! Here steps to reproduce

step 1

import pytensor
import scipy.stats as st
import pytensor.tensor as pt
import pymc as pm
import arviz as az
import numpy as np
import matplotlib.pyplot as plt

print(pm.__version__)
print(az.__version__)
print(pytensor.__version__)

np.random.seed(0)
y = np.concatenate([
    st.fisk.rvs(c=6.00, loc=0.5, scale=1.0, size=30, random_state=100),
    st.fisk.rvs(c=2.50, loc=2.0, scale=1.0, size=90, random_state=100)
])
t = np.arange(y.shape[0])

_, ax = plt.subplots(figsize=(7, 2.5))
ax.plot(y)
plt.show()

gives:

5.10.2
0.16.1
2.18.4

step 2
logp with random for fisk (log-logistic dist)

def random(c: Union[np.ndarray, float],
           loc: Union[np.ndarray, float],
           scale: Union[np.ndarray, float],
           rng: Optional[np.random.Generator]=None,
           size : Optional[Tuple[int]]=None) -> Union[np.ndarray, float]:
    return st.fisk(c=c, loc=loc, scale=1.0)


def logp(x: TensorVariable,
         c: TensorVariable,
         loc: TensorVariable) -> TensorVariable:
    # https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.fisk.html
    # Specifically, fisk.pdf(x, c, loc, scale) is identically equivalent 
    # to fisk.pdf(y, c) / scale with y = (x - loc) / scale.
    scale = 1.0
    x = (x - loc) / scale
    up = c * pt.pow(x, c - 1)
    dw = pt.pow(1 + x ** c, 2)
    pdf = (up / dw) / scale

    check = pt.all(x >= 0) & pt.all(c > 0)
    res = pm.math.switch(check, pt.log(pdf), 0)
    return res


model = pm.Model()
with model:
    tau = pm.DiscreteUniform('switchpoint', lower=0, upper=y.shape[0])

    c = pm.Uniform('c', size=2, lower=0.1, upper=20.0)
    loc = pm.Uniform('loc', size=2, lower=0.1, upper=20.0)

    c_obs = pm.math.switch(t < tau, c[0], c[1])
    loc_obs = pm.math.switch(t < tau, loc[0], loc[1])

    obs = pm.CustomDist('obs', c_obs, loc_obs,
                        logp=logp,
                        random=random,
                        observed=y,
                        size=y.shape)


# check logp with scipy.stats.fisk.logpdf
for x in  np.linspace(0.5, 10, 5):
    r1 = st.fisk.logpdf(x=x, c=2.0, loc=0.5, scale=1.0)
    r2 = logp(x=x,
              c=pytensor.shared(2.0),
              loc=pytensor.shared(0.5)).eval()
    # print(f'x={x}: r1: {r1:.5f}, r2: {r2:.5f}')
    assert np.allclose(r1, r1), 'not equal'


model.debug(verbose=True)

gives:

point={‘switchpoint’: array(60), ‘c_interval__’: array([0., 0.]), ‘loc_interval__’: array([0., 0.])}
No problems found

step 3
for any draws and tune params it always detects switch point in the middle

with model:
    trace = pm.sample(draws=500, tune=500, progressbar=True)
    mean = trace.posterior['switchpoint'].mean()
    print('mean checkpoint:\n', mean.values)
    print(az.summary(trace))

    # plot post
    az.plot_posterior(trace, figsize=(8, 5))
    # plot trace
    strace = trace.posterior.stack(draws=('chain', 'draw'))
    az.plot_trace(trace, figsize=(8, 5))

    # plot obs
    for k in ['c', 'loc']:
        _, ax = plt.subplots(figsize=(5, 2))
        ax.plot(y, alpha=0.6)
        ax.vlines(strace['switchpoint'].mean(), y.min(), y.max(), color='C1')
        avg_trace = np.zeros_like(y)
        for i, year in enumerate(np.arange(y.shape[0])):
            idx = year < strace['switchpoint']
            avg_trace[i] = np.mean(np.where(idx, strace[k][0], strace[k][1]))

        sp_hpd = az.hdi(trace, var_names=['switchpoint'])['switchpoint'].values
        ax.fill_betweenx(
            y=[y.min(), y.max()],
            x1=sp_hpd[0],
            x2=sp_hpd[1],
            alpha=0.5,
            color='C1',
        )
        ax.grid(False)
        ax2 = ax.twinx()
        ax2.plot(avg_trace, 'k--', lw=1.5, label=k)
        ax2.legend(loc='best')
        plt.show()

output:

Multiprocess sampling (2 chains in 2 jobs)
[distfit] INFO Multiprocess sampling (2 chains in 2 jobs)
CompoundStep
[distfit] INFO CompoundStep
Metropolis: [switchpoint]
[distfit] INFO Metropolis: [switchpoint]
NUTS: [c, loc]
[distfit] INFO NUTS: [c, loc]
WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.
WARNING (pytensor.tensor.blas): Using NumPy C-API based implementation for BLAS functions.

Sampling 2 chains for 500 tune and 500 draw iterations (1_000 + 1_000 draws total) took 18 seconds.
[distfit] INFO Sampling 2 chains for 500 tune and 500 draw iterations (1_000 + 1_000 draws total) took 18 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics
[distfit] INFO We recommend running at least 4 chains for robust computation of convergence diagnostics
mean checkpoint:
 58.235
               mean      sd  hdi_3%  hdi_97%  mcse_mean  mcse_sd  ess_bulk   
switchpoint  58.235  35.124   4.000  115.000      2.194    1.576     265.0  \
c[0]         10.048   5.820   1.335   19.897      0.176    0.126     980.0   
c[1]         10.002   5.591   0.881   19.319      0.188    0.137     719.0   
loc[0]       10.160   5.851   1.413   19.785      0.232    0.164     494.0   
loc[1]       10.164   5.718   1.020   19.444      0.186    0.132     916.0   

             ess_tail  r_hat  
switchpoint     193.0   1.01  
c[0]            694.0   1.00  
c[1]            548.0   1.00  
loc[0]          443.0   1.00  
loc[1]          633.0   1.00 



i4
i5
Thanks for reading