「ゼロから作るDeep Learning ❷ ―自然言語処理編」の第6章前半を読んだので概要をメモ
ゼロから作るDeep Learning ❷ ―自然言語処理編
第6章ゲート付きRNN
「RNN」と言った時、それが指すレイヤは前章のRNNではなく、LSTMであることが多い。
第5章までに紹介した「RNN」は、「シンプルなRNN」、「エルマン」と呼ばれる。
「シンプルなRNN」は、あまり性能が良くなく、時系列データの長期依存関係を上手く学習できない。一方で、LSTMやGRUは「ゲート」と呼ばれる仕組みが加わっており、時系列データの長期的な依存関係を学習することができるようになっている。
RNNの問題点-勾配消失もしくは勾配爆発
「シンプルなRNN」が、時系列データの長期依存関係を上手く学習できない理由は、BPTT(Backpropagation Through Time)において、勾配消失もしくは勾配爆発が起こることに原因がある。
![](https://take-tech-engineer.com/wp-content/uploads/2022/05/abe06b778b47058cd34398672cbc682e-1024x285.jpg)
(図)RNNレイヤの時間方向の勾配の伝播の、時間方向の勾配に着目すると、「tanh」、「+」、「MatMul」の演算を行なうことが分かる。
+の逆伝播
勾配をそのまま流すだけのため、値は変わらない。
tanhの逆伝播
![](https://take-tech-engineer.com/wp-content/uploads/2022/05/4e1703175ea0774a7b6f94c11a393e9b-1024x768.jpg)
tanhの微分は1.0以下となるため、tanhを通るたびに、値がどんどん小さくなっていくため、勾配は弱められる。
※活性化関数にReLUに変えることで勾配消失を抑えることが期待できる
MatMul(行列の積)の逆伝播
![](https://take-tech-engineer.com/wp-content/uploads/2022/05/cf2ffed09d08fccdfe6fb292de157ee4.jpg)
単純化するために、tanhノードを無視したRNNレイヤの逆伝播の勾配は、図のように「MatMul」の演算によってのみ変化することがわかる。dhWhによる行列の積によって勾配が計算され、時系列データの時間サイズ分だけ繰り返される。
この行列の積には、毎回同じ重みであるWhが使われるため、勾配の大きさは繰り返すたびに、指数的に増加/減少する。
勾配爆発
勾配の大きさが発散し、最終的にオーバーフローを起こして、NaNのような値が発生する。
勾配消失
勾配の大きさが小さくなりすぎることで、重みパラメータが更新されずに、長期的な依存関係を学習できなくなる
※行列の「特異値」の最大値が1より大きい場合は、指数的に増加する可能性が高く。特異値が1より小さい場合は、指数的に減少する。
勾配クリッピング
勾配爆発対策の定番の手法で、非常に簡単な方法。
![\begin{align*}if\ ||\hat{g}|| \geqq threshold\\\hat{g}=\frac{threshold}{||\hat{g}||}\hat{g}\end{align*}](https://texclip.marutank.net/render.php/texclip20220513210911.png?s=%5Cbegin%7Balign*%7D%0Aif%5C%20%7C%7C%5Chat%7Bg%7D%7C%7C%20%5Cgeqq%20threshold%5C%5C%0A%5Chat%7Bg%7D%3D%5Cfrac%7Bthreshold%7D%7B%7C%7C%5Chat%7Bg%7D%7C%7C%7D%5Chat%7Bg%7D%0A%5Cend%7Balign*%7D&f=c&r=300&m=p&b=f&k=f)
gは、ニューラルネットワークで使われるすべてのパラメータに対する勾配を一つにまとめたもの。勾配のL2ノルムがしきい値を超えた場合に、上記式のように勾配を修正する。
勾配消失とLSTM
「シンプルなRNN」では、勾配消失の対策が難しく、アーキテクチャを根本から変える必要がある。ここで登場するのがゲート付きRNNで、代表格にLSTMとGRUがある。
LSTM
![](https://take-tech-engineer.com/wp-content/uploads/2022/05/f9c9dd97037c8aa7a30a152f74d1a776.jpg)
LSTMの全体構成は図のようになっており、「シンプルRNN」と比較すると、構造は複雑であるが、外部的にはcが追加されたのみであること分かる。このcは記憶セルと呼ばれ長期的な依存関係を学習するのに役立つ。
cは長期的な記憶、hは短期的な記憶を意味している。
LSTMのキーアイディアcとゲート
cは長期的な記憶を保存するセルで、乗算部で長期的な記憶をどの程度削除するか決め、加算部で長期的な記憶にどの程度新しい情報を追加するか決める。
![](https://take-tech-engineer.com/wp-content/uploads/2022/05/6b149a07a04920a207898938121b6239.jpg)
ゲートは選択的に情報を通す仕組みで、シグモイドニューラルネットワーク層と一点の乗算により構成されている。
![](https://take-tech-engineer.com/wp-content/uploads/2022/05/db2151515f41253ba26a5f9a27b6581e.jpg)
シグモイド関数は0~1までの数値を出力するので、0であれば情報を通さず、1であれば情報を通される。
![\begin{align*}y = \frac{1}{1+exp(-x)}\end{align*}](https://texclip.marutank.net/render.php/texclip20220513211449.png?s=%5Cbegin%7Balign*%7D%0Ay%20%3D%20%5Cfrac%7B1%7D%7B1%2Bexp(-x)%7D%0A%5Cend%7Balign*%7D&f=c&r=300&m=p&b=f&k=f)
![](https://take-tech-engineer.com/wp-content/uploads/2022/05/1bcf493e327f41ba96b392652239996f-1024x768.png)
forgetゲート
forgetゲートでは、記憶セルから不要な記憶を忘れる働きをする。
![](https://take-tech-engineer.com/wp-content/uploads/2022/05/Forget-gate-1.jpg)
![\begin{align*}\mathbf{f} = \sigma(\mathbf{x}{t}\mathbf{W}{\mathbf{x}}^{(f)}+\mathbf{h}{t-1}\mathbf{W}{\mathbf{h}}^{(f)}+\mathbf{b}^{f})\end{align*}](https://texclip.marutank.net/render.php/texclip20220513211818.png?s=%5Cbegin%7Balign*%7D%0A%5Cmathbf%7Bf%7D%20%3D%20%5Csigma(%5Cmathbf%7Bx%7D%7Bt%7D%5Cmathbf%7BW%7D%7B%5Cmathbf%7Bx%7D%7D%5E%7B(f)%7D%2B%5Cmathbf%7Bh%7D%7Bt-1%7D%5Cmathbf%7BW%7D%7B%5Cmathbf%7Bh%7D%7D%5E%7B(f)%7D%2B%5Cmathbf%7Bb%7D%5E%7Bf%7D)%0A%5Cend%7Balign*%7D&f=c&r=300&m=p&b=f&k=f)
forgetゲートの出力fは上記式によって求められ、ct-1との要素の積によってctが求められる。
新しい記憶セルとinputゲート
forgetゲートによって、前時刻の記憶セルから忘れるべき情報が削除されたが、このままだと記憶セルは忘れることしかできない。このため、新しく覚えるべき情報を記憶セルに追加する。
![](https://take-tech-engineer.com/wp-content/uploads/2022/05/Candidate-memory.jpg)
![\begin{align*}\mathbf{g} = tanh(\mathbf{x}{t}\mathbf{W}{\mathbf{x}}^{(g)}+\mathbf{h}{t-1}\mathbf{W}{\mathbf{h}}^{(g)}+\mathbf{b}^{g})\end{align*}](https://texclip.marutank.net/render.php/texclip20220513211952.png?s=%5Cbegin%7Balign*%7D%0A%5Cmathbf%7Bg%7D%20%3D%20tanh(%5Cmathbf%7Bx%7D%7Bt%7D%5Cmathbf%7BW%7D%7B%5Cmathbf%7Bx%7D%7D%5E%7B(g)%7D%2B%5Cmathbf%7Bh%7D%7Bt-1%7D%5Cmathbf%7BW%7D%7B%5Cmathbf%7Bh%7D%7D%5E%7B(g)%7D%2B%5Cmathbf%7Bb%7D%5E%7Bg%7D)%0A%5Cend%7Balign*%7D&f=c&r=300&m=p&b=f&k=f)
上記式より、新しく覚えるべき情報の候補gが算出される。ゲートではなく、新しい情報を生成するため、活性化関数はtanhが使われる。
ここで、gはそのまま記憶セルct-1に追加されず、inputゲートによって追加する情報の取捨選択が行われ、ct-1に追加される。
![](https://take-tech-engineer.com/wp-content/uploads/2022/05/Input-gate.jpg)
![\begin{align*}\mathbf{i} = \sigma(\mathbf{x}{t}\mathbf{W}{\mathbf{x}}^{(i)}+\mathbf{h}{t-1}\mathbf{W}{\mathbf{h}}^{(i)}+\mathbf{b}^{i})\\\end{align*}](https://texclip.marutank.net/render.php/texclip20220513212225.png?s=%5Cbegin%7Balign*%7D%0A%5Cmathbf%7Bi%7D%20%3D%20%5Csigma(%5Cmathbf%7Bx%7D%7Bt%7D%5Cmathbf%7BW%7D%7B%5Cmathbf%7Bx%7D%7D%5E%7B(i)%7D%2B%5Cmathbf%7Bh%7D%7Bt-1%7D%5Cmathbf%7BW%7D%7B%5Cmathbf%7Bh%7D%7D%5E%7B(i)%7D%2B%5Cmathbf%7Bb%7D%5E%7Bi%7D)%5C%5C%0A%5Cend%7Balign*%7D&f=c&r=300&m=p&b=f&k=f)
![\begin{align*}\mathbf{c}{t} = \mathbf{f} \odot \mathbf{c}{t-1} + \mathbf{g} \odot \mathbf{i}\end{align*}](https://texclip.marutank.net/render.php/texclip20220513212244.png?s=%5Cbegin%7Balign*%7D%0A%5Cmathbf%7Bc%7D%7Bt%7D%20%3D%20%5Cmathbf%7Bf%7D%20%5Codot%20%5Cmathbf%7Bc%7D%7Bt-1%7D%20%2B%20%5Cmathbf%7Bg%7D%20%5Codot%20%5Cmathbf%7Bi%7D%0A%5Cend%7Balign*%7D&f=c&r=300&m=p&b=f&k=f)
outputゲート
記憶セルctにtanh関数を適用することで隠れ状態htを出力するが、ここでもゲートを適用することで、tanh(ct)が次時刻の隠れ状態としてどれだけ重要か調整される。
![](https://take-tech-engineer.com/wp-content/uploads/2022/05/output-gate.jpg)
![\begin{align*}\mathbf{o} = \sigma(\mathbf{x}{t}\mathbf{W}{\mathbf{x}}^{(o)}+\mathbf{h}{t-1}\mathbf{W}{\mathbf{h}}^{(o)}+\mathbf{b}^{o})\end{align*}](https://texclip.marutank.net/render.php/texclip20220513212438.png?s=%5Cbegin%7Balign*%7D%0A%5Cmathbf%7Bo%7D%20%3D%20%5Csigma(%5Cmathbf%7Bx%7D%7Bt%7D%5Cmathbf%7BW%7D%7B%5Cmathbf%7Bx%7D%7D%5E%7B(o)%7D%2B%5Cmathbf%7Bh%7D%7Bt-1%7D%5Cmathbf%7BW%7D%7B%5Cmathbf%7Bh%7D%7D%5E%7B(o)%7D%2B%5Cmathbf%7Bb%7D%5E%7Bo%7D)%0A%5Cend%7Balign*%7D&f=c&r=300&m=p&b=f&k=f)
![\begin{align*}\mathbf{h}{t} = \mathbf{o} \odot tanh(\mathbf{c}{t})\end{align*}](https://texclip.marutank.net/render.php/texclip20220513212457.png?s=%5Cbegin%7Balign*%7D%0A%5Cmathbf%7Bh%7D%7Bt%7D%20%3D%20%5Cmathbf%7Bo%7D%20%5Codot%20tanh(%5Cmathbf%7Bc%7D%7Bt%7D)%0A%5Cend%7Balign*%7D&f=c&r=300&m=p&b=f&k=f)
LSTMの勾配の流れ
![](https://take-tech-engineer.com/wp-content/uploads/2022/05/138fea4548520763f765d9a267e0b17d-1024x201.jpg)
記憶セルの逆伝播では、「+」と「×」ノードだけを通ることになる。「+」ノードは、勾配をそのまま流すだけなので、勾配の変化は起きない。「×」は、要素ごとの積で、毎時刻、異なるゲート値によって要素ごとの積の計算が行われるため、勾配消失を起こしにくい。
コメント