神戸のデータ活用塾!KDL Data Blog

KDLが誇るデータ活用のプロフェッショナル達が書き連ねるブログです。

【第4回 Batch Normalization導入編】PyTorchとCIFAR-10で学ぶCNNの精度向上

株式会社神戸デジタル・ラボ DataIntelligenceチームの原口です。

前回は、Dropoutを導入することで過学習を抑制しました。

もう一度内容を確認したい方は、以下の記事をご覧ください。

kdl-di.hatenablog.com

連載を最初から読みたい方は、以下の記事をご覧ください。

kdl-di.hatenablog.com

今回はAIモデルに「ひと手間」加えることで、学習の安定性を向上させます!

前回の問題点

前回はDropoutを利用することで、モデルの過学習を抑制できました。

一方で、AIモデルの学習速度が緩やかになる問題も発生しました。

そこで今回は、学習をより早く収束させられる(かもしれない)Batch Normalizationを導入することで、さらなる精度向上を図ります。

ニューラルネットワークの収束が遅い原因

Batch Normalizationの説明をする前に、なぜニューラルネットワークの収束が遅くなったのか考えてみましょう。

ニューラルネットワークの学習は次の手順を踏んで行います。

  • 学習データを計算し、予測を出力する
  • 予測と正解の誤差を計算する
  • 誤差逆伝搬を行い、モデルのパラメータを更新する

この手順を繰り返すことで、学習は進みます。この中のどこで問題が発生しているのでしょうか?

それは、「誤差逆伝搬を行い、モデルのパラメータを更新する」部分です。

ではどのような問題があるか見ていきましょう。

学習データが入力されるとCNNは予測を出力するために、複雑な計算を行います。

各畳み込み層は学習できるパラメータ wを所有しており、各層は前の層の出力に対して wを掛け合わせることで出力を得ています。

ここで重要なのは、各層の計算は前の層の出力に依存するということです。つまりx_3の出力を得るためにはx_2が必要となり、x_2を得るためにはx_1が必要というわけです。

学習データを計算し、予測を出力する

続いて学習を実行します。学習をすると、各畳み込み層のパラメータはwから w^{'} へと変化します。(掛け算の値が3.0→3.1に変化したようなイメージです)

学習後の状態
このとき、各畳み込み層のパラメータが学習によって変化しているため、それぞれの層が出力するxも大きく変化します。

ここで問題になるのが先ほど重要であるとお伝えした「各層の計算は前の層の出力に依存する」という点です。各層のパラメータは学習前の時点で得られるデータ分布xに対して最適な形に変更されました。

しかしすべての層が学習したことによって、各層が得られるデータ分布は xからx^{'}へと変わっています。

このxx^{'}の乖離が小さければ変化は微々たるのもなのでxで学習したw^{'}で対応できますが、xx^{'}が大きく乖離している場合はw^{'}で対応できないため、もう一度大きな修正を加える必要があります。

特に学習の序盤では、パラメータは頻繁に大きく変更されるため、安定した学習ができないという結果になるのです。

学習前・学習後のイメージ

こういった問題を解決するために、Batch Normalizationを用いてみましょう!

Batch Normalizationとは?

Batch Normalizationについて説明します。Batch Normalizationは出力されたデータの分布を強制的に標準化します。出力されるデータ分布が固定されることによって、後段の学習がスムーズに行われるようになります。

Batch Normalizationによる正規化のイメージ

では具体的な式を確認しましょう。Batch Normalizationはミニバッチ(データセット群から学習のために取り出してきた一部のデータを指します)に対して出力されたデータを正規化します。つまりデータの平均と分散が必要になります。

Batch Normalizationの仕組み

またBatch Normalizationでは、標準化したデータに対して係数をかけ(スケール変換)、定数を加算(シフト変換)することで値の大きさと平均値を移動します。ここで出現する変数 \gamma\betaは学習の中でより適したものに更新されます。

Batch Normalizationによるスケール・シフト変換

このように変換されることにより、後段で受け取るデータは常に標準化された分布に変化しているため、パラメータ更新によって大きな変化が生まれることが少なくなり、学習が安定するということです。

実験

実際に実験をして確かめてみましょう。学習コードは【第3回 Dropout導入編】PyTorchとCIFAR-10で学ぶCNNの精度向上のものを使います。

Batch Normalizationの実験には以下のモデルを利用しました。

class BatchNormModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.bn_1 = nn.BatchNorm1d(120)   #新しく追加した箇所
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) 
        x = F.relu(self.bn_1(self.fc1(x)))
        x = F.relu(self.fc2(x))     
        x = self.fc3(x)
        return x

実際に学習して、グラフにどのような変化が現れるか確認しましょう。

実験結果

普通のモデル・Dropout導入モデル、Batch Normalization実験結果は次のグラフのようになりました。

実験結果

Batch Normalizationを導入したモデルでは、かなり早い段階で損失が最低になっていることが分かります。一方、早期に過学習になる傾向がみられます。

過学習になれば学習を打ち止め出来るため、たくさん試行することができます。AI開発では仮説検証をたくさん行うことが重要です。多く試行できるということは、多くの仮説検証を行うことに繋がります。

つまり、Batch Normalizationを導入することで速くそして高精度に実験を繰り返すことができ、より良いAI開発に集中することができます。

まとめ

今回はBatch Normalizationをニューラルネットワークに導入しました。導入によって高速な学習が実現し、それによって多くの試行が出来るようになることを確認しました。

Batch Normalizationはどの深層学習においてスタンダードとなっている手法であり、どのモデルを見ても何らかの形で導入されています。

最新のモデルの中身を確認した際は、どこにBatch Normalizationが利用されているか探してみてください!

原口俊樹

データインテリジェンスチーム所属
データエンジニアを担当しています。画像認識を得意としており、画像認識・ニューラルネットワーク系の技術記事を発信していきます