Hi,
I wanted to see whether I could use Pymc to generate synthetic datasets, and I guess I don’t fully understand (or fail to properly use) switch
and Discrete
.
I thought the following example would result in generated data points, where the variable dependent_rules
would deterministically be 0
, when the root
point would be 0
. However, that is not the case as it seems
import pymc3 as pm
from pymc3.math import switch
with pm.Model() as model:
root = pm.DiscreteUniform('root', lower=0, upper=6)
dependent = pm.DiscreteUniform('dependent', lower=0, upper=10)
dependent_with_change = switch((root == 0.0) * 1.0, 0, dependent)
dependent_rules = pm.Deterministic('dependent_rules', dependent_with_change)
trace = pm.sample(100)
print(dir(trace))
print(trace.nchains)
w = trace.get_values('root', chains=1)
d = trace.get_values('dependent_rules', chains=1)
print(list(x for x in zip(w, d) if x[0] == 0))
This results in [(0, 2), (0, 6), (0, 4), (0, 0), (0, 1), (0, 1), (0, 1), (0, 1), (0, 1), (0, 9), (0, 9), (0, 9), (0, 9), (0, 9), (0, 9), (0, 4), (0, 4), (0, 2)]
on my machine.
Am I completely misunderstanding the purpose of switch here? Or is this a problem with the conditional clause (root == 0.0) * 1.0
?