Hi,
I’ve got some slow running models (on GPU, using numpyro sampler) and I’ve seen that using float32 should in theory be faster than float64. Similar to this post, I’m trying to force the use of float32, but it doesn’t appear to be working. As a very simplified repro, the following snippet raises Exception: You are creating a TensorVariable with float64 dtype. You requested an action via the PyTensor flag warn_float64={ignore,warn,raise,pdb}.
:
import pymc as pm
import pytensor as pt
pt.config.floatX = "float32"
pt.config.warn_float64 = "raise"
with pm.Model() as model:
sigma = pm.HalfNormal("sigma", 1.0)
z_beta = pm.Normal(z_beta", =0.0,=1.0)
az_data = pm.sample(1000, tune=500)
Exception:
ERROR (pytensor.graph.rewriting.basic): Traceback (most recent call last):
File ".venv/lib/python3.10/site-packages/pytensor/graph/rewriting/basic.py", line 1922, in process_node
replacements = node_rewriter.transform(fgraph, node)
File ".venv/lib/python3.10/site-packages/pytensor/graph/rewriting/basic.py", line 1082, in transform
return self.fn(fgraph, node)
File ".venv/lib/python3.10/site-packages/pytensor/tensor/rewriting/basic.py", line 989, in local_sum_make_vector
add(*[cast(value, acc_dtype) for value in elements]), out_dtype
File ".venv/lib/python3.10/site-packages/pytensor/tensor/rewriting/basic.py", line 989, in <listcomp>
add(*[cast(value, acc_dtype) for value in elements]), out_dtype
File ".venv/lib/python3.10/site-packages/pytensor/tensor/basic.py", line 755, in cast
return _cast_mapping[dtype_name](x)
File ".venv/lib/python3.10/site-packages/pytensor/graph/op.py", line 304, in __call__
node = self.make_node(*inputs, **kwargs)
File ".venv/lib/python3.10/site-packages/pytensor/tensor/elemwise.py", line 497, in make_node
outputs = [
File ".venv/lib/python3.10/site-packages/pytensor/tensor/elemwise.py", line 498, in <listcomp>
TensorType(dtype=dtype, shape=shape)()
File ".venv/lib/python3.10/site-packages/pytensor/graph/type.py", line 228, in __call__
return utils.add_tag_trace(self.make_variable(name))
File ".venv/lib/python3.10/site-packages/pytensor/graph/type.py", line 200, in make_variable
return self.variable_type(self, None, name=name)
File ".venv/lib/python3.10/site-packages/pytensor/tensor/var.py", line 860, in __init__
raise Exception(msg)
Exception: You are creating a TensorVariable with float64 dtype. You requested an action via the PyTensor flag warn_float64={ignore,warn,raise,pdb}.
Note that if I comment either sigma = ...
or z_beta = ...
no exceptions are raised.
This is on pymc v5.6.1. I feel like I’m missing something simple, so any help would be appreciated.