Use of numpyro/Jax with pymc-dev

I have a model which involves a mixture of discrete categoricals and continuous priors (the Beta distribution). The model was running fine in pymc3 but was somewhat slow (I am implementing the Asthma model in Chapter 6 of Winn&Bishop’s book Model-based Machine Learning.) So I installed pymc-dev, along with Jax, and numpyro to try and accelerate the code. I get the error below, which seems to suggest that the categoricals are creating problems since I am using the NUTS sampler, which requires gradients. Apparently, pymc will choose NUTS or Metropolis as needed depending on the distribution. What does one have to do to get JAX to work under pymc? Thanks.

with model:
    idata = jx.sample_numpyro_nuts(target_accept=0.9)
Hi, which OS are you using? And what are the pymc and numpyro and jax version ?

Also, can you provide a small sample of code that can produce this error? Thanks

I am working on Pop!_OS 21.10 . After sending this message, I decided to work on a MWE by coding up a coin flip, and the program worked for a while, and I got really good acceleration. However, at some point, errors started to occur again shortly after I added numpyro with GPU support. I will provide a MWE once I have something reproducible. In the meantime, numpyro 0.9.2 and jax 0.3.13 and jaxlib 0.3.10+cuda11.cudnn82.

Here is a MWE:

import pymc as pm  # pymc4
import pandas as pd
import numpy as np
import pymc.distributions.discrete as discrete
import pymc.distributions.continuous as continuous
import aesara.tensor as tt
import aesara as ae
import pymc.sampling_jax as jx
import jax
jax.config.update('jax_platform_name', 'cpu')
jax.default_backend(), jax.local_device_count()
print(jax.numpy.ones(3).device()) # TFRT_CPU_0

def create_coin(data):
    with pm.Model() as coin_model:
        p_coin = pm.Beta("p_coin", 2, 2)
        heads = pm.Bernoulli("flips", p_coin, observed=data)
    return coin_model

p = 0.55
nb_flips = 1000
nb_chains = 4
nb_draws = 20000
data = np.random.choice([0,1], size=nb_flips, p=[1-p, p])
model = create_coin(data)

with model:
    # tune is nb warmups
    idata = jx.sample_numpyro_nuts(target_accept=0.9, draws=nb_draws, tune=0, chains=4, chain_method='parallel')

Here is the error (different from the one previously posted).

  File /tmp/tmp6cmi_6tv:22
    1]}, InplaceDimShuffle{x}.0, InplaceDimShuffle{x}.0)
IndentationError: unexpected indent

If one looks at the file in question in /tmp/tmp6cmi_6tv, one finds the line

    # Elemwise{Composite{(Switch(AND(GE(i0, i1), LE(i0, i2)), (i3 + i4 + i5), i6) + i4 + i5)}}[(0, 0)](Elemwise{sigmoid,no_inplace}.0, TensorConstant{0.0}, TensorConstant{1.0}, TensorConstant{1.791759469228055}, Elemwise{Composite{(-scalar_softplus((-i0)))}}.0, Elemwise{Composite{(-scalar_softplus(i0))}}.0, TensorConstant{-inf})
    auto_8659 = composite2(auto_7739, auto_7510, auto_7434, auto_7964, auto_8586, auto_8600, auto_7571)
    # MakeVector{dtype='bool'}(Elemwise{ge,no_inplace}.0, Elemwise{le,no_inplace}.0)
    auto_7745 = makevector(auto_7743, auto_7741)
    # Elemwise{switch,no_inplace}(flips{[1 0 1 0 0.. 1 1 1
     1]}, InplaceDimShuffle{x}.0, InplaceDimShuffle{x}.0)     <<<<<<<<< ERROR
    auto_7752 = where(flips, auto_7751, auto_7749)
    # All(MakeVector{dtype='bool'}.0)
    auto_7746 = careduce(auto_7745)

which is clearly a comment line that has been broken inappropriately.
Any ideas? That this kind of error could occur is incomprehensible to me given the number of lines that were successfully compiled to JAX. Thanks for your insights!

Your first example was not working because it had discrete unobserved variables which cant be sampled with NUTS. The numpyro / jax backend only supports NUTS unlike pymc vanilla samplers which can combine NUTS with metropolis steppers for discrete variables.


In that case, what would it look like to write a custom sampler that alternates between NUTS (with numpyro/jax) and a Metropolis stepper for discrete variables? Is there some way to do it by writing the sampler in jax?

I think you would need to write your own JAX sampler or use a more modular one like blackjax to be able to mix samplers. CC @junpenglao

