Switch and deterministic rules

Hi,

I wanted to see whether I could use Pymc to generate synthetic datasets, and I guess I don’t fully understand (or fail to properly use) switch and Discrete.

I thought the following example would result in generated data points, where the variable dependent_rules would deterministically be 0, when the root point would be 0. However, that is not the case as it seems

import pymc3 as pm
from pymc3.math import switch


with pm.Model() as model:
    root = pm.DiscreteUniform('root', lower=0, upper=6)
    dependent = pm.DiscreteUniform('dependent', lower=0, upper=10)
    dependent_with_change = switch((root == 0.0) * 1.0, 0, dependent)
    dependent_rules = pm.Deterministic('dependent_rules', dependent_with_change)
    trace = pm.sample(100)


print(dir(trace))
print(trace.nchains)
w = trace.get_values('root', chains=1)
d = trace.get_values('dependent_rules', chains=1)
print(list(x for x in zip(w, d) if x[0] == 0))

This results in [(0, 2), (0, 6), (0, 4), (0, 0), (0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 9), (0, 9), (0, 9), (0, 9), (0, 9), (0, 9), (0, 4), (0, 4), (0, 2)] on my machine.

Am I completely misunderstanding the purpose of switch here? Or is this a problem with the conditional clause (root == 0.0) * 1.0?

It should be:

import theano.tensor as tt 
with pm.Model() as model: 
    ...
    dependent_with_change = tt.switch(tt.eq(root, 0), 0, dependent) 
    ...

thanks a lot!