Saturday, June 30, 2018

What Keras Source Code Tells Us about LSTM Processing

Recurrent Neural Network (RNN) is a type of neural network which can save its states and process input sequences. With this recurrent architecture, RNN can be used for applications such as language analysis, time series prediction and speech recognition. Sometimes it can be challenging to understand how RNN processes it inputs. I met with this problem when I started to learn RNN and reading Keras source code helps to clarify this puzzle.

Under Keras, the input to RNN has three dimensions (batch_size, timesteps, input_dim). batch_size controlled how often the network is updated. The network is updated once after batch_size sequences are processed. For example, if there are 256 sequences and batch_size is set as 32, that means after processing the whole set, the network will be updated 8 times. The parameter of timesteps is how many times a RNN cell runs. In each run, it starts from saved states and generates new states and outputs. input_dim is the number of features contained in an input. Another important parameter is the number of RNN units. For example, the instruction below creates 4 units of LSTM while LSTM is a popular type of RNN. Here, timesteps for input is set as 1 and input_dim is set as 2. This instruction does not define batch_size and it means that batch_size will be defined later. From now on, we name num_units as the number of units.

model.add(LSTM(4, input_shape=(1, 2)))

Keras RNN source code can be found in here. Reading the source code reveals how RNN processes its input. How LSTM operates can be found in call() function of class LSTMCell. This part of code (shown below) are only related to input_dim and num_unit but not other parameters. For instance, the dimension of self.kernel_i/self.kernel_f/self.kernel_c/self.kernel_o is (input_dim, num_units). inputs is a vector of input_dim elements. Therefore, after dot product between kernel and inputs, input_dim does not exist any more and x_i/x_f/x_c/x_o become vectors with the length of num_units, which match with the dimensionality of bias. Another observation is that in this stage, each LSTM unit works independently without interference from other units.

            if 0 < self.dropout < 1.:
                inputs_i = inputs * dp_mask[0]
                inputs_f = inputs * dp_mask[1]
                inputs_c = inputs * dp_mask[2]
                inputs_o = inputs * dp_mask[3]
            else:
                inputs_i = inputs
                inputs_f = inputs
                inputs_c = inputs
                inputs_o = inputs
            x_i = K.dot(inputs_i, self.kernel_i)
            x_f = K.dot(inputs_f, self.kernel_f)
            x_c = K.dot(inputs_c, self.kernel_c)
            x_o = K.dot(inputs_o, self.kernel_o)
            if self.use_bias:
                x_i = K.bias_add(x_i, self.bias_i)
                x_f = K.bias_add(x_f, self.bias_f)
                x_c = K.bias_add(x_c, self.bias_c)
                x_o = K.bias_add(x_o, self.bias_o)

            if 0 < self.recurrent_dropout < 1.:
                h_tm1_i = h_tm1 * rec_dp_mask[0]
                h_tm1_f = h_tm1 * rec_dp_mask[1]
                h_tm1_c = h_tm1 * rec_dp_mask[2]
                h_tm1_o = h_tm1 * rec_dp_mask[3]
            else:
                h_tm1_i = h_tm1
                h_tm1_f = h_tm1
                h_tm1_c = h_tm1
                h_tm1_o = h_tm1
            i = self.recurrent_activation(x_i + K.dot(h_tm1_i,
                                                      self.recurrent_kernel_i))
            f = self.recurrent_activation(x_f + K.dot(h_tm1_f,
                                                      self.recurrent_kernel_f))
            c = f * c_tm1 + i * self.activation(x_c + K.dot(h_tm1_c,
                                                            self.recurrent_kernel_c))
            o = self.recurrent_activation(x_o + K.dot(h_tm1_o,
                                                      self.recurrent_kernel_o))

Code above has nothing to do with time concept, which is signature of RNN. So how is time getting introduced into RNN implementation? The class of RNN(layer) defined in the same script of recurrent.py is the base class for RNN. In call() function of this class, we can see a call to backend function K.rnn where timesteps is defined as input_length to this function. We call it backend function since Keras is based on tensorflow or Theano, and rnn() function is provided not in Keras but in tensorflow or Theano library depending on which one Keras replies. Although having not yet checked the source code of rnn function, I believe what happened is that in the backend function, the above logic of LSTMCell is called timesteps times to generate the final output.

        last_output, outputs, states = K.rnn(step,
                                             inputs,
                                             initial_state,
                                             constants=constants,
                                             go_backwards=self.go_backwards,
                                             mask=mask,
                                             unroll=self.unroll,
                                             input_length=timesteps)

For reference purpose, equations for LSTM are provided here:
Input gate: \(i_{t}=\sigma(W^{(i)}x_{t}+U^{(i)}h_{t-1})\)
Forget: \(f_{t}=\sigma(W^{(f)}x_{t}+U^{(f)}h_{t-1})\)
Output: \(o_{t}=\sigma(W^{(o)}x_{t}+U^{(o)}h_{t-1})\)
New memory cell: \(\tilde{c}_{t}=tanh(W^{(c)}x_{t}+U^{(c)}h_{t-1})\)
Final memory cell: \(c_{t}=f_{t}\circ c_{t-1}+i_{t}\circ \tilde{c}_{t}\)
Final hidden state: \(h_{t}=o_{t}\circ tanh(c_{t})\)
whereas \(\circ\) means pointwise operation


In this blog, I summarized what I have found in Keras source code. Hopefully it can be useful to you.

The largest difference between ordinary RNN versus LSTM is that LSTM has an extra hidden state.

No comments:

Post a Comment