Pymc 4.0.0b6 :AttributeError: module 'jax.ops' has no attribute 'index_update'

When I use sample_blackjax_nuts() or sample_blackjax_nuts() , I got this:

AttributeError                            Traceback (most recent call last)
/home/anaconda/workspace/chen/group_code/long_rt/xi.ipynb Cell 7' in <cell line: 1>()
      1 with m1:
----> 2     idta = jax_sample.sample_blackjax_nuts(10,10)

File ~/anaconda3/envs/jax_pymc/lib/python3.10/site-packages/pymc/sampling_jax.py:306, in sample_blackjax_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, keep_untransformed, chain_method, idata_kwargs)
    303     init_params = [np.stack(init_params)]
    304     init_params = [np.stack(init_state) for init_state in zip(*init_params)]
--> 306 logprob_fn = get_jaxified_logp(model)
    308 seed = jax.random.PRNGKey(random_seed)
    309 keys = jax.random.split(seed, chains)

File ~/anaconda3/envs/jax_pymc/lib/python3.10/site-packages/pymc/sampling_jax.py:106, in get_jaxified_logp(model, negative_logp)
    104 if not negative_logp:
    105     model_logpt = -model_logpt
--> 106 logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logpt])
    108 def logp_fn_wrap(x):
    109     return logp_fn(*x)[0]

File ~/anaconda3/envs/jax_pymc/lib/python3.10/site-packages/pymc/sampling_jax.py:99, in get_jaxified_graph(inputs, outputs)
     96 mode.JAX.optimizer.optimize(fgraph)
     98 # We now jaxify the optimized fgraph
---> 99 return jax_funcify(fgraph)

File ~/anaconda3/envs/jax_pymc/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw)
    885 if not args:
    886     raise TypeError(f'{funcname} requires at least '
    887                     '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)

File ~/anaconda3/envs/jax_pymc/lib/python3.10/site-packages/aesara/link/jax/dispatch.py:657, in jax_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
    650 @jax_funcify.register(FunctionGraph)
    651 def jax_funcify_FunctionGraph(
    652     fgraph,
   (...)
    655     **kwargs,
    656 ):
--> 657     return fgraph_to_python(
    658         fgraph,
    659         jax_funcify,
    660         type_conversion_fn=jax_typify,
    661         fgraph_name=fgraph_name,
    662         **kwargs,
    663     )

File ~/anaconda3/envs/jax_pymc/lib/python3.10/site-packages/aesara/link/utils.py:745, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, input_storage, output_storage, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, **kwargs)
    743 body_assigns = []
    744 for node in order:
--> 745     compiled_func = op_conversion_fn(
    746         node.op, node=node, storage_map=storage_map, **kwargs
    747     )
    749     # Create a local alias with a unique name
    750     local_compiled_func_name = unique_name(compiled_func)

File ~/anaconda3/envs/jax_pymc/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw)
    885 if not args:
    886     raise TypeError(f'{funcname} requires at least '
    887                     '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)

File ~/anaconda3/envs/jax_pymc/lib/python3.10/site-packages/aesara/link/jax/dispatch.py:619, in jax_funcify_IncSubtensor(op, **kwargs)
    616 idx_list = getattr(op, "idx_list", None)
    618 if getattr(op, "set_instead_of_inc", False):
--> 619     jax_fn = jax.ops.index_update
    620 else:
    621     jax_fn = jax.ops.index_add

AttributeError: module 'jax.ops' has no attribute 'index_update'

Package information :

blackjax 0.4.0
jax 0.3.4
jaxlib 0.3.0
pymc 4.0.0b6
aeppl 0.0.27
aesara 2.5.1
numpyro 0.9.2

Any help would be appreciated!

Can you share the model? Difficult to help just from the error message. It might be you are using a version of JAX that is too recent and where internals have changed

3 Likes

I think jax and jaxlib must both be the exact same version for jax to work but it doesn’t loom like the case in your shared versions.

1 Like

Thanks ~
it just a toy model.

# simulation data
x = np.random.normal(0,1,100)
slope = 1.3
p = np.exp(x*slope)/(1+np.exp(x*slope))
y = np.random.binomial(1,p)
# model
with pm.Model() as m:
    sl = pm.Normal("slope",1,2)
    logit_p = sl*x
    pm.Bernoulli("Y",logit_p=logit_p) # missing observed data
with m:
    tr = sampling_jax.sample_numpyro_nuts()

emmmm…
I think this error is caused by missing observed data.
When I pass observed data. It works well~

1 Like

This error still occurs.
It seem associated with at.set_subtensor() .
When I use this functions, raise AttributeError: module 'jax.ops' has no attribute 'index_update'

Y = np.random.multivariate_normal([0,0,0,0,0],np.eye(5),size=500)

triL_idx = np.tril_indices(5, k=-1)

with pm.Model() as m1:

    # I am trying to get a cov through L_mat.dot(L_mat.T) , 
    # and constrain the first variance to equal to 1

    L_mat_tri = pm.Normal("L_mat_tri", 0, 2.5, shape=at.as_tensor(sum(np.arange(5))))
    L_mat = at.set_subtensor(at.zeros((5, 5))[triL_idx], L_mat_tri)
    L_mat_diag = at.concatenate(
        [[1], pm.HalfNormal("s", 2.5, shape=(5-1))])
    L_mat = at.set_subtensor(L_mat[np.diag_indices(5)], L_mat_diag)
    pm.MvNormal("Y",[0,0,0,0,0],chol=L_mat,observed=Y)

with m1:
    tr = sampling_jax.sample_numpyro_nuts()

Not sure if it might have been fixed or introduced by this PR Use new JAX index update approach when available by brandonwillard · Pull Request #864 · aesara-devs/aesara · GitHub. You can try to install the latest PyMC version from main and see it if works.

1 Like

I came across this error recently also. I found that I could get it to at least run by downgrading jax to 0.2.22 (based on “The functions jax.ops.index_update , jax.ops.index_add , etc., which were deprecated in JAX 0.2.22, have been removed. Please use the jax.numpy.ndarray.at property on JAX arrays instead.” from here).

2 Likes

Thanks for your help!
I think maybe jax_funcify_IncSubtensor(op, **kwargs) has been fixed,

but did not fix jax_funcify_AdvancedIncSubtensor at line 648.

I tried to fix it in the same way, and it seems work well~

@jax_funcify.register(AdvancedIncSubtensor)
def jax_funcify_AdvancedIncSubtensor(op, **kwargs):

    if getattr(op, "set_instead_of_inc", False):
        jax_fn = getattr(jax.ops, "index_update", None)
        if jax_fn is None:

            def jax_fn(x, indices, y):
                return x.at[indices].set(y)
    else:
         jax_fn = getattr(jax.ops, "index_add", None)
         if jax_fn is None:

            def jax_fn(x, indices, y):
                return x.at[indices].add(y)


    def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn):
        return jax_fn(x, ilist, y)

    return advancedincsubtensor

:thinking:
I tried to roll back the jax version, but didn’t realize it would need to roll back so much .

Thats amazing @qipengchen, do you want to open a PR in Aesara to fix it? GitHub - aesara-devs/aesara: Aesara is a Python library for defining, optimizing, and efficiently evaluating mathematical expressions involving multi-dimensional arrays.

1 Like

Sure, I’d love to~

1 Like

For others running into this problem, downgrading jax to 0.2.22 as discovered by @djmannion fixed this for me.

Here are the various players in my current conda environment after re-building it with the constraint on jax:

# Name                    Version                   Build  Channel
aeppl                     0.0.27             pyhd8ed1ab_0    conda-forge
aesara                    2.6.6           py310hd17ff3b_0    conda-forge
arviz                     0.12.1             pyhd8ed1ab_0    conda-forge
jax                       0.2.22             pyhd8ed1ab_0    conda-forge
jaxlib                    0.3.0           py310hc3794dd_3    conda-forge
numpyro                   0.9.2              pyhd8ed1ab_0    conda-forge
pymc                      4.0.0b5         py310h4714cba_0    conda-forge
python                    3.10.4          h8b4d769_0_cpython    conda-forge

Edit: Force installing the latest version of aesara (v2.6.6) was required.

1 Like

Yes, according to the change log of jax, index_update was deprecated at jax 0.2.22, but removed at jax 0.3.2, so you can solve it by downgrading jax to the version before 0.3.2.

1 Like

Even better, update pymc to a more recent version :wink:

1 Like