import numpy as np
import pymc as pm
from pytensor.tensor import TensorVariable
import matplotlib.pyplot as plt
import arviz as az
import pytensor.tensor as pt
from pymc.sampling.jax import sample_blackjax_nuts, sample_numpyro_nuts
def logp(value, lower, upper):
res = pt.switch(
pt.bitwise_and(pt.ge(value, lower), pt.le(value, upper)),
pt.fill(value, -pt.log(upper - lower)),
-np.inf,
)
return res
with pm.Model():
lower = 0
upper = 10
pm.CustomDist(
'a',
lower,
upper,
logp=logp,
)
idata = sample_numpyro_nuts(draws=1000, tune=1000, chains=4, target_accept=0.96, random_seed=123,
chain_method='sequential', compute_convergence_checks = True, progressbar = True,postprocessing_backend = 'gpu')
Got warning: There were 3907 divergences after tuning. Increase target_accept or reparameterize.
The rhat statistic is larger than 1.01 for some parameters. This indicates problems during sampling.
using az.plot_trace() shows that it is uniform distributed.