I wanted to try out sampling with numpyro on an AMD GPU using a model based on this example:
https://www.pymc.io/projects/examples/en/latest/howto/LKJ.html
It works with PyMC and nutpie samplers, but not with jax samplers.
The code I run looks like the following:
import pymc as pm
import numpy
def main():
RANDOM_SEED = 8927
rng = numpy.random.default_rng(RANDOM_SEED)
N = 10000
mu_actual = numpy.array([1.0, -2.0])
sigmas_actual = numpy.array([0.7, 1.5])
Rho_actual = numpy.array([[1.0, -0.4], [-0.4, 1.0]])
Sigma_actual = numpy.diag(sigmas_actual) @ Rho_actual @ numpy.diag(sigmas_actual)
print(Sigma_actual)
x = rng.multivariate_normal(mu_actual, Sigma_actual, size=N)
print(type(x))
coords = {"axis": ["y", "z"], "axis_bis": ["y", "z"], "obs_id": numpy.arange(N)}
with pm.Model(coords=coords) as model:
chol, corr, stds = pm.LKJCholeskyCov(
"chol", n=2, eta=2.0, sd_dist=pm.Exponential.dist(1.0, shape=2)
)
cov = pm.Deterministic("cov", chol.dot(chol.T), dims=("axis", "axis_bis"))
mu = pm.Normal("mu", 0.0, sigma=1.5, dims="axis")
obs = pm.MvNormal("obs", mu, chol=chol, observed=x, dims=("obs_id", "axis"))
idata = pm.sample(
nuts_sampler="numpyro",
progressbar=False,
tune=1000,
draws=1000,
chains=4,
cores=1,
mp_ctx="forkserver",
nuts_sampler_kwargs=dict(
chain_method="vectorized",
postprocessing_backend="gpu"
),
idata_kwargs=dict(dims={"chol_stds": ["axis"], "chol_corr": ["axis", "axis_bis"]})
)
if __name__ == "__main__":
main()
This is the conda env.yaml
I’m using:
name: test_env
channels:
- conda-forge
dependencies:
- ipykernel
- ipywidgets
- jupyter
- jupyterlab
- numpy
- pip
- pymc
- python=3.12.7
- numpyro
- pip:
- https://github.com/ROCm/jax/releases/download/rocm-jax-v0.4.35/jaxlib-0.4.35-cp312-cp312-manylinux_2_28_x86_64.whl
- https://github.com/ROCm/jax/releases/download/rocm-jax-v0.4.35/jax_rocm60_pjrt-0.4.35-py3-none-manylinux_2_28_x86_64.whl
- https://github.com/ROCm/jax/releases/download/rocm-jax-v0.4.35/jax_rocm60_plugin-0.4.35-cp312-cp312-manylinux_2_28_x86_64.whl
- https://github.com/ROCm/jax/archive/refs/tags/rocm-jax-v0.4.35.tar.gz
- ml-dtypes==0.4.0
variables:
ROCM_PATH: /opt/rocm-6.2.1
LLVM_PATH: /opt/rocm-6.2.1/llvm
ENABLE_PJRT_COMPATIBILITY: 1
When I sample with PyMC, everything works fine.
However, the numpyro sampler gives the following error:
Traceback (most recent call last):
File "/home/eichberg/test_pymc/test_jax.py", line 46, in <module>
main()
File "/home/eichberg/test_pymc/test_jax.py", line 29, in main
idata = pm.sample(
^^^^^^^^^^
File "/home/eichberg/Programs/miniforge3/envs/test_env/lib/python3.12/site-packages/pymc/sampling/mcmc.py", line 809, in sample
return _sample_external_nuts(
^^^^^^^^^^^^^^^^^^^^^^
File "/home/eichberg/Programs/miniforge3/envs/test_env/lib/python3.12/site-packages/pymc/sampling/mcmc.py", line 396, in _sample_external_nuts
idata = pymc_jax.sample_jax_nuts(
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/eichberg/Programs/miniforge3/envs/test_env/lib/python3.12/site-packages/pymc/sampling/jax.py", line 651, in sample_jax_nuts
initial_points = _get_batched_jittered_initial_points(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/eichberg/Programs/miniforge3/envs/test_env/lib/python3.12/site-packages/pymc/sampling/jax.py", line 245, in _get_batched_jittered_initial_points
initial_points = _init_jitter(
^^^^^^^^^^^^^
File "/home/eichberg/Programs/miniforge3/envs/test_env/lib/python3.12/site-packages/pymc/sampling/mcmc.py", line 1482, in _init_jitter
point_logp = model_logp_fn(point)
^^^^^^^^^^^^^^^^^^^^
File "/home/eichberg/Programs/miniforge3/envs/test_env/lib/python3.12/site-packages/pymc/sampling/jax.py", line 243, in eval_logp_initial_point
return logp_fn(point.values())
^^^^^^^^^^^^^^^^^^^^^^^
File "/home/eichberg/Programs/miniforge3/envs/test_env/lib/python3.12/site-packages/pymc/sampling/jax.py", line 155, in logp_fn_wrap
return logp_fn(*x)[0]
^^^^^^^^^^^
File "/tmp/tmp0o4bd774", line 7, in jax_funcified_fgraph
tensor_variable_2 = incsubtensor(chol_cholesky_cov_packed_, tensor_variable_1, tensor_constant)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/home/eichberg/Programs/miniforge3/envs/test_env/lib/python3.12/site-packages/pytensor/link/jax/dispatch/subtensor.py", line 70, in incsubtensor
return jax_fn(x, indices, y)
^^^^^^^^^^^^^^^^^^^^^
File "/home/eichberg/Programs/miniforge3/envs/test_env/lib/python3.12/site-packages/pytensor/link/jax/dispatch/subtensor.py", line 58, in jax_fn
return x.at[indices].set(y)
^^^^
AttributeError: 'numpy.ndarray' object has no attribute 'at'
The error persists when turning off nuts_sampler_kwargs
, or setting the sampler to blackjax and also when using the original jax instead of ROCm-jax (thus using CPU).
Is that a bug in PyMC, pytensor, jax, etc. or did I do something wrong?
And if it’s a bug, who’s at fault?
My guess is that jit broke something.