Skip to content

Commit 58f1405

Browse files
committed
RNN update
1 parent 92a0cdd commit 58f1405

File tree

2 files changed

+36
-5
lines changed

2 files changed

+36
-5
lines changed

assets/rnn/lstm_highway.png

88.6 KB
Loading

rnn.md

Lines changed: 36 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -224,13 +224,16 @@ For the back propagation, Let's examine how the output at the very last timestep
224224
The partial derivative of $$h_t$$ with respect to $$h_{t-1}$$ is written as:
225225
$$ \frac{\partial h_t}{\partial h_{t-1}} = tanh^{'}(W_{hh}h_{t-1} + W_{xh}x_t)W_{hh} $$
226226

227-
We update the weights $$W$$ by getting the derivative of the loss at the very last time step $$L_{t}$$ with respect to $$W$$.
227+
We update the weights $$W_{hh}$$ by getting the derivative of the loss at the very last time step $$L_{t}$$ with respect to $$W_{hh}$$.
228+
229+
$$
228230
\begin{aligned}
229-
\frac{\partial L_{t}}{\partial W} = \frac{\partial L_{t}}{\partial h_{t}} \frac{\partial h_{t}}{\partial h_{t-1} } \dots \frac{\partial h_{1}}{\partial W} } \\
230-
= \frac{\partial L_{t}}{\partial h_{t}}(\prod_{t=2}^{T} \frac{\partial h_{t}}{\partial h_{t-1}})\frac{\partial h_{1}}{\partial W} \\
231-
= \frac{\partial L_{t}}{\partial h_{t}}(\prod_{t=2}^{T} tanh^{'}(W_{hh}h_{t-1} + W_{xh}x_t)W_{hh}^{T-1})\frac{\partial h_{1}}{\partial W}
231+
\frac{\partial L_{t}}{\partial W_{hh}} = \frac{\partial L_{t}}{\partial h_{t}} \frac{\partial h_{t}}{\partial h_{t-1} } \dots \frac{\partial h_{1}}{\partial W_{hh}} \\
232+
= \frac{\partial L_{t}}{\partial h_{t}}(\prod_{t=2}^{T} \frac{\partial h_{t}}{\partial h_{t-1}})\frac{\partial h_{1}}{\partial W_{hh}} \\
233+
= \frac{\partial L_{t}}{\partial h_{t}}(\prod_{t=2}^{T} tanh^{'}(W_{hh}h_{t-1} + W_{xh}x_t)W_{hh}^{T-1})\frac{\partial h_{1}}{\partial W_{hh}} \\
232234
\end{aligned}
233-
$$
235+
$$
236+
234237
* **Vanishing gradient:** We see that $$tanh^{'}(W_{hh}h_{t-1} + W_{xh}x_t)$$ will almost always be less than 1 because tanh is always between negative one and one. Thus, as $$t$$ gets larger (i.e. longer timesteps), the gradient ($$\frac{\partial L_{t}}{\partial W} $$) will descrease in value and get close to zero.
235238
This will lead to vanishing gradient problem, where gradients at future time steps rarely impact gradients at the very first time step. This is problematic when we model long sequence of inputs because the updates will be extremely slow.
236239

@@ -282,3 +285,31 @@ $$
282285
<div class="fig figcenter">
283286
<img src="/assets/rnn/lstm_mformula_2.png" width="40%" >
284287
</div>
288+
289+
where $$\odot$$ is an element-wise Hadamard product. $$g_t$$ in the above formulas is an intermediary
290+
calculation cache that's later used with $$o$$ gate in the above formulas.
291+
292+
Since all $$f, i, o$$ gate vector values range from 0 to 1, because they were squashed by sigmoid function
293+
$$\sigma$$, when multiplied element-wise, we can see that:
294+
295+
* Forget gate $$f_t$$ at time step $$t$$ controls how much information needs to be "removed" from the previous cell state $$c_{t-1}$$
296+
* Input gate $$i_t$$ at time step $$t$$ controls how much information needs to be "added" to the next cell state $$c_t$$ from previous hidden state $$h_{t-1}$$ and input $$x_t$$
297+
* Output gate $$o_t$$ at time step $$t$$ controls how much information needs to be "shown" as output in the current hidden state $$h_t$$
298+
299+
The key idea of LSTM is the cell state, the horizontal line running through between recurrent timesteps. You can imagine the cell
300+
state to be some kind of highway of information passing through straight down the entire chain, with
301+
only some minor linear interactions. With the formulation above, it's easy for information to just flow
302+
along this highway (Figure 5). This greatly fixes the gradient vanishing/exploding problem we have outlined above.
303+
304+
<div class="fig figcenter fighighlight">
305+
<img src="/assets/rnn/lstm_highway.png" width="70%" >
306+
<div class="figcaption">Figure 5. LSTM cell state highway.</div>
307+
</div>
308+
309+
LSTM architecture makes it easier for the RNN to preserve information over many recurrent time steps. For example,
310+
if the forget gate is set to 1, and the input gate is set to 0, then the infomation of the cell state
311+
will always be preserved over many recurrent time steps. For a Vanilla RNN, in contrast, it's much harder to preserve information
312+
in hidden states in recurrent time steps by just making use of a single weight matrix.
313+
314+
LSTMs do not guarantee that there is no vanishing/exploding gradient problems, but it does provide an
315+
easier way for the model to learn long-distance dependencies.

0 commit comments

Comments
 (0)