Compound step with JAX

Is there a way to run compound steps when using the JAX numpyro sampler instead of the traditional sampler?

I have a model that has a Bernoulli RV. With the old sampler, this variable automatically gets assigned the BinaryGibbsMetropolis sampler, while other variables get assigned NUTS. The JAX sampler gives an error.