Nuts_sampler = "blackjax" error with continuous model

,

Hello, I’m getting an error when trying to sample with nuts_sampler = "blackjax" but am unsure why. I’m new to pymc however I cannot find any information on this error online.
Here I try to fit a model where mu and sigma are estimated for data randomly generated from a normal distribution:

# Simulate Fake Data ----
Y = np.random.normal(loc = 172, scale = 25, size = 10000)

# Allocate variable for model
basic_model = pm.Model()

# Define model
with basic_model:
    # Priors
    mu = pm.Normal("mu", mu=150, sigma=25)
    sigma = pm.HalfNormal("sigma", sigma=10)

    # Likelihood
    Y_obs = pm.Normal("Y_obs", mu=mu, sigma=sigma, observed=Y)


with basic_model:
    # draw 1000 posterior samples
    idata = pm.sample(draws = 1000, tune = 1000, chains = 1, cores = 1, nuts_sampler="blackjax")

Running the last 3 lines produces this error:

/home/user/mambaforge/envs/scwork/lib/python3.11/site-packages/pymc/sampling/mcmc.py:273: UserWarning: Use of external NUTS sampler is still experimental
  warnings.warn("Use of external NUTS sampler is still experimental", UserWarning)
Compiling...
2023-07-30 14:35:04.141546: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:445] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2023-07-30 14:35:04.141577: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:449] Memory usage: 3025338368 bytes free, 12619939840 bytes total.
---------------------------------------------------------------------------
XlaRuntimeError                           Traceback (most recent call last)
Cell In[4], line 3
      1 with basic_model:
      2     # draw 1000 posterior samples
----> 3     idata = pm.sample(draws = 1000, tune = 1000, chains = 1, cores = 1, nuts_sampler="blackjax")

File ~/mambaforge/envs/scwork/lib/python3.11/site-packages/pymc/sampling/mcmc.py:660, in sample(draws, tune, chains, cores, random_seed, progressbar, step, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, model, **kwargs)
    656     if not isinstance(step, NUTS):
    657         raise ValueError(
    658             "Model can not be sampled with NUTS alone. Your model is probably not continuous."
    659         )
--> 660     return _sample_external_nuts(
    661         sampler=nuts_sampler,
    662         draws=draws,
    663         tune=tune,
    664         chains=chains,
    665         target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
    666         random_seed=random_seed,
    667         initvals=initvals,
    668         model=model,
    669         progressbar=progressbar,
    670         idata_kwargs=idata_kwargs,
    671         nuts_sampler_kwargs=nuts_sampler_kwargs,
    672         **kwargs,
    673     )
    675 if isinstance(step, list):
    676     step = CompoundStep(step)

File ~/mambaforge/envs/scwork/lib/python3.11/site-packages/pymc/sampling/mcmc.py:332, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, progressbar, idata_kwargs, nuts_sampler_kwargs, **kwargs)
    329 elif sampler == "blackjax":
    330     import pymc.sampling.jax as pymc_jax
--> 332     idata = pymc_jax.sample_blackjax_nuts(
    333         draws=draws,
    334         tune=tune,
    335         chains=chains,
    336         target_accept=target_accept,
    337         random_seed=random_seed,
    338         initvals=initvals,
    339         model=model,
    340         idata_kwargs=idata_kwargs,
    341         **nuts_sampler_kwargs,
    342     )
    343     return idata
    345 else:

File ~/mambaforge/envs/scwork/lib/python3.11/site-packages/pymc/sampling/jax.py:414, in sample_blackjax_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, keep_untransformed, chain_method, postprocessing_backend, postprocessing_chunks, idata_kwargs)
    410     init_params = [np.stack(init_state) for init_state in zip(init_params)]
    412 logprob_fn = get_jaxified_logp(model)
--> 414 seed = jax.random.PRNGKey(random_seed)
    415 keys = jax.random.split(seed, chains)
    417 get_posterior_samples = partial(
    418     _blackjax_inference_loop,
    419     logprob_fn=logprob_fn,
   (...)
    422     target_accept=target_accept,
    423 )

File ~/mambaforge/envs/scwork/lib/python3.11/site-packages/jax/_src/random.py:177, in PRNGKey(seed, impl)
    160 def PRNGKey(seed: Union[int, Array], *,
    161             impl: Optional[str] = None) -> KeyArray:
    162   """Create a pseudo-random number generator (PRNG) key given an integer seed.
    163 
    164   The resulting key carries the default PRNG implementation, as
   (...)
    175     and ``fold_in``.
    176   """
--> 177   return _return_prng_keys(True, _key('PRNGKey', seed, impl))

File ~/mambaforge/envs/scwork/lib/python3.11/site-packages/jax/_src/random.py:139, in _key(ctor_name, seed, impl_spec)
    135 if np.ndim(seed):
    136   raise TypeError(
    137       f"{ctor_name} accepts a scalar seed, but was given an array of "
    138       f"shape {np.shape(seed)} != (). Use jax.vmap for batching")
--> 139 return prng.seed_with_impl(impl, seed)

File ~/mambaforge/envs/scwork/lib/python3.11/site-packages/jax/_src/prng.py:406, in seed_with_impl(impl, seed)
    405 def seed_with_impl(impl: PRNGImpl, seed: int | Array) -> PRNGKeyArrayImpl:
--> 406   return random_seed(seed, impl=impl)

File ~/mambaforge/envs/scwork/lib/python3.11/site-packages/jax/_src/prng.py:689, in random_seed(seeds, impl)
    687 else:
    688   seeds_arr = jnp.asarray(seeds)
--> 689 return random_seed_p.bind(seeds_arr, impl=impl)

File ~/mambaforge/envs/scwork/lib/python3.11/site-packages/jax/_src/core.py:386, in Primitive.bind(self, *args, **params)
    383 def bind(self, *args, **params):
    384   assert (not config.jax_enable_checks or
    385           all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 386   return self.bind_with_trace(find_top_trace(args), args, params)

File ~/mambaforge/envs/scwork/lib/python3.11/site-packages/jax/_src/core.py:389, in Primitive.bind_with_trace(self, trace, args, params)
    388 def bind_with_trace(self, trace, args, params):
--> 389   out = trace.process_primitive(self, map(trace.full_raise, args), params)
    390   return map(full_lower, out) if self.multiple_results else full_lower(out)

File ~/mambaforge/envs/scwork/lib/python3.11/site-packages/jax/_src/core.py:821, in EvalTrace.process_primitive(self, primitive, tracers, params)
    820 def process_primitive(self, primitive, tracers, params):
--> 821   return primitive.impl(*tracers, **params)

File ~/mambaforge/envs/scwork/lib/python3.11/site-packages/jax/_src/prng.py:701, in random_seed_impl(seeds, impl)
    699 @random_seed_p.def_impl
    700 def random_seed_impl(seeds, *, impl):
--> 701   base_arr = random_seed_impl_base(seeds, impl=impl)
    702   return PRNGKeyArrayImpl(impl, base_arr)

File ~/mambaforge/envs/scwork/lib/python3.11/site-packages/jax/_src/prng.py:706, in random_seed_impl_base(seeds, impl)
    704 def random_seed_impl_base(seeds, *, impl):
    705   seed = iterated_vmap_unary(seeds.ndim, impl.seed)
--> 706   return seed(seeds)

File ~/mambaforge/envs/scwork/lib/python3.11/site-packages/jax/_src/prng.py:935, in threefry_seed(seed)
    923 def threefry_seed(seed: typing.Array) -> typing.Array:
    924   """Create a single raw threefry PRNG key from an integer seed.
    925 
    926   Args:
   (...)
    933     first padding out with zeros).
    934   """
--> 935   return _threefry_seed(seed)

    [... skipping hidden 14 frame]

File ~/mambaforge/envs/scwork/lib/python3.11/site-packages/jax/_src/dispatch.py:464, in backend_compile(backend, module, options, host_callbacks)
    459   return backend.compile(built_c, compile_options=options,
    460                          host_callbacks=host_callbacks)
    461 # Some backends don't have `host_callbacks` option yet
    462 # TODO(sharadmv): remove this fallback when all backends allow `compile`
    463 # to take in `host_callbacks`
--> 464 return backend.compile(built_c, compile_options=options)

XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.

For more context I am running on ubuntu 22.04 and pymc v5.6.1 and blackjax v0.9.6.

I don’t understand why it would return "Model can not be sampled with NUTS alone. Your model is probably not continuous." as nothing is discrete here.

Thank you for any help on this.

Nevermind! Fixed by reinstalling jax with:

pip install --upgrade "jax[cuda12_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html

Sampling time decreased too which is nice.

2 Likes