Float32 not working with pytensor, worked with aesara

Error:

  File "/usr/local/lib/python3.10/dist-packages/jax/_src/lax/control_flow/common.py", line 108, in _check_tree_and_avals
    raise TypeError(f"{what} must have identical types, got\n{diff}.")
jax._src.traceback_util.UnfilteredStackTrace: TypeError: true_fun and false_fun output must have identical types, got
(['DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])'], ['ShapedArray(float64[])'], ['DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])'], ['DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])'], ['ShapedArray(float64[])'], ['DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])']).

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "minimal.py", line 24, in <module>
    my_fit_pymc = pm.sampling_jax.sample_numpyro_nuts(random_seed=1234, tune=1000,
  File "/usr/local/lib/python3.10/dist-packages/pymc/sampling/jax.py", line 660, in sample_numpyro_nuts
    pmap_numpyro.run(
  File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/mcmc.py", line 598, in run
    states, last_state = _laxmap(partial_map_fn, map_args)
  File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/mcmc.py", line 160, in _laxmap
    ys.append(f(x))
  File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/mcmc.py", line 404, in _single_chain_mcmc
    collect_vals = fori_collect(
  File "/usr/local/lib/python3.10/dist-packages/numpyro/util.py", line 358, in fori_collect
    vals = jit(_body_fn)(i, vals)
  File "/usr/local/lib/python3.10/dist-packages/numpyro/util.py", line 323, in _body_fn
    val = body_fun(val)
  File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/mcmc.py", line 172, in _sample_fn_nojit_args
    return (sampler.sample(state[0], args, kwargs),)
  File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/hmc.py", line 771, in sample
    return self._sample_fn(state, model_args, model_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/hmc.py", line 467, in sample_kernel
    vv_state, energy, num_steps, accept_prob, diverging = _next(
  File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/hmc.py", line 407, in _nuts_next
    binary_tree = build_tree(
  File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/hmc_util.py", line 1177, in build_tree
    tree, _ = while_loop(_cond_fn, _body_fn, state)
  File "/usr/local/lib/python3.10/dist-packages/numpyro/util.py", line 131, in while_loop
    return lax.while_loop(cond_fun, body_fun, init_val)
  File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/hmc_util.py", line 1161, in _body_fn
    tree = _double_tree(
  File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/hmc_util.py", line 916, in _double_tree
    new_tree = _iterative_build_subtree(
  File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/hmc_util.py", line 1061, in _iterative_build_subtree
    tree, turning, _, _, _ = while_loop(
  File "/usr/local/lib/python3.10/dist-packages/numpyro/util.py", line 131, in while_loop
    return lax.while_loop(cond_fun, body_fun, init_val)
  File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/hmc_util.py", line 1018, in _body_fn
    new_tree = cond(
  File "/usr/local/lib/python3.10/dist-packages/numpyro/util.py", line 121, in cond
    return lax.cond(pred, true_operand, true_fun, false_operand, false_fun)
  File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/hmc_util.py", line 1029, in <lambda>
    lambda x: _combine_tree(*x, False),
  File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/hmc_util.py", line 766, in _combine_tree
    z_left, r_left, z_left_grad, z_right, r_right, r_right_grad = cond(
  File "/usr/local/lib/python3.10/dist-packages/numpyro/util.py", line 121, in cond
    return lax.cond(pred, true_operand, true_fun, false_operand, false_fun)
TypeError: true_fun and false_fun output must have identical types, got
(['DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])'], ['ShapedArray(float64[])'], ['DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])'], ['DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])'], ['ShapedArray(float64[])'], ['DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])']).