Use of numpyro/Jax with pymc-dev

I have a model which involves a mixture of discrete categoricals and continuous priors (the Beta distribution). The model was running fine in pymc3 but was somewhat slow (I am implementing the Asthma model in Chapter 6 of Winn&Bishop’s book Model-based Machine Learning.) So I installed pymc-dev, along with Jax, and numpyro to try and accelerate the code. I get the error below, which seems to suggest that the categoricals are creating problems since I am using the NUTS sampler, which requires gradients. Apparently, pymc will choose NUTS or Metropolis as needed depending on the distribution. What does one have to do to get JAX to work under pymc? Thanks.

with model:
    idata = jx.sample_numpyro_nuts(target_accept=0.9)
TypeError                                 Traceback (most recent call last)
File <timed exec>:4, in <module>

File ~/anaconda3/envs/pymc4/lib/python3.9/site-packages/pymc/, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, postprocessing_backend, idata_kwargs, nuts_kwargs)
    512 if chains > 1:
    513     map_seed = jax.random.split(map_seed, chains)
--> 515
    516     map_seed,
    517     init_params=init_params,
    518     extra_fields=(
    519         "num_steps",
    520         "potential_energy",
    521         "energy",
    522         "adapt_state.step_size",
    523         "accept_prob",
    524         "diverging",
    525     ),
    526 )
    528 raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)
    530 tic3 =

File ~/anaconda3/envs/pymc4/lib/python3.9/site-packages/numpyro/infer/, in, rng_key, extra_fields, init_params, *args, **kwargs)
    597     states, last_state = _laxmap(partial_map_fn, map_args)
    598 elif self.chain_method == "parallel":
--> 599     states, last_state = pmap(partial_map_fn)(map_args)
    600 else:
    601     assert self.chain_method == "vectorized"

    [... skipping hidden 17 frame]

File ~/anaconda3/envs/pymc4/lib/python3.9/site-packages/numpyro/infer/, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
    379 rng_key, init_state, init_params = init
    380 if init_state is None:
--> 381     init_state = self.sampler.init(
    382         rng_key,
    383         self.num_warmup,
    384         init_params,
    385         model_args=args,
    386         model_kwargs=kwargs,
    387     )
    388 sample_fn, postprocess_fn = self._get_cached_fns()
    389 diagnostics = (
    390     lambda x: self.sampler.get_diagnostics_str(x[0])
    391     if rng_key.ndim == 1
    392     else ""
    393 )  # noqa: E731

File ~/anaconda3/envs/pymc4/lib/python3.9/site-packages/numpyro/infer/, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
    726 hmc_init_fn = lambda init_params, rng_key: self._init_fn(  # noqa: E731
    727     init_params,
    728     num_warmup=num_warmup,
    743     rng_key=rng_key,
    744 )
    745 if rng_key.ndim == 1:
--> 746     init_state = hmc_init_fn(init_params, rng_key)
    747 else:
    748     # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some
    749     # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth,
    750     # wa_steps because those variables do not depend on traced args: init_params, rng_key.
    751     init_state = vmap(hmc_init_fn)(init_params, rng_key)

File ~/anaconda3/envs/pymc4/lib/python3.9/site-packages/numpyro/infer/, in HMC.init.<locals>.<lambda>(init_params, rng_key)
    723         dense_mass = [tuple(sorted(z))] if dense_mass else []
    724     assert isinstance(dense_mass, list)
--> 726 hmc_init_fn = lambda init_params, rng_key: self._init_fn(  # noqa: E731
    727     init_params,
    728     num_warmup=num_warmup,
    729     step_size=self._step_size,
    730     num_steps=self._num_steps,
    731     inverse_mass_matrix=inverse_mass_matrix,
    732     adapt_step_size=self._adapt_step_size,
    733     adapt_mass_matrix=self._adapt_mass_matrix,
    734     dense_mass=dense_mass,
    735     target_accept_prob=self._target_accept_prob,
    736     trajectory_length=self._trajectory_length,
    737     max_tree_depth=self._max_tree_depth,
    738     find_heuristic_step_size=self._find_heuristic_step_size,
    739     forward_mode_differentiation=self._forward_mode_differentiation,
    740     regularize_mass_matrix=self._regularize_mass_matrix,
    741     model_args=model_args,
    742     model_kwargs=model_kwargs,
    743     rng_key=rng_key,
    744 )
    745 if rng_key.ndim == 1:
    746     init_state = hmc_init_fn(init_params, rng_key)

File ~/anaconda3/envs/pymc4/lib/python3.9/site-packages/numpyro/infer/, in hmc.<locals>.init_kernel(init_params, num_warmup, step_size, inverse_mass_matrix, adapt_step_size, adapt_mass_matrix, dense_mass, target_accept_prob, num_steps, trajectory_length, max_tree_depth, find_heuristic_step_size, forward_mode_differentiation, regularize_mass_matrix, model_args, model_kwargs, rng_key)
    320 r = momentum_generator(z, wa_state.mass_matrix_sqrt, rng_key_momentum)
    321 vv_init, vv_update = velocity_verlet(pe_fn, kinetic_fn, forward_mode_ad)
--> 322 vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad)
    323 energy = vv_state.potential_energy + kinetic_fn(
    324     wa_state.inverse_mass_matrix, vv_state.r
    325 )
    326 zero_int = jnp.array(0, dtype=jnp.result_type(int))

File ~/anaconda3/envs/pymc4/lib/python3.9/site-packages/numpyro/infer/, in velocity_verlet.<locals>.init_fn(z, r, potential_energy, z_grad)
    270 """
    271 :param z: Position of the particle.
    272 :param r: Momentum of the particle.
    275 :return: initial state for the integrator.
    276 """
    277 if potential_energy is None or z_grad is None:
--> 278     potential_energy, z_grad = _value_and_grad(
    279         potential_fn, z, forward_mode_differentiation
    280     )
    281 return IntegratorState(z, r, potential_energy, z_grad)

File ~/anaconda3/envs/pymc4/lib/python3.9/site-packages/numpyro/infer/, in _value_and_grad(f, x, forward_mode_differentiation)
    244     return f(x), jacfwd(f)(x)
    245 else:
--> 246     return value_and_grad(f)(x)

    [... skipping hidden 2 frame]

File ~/anaconda3/envs/pymc4/lib/python3.9/site-packages/jax/_src/, in _check_input_dtype_revderiv(name, holomorphic, allow_int, x)
   1030 if (dtypes.issubdtype(aval.dtype, np.integer) or
   1031     dtypes.issubdtype(aval.dtype, np.bool_)):
   1032   if not allow_int:
-> 1033     raise TypeError(f"{name} requires real- or complex-valued inputs (input dtype "
   1034                     f"that is a sub-dtype of np.inexact), but got {}. "
   1035                     "If you want to use Boolean- or integer-valued inputs, use vjp "
   1036                     "or set allow_int to True.")
   1037 elif not dtypes.issubdtype(aval.dtype, np.inexact):
   1038   raise TypeError(f"{name} requires numerical-valued inputs (input dtype that is a "
   1039                   f"sub-dtype of np.bool_ or np.number), but got {}.")

TypeError: grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.inexact), but got int64. If you want to use Boolean- or integer-valued inputs, use vjp or set allow_int to True.

Hi, which OS are you using? And what are the pymc and numpyro and jax version ?

Also, can you provide a small sample of code that can produce this error? Thanks

I am working on Pop!_OS 21.10 . After sending this message, I decided to work on a MWE by coding up a coin flip, and the program worked for a while, and I got really good acceleration. However, at some point, errors started to occur again shortly after I added numpyro with GPU support. I will provide a MWE once I have something reproducible. In the meantime, numpyro 0.9.2 and jax 0.3.13 and jaxlib 0.3.10+cuda11.cudnn82.

Here is a MWE:

import pymc as pm  # pymc4
import pandas as pd
import numpy as np
import pymc.distributions.discrete as discrete
import pymc.distributions.continuous as continuous
import aesara.tensor as tt
import aesara as ae
import pymc.sampling_jax as jx
import jax
jax.config.update('jax_platform_name', 'cpu')
jax.default_backend(), jax.local_device_count()
print(jax.numpy.ones(3).device()) # TFRT_CPU_0

def create_coin(data):
    with pm.Model() as coin_model:
        p_coin = pm.Beta("p_coin", 2, 2)
        heads = pm.Bernoulli("flips", p_coin, observed=data)
    return coin_model

p = 0.55
nb_flips = 1000
nb_chains = 4
nb_draws = 20000
data = np.random.choice([0,1], size=nb_flips, p=[1-p, p])
model = create_coin(data)

with model:
    # tune is nb warmups
    idata = jx.sample_numpyro_nuts(target_accept=0.9, draws=nb_draws, tune=0, chains=4, chain_method='parallel')

Here is the error (different from the one previously posted).

  File /tmp/tmp6cmi_6tv:22
    1]}, InplaceDimShuffle{x}.0, InplaceDimShuffle{x}.0)
IndentationError: unexpected indent

If one looks at the file in question in /tmp/tmp6cmi_6tv, one finds the line

    # Elemwise{Composite{(Switch(AND(GE(i0, i1), LE(i0, i2)), (i3 + i4 + i5), i6) + i4 + i5)}}[(0, 0)](Elemwise{sigmoid,no_inplace}.0, TensorConstant{0.0}, TensorConstant{1.0}, TensorConstant{1.791759469228055}, Elemwise{Composite{(-scalar_softplus((-i0)))}}.0, Elemwise{Composite{(-scalar_softplus(i0))}}.0, TensorConstant{-inf})
    auto_8659 = composite2(auto_7739, auto_7510, auto_7434, auto_7964, auto_8586, auto_8600, auto_7571)
    # MakeVector{dtype='bool'}(Elemwise{ge,no_inplace}.0, Elemwise{le,no_inplace}.0)
    auto_7745 = makevector(auto_7743, auto_7741)
    # Elemwise{switch,no_inplace}(flips{[1 0 1 0 0.. 1 1 1
     1]}, InplaceDimShuffle{x}.0, InplaceDimShuffle{x}.0)     <<<<<<<<< ERROR
    auto_7752 = where(flips, auto_7751, auto_7749)
    # All(MakeVector{dtype='bool'}.0)
    auto_7746 = careduce(auto_7745)

which is clearly a comment line that has been broken inappropriately.
Any ideas? That this kind of error could occur is incomprehensible to me given the number of lines that were successfully compiled to JAX. Thanks for your insights!

Your first example was not working because it had discrete unobserved variables which cant be sampled with NUTS. The numpyro / jax backend only supports NUTS unlike pymc vanilla samplers which can combine NUTS with metropolis steppers for discrete variables.


In that case, what would it look like to write a custom sampler that alternates between NUTS (with numpyro/jax) and a Metropolis stepper for discrete variables? Is there some way to do it by writing the sampler in jax?

I think you would need to write your own JAX sampler or use a more modular one like blackjax to be able to mix samplers. CC @junpenglao

1 Like