Hi, I’m moving from pymc3 to pymc4.
the code was just working fine but when moving to pymc4, I had to change the loglikelihood functions and it stopped working.
Pymc4 implementation
If I run:
observed_label1 = aesara.shared(emission1)
observed_label2 = aesara.shared(emission2)
N_states = 4
def onestep(obs1, obs2, gamma_, theta, A, B, C, D):
alpha = gamma_ + at.log(theta) + at.tile( pm.logp(pm.Beta.dist(alpha=A, beta=B), obs1) + pm.logp(pm.Beta.dist(alpha=C, beta=D), obs2), (theta.shape[0], 1))
return pm.math.logsumexp(alpha, axis=0).T
T = len(emission1)
with pm.Model() as model:
Pt = pm.Dirichlet('P_transition',
a=np.ones((N_states, N_states)),
shape=(N_states, N_states))
A = pm.Exponential("A", lam=1, shape=(N_states,))
B = pm.Exponential("B", lam=1, shape=(N_states,))
C = pm.Exponential("C", lam=1, shape=(N_states,))
D = pm.Exponential("D", lam=1, shape=(N_states,))
aux = 0
gamma=pm.logp(pm.Beta.dist(alpha=A, beta=B), emission1[0]) + pm.logp(pm.Beta.dist(alpha=C, beta=D), emission2[0])
gamma = at.tile(gamma, (1, 1)).T
result, updates = aesara.scan(fn=onestep,
outputs_info=gamma,
sequences=[observed_label1, observed_label2],
non_sequences=[Pt, A, B, C, D],
n_steps=T-1)
aux += pm.math.logsumexp(result[-1])
obs_logp = pm.Potential('obs_logp', aux)
trace = pm.sample(1000, tune=1000, chains=2)
pm.traceplot(trace, combined=True)
I obtain the following error message:
---------------------------------------------------------------------------
ValueError Traceback (most recent call last)
<ipython-input-33-eef8df1369e1> in <module>
49 gamma = at.tile(gamma, (1, 1)).T
50
---> 51 result, updates = aesara.scan(fn=onestep,
52 outputs_info=gamma,
53 sequences=[observed_label1, observed_label2],
2 frames
/usr/local/lib/python3.8/dist-packages/aesara/scan/op.py in make_node(self, *inputs)
1131 )
1132 if inner_sitsot_out.ndim != outer_sitsot.ndim - 1:
-> 1133 raise ValueError(
1134 err_msg3
1135 % (
ValueError: When compiling the inner function of scan (the function called by scan in each of its iterations) the following error has been encountered: The initial state (`outputs_info` in scan nomenclature) of variable IncSubtensor{Set;:int64:}.0 (argument number 2) has 3 dimension(s), while the corresponding variable in the result of the inner function of scan (`fn`) has 1 dimension(s) (it should be one less than the initial state). For example, if the inner function of scan returns a vector of size d and scan uses the values of the previous time-step, then the initial state in scan should be a matrix of shape (1, d). The first dimension of this matrix corresponds to the number of previous time-steps that scan uses in each of its iterations. In order to solve this issue if the two varialbe currently have the same dimensionality, you can increase the dimensionality of the variable in the initial state of scan by using dimshuffle or shape_padleft.
pymc3 implementation (working):
observed_label1 = theano.shared(emission1)
observed_label2 = theano.shared(emission2)
N_states = theta.shape[0]
def onestep(obs1, obs2, gamma_, theta, A, B, C, D):
alpha = gamma_ + tt.log(theta) + tt.tile(pm.Beta.dist(alpha=A, beta=B).logp(obs1) + pm.Beta.dist(alpha=C, beta=D).logp(obs2), (theta.shape[0], 1))
return pm.math.logsumexp(alpha, axis=0).T
T = len(emission1)
with pm.Model() as model:
Pt = pm.Dirichlet('P_transition',
a=np.ones((N_states, N_states)),
shape=(N_states, N_states))
A = pm.Exponential("A", lam=1, shape=(N_states,))
B = pm.Exponential("B", lam=1, shape=(N_states,))
C = pm.Exponential("C", lam=1, shape=(N_states,))
D = pm.Exponential("D", lam=1, shape=(N_states,))
aux = 0
gamma = pm.Beta.dist(alpha=A, beta=B).logp(emission1[0]) + pm.Beta.dist(alpha=C, beta=D).logp(emission2[0])
gamma = tt.tile(gamma, (1, 1)).T
result, updates = theano.scan(fn=onestep,
outputs_info=gamma,
sequences=[observed_label1, observed_label1],
non_sequences=[Pt, A, B, C, D],
n_steps=T-1)
aux += pm.math.logsumexp(result[-1])
obs_logp = pm.Potential('obs_logp', aux)
trace = pm.sample(1000, tune=1000, chains=3)
#map_est = pm.find_MAP()
pm.traceplot(trace, combined=True)
Any ideas?