Cannot create Dirichlet with "float32" type

Versions
Theano: 1.0.4
pymc3: 3.6

Test:

import pymc3 as pm
import numpy as np
str_dtype = "float32"
with pm.Model() as op_model:
    mat = pm.Dirichlet('A', np.ones(4), shape=(4, 4), 
                       dtype=str_dtype)

Error:

Traceback (most recent call last):
  File "test_dirichlet.py", line 8, in <module>
    dtype=str_dtype)
  File "<...>/pymc3/distributions/distribution.py", line 42, in __new__
  File "<...>/pymc3/model.py", line 816, in Var
  File "<...>/pymc3/model.py", line 1499, in __init__
  File "<...>/theano/tensor/var.py", line 275, in <lambda>
    shape = property(lambda self: theano.tensor.basic.shape(self))
  File "<...>/theano/gof/op.py", line 625, in __call__
    storage_map[ins] = [self._get_test_value(ins)]
  File "<...>/theano/gof/op.py", line 562, in _get_test_value
    ret = v.type.filter(v.tag.test_value)
  File "<...>/theano/tensor/type.py", line 140, in filter
    raise TypeError(err_msg)
TypeError: For compute_test_value, one input test value does 
not have the requested type.

The error when converting the test value to that variable type:
TensorType(float32, matrix) cannot store a value of dtype 
float64 without risking loss of precision. 

If you do not mind this loss, you can: 
1) explicitly cast your data to float32, or 
2) set "allow_input_downcast=True" when calling "function".

Value: "array([[0.25, 0.25, 0.25, 0.25],
       [0.25, 0.25, 0.25, 0.25],
       [0.25, 0.25, 0.25, 0.25],
       [0.25, 0.25, 0.25, 0.25]])"

Specifying the dtype in np.ones(4, dtype=np.float32) does not fix this problem.

import pymc3 as pm
import numpy as np
str_dtype = "float32"
with pm.Model() as op_model:
    mat = pm.Dirichlet('A', np.ones(4, dtype=np.float32), shape=(4, 4), 
                       dtype=str_dtype)

Try setting theano floatX to float32 in the environment

That works. Thanks.

import pymc3 as pm
import numpy as np
import theano
theano.config.floatX = 'float32'

s = 4 
str_dtype = "float32"
with pm.Model() as op_model:
    mat = pm.Dirichlet('A', np.ones(s, dtype=np.float32),
                       shape=(s, s), dtype=str_dtype)

Great! It’s a bit disappointing but our present implementation imposes floatX to many parts of distributions, models and step methods, so there’s not so much freedom of movement to choose arbitrary dtypes for different parts of the model