I have a mixed model that uses priors with a normal distribution as well as a Bernoulli distribution. Recently, I have switched to using the JAX sampler after gaining access to a GPU. Previously, I was able to run this model fine as the variable would get assigned to the BinaryGibbsMetropolis Sampler while other variables got assigned to NUTS. I am currently getting an error with the JAX sampler. Is there a way to resolve this issue? There is a similar issue that is still open.
The JAX sampler is a pure NUTS sampler, it does not work woth mixed compound samplers for discrete variables like the one in pm.sample
does.