株式会社神戸デジタル・ラボ DataIntelligenceチームの原口です。
前回は、Dropoutを導入することで過学習を抑制しました。
もう一度内容を確認したい方は、以下の記事をご覧ください。
連載を最初から読みたい方は、以下の記事をご覧ください。
今回はAIモデルに「ひと手間」加えることで、学習の安定性を向上させます!
前回の問題点
前回はDropoutを利用することで、モデルの過学習を抑制できました。
一方で、AIモデルの学習速度が緩やかになる問題も発生しました。
そこで今回は、学習をより早く収束させられる(かもしれない)Batch Normalizationを導入することで、さらなる精度向上を図ります。
ニューラルネットワークの収束が遅い原因
Batch Normalizationの説明をする前に、なぜニューラルネットワークの収束が遅くなったのか考えてみましょう。
ニューラルネットワークの学習は次の手順を踏んで行います。
- 学習データを計算し、予測を出力する
- 予測と正解の誤差を計算する
- 誤差逆伝搬を行い、モデルのパラメータを更新する
この手順を繰り返すことで、学習は進みます。この中のどこで問題が発生しているのでしょうか?
それは、「誤差逆伝搬を行い、モデルのパラメータを更新する」部分です。
ではどのような問題があるか見ていきましょう。
学習データが入力されるとCNNは予測を出力するために、複雑な計算を行います。
各畳み込み層は学習できるパラメータを所有しており、各層は前の層の出力に対してを掛け合わせることで出力を得ています。
ここで重要なのは、各層の計算は前の層の出力に依存するということです。つまりの出力を得るためにはが必要となり、を得るためにはが必要というわけです。
続いて学習を実行します。学習をすると、各畳み込み層のパラメータはからへと変化します。(掛け算の値が3.0→3.1に変化したようなイメージです) このとき、各畳み込み層のパラメータが学習によって変化しているため、それぞれの層が出力するも大きく変化します。
ここで問題になるのが先ほど重要であるとお伝えした「各層の計算は前の層の出力に依存する」という点です。各層のパラメータは学習前の時点で得られるデータ分布に対して最適な形に変更されました。
しかしすべての層が学習したことによって、各層が得られるデータ分布はからへと変わっています。
このとの乖離が小さければ変化は微々たるのもなのでで学習したで対応できますが、とが大きく乖離している場合はで対応できないため、もう一度大きな修正を加える必要があります。
特に学習の序盤では、パラメータは頻繁に大きく変更されるため、安定した学習ができないという結果になるのです。
こういった問題を解決するために、Batch Normalizationを用いてみましょう!
Batch Normalizationとは?
Batch Normalizationについて説明します。Batch Normalizationは出力されたデータの分布を強制的に標準化します。出力されるデータ分布が固定されることによって、後段の学習がスムーズに行われるようになります。
では具体的な式を確認しましょう。Batch Normalizationはミニバッチ(データセット群から学習のために取り出してきた一部のデータを指します)に対して出力されたデータを正規化します。つまりデータの平均と分散が必要になります。
またBatch Normalizationでは、標準化したデータに対して係数をかけ(スケール変換)、定数を加算(シフト変換)することで値の大きさと平均値を移動します。ここで出現する変数とは学習の中でより適したものに更新されます。
このように変換されることにより、後段で受け取るデータは常に標準化された分布に変化しているため、パラメータ更新によって大きな変化が生まれることが少なくなり、学習が安定するということです。
実験
実際に実験をして確かめてみましょう。学習コードは【第3回 Dropout導入編】PyTorchとCIFAR-10で学ぶCNNの精度向上のものを使います。
Batch Normalizationの実験には以下のモデルを利用しました。
class BatchNormModel(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.bn_1 = nn.BatchNorm1d(120) #新しく追加した箇所 self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = torch.flatten(x, 1) x = F.relu(self.bn_1(self.fc1(x))) x = F.relu(self.fc2(x)) x = self.fc3(x) return x
実際に学習して、グラフにどのような変化が現れるか確認しましょう。
実験結果
普通のモデル・Dropout導入モデル、Batch Normalization実験結果は次のグラフのようになりました。
Batch Normalizationを導入したモデルでは、かなり早い段階で損失が最低になっていることが分かります。一方、早期に過学習になる傾向がみられます。
過学習になれば学習を打ち止め出来るため、たくさん試行することができます。AI開発では仮説検証をたくさん行うことが重要です。多く試行できるということは、多くの仮説検証を行うことに繋がります。
つまり、Batch Normalizationを導入することで速くそして高精度に実験を繰り返すことができ、より良いAI開発に集中することができます。
まとめ
今回はBatch Normalizationをニューラルネットワークに導入しました。導入によって高速な学習が実現し、それによって多くの試行が出来るようになることを確認しました。
Batch Normalizationはどの深層学習においてスタンダードとなっている手法であり、どのモデルを見ても何らかの形で導入されています。
最新のモデルの中身を確認した際は、どこにBatch Normalizationが利用されているか探してみてください!
データインテリジェンスチーム所属
データエンジニアを担当しています。画像認識を得意としており、画像認識・ニューラルネットワーク系の技術記事を発信していきます