When I use sample_blackjax_nuts() or sample_blackjax_nuts() , I got this:
AttributeError Traceback (most recent call last)
/home/anaconda/workspace/chen/group_code/long_rt/xi.ipynb Cell 7' in <cell line: 1>()
1 with m1:
----> 2 idta = jax_sample.sample_blackjax_nuts(10,10)
File ~/anaconda3/envs/jax_pymc/lib/python3.10/site-packages/pymc/sampling_jax.py:306, in sample_blackjax_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, keep_untransformed, chain_method, idata_kwargs)
303 init_params = [np.stack(init_params)]
304 init_params = [np.stack(init_state) for init_state in zip(*init_params)]
--> 306 logprob_fn = get_jaxified_logp(model)
308 seed = jax.random.PRNGKey(random_seed)
309 keys = jax.random.split(seed, chains)
File ~/anaconda3/envs/jax_pymc/lib/python3.10/site-packages/pymc/sampling_jax.py:106, in get_jaxified_logp(model, negative_logp)
104 if not negative_logp:
105 model_logpt = -model_logpt
--> 106 logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logpt])
108 def logp_fn_wrap(x):
109 return logp_fn(*x)[0]
File ~/anaconda3/envs/jax_pymc/lib/python3.10/site-packages/pymc/sampling_jax.py:99, in get_jaxified_graph(inputs, outputs)
96 mode.JAX.optimizer.optimize(fgraph)
98 # We now jaxify the optimized fgraph
---> 99 return jax_funcify(fgraph)
File ~/anaconda3/envs/jax_pymc/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw)
885 if not args:
886 raise TypeError(f'{funcname} requires at least '
887 '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)
File ~/anaconda3/envs/jax_pymc/lib/python3.10/site-packages/aesara/link/jax/dispatch.py:657, in jax_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
650 @jax_funcify.register(FunctionGraph)
651 def jax_funcify_FunctionGraph(
652 fgraph,
(...)
655 **kwargs,
656 ):
--> 657 return fgraph_to_python(
658 fgraph,
659 jax_funcify,
660 type_conversion_fn=jax_typify,
661 fgraph_name=fgraph_name,
662 **kwargs,
663 )
File ~/anaconda3/envs/jax_pymc/lib/python3.10/site-packages/aesara/link/utils.py:745, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, input_storage, output_storage, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, **kwargs)
743 body_assigns = []
744 for node in order:
--> 745 compiled_func = op_conversion_fn(
746 node.op, node=node, storage_map=storage_map, **kwargs
747 )
749 # Create a local alias with a unique name
750 local_compiled_func_name = unique_name(compiled_func)
File ~/anaconda3/envs/jax_pymc/lib/python3.10/functools.py:889, in singledispatch.<locals>.wrapper(*args, **kw)
885 if not args:
886 raise TypeError(f'{funcname} requires at least '
887 '1 positional argument')
--> 889 return dispatch(args[0].__class__)(*args, **kw)
File ~/anaconda3/envs/jax_pymc/lib/python3.10/site-packages/aesara/link/jax/dispatch.py:619, in jax_funcify_IncSubtensor(op, **kwargs)
616 idx_list = getattr(op, "idx_list", None)
618 if getattr(op, "set_instead_of_inc", False):
--> 619 jax_fn = jax.ops.index_update
620 else:
621 jax_fn = jax.ops.index_add
AttributeError: module 'jax.ops' has no attribute 'index_update'
Package information :
blackjax 0.4.0
jax 0.3.4
jaxlib 0.3.0
pymc 4.0.0b6
aeppl 0.0.27
aesara 2.5.1
numpyro 0.9.2
Any help would be appreciated!