Hello everyone
Following these very nice blogpost and talk I recently tried to (clumsily) implement a minimal version of some Keras layers using PyMC3.
Densely connected and embedding layers seems to work ok-ish but I hit a wall when trying with the LSTM. A simple one-step-ahead forecast model compiles without issues but it kills the jupyter kernell when trying to approximate the posterior with ADVI.
The LSTM layer class
class LSTM(_AbstractLayer):
"""
"""
def __init__(self, shapes, units, layer_name, prior,
return_sequences=False, weight_init_func='gaussian',
bias_init_func='gaussian', **priors_kwargs):
"""
"""
self.units = units
self.return_sequences = return_sequences
# ugly shape specification
self.shape_batch = shapes[0]
self.shape_seq = shapes[1]
self.shape_feat = shapes[2]
self.shape_z = self.shape_feat + units
self.layer_name = layer_name
self.weight_init_func = getattr(self, weight_init_func)
self.bias_init_func = getattr(self, bias_init_func)
self.prior = prior
self.priors_kwargs = priors_kwargs
def __LSTM_init(self):
"""
"""
floatX = theano.config.floatX
fw_init = self.weight_init_func(
shape=(self.shape_z, self.units)
).astype(floatX)
fb_init = self.bias_init_func(
shape=self.units
)
iw_init = self.weight_init_func(
shape=(self.shape_z, self.units)
).astype(floatX)
ib_init = self.bias_init_func(
shape=self.units
)
ow_init = self.weight_init_func(
shape=(self.shape_z, self.units)
).astype(floatX)
ob_init = self.bias_init_func(
shape=self.units
)
cw_init = self.weight_init_func(
shape=(self.shape_z, self.units)
).astype(floatX)
cb_init = self.bias_init_func(
shape=self.units
)
return fw_init, fb_init, iw_init, ib_init, \
ow_init, ob_init, cw_init, cb_init
def __LSTM_weights(self, *args):
"""
"""
fw_init, fb_init, iw_init, ib_init, ow_init, ob_init, cw_init, \
cb_init = args[0]
fw = self.prior(
f'forget_weights_{self.layer_name}',
shape=(self.shape_z, self.units),
testval=fw_init,
**self.priors_kwargs
)
fb = self.prior(
f'forget_biases_{self.layer_name}',
shape=self.units,
testval=fb_init,
**self.priors_kwargs
)
iw = self.prior(
f'input_weights_{self.layer_name}',
shape=(self.shape_z, self.units),
testval=iw_init,
**self.priors_kwargs
)
ib = self.prior(
f'input_biases_{self.layer_name}',
shape=self.units,
testval=ib_init,
**self.priors_kwargs
)
ow = self.prior(
f'output_weights_{self.layer_name}',
shape=(self.shape_z, self.units),
testval=ow_init,
**self.priors_kwargs
)
ob = self.prior(
f'output_biases_{self.layer_name}',
shape=self.units,
testval=ob_init,
**self.priors_kwargs
)
cw = self.prior(
f'cell_weights_{self.layer_name}',
shape=(self.shape_z, self.units),
testval=cw_init,
**self.priors_kwargs
)
cb = self.prior(
f'cell_biases_{self.layer_name}',
shape=self.units,
testval=cb_init,
**self.priors_kwargs
)
return fw, fb, iw, ib, ow, ob, cw, cb
def build(self, input_tensor):
"""
"""
hidden_states = []
cell_states = []
h = pm.Normal(
'h',
mu=0,
sigma=1,
shape=(self.shape_batch, self.units)
)
c = pm.Normal(
'c',
mu=0,
sigma=1,
shape=(self.shape_batch, self.units)
)
fw, fb, iw, ib, ow, ob, cw, cb = self.__LSTM_weights(
self.__LSTM_init()
)
def cell_ops(input_tensor, h, c, fw, fb, iw, ib, ow, ob, cw, cb):
"""
"""
fz = tt.concatenate([input_tensor, h], axis=-1)
iz = theano.clone(fz)
cz = theano.clone(fz)
oz = theano.clone(fz)
# forget
forget_gate = pm.math.sigmoid(
pm.math.dot(fz, fw) + fb
)
c *= forget_gate
# input
input_gate_1 = pm.math.sigmoid(
pm.math.dot(iz, iw) + ib
)
input_gate_2 = pm.math.tanh(
pm.math.dot(cz, cw) + cb
)
input_cell = input_gate_1 * input_gate_2
c += input_cell
# output
output_gate = pm.math.sigmoid(
pm.math.dot(oz, ow) + ob
)
co = pm.math.tanh(theano.clone(c))
h = output_gate * co
return h, c
for i in range(self.shape_seq):
h, c, = cell_ops(
input_tensor[:, i, :],
h=h,
c=c,
fw=fw,
fb=fb,
iw=iw,
ib=ib,
ow=ow,
ob=ob,
cw=cw,
cb=cb
)
hidden_states.append(h)
cell_states.append(c)
if self.return_sequences:
raise NotImplementedError
else:
return hidden_states[-1], cell_states[-1]
The notebook I am using for testing.
I am pretty sure I am doing something that is horribly wrong but I am not able to figure out what it is.