Hello, I’m getting an error when trying to sample with nuts_sampler = "blackjax"
but am unsure why. I’m new to pymc however I cannot find any information on this error online.
Here I try to fit a model where mu
and sigma
are estimated for data randomly generated from a normal distribution:
# Simulate Fake Data ----
Y = np.random.normal(loc = 172, scale = 25, size = 10000)
# Allocate variable for model
basic_model = pm.Model()
# Define model
with basic_model:
# Priors
mu = pm.Normal("mu", mu=150, sigma=25)
sigma = pm.HalfNormal("sigma", sigma=10)
# Likelihood
Y_obs = pm.Normal("Y_obs", mu=mu, sigma=sigma, observed=Y)
with basic_model:
# draw 1000 posterior samples
idata = pm.sample(draws = 1000, tune = 1000, chains = 1, cores = 1, nuts_sampler="blackjax")
Running the last 3 lines produces this error:
/home/user/mambaforge/envs/scwork/lib/python3.11/site-packages/pymc/sampling/mcmc.py:273: UserWarning: Use of external NUTS sampler is still experimental
warnings.warn("Use of external NUTS sampler is still experimental", UserWarning)
Compiling...
2023-07-30 14:35:04.141546: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:445] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
2023-07-30 14:35:04.141577: E external/xla/xla/stream_executor/cuda/cuda_dnn.cc:449] Memory usage: 3025338368 bytes free, 12619939840 bytes total.
---------------------------------------------------------------------------
XlaRuntimeError Traceback (most recent call last)
Cell In[4], line 3
1 with basic_model:
2 # draw 1000 posterior samples
----> 3 idata = pm.sample(draws = 1000, tune = 1000, chains = 1, cores = 1, nuts_sampler="blackjax")
File ~/mambaforge/envs/scwork/lib/python3.11/site-packages/pymc/sampling/mcmc.py:660, in sample(draws, tune, chains, cores, random_seed, progressbar, step, nuts_sampler, initvals, init, jitter_max_retries, n_init, trace, discard_tuned_samples, compute_convergence_checks, keep_warning_stat, return_inferencedata, idata_kwargs, nuts_sampler_kwargs, callback, mp_ctx, model, **kwargs)
656 if not isinstance(step, NUTS):
657 raise ValueError(
658 "Model can not be sampled with NUTS alone. Your model is probably not continuous."
659 )
--> 660 return _sample_external_nuts(
661 sampler=nuts_sampler,
662 draws=draws,
663 tune=tune,
664 chains=chains,
665 target_accept=kwargs.pop("nuts", {}).get("target_accept", 0.8),
666 random_seed=random_seed,
667 initvals=initvals,
668 model=model,
669 progressbar=progressbar,
670 idata_kwargs=idata_kwargs,
671 nuts_sampler_kwargs=nuts_sampler_kwargs,
672 **kwargs,
673 )
675 if isinstance(step, list):
676 step = CompoundStep(step)
File ~/mambaforge/envs/scwork/lib/python3.11/site-packages/pymc/sampling/mcmc.py:332, in _sample_external_nuts(sampler, draws, tune, chains, target_accept, random_seed, initvals, model, progressbar, idata_kwargs, nuts_sampler_kwargs, **kwargs)
329 elif sampler == "blackjax":
330 import pymc.sampling.jax as pymc_jax
--> 332 idata = pymc_jax.sample_blackjax_nuts(
333 draws=draws,
334 tune=tune,
335 chains=chains,
336 target_accept=target_accept,
337 random_seed=random_seed,
338 initvals=initvals,
339 model=model,
340 idata_kwargs=idata_kwargs,
341 **nuts_sampler_kwargs,
342 )
343 return idata
345 else:
File ~/mambaforge/envs/scwork/lib/python3.11/site-packages/pymc/sampling/jax.py:414, in sample_blackjax_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, keep_untransformed, chain_method, postprocessing_backend, postprocessing_chunks, idata_kwargs)
410 init_params = [np.stack(init_state) for init_state in zip(init_params)]
412 logprob_fn = get_jaxified_logp(model)
--> 414 seed = jax.random.PRNGKey(random_seed)
415 keys = jax.random.split(seed, chains)
417 get_posterior_samples = partial(
418 _blackjax_inference_loop,
419 logprob_fn=logprob_fn,
(...)
422 target_accept=target_accept,
423 )
File ~/mambaforge/envs/scwork/lib/python3.11/site-packages/jax/_src/random.py:177, in PRNGKey(seed, impl)
160 def PRNGKey(seed: Union[int, Array], *,
161 impl: Optional[str] = None) -> KeyArray:
162 """Create a pseudo-random number generator (PRNG) key given an integer seed.
163
164 The resulting key carries the default PRNG implementation, as
(...)
175 and ``fold_in``.
176 """
--> 177 return _return_prng_keys(True, _key('PRNGKey', seed, impl))
File ~/mambaforge/envs/scwork/lib/python3.11/site-packages/jax/_src/random.py:139, in _key(ctor_name, seed, impl_spec)
135 if np.ndim(seed):
136 raise TypeError(
137 f"{ctor_name} accepts a scalar seed, but was given an array of "
138 f"shape {np.shape(seed)} != (). Use jax.vmap for batching")
--> 139 return prng.seed_with_impl(impl, seed)
File ~/mambaforge/envs/scwork/lib/python3.11/site-packages/jax/_src/prng.py:406, in seed_with_impl(impl, seed)
405 def seed_with_impl(impl: PRNGImpl, seed: int | Array) -> PRNGKeyArrayImpl:
--> 406 return random_seed(seed, impl=impl)
File ~/mambaforge/envs/scwork/lib/python3.11/site-packages/jax/_src/prng.py:689, in random_seed(seeds, impl)
687 else:
688 seeds_arr = jnp.asarray(seeds)
--> 689 return random_seed_p.bind(seeds_arr, impl=impl)
File ~/mambaforge/envs/scwork/lib/python3.11/site-packages/jax/_src/core.py:386, in Primitive.bind(self, *args, **params)
383 def bind(self, *args, **params):
384 assert (not config.jax_enable_checks or
385 all(isinstance(arg, Tracer) or valid_jaxtype(arg) for arg in args)), args
--> 386 return self.bind_with_trace(find_top_trace(args), args, params)
File ~/mambaforge/envs/scwork/lib/python3.11/site-packages/jax/_src/core.py:389, in Primitive.bind_with_trace(self, trace, args, params)
388 def bind_with_trace(self, trace, args, params):
--> 389 out = trace.process_primitive(self, map(trace.full_raise, args), params)
390 return map(full_lower, out) if self.multiple_results else full_lower(out)
File ~/mambaforge/envs/scwork/lib/python3.11/site-packages/jax/_src/core.py:821, in EvalTrace.process_primitive(self, primitive, tracers, params)
820 def process_primitive(self, primitive, tracers, params):
--> 821 return primitive.impl(*tracers, **params)
File ~/mambaforge/envs/scwork/lib/python3.11/site-packages/jax/_src/prng.py:701, in random_seed_impl(seeds, impl)
699 @random_seed_p.def_impl
700 def random_seed_impl(seeds, *, impl):
--> 701 base_arr = random_seed_impl_base(seeds, impl=impl)
702 return PRNGKeyArrayImpl(impl, base_arr)
File ~/mambaforge/envs/scwork/lib/python3.11/site-packages/jax/_src/prng.py:706, in random_seed_impl_base(seeds, impl)
704 def random_seed_impl_base(seeds, *, impl):
705 seed = iterated_vmap_unary(seeds.ndim, impl.seed)
--> 706 return seed(seeds)
File ~/mambaforge/envs/scwork/lib/python3.11/site-packages/jax/_src/prng.py:935, in threefry_seed(seed)
923 def threefry_seed(seed: typing.Array) -> typing.Array:
924 """Create a single raw threefry PRNG key from an integer seed.
925
926 Args:
(...)
933 first padding out with zeros).
934 """
--> 935 return _threefry_seed(seed)
[... skipping hidden 14 frame]
File ~/mambaforge/envs/scwork/lib/python3.11/site-packages/jax/_src/dispatch.py:464, in backend_compile(backend, module, options, host_callbacks)
459 return backend.compile(built_c, compile_options=options,
460 host_callbacks=host_callbacks)
461 # Some backends don't have `host_callbacks` option yet
462 # TODO(sharadmv): remove this fallback when all backends allow `compile`
463 # to take in `host_callbacks`
--> 464 return backend.compile(built_c, compile_options=options)
XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed. Look at the errors above for more details.
For more context I am running on ubuntu 22.04 and pymc v5.6.1
and blackjax v0.9.6
.
I don’t understand why it would return "Model can not be sampled with NUTS alone. Your model is probably not continuous."
as nothing is discrete here.
Thank you for any help on this.