Help with LSTM implementation

Hello everyone :slight_smile:

Following these very nice blogpost and talk I recently tried to (clumsily) implement a minimal version of some Keras layers using PyMC3.

Densely connected and embedding layers seems to work ok-ish but I hit a wall when trying with the LSTM. A simple one-step-ahead forecast model compiles without issues but it kills the jupyter kernell when trying to approximate the posterior with ADVI.

The LSTM layer class

class LSTM(_AbstractLayer):
    """
    """
    def __init__(self, shapes, units, layer_name, prior,
                 return_sequences=False, weight_init_func='gaussian',
                 bias_init_func='gaussian', **priors_kwargs):
        """
        """
        self.units = units
        self.return_sequences = return_sequences

        # ugly shape specification
        self.shape_batch = shapes[0]
        self.shape_seq = shapes[1]
        self.shape_feat = shapes[2]
        self.shape_z = self.shape_feat + units

        self.layer_name = layer_name
        self.weight_init_func = getattr(self, weight_init_func)
        self.bias_init_func = getattr(self, bias_init_func)
        self.prior = prior
        self.priors_kwargs = priors_kwargs

    def __LSTM_init(self):
        """
        """
        floatX = theano.config.floatX

        fw_init = self.weight_init_func(
            shape=(self.shape_z, self.units)
        ).astype(floatX)
        fb_init = self.bias_init_func(
            shape=self.units
        )

        iw_init = self.weight_init_func(
            shape=(self.shape_z, self.units)
        ).astype(floatX)
        ib_init = self.bias_init_func(
            shape=self.units
        )

        ow_init = self.weight_init_func(
            shape=(self.shape_z, self.units)
        ).astype(floatX)
        ob_init = self.bias_init_func(
            shape=self.units
        )

        cw_init = self.weight_init_func(
            shape=(self.shape_z, self.units)
        ).astype(floatX)
        cb_init = self.bias_init_func(
            shape=self.units
        )

        return fw_init, fb_init, iw_init, ib_init, \
            ow_init, ob_init, cw_init, cb_init

    def __LSTM_weights(self, *args):
        """
        """
        fw_init, fb_init, iw_init, ib_init, ow_init, ob_init, cw_init, \
            cb_init = args[0]

        fw = self.prior(
            f'forget_weights_{self.layer_name}',
            shape=(self.shape_z, self.units),
            testval=fw_init,
            **self.priors_kwargs
        )
        fb = self.prior(
            f'forget_biases_{self.layer_name}',
            shape=self.units,
            testval=fb_init,
            **self.priors_kwargs
        )

        iw = self.prior(
            f'input_weights_{self.layer_name}',
            shape=(self.shape_z, self.units),
            testval=iw_init,
            **self.priors_kwargs
        )
        ib = self.prior(
            f'input_biases_{self.layer_name}',
            shape=self.units,
            testval=ib_init,
            **self.priors_kwargs
        )

        ow = self.prior(
            f'output_weights_{self.layer_name}',
            shape=(self.shape_z, self.units),
            testval=ow_init,
            **self.priors_kwargs
        )
        ob = self.prior(
            f'output_biases_{self.layer_name}',
            shape=self.units,
            testval=ob_init,
            **self.priors_kwargs
        )

        cw = self.prior(
            f'cell_weights_{self.layer_name}',
            shape=(self.shape_z, self.units),
            testval=cw_init,
            **self.priors_kwargs
        )
        cb = self.prior(
            f'cell_biases_{self.layer_name}',
            shape=self.units,
            testval=cb_init,
            **self.priors_kwargs
        )

        return fw, fb, iw, ib, ow, ob, cw, cb

    def build(self, input_tensor):
        """
        """
        hidden_states = []
        cell_states = []

        h = pm.Normal(
            'h',
            mu=0,
            sigma=1,
            shape=(self.shape_batch, self.units)
        )
        c = pm.Normal(
            'c',
            mu=0,
            sigma=1,
            shape=(self.shape_batch, self.units)
        )

        fw, fb, iw, ib, ow, ob, cw, cb = self.__LSTM_weights(
            self.__LSTM_init()
        )

        def cell_ops(input_tensor, h, c, fw, fb, iw, ib, ow, ob, cw, cb):
            """
            """
            fz = tt.concatenate([input_tensor, h], axis=-1)
            iz = theano.clone(fz)
            cz = theano.clone(fz)
            oz = theano.clone(fz)

            # forget
            forget_gate = pm.math.sigmoid(
                pm.math.dot(fz, fw) + fb
            )
            c *= forget_gate

            # input
            input_gate_1 = pm.math.sigmoid(
                pm.math.dot(iz, iw) + ib
            )
            input_gate_2 = pm.math.tanh(
                pm.math.dot(cz, cw) + cb
            )
            input_cell = input_gate_1 * input_gate_2
            c += input_cell

            # output
            output_gate = pm.math.sigmoid(
                pm.math.dot(oz, ow) + ob
            )
            co = pm.math.tanh(theano.clone(c))
            h = output_gate * co

            return h, c

        for i in range(self.shape_seq):

            h, c, = cell_ops(
                    input_tensor[:, i, :],
                    h=h,
                    c=c,
                    fw=fw,
                    fb=fb,
                    iw=iw,
                    ib=ib,
                    ow=ow,
                    ob=ob,
                    cw=cw,
                    cb=cb
                )

            hidden_states.append(h)
            cell_states.append(c)

        if self.return_sequences:
            raise NotImplementedError
        else:
            return hidden_states[-1], cell_states[-1]

The notebook I am using for testing.

I am pretty sure I am doing something that is horribly wrong but I am not able to figure out what it is.