I have a model which involves a mixture of discrete categoricals and continuous priors (the Beta distribution). The model was running fine in pymc3 but was somewhat slow (I am implementing the Asthma model in Chapter 6 of Winn&Bishop’s book Model-based Machine Learning.) So I installed pymc-dev, along with Jax, and numpyro to try and accelerate the code. I get the error below, which seems to suggest that the categoricals are creating problems since I am using the NUTS sampler, which requires gradients. Apparently, pymc will choose NUTS or Metropolis as needed depending on the distribution. What does one have to do to get JAX to work under pymc? Thanks.
with model:
idata = jx.sample_numpyro_nuts(target_accept=0.9)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
File <timed exec>:4, in <module>
File ~/anaconda3/envs/pymc4/lib/python3.9/site-packages/pymc/sampling_jax.py:515, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, postprocessing_backend, idata_kwargs, nuts_kwargs)
512 if chains > 1:
513 map_seed = jax.random.split(map_seed, chains)
--> 515 pmap_numpyro.run(
516 map_seed,
517 init_params=init_params,
518 extra_fields=(
519 "num_steps",
520 "potential_energy",
521 "energy",
522 "adapt_state.step_size",
523 "accept_prob",
524 "diverging",
525 ),
526 )
528 raw_mcmc_samples = pmap_numpyro.get_samples(group_by_chain=True)
530 tic3 = datetime.now()
File ~/anaconda3/envs/pymc4/lib/python3.9/site-packages/numpyro/infer/mcmc.py:599, in MCMC.run(self, rng_key, extra_fields, init_params, *args, **kwargs)
597 states, last_state = _laxmap(partial_map_fn, map_args)
598 elif self.chain_method == "parallel":
--> 599 states, last_state = pmap(partial_map_fn)(map_args)
600 else:
601 assert self.chain_method == "vectorized"
[... skipping hidden 17 frame]
File ~/anaconda3/envs/pymc4/lib/python3.9/site-packages/numpyro/infer/mcmc.py:381, in MCMC._single_chain_mcmc(self, init, args, kwargs, collect_fields)
379 rng_key, init_state, init_params = init
380 if init_state is None:
--> 381 init_state = self.sampler.init(
382 rng_key,
383 self.num_warmup,
384 init_params,
385 model_args=args,
386 model_kwargs=kwargs,
387 )
388 sample_fn, postprocess_fn = self._get_cached_fns()
389 diagnostics = (
390 lambda x: self.sampler.get_diagnostics_str(x[0])
391 if rng_key.ndim == 1
392 else ""
393 ) # noqa: E731
File ~/anaconda3/envs/pymc4/lib/python3.9/site-packages/numpyro/infer/hmc.py:746, in HMC.init(self, rng_key, num_warmup, init_params, model_args, model_kwargs)
726 hmc_init_fn = lambda init_params, rng_key: self._init_fn( # noqa: E731
727 init_params,
728 num_warmup=num_warmup,
(...)
743 rng_key=rng_key,
744 )
745 if rng_key.ndim == 1:
--> 746 init_state = hmc_init_fn(init_params, rng_key)
747 else:
748 # XXX it is safe to run hmc_init_fn under vmap despite that hmc_init_fn changes some
749 # nonlocal variables: momentum_generator, wa_update, trajectory_len, max_treedepth,
750 # wa_steps because those variables do not depend on traced args: init_params, rng_key.
751 init_state = vmap(hmc_init_fn)(init_params, rng_key)
File ~/anaconda3/envs/pymc4/lib/python3.9/site-packages/numpyro/infer/hmc.py:726, in HMC.init.<locals>.<lambda>(init_params, rng_key)
723 dense_mass = [tuple(sorted(z))] if dense_mass else []
724 assert isinstance(dense_mass, list)
--> 726 hmc_init_fn = lambda init_params, rng_key: self._init_fn( # noqa: E731
727 init_params,
728 num_warmup=num_warmup,
729 step_size=self._step_size,
730 num_steps=self._num_steps,
731 inverse_mass_matrix=inverse_mass_matrix,
732 adapt_step_size=self._adapt_step_size,
733 adapt_mass_matrix=self._adapt_mass_matrix,
734 dense_mass=dense_mass,
735 target_accept_prob=self._target_accept_prob,
736 trajectory_length=self._trajectory_length,
737 max_tree_depth=self._max_tree_depth,
738 find_heuristic_step_size=self._find_heuristic_step_size,
739 forward_mode_differentiation=self._forward_mode_differentiation,
740 regularize_mass_matrix=self._regularize_mass_matrix,
741 model_args=model_args,
742 model_kwargs=model_kwargs,
743 rng_key=rng_key,
744 )
745 if rng_key.ndim == 1:
746 init_state = hmc_init_fn(init_params, rng_key)
File ~/anaconda3/envs/pymc4/lib/python3.9/site-packages/numpyro/infer/hmc.py:322, in hmc.<locals>.init_kernel(init_params, num_warmup, step_size, inverse_mass_matrix, adapt_step_size, adapt_mass_matrix, dense_mass, target_accept_prob, num_steps, trajectory_length, max_tree_depth, find_heuristic_step_size, forward_mode_differentiation, regularize_mass_matrix, model_args, model_kwargs, rng_key)
320 r = momentum_generator(z, wa_state.mass_matrix_sqrt, rng_key_momentum)
321 vv_init, vv_update = velocity_verlet(pe_fn, kinetic_fn, forward_mode_ad)
--> 322 vv_state = vv_init(z, r, potential_energy=pe, z_grad=z_grad)
323 energy = vv_state.potential_energy + kinetic_fn(
324 wa_state.inverse_mass_matrix, vv_state.r
325 )
326 zero_int = jnp.array(0, dtype=jnp.result_type(int))
File ~/anaconda3/envs/pymc4/lib/python3.9/site-packages/numpyro/infer/hmc_util.py:278, in velocity_verlet.<locals>.init_fn(z, r, potential_energy, z_grad)
270 """
271 :param z: Position of the particle.
272 :param r: Momentum of the particle.
(...)
275 :return: initial state for the integrator.
276 """
277 if potential_energy is None or z_grad is None:
--> 278 potential_energy, z_grad = _value_and_grad(
279 potential_fn, z, forward_mode_differentiation
280 )
281 return IntegratorState(z, r, potential_energy, z_grad)
File ~/anaconda3/envs/pymc4/lib/python3.9/site-packages/numpyro/infer/hmc_util.py:246, in _value_and_grad(f, x, forward_mode_differentiation)
244 return f(x), jacfwd(f)(x)
245 else:
--> 246 return value_and_grad(f)(x)
[... skipping hidden 2 frame]
File ~/anaconda3/envs/pymc4/lib/python3.9/site-packages/jax/_src/api.py:1033, in _check_input_dtype_revderiv(name, holomorphic, allow_int, x)
1030 if (dtypes.issubdtype(aval.dtype, np.integer) or
1031 dtypes.issubdtype(aval.dtype, np.bool_)):
1032 if not allow_int:
-> 1033 raise TypeError(f"{name} requires real- or complex-valued inputs (input dtype "
1034 f"that is a sub-dtype of np.inexact), but got {aval.dtype.name}. "
1035 "If you want to use Boolean- or integer-valued inputs, use vjp "
1036 "or set allow_int to True.")
1037 elif not dtypes.issubdtype(aval.dtype, np.inexact):
1038 raise TypeError(f"{name} requires numerical-valued inputs (input dtype that is a "
1039 f"sub-dtype of np.bool_ or np.number), but got {aval.dtype.name}.")
TypeError: grad requires real- or complex-valued inputs (input dtype that is a sub-dtype of np.inexact), but got int64. If you want to use Boolean- or integer-valued inputs, use vjp or set allow_int to True.