Float32 not working with pytensor, worked with aesara

I used these lines to get single precision with aesara

import aesara
aesara.config.floatX="float32"

Simply replacing aesara with pytensor

import pytensor
pytensor.config.floatX="float32"

results in a lot of error messages when sampling starts on this theme:
'DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[]) '
Is single precision supported by pytensor?

It is supported, but it’s also easy for float64 to sneak in. One place where this happens frequently is with Discrete distributions, which default to int64 and then float64 when almost any float point operation is applied to them (e.g., log).

Similarly with shape operations, as shapes are always represented as int64 by Aesara/PyTensor

If you can share the model that used to work and now fails (in a way that’s fully reproducible), someone may be able to point out the source of the problem.

You can also use the warn_float64 flag to see where are float64 being introduced: config – PyTensor Configuration — PyTensor dev documentation

import pandas as pd
import pymc as pm
import pymc.sampling_jax
import pandas as pd
import random
random.seed(1)

import os
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
import pytensor
pytensor.config.floatX="float32"
pytensor.config.warn_float64="raise"
import jax
jax.default_backend()
df = pd.read_csv("minimal.csv")

my_model = pm.Model()
with my_model:
    Intercept = pm.Normal("Intercept", mu=0, sigma=5)
    theta_as_logit = Intercept
    y = pm.ConstantData("y", df["Y"] == True)
    theta_as_probability = pm.invlogit(theta_as_logit)
    Y_obs = pm.Bernoulli("Y_obs", p=theta_as_probability, observed=y)
    my_fit_pymc = pm.sampling_jax.sample_numpyro_nuts(random_seed=1234, tune=1000,
         draws=20, target_accept=0.47, chains = 2, chain_method='parallel', 
         idata_kwargs=dict(log_likelihood=False), 
         postprocessing_backend="cpu")

minimal.csv attached
minimal.csv (129 Bytes)

Error:

  File "/usr/local/lib/python3.10/dist-packages/jax/_src/lax/control_flow/common.py", line 108, in _check_tree_and_avals
    raise TypeError(f"{what} must have identical types, got\n{diff}.")
jax._src.traceback_util.UnfilteredStackTrace: TypeError: true_fun and false_fun output must have identical types, got
(['DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])'], ['ShapedArray(float64[])'], ['DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])'], ['DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])'], ['ShapedArray(float64[])'], ['DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])']).

The stack trace below excludes JAX-internal frames.
The preceding is the original exception that occurred, unmodified.

--------------------

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "minimal.py", line 24, in <module>
    my_fit_pymc = pm.sampling_jax.sample_numpyro_nuts(random_seed=1234, tune=1000,
  File "/usr/local/lib/python3.10/dist-packages/pymc/sampling/jax.py", line 660, in sample_numpyro_nuts
    pmap_numpyro.run(
  File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/mcmc.py", line 598, in run
    states, last_state = _laxmap(partial_map_fn, map_args)
  File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/mcmc.py", line 160, in _laxmap
    ys.append(f(x))
  File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/mcmc.py", line 404, in _single_chain_mcmc
    collect_vals = fori_collect(
  File "/usr/local/lib/python3.10/dist-packages/numpyro/util.py", line 358, in fori_collect
    vals = jit(_body_fn)(i, vals)
  File "/usr/local/lib/python3.10/dist-packages/numpyro/util.py", line 323, in _body_fn
    val = body_fun(val)
  File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/mcmc.py", line 172, in _sample_fn_nojit_args
    return (sampler.sample(state[0], args, kwargs),)
  File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/hmc.py", line 771, in sample
    return self._sample_fn(state, model_args, model_kwargs)
  File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/hmc.py", line 467, in sample_kernel
    vv_state, energy, num_steps, accept_prob, diverging = _next(
  File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/hmc.py", line 407, in _nuts_next
    binary_tree = build_tree(
  File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/hmc_util.py", line 1177, in build_tree
    tree, _ = while_loop(_cond_fn, _body_fn, state)
  File "/usr/local/lib/python3.10/dist-packages/numpyro/util.py", line 131, in while_loop
    return lax.while_loop(cond_fun, body_fun, init_val)
  File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/hmc_util.py", line 1161, in _body_fn
    tree = _double_tree(
  File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/hmc_util.py", line 916, in _double_tree
    new_tree = _iterative_build_subtree(
  File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/hmc_util.py", line 1061, in _iterative_build_subtree
    tree, turning, _, _, _ = while_loop(
  File "/usr/local/lib/python3.10/dist-packages/numpyro/util.py", line 131, in while_loop
    return lax.while_loop(cond_fun, body_fun, init_val)
  File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/hmc_util.py", line 1018, in _body_fn
    new_tree = cond(
  File "/usr/local/lib/python3.10/dist-packages/numpyro/util.py", line 121, in cond
    return lax.cond(pred, true_operand, true_fun, false_operand, false_fun)
  File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/hmc_util.py", line 1029, in <lambda>
    lambda x: _combine_tree(*x, False),
  File "/usr/local/lib/python3.10/dist-packages/numpyro/infer/hmc_util.py", line 766, in _combine_tree
    z_left, r_left, z_left_grad, z_right, r_right, r_right_grad = cond(
  File "/usr/local/lib/python3.10/dist-packages/numpyro/util.py", line 121, in cond
    return lax.cond(pred, true_operand, true_fun, false_operand, false_fun)
TypeError: true_fun and false_fun output must have identical types, got
(['DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])'], ['ShapedArray(float64[])'], ['DIFFERENT ShapedArray(float32[]) vs. ShapedArray(float64[])'], ['DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])'], ['ShapedArray(float64[])'], ['DIFFERENT ShapedArray(float64[]) vs. ShapedArray(float32[])']).