yeah here it is!
Compiling...
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Input In [2], in <module>
1 with mW:
----> 2 idata2 = pm.sampling_jax.sample_numpyro_nuts()
3 idata2.extend(pm.sample_prior_predictive())
4 idata2.extend(pm.sample_posterior_predictive(idata2))
File ~/.pyenv/versions/3.9.7/lib/python3.9/site-packages/pymc/sampling_jax.py:483, in sample_numpyro_nuts(draws, tune, chains, target_accept, random_seed, initvals, model, var_names, progress_bar, keep_untransformed, chain_method, postprocessing_backend, idata_kwargs, nuts_kwargs)
474 print("Compiling...", file=sys.stdout)
476 init_params = _get_batched_jittered_initial_points(
477 model=model,
478 chains=chains,
479 initvals=initvals,
480 random_seed=random_seed,
481 )
--> 483 logp_fn = get_jaxified_logp(model, negative_logp=False)
485 if nuts_kwargs is None:
486 nuts_kwargs = {}
File ~/.pyenv/versions/3.9.7/lib/python3.9/site-packages/pymc/sampling_jax.py:106, in get_jaxified_logp(model, negative_logp)
104 if not negative_logp:
105 model_logpt = -model_logpt
--> 106 logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logpt])
108 def logp_fn_wrap(x):
109 return logp_fn(*x)[0]
File ~/.pyenv/versions/3.9.7/lib/python3.9/site-packages/pymc/sampling_jax.py:99, in get_jaxified_graph(inputs, outputs)
96 mode.JAX.optimizer.optimize(fgraph)
98 # We now jaxify the optimized fgraph
---> 99 return jax_funcify(fgraph)
File ~/.pyenv/versions/3.9.7/lib/python3.9/functools.py:877, in singledispatch.<locals>.wrapper(*args, **kw)
873 if not args:
874 raise TypeError(f'{funcname} requires at least '
875 '1 positional argument')
--> 877 return dispatch(args[0].__class__)(*args, **kw)
File ~/.pyenv/versions/3.9.7/lib/python3.9/site-packages/aesara/link/jax/dispatch.py:668, in jax_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
661 @jax_funcify.register(FunctionGraph)
662 def jax_funcify_FunctionGraph(
663 fgraph,
(...)
666 **kwargs,
667 ):
--> 668 return fgraph_to_python(
669 fgraph,
670 jax_funcify,
671 type_conversion_fn=jax_typify,
672 fgraph_name=fgraph_name,
673 **kwargs,
674 )
File ~/.pyenv/versions/3.9.7/lib/python3.9/site-packages/aesara/link/utils.py:745, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, input_storage, output_storage, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, **kwargs)
743 body_assigns = []
744 for node in order:
--> 745 compiled_func = op_conversion_fn(
746 node.op, node=node, storage_map=storage_map, **kwargs
747 )
749 # Create a local alias with a unique name
750 local_compiled_func_name = unique_name(compiled_func)
File ~/.pyenv/versions/3.9.7/lib/python3.9/functools.py:877, in singledispatch.<locals>.wrapper(*args, **kw)
873 if not args:
874 raise TypeError(f'{funcname} requires at least '
875 '1 positional argument')
--> 877 return dispatch(args[0].__class__)(*args, **kw)
File ~/.pyenv/versions/3.9.7/lib/python3.9/site-packages/aesara/link/jax/dispatch.py:401, in jax_funcify_Elemwise(op, **kwargs)
398 @jax_funcify.register(Elemwise)
399 def jax_funcify_Elemwise(op, **kwargs):
400 scalar_op = op.scalar_op
--> 401 return jax_funcify(scalar_op, **kwargs)
File ~/.pyenv/versions/3.9.7/lib/python3.9/functools.py:877, in singledispatch.<locals>.wrapper(*args, **kw)
873 if not args:
874 raise TypeError(f'{funcname} requires at least '
875 '1 positional argument')
--> 877 return dispatch(args[0].__class__)(*args, **kw)
File ~/.pyenv/versions/3.9.7/lib/python3.9/site-packages/aesara/link/jax/dispatch.py:406, in jax_funcify_Composite(op, vectorize, **kwargs)
404 @jax_funcify.register(Composite)
405 def jax_funcify_Composite(op, vectorize=True, **kwargs):
--> 406 jax_impl = jax_funcify(op.fgraph)
408 def composite(*args):
409 return jax_impl(*args)[0]
File ~/.pyenv/versions/3.9.7/lib/python3.9/functools.py:877, in singledispatch.<locals>.wrapper(*args, **kw)
873 if not args:
874 raise TypeError(f'{funcname} requires at least '
875 '1 positional argument')
--> 877 return dispatch(args[0].__class__)(*args, **kw)
File ~/.pyenv/versions/3.9.7/lib/python3.9/site-packages/aesara/link/jax/dispatch.py:668, in jax_funcify_FunctionGraph(fgraph, node, fgraph_name, **kwargs)
661 @jax_funcify.register(FunctionGraph)
662 def jax_funcify_FunctionGraph(
663 fgraph,
(...)
666 **kwargs,
667 ):
--> 668 return fgraph_to_python(
669 fgraph,
670 jax_funcify,
671 type_conversion_fn=jax_typify,
672 fgraph_name=fgraph_name,
673 **kwargs,
674 )
File ~/.pyenv/versions/3.9.7/lib/python3.9/site-packages/aesara/link/utils.py:745, in fgraph_to_python(fgraph, op_conversion_fn, type_conversion_fn, order, input_storage, output_storage, storage_map, fgraph_name, global_env, local_env, get_name_for_object, squeeze_output, **kwargs)
743 body_assigns = []
744 for node in order:
--> 745 compiled_func = op_conversion_fn(
746 node.op, node=node, storage_map=storage_map, **kwargs
747 )
749 # Create a local alias with a unique name
750 local_compiled_func_name = unique_name(compiled_func)
File ~/.pyenv/versions/3.9.7/lib/python3.9/functools.py:877, in singledispatch.<locals>.wrapper(*args, **kw)
873 if not args:
874 raise TypeError(f'{funcname} requires at least '
875 '1 positional argument')
--> 877 return dispatch(args[0].__class__)(*args, **kw)
File ~/.pyenv/versions/3.9.7/lib/python3.9/site-packages/aesara/link/jax/dispatch.py:157, in jax_funcify_ScalarOp(op, **kwargs)
155 @jax_funcify.register(ScalarOp)
156 def jax_funcify_ScalarOp(op, **kwargs):
--> 157 func_name = op.nfunc_spec[0]
159 if "." in func_name:
160 jnp_func = reduce(getattr, [jax] + func_name.split("."))
AttributeError: 'Log1mexp' object has no attribute 'nfunc_spec'