Hello, I’m trying to implement Skew Student-T distribution using pm.CustomDist
. The logp
parameters is defined as
def logp_skewt(value, nu, mu, sigma, alpha, *args, **kwargs):
return (
pm.math.log(2) +
pm.logp(pm.StudentT.dist(nu, mu=mu, sigma=sigma), value) +
pm.logcdf(pm.StudentT.dist(nu, mu=mu, sigma=sigma), alpha*value) -
pm.math.log(sigma)
)
I am able to sample from this distribution
with pm.Model():
pm.CustomDist('target', 1, 0, 3, -10, logp=logp_skewt)
model_trace = pm.sample(
nuts_sampler="numpyro",
draws=2_000,
chains=1,
)
samples = model_trace.posterior.target.to_numpy()
eps = 0.01
min_val, max_val = np.quantile(samples, [eps, 1 - eps])
valid_samples = samples[(samples >= min_val) & (samples <= max_val)]
However, when I try to re-fit the model, it became very slow
with pm.Model() as fitted_model:
nu = pm.HalfCauchy('nu', beta=1)
mu = pm.Normal('mu', mu=0, sigma=1)
sigma = pm.HalfCauchy('sigma', beta=1)
alpha = pm.Normal('alpha', mu=0, sigma=1)
skewt = pm.CustomDist('likelihood', nu + eps, mu, sigma + eps, alpha, logp=logp_skewt, observed=valid_samples[:1000])
model_trace = pm.sample(
nuts_sampler="pymc",
draws=100,
tune=100,
chains=1,
)
There are warnings which are
/Users/admin/miniforge3/envs/python311/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py:680: UserWarning: Optimization Warning: The Op betainc does not provide a C implementation. As well as being potentially slow, this also disables loop fusion.
warn(
/Users/admin/miniforge3/envs/python311/lib/python3.11/site-packages/pytensor/tensor/rewriting/elemwise.py:680: UserWarning: Optimization Warning: The Op betainc_der does not provide a C implementation. As well as being potentially slow, this also disables loop fusion.
warn(
It took about 16 minutes to finish fitting on 100
draw, and 100
tune
If I switch to numpyro
it failed out right with error
ValueError: Betainc gradient with respect to a and b not supported.
So is there a common way to speedup the computation?
FYI I use pymc 5.1.2
I also posted this at How to speed up model fitting of CustomDist? · pymc-devs/pymc · Discussion #6661 · GitHub (I’m not sure where to post)