I would like to use pm.Categorical to draw multiple samples, like this (drawing 30 independent values between 0-2 uniformly):
p = 1 / 3 * np.ones([30, 3])
pm.Categorical.dist(p=p).random(1)
>>> array([2, 1, 0, 0, 1, 2, 2, 2, 0, 0, 0, 1, 0, 0, 0, 2, 0, 1, 0, 0, 1, 0,
0, 2, 1, 1, 2, 1, 0, 1], dtype=int64)
But when I do the same thing inside a model, pm.Categorical only draws a single value:
p = 1 / 3 * np.ones([30, 3])
with pm.Model() as model:
p = T.as_tensor_variable(p)
choice = pm.Categorical('choice', p=p)
T.printing.Print('choice')(choice) # prints choice __str__ = 0
trace = pm.sample(100, tune=10, chains=1, cores=1)
I am trying to understand what is happening here? Thanks in advance!
Here are a few more things I tried:
Interestingly, when I pass an array of length 30 as observed, pm.Categorical draws multiple values again:
with pm.Model() as model:
p = pm.Uniform('p', lower=0, upper=1, shape=3)
p_tile = T.tile(p, 30).reshape([30, 3])
choice = pm.Categorical('choice', p=p_tile, observed=np.array([0, 0, 2] * 10))
T.printing.Print('choice')(choice) # prints choice __str__ = [0 0 2 0 0 2 0 0 2 0 0 2 0 0 2 0 0 2 0 0 2 0 0 2 0 0 2 0 0 2]
trace = pm.sample(5000, tune=500, chains=2, cores=1)
print(pm.summary(trace))
>>>
mean sd mc_error hpd_2.5 hpd_97.5
p__0 0.746065 0.196305 0.013013 0.321824 0.995407 # estimates of p correspond to observed
p__1 0.043839 0.051256 0.002698 0.000105 0.140824
p__2 0.460127 0.208897 0.012058 0.107075 0.903796
The same model without observations seems to be doing something quite different:
with pm.Model() as model:
p = pm.Uniform('p', lower=0, upper=1, shape=3)
p_tile = T.tile(p, 30).reshape([30, 3])
choice = pm.Categorical('choice', p=p_tile)
T.printing.Print('choice')(choice) # prints choice __str__ = 0
trace = pm.sample(500, tune=500, chains=1, cores=1)
print(pm.summary(trace))
>>>
mean sd mc_error hpd_2.5 hpd_97.5
choice 1.000000 0.000000 0.000000 1.000000 1.000000 # don't understand where these estimates
p__0 0.026698 0.031869 0.001404 0.000109 0.085003
p__1 0.743670 0.184351 0.009447 0.377753 0.999541
p__2 0.026075 0.028349 0.001499 0.000077 0.081620