AttributeError: 'numpy.ndarray' object has no attribute 'at' when sampling LKJ Cholesky Covariance Priors for Multivariate Normal Models example with numpyro or blackjax

I wanted to try out sampling with numpyro on an AMD GPU using a model based on this example:
https://www.pymc.io/projects/examples/en/latest/howto/LKJ.html
It works with PyMC and nutpie samplers, but not with jax samplers.
The code I run looks like the following:

import pymc as pm
import numpy


def main():
    RANDOM_SEED = 8927
    rng = numpy.random.default_rng(RANDOM_SEED)

    N = 10000

    mu_actual = numpy.array([1.0, -2.0])
    sigmas_actual = numpy.array([0.7, 1.5])
    Rho_actual = numpy.array([[1.0, -0.4], [-0.4, 1.0]])

    Sigma_actual = numpy.diag(sigmas_actual) @ Rho_actual @ numpy.diag(sigmas_actual)
    print(Sigma_actual)

    x = rng.multivariate_normal(mu_actual, Sigma_actual, size=N)
    print(type(x))

    coords = {"axis": ["y", "z"], "axis_bis": ["y", "z"], "obs_id": numpy.arange(N)}
    with pm.Model(coords=coords) as model:
        chol, corr, stds = pm.LKJCholeskyCov(
            "chol", n=2, eta=2.0, sd_dist=pm.Exponential.dist(1.0, shape=2)
        )
        cov = pm.Deterministic("cov", chol.dot(chol.T), dims=("axis", "axis_bis"))
        mu = pm.Normal("mu", 0.0, sigma=1.5, dims="axis")
        obs = pm.MvNormal("obs", mu, chol=chol, observed=x, dims=("obs_id", "axis"))
        idata = pm.sample(
            nuts_sampler="numpyro",
            progressbar=False,
            tune=1000,
            draws=1000,
            chains=4,
            cores=1,
            mp_ctx="forkserver",
            nuts_sampler_kwargs=dict(
                chain_method="vectorized",
                postprocessing_backend="gpu"
                ),
            idata_kwargs=dict(dims={"chol_stds": ["axis"], "chol_corr": ["axis", "axis_bis"]})
        )


if __name__ == "__main__":
    main()

This is the conda env.yaml I’m using:

name: test_env
channels:
  - conda-forge
dependencies:
  - ipykernel
  - ipywidgets
  - jupyter
  - jupyterlab
  - numpy
  - pip
  - pymc
  - python=3.12.7
  - numpyro
  - pip:
    - https://github.com/ROCm/jax/releases/download/rocm-jax-v0.4.35/jaxlib-0.4.35-cp312-cp312-manylinux_2_28_x86_64.whl
    - https://github.com/ROCm/jax/releases/download/rocm-jax-v0.4.35/jax_rocm60_pjrt-0.4.35-py3-none-manylinux_2_28_x86_64.whl 
    - https://github.com/ROCm/jax/releases/download/rocm-jax-v0.4.35/jax_rocm60_plugin-0.4.35-cp312-cp312-manylinux_2_28_x86_64.whl
    - https://github.com/ROCm/jax/archive/refs/tags/rocm-jax-v0.4.35.tar.gz
    - ml-dtypes==0.4.0
variables:
  ROCM_PATH: /opt/rocm-6.2.1
  LLVM_PATH: /opt/rocm-6.2.1/llvm
  ENABLE_PJRT_COMPATIBILITY: 1

When I sample with PyMC, everything works fine.
However, the numpyro sampler gives the following error:

Traceback (most recent call last):
  File "/home/eichberg/test_pymc/test_jax.py", line 46, in <module>
    main()
  File "/home/eichberg/test_pymc/test_jax.py", line 29, in main
    idata = pm.sample(
            ^^^^^^^^^^
  File "/home/eichberg/Programs/miniforge3/envs/test_env/lib/python3.12/site-packages/pymc/sampling/mcmc.py", line 809, in sample
    return _sample_external_nuts(
           ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/eichberg/Programs/miniforge3/envs/test_env/lib/python3.12/site-packages/pymc/sampling/mcmc.py", line 396, in _sample_external_nuts
    idata = pymc_jax.sample_jax_nuts(
            ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/eichberg/Programs/miniforge3/envs/test_env/lib/python3.12/site-packages/pymc/sampling/jax.py", line 651, in sample_jax_nuts
    initial_points = _get_batched_jittered_initial_points(
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/eichberg/Programs/miniforge3/envs/test_env/lib/python3.12/site-packages/pymc/sampling/jax.py", line 245, in _get_batched_jittered_initial_points
    initial_points = _init_jitter(
                     ^^^^^^^^^^^^^
  File "/home/eichberg/Programs/miniforge3/envs/test_env/lib/python3.12/site-packages/pymc/sampling/mcmc.py", line 1482, in _init_jitter
    point_logp = model_logp_fn(point)
                 ^^^^^^^^^^^^^^^^^^^^
  File "/home/eichberg/Programs/miniforge3/envs/test_env/lib/python3.12/site-packages/pymc/sampling/jax.py", line 243, in eval_logp_initial_point
    return logp_fn(point.values())
           ^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/eichberg/Programs/miniforge3/envs/test_env/lib/python3.12/site-packages/pymc/sampling/jax.py", line 155, in logp_fn_wrap
    return logp_fn(*x)[0]
           ^^^^^^^^^^^
  File "/tmp/tmp0o4bd774", line 7, in jax_funcified_fgraph
    tensor_variable_2 = incsubtensor(chol_cholesky_cov_packed_, tensor_variable_1, tensor_constant)
                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/eichberg/Programs/miniforge3/envs/test_env/lib/python3.12/site-packages/pytensor/link/jax/dispatch/subtensor.py", line 70, in incsubtensor
    return jax_fn(x, indices, y)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/eichberg/Programs/miniforge3/envs/test_env/lib/python3.12/site-packages/pytensor/link/jax/dispatch/subtensor.py", line 58, in jax_fn
    return x.at[indices].set(y)
           ^^^^
AttributeError: 'numpy.ndarray' object has no attribute 'at'

The error persists when turning off nuts_sampler_kwargs, or setting the sampler to blackjax and also when using the original jax instead of ROCm-jax (thus using CPU).
Is that a bug in PyMC, pytensor, jax, etc. or did I do something wrong?
And if it’s a bug, who’s at fault?
My guess is that jit broke something.

Odd somehow the jax function is getting a numpy array instead of a jax one

It’s a bug in our side, there were some recent changes to reuse the jax logp for the initial point and they don’t handle inputs correctly

Should be fixed by Fix bug when reusing jax logp for initial point generation by ricardoV94 · Pull Request #7695 · pymc-devs/pymc · GitHub

Yep, this works, thank you :slight_smile: