How to force float32?

Hi,

I’ve got some slow running models (on GPU, using numpyro sampler) and I’ve seen that using float32 should in theory be faster than float64. Similar to this post, I’m trying to force the use of float32, but it doesn’t appear to be working. As a very simplified repro, the following snippet raises Exception: You are creating a TensorVariable with float64 dtype. You requested an action via the PyTensor flag warn_float64={ignore,warn,raise,pdb}.:

import pymc as pm
import pytensor as pt

pt.config.floatX = "float32"
pt.config.warn_float64 = "raise"

with pm.Model() as model:
    sigma = pm.HalfNormal("sigma", 1.0)
    z_beta = pm.Normal(z_beta", =0.0,=1.0)
    az_data = pm.sample(1000, tune=500)

Exception:

ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
  File ".venv/lib/python3.10/site-packages/pytensor/graph/rewriting/basic.py", line 1922, in process_node
    replacements = node_rewriter.transform(fgraph, node)
  File ".venv/lib/python3.10/site-packages/pytensor/graph/rewriting/basic.py", line 1082, in transform
    return self.fn(fgraph, node)
  File ".venv/lib/python3.10/site-packages/pytensor/tensor/rewriting/basic.py", line 989, in local_sum_make_vector
    add(*[cast(value, acc_dtype) for value in elements]), out_dtype
  File ".venv/lib/python3.10/site-packages/pytensor/tensor/rewriting/basic.py", line 989, in <listcomp>
    add(*[cast(value, acc_dtype) for value in elements]), out_dtype
  File ".venv/lib/python3.10/site-packages/pytensor/tensor/basic.py", line 755, in cast
    return _cast_mapping[dtype_name](x)
  File ".venv/lib/python3.10/site-packages/pytensor/graph/op.py", line 304, in __call__
    node = self.make_node(*inputs, **kwargs)
  File ".venv/lib/python3.10/site-packages/pytensor/tensor/elemwise.py", line 497, in make_node
    outputs = [
  File ".venv/lib/python3.10/site-packages/pytensor/tensor/elemwise.py", line 498, in <listcomp>
    TensorType(dtype=dtype, shape=shape)()
  File ".venv/lib/python3.10/site-packages/pytensor/graph/type.py", line 228, in __call__
    return utils.add_tag_trace(self.make_variable(name))
  File ".venv/lib/python3.10/site-packages/pytensor/graph/type.py", line 200, in make_variable
    return self.variable_type(self, None, name=name)
  File ".venv/lib/python3.10/site-packages/pytensor/tensor/var.py", line 860, in __init__
    raise Exception(msg)
Exception: You are creating a TensorVariable with float64 dtype. You requested an action via the PyTensor flag warn_float64={ignore,warn,raise,pdb}.

Note that if I comment either sigma = ... or z_beta = ... no exceptions are raised.

This is on pymc v5.6.1. I feel like I’m missing something simple, so any help would be appreciated.