The only solution i found was to create a custom random variable. This guide from pymc is very useful for that.
In my case, i had to just copy and modify the code that was already in these modules:
This is the code modified:
from pytensor.tensor.random.basic import ScipyRandomVariable
from pymc.distributions.distribution import Discrete
from pymc.pytensorf import floatX
import pytensor.tensor as pt
import scipy.stats as stats
from pymc.distributions.dist_math import check_parameters
from pymc.distributions.shape_utils import rv_size_is_none
class MyBernoulliRV(ScipyRandomVariable):
name = "bernoulli"
ndim_supp = 0
ndims_params = [0]
dtype = "bool" # this is basically the only change from the code in pytensor library
_print_name = ("Bernoulli", "\\operatorname{Bernoulli}")
def __call__(self, p, size=None, **kwargs):
return super().__call__(p, size=size, **kwargs)
@classmethod
def rng_fn_scipy(cls, rng, p, size):
return stats.bernoulli.rvs(p, size=size, random_state=rng)
bernoulli = MyBernoulliRV()
class MyBernoulli(Discrete):
rv_op = bernoulli
@classmethod
def dist(cls, p=None, logit_p=None, *args, **kwargs):
if p is not None and logit_p is not None:
raise ValueError("Incompatible parametrization. Can't specify both p and logit_p.")
elif p is None and logit_p is None:
raise ValueError("Incompatible parametrization. Must specify either p or logit_p.")
if logit_p is not None:
p = pt.sigmoid(logit_p)
p = pt.as_tensor_variable(floatX(p))
return super().dist([p], **kwargs)
def logp(value, p):
res = pt.switch(
pt.or_(pt.lt(value, 0), pt.gt(value, 1)),
-np.inf,
pt.switch(value, pt.log(p), pt.log1p(-p)),
)
return check_parameters(
res,
0 <= p,
p <= 1,
msg="0 <= p <= 1",
)
def moment(rv, size, p):
if not rv_size_is_none(size):
p = pt.full(size, p)
return pt.switch(p < 0.5, 0, 1)
def logcdf(value, p):
res = pt.switch(
pt.lt(value, 0),
-np.inf,
pt.switch(
pt.lt(value, 1),
pt.log1p(-p),
0,
),
)
return check_parameters(
res,
0 <= p,
p <= 1,
msg="0 <= p <= 1",
)
Important note: If you create a custom random variable and change the dtype to bool, the pymc samplers may not work. I have tried in a case where the samplers were BinaryGibbsMetropolis + NUTS and it didn’t work. The error was “Can only compute the gradient of continuous types: T”. This problem can be solved using only sampler that works with bools or create a custom one.