Error:
File "/usr/local/lib/python3.10/dist-packages/jax/_src/lax/control_flow/common.py", line 108, in _check_tree_and_avals
raise TypeError(f"{what} must have identical types, got\n{diff}.")
jax._src.traceback_util.UnfilteredStackTrace: TypeError: true_fun and false_fun output must have identical types, got
(['DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])'], ['ShapedArray(float64[])'], ['DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])'], ['DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])'], ['ShapedArray(float64[])'], ['DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])']).
The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.
--------------------
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "minimal.py", line 24, in <module>
my_fit_pymc = pm.sampling_jax.sample_numpyro_nuts(random_seed=1234, tune=1000,
File "/usr/local/lib/python3.10/dist-packages/pymc/sampling/jax.py", line 660, in sample_numpyro_nuts
pmap_numpyro.run(
File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/mcmc.py", line 598, in run
states, last_state = _laxmap(partial_map_fn, map_args)
File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/mcmc.py", line 160, in _laxmap
ys.append(f(x))
File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/mcmc.py", line 404, in _single_chain_mcmc
collect_vals = fori_collect(
File "/usr/local/lib/python3.10/dist-packages/numpyro/util.py", line 358, in fori_collect
vals = jit(_body_fn)(i, vals)
File "/usr/local/lib/python3.10/dist-packages/numpyro/util.py", line 323, in _body_fn
val = body_fun(val)
File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/mcmc.py", line 172, in _sample_fn_nojit_args
return (sampler.sample(state[0], args, kwargs),)
File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/hmc.py", line 771, in sample
return self._sample_fn(state, model_args, model_kwargs)
File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/hmc.py", line 467, in sample_kernel
vv_state, energy, num_steps, accept_prob, diverging = _next(
File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/hmc.py", line 407, in _nuts_next
binary_tree = build_tree(
File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/hmc_util.py", line 1177, in build_tree
tree, _ = while_loop(_cond_fn, _body_fn, state)
File "/usr/local/lib/python3.10/dist-packages/numpyro/util.py", line 131, in while_loop
return lax.while_loop(cond_fun, body_fun, init_val)
File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/hmc_util.py", line 1161, in _body_fn
tree = _double_tree(
File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/hmc_util.py", line 916, in _double_tree
new_tree = _iterative_build_subtree(
File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/hmc_util.py", line 1061, in _iterative_build_subtree
tree, turning, _, _, _ = while_loop(
File "/usr/local/lib/python3.10/dist-packages/numpyro/util.py", line 131, in while_loop
return lax.while_loop(cond_fun, body_fun, init_val)
File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/hmc_util.py", line 1018, in _body_fn
new_tree = cond(
File "/usr/local/lib/python3.10/dist-packages/numpyro/util.py", line 121, in cond
return lax.cond(pred, true_operand, true_fun, false_operand, false_fun)
File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/hmc_util.py", line 1029, in <lambda>
lambda x: _combine_tree(*x, False),
File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/hmc_util.py", line 766, in _combine_tree
z_left, r_left, z_left_grad, z_right, r_right, r_right_grad = cond(
File "/usr/local/lib/python3.10/dist-packages/numpyro/util.py", line 121, in cond
return lax.cond(pred, true_operand, true_fun, false_operand, false_fun)
TypeError: true_fun and false_fun output must have identical types, got
(['DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])'], ['ShapedArray(float64[])'], ['DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])'], ['DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])'], ['ShapedArray(float64[])'], ['DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])']).