PyTorchのBatch Normalizationがパラメータを固定しても変動する問題の解決法

起こった問題
大規模データセットで学習済みのResNetを用いて転移学習を行うと、ResNetの全てのパラメータに対してrequires_grad = Falseとしているにも関わらず、転移学習の前後でResNetの出力が変化する問題が生じました。


原因
調べると、PyTorchのBatch Normalizationの層の仕様によるものでした。これは転移学習などが収束するために必要な正常な動作のようです。
Batch Normalizationの内部計算ではデータの平均,分散, アフィンパラメータβ,γを用いています。β,γは学習パラメータですが平均と分散は学習パラメータではなくtrainingモードでforwardを呼び出すと無条件に更新される値であるため、requires_grad = Falseとしてもこの値は転移学習で用いたデータにより更新されてしまいます。


解決方法
簡単な対応として、evaluationモードではtrainingモードで記憶した平均と分散を利用するため、学習時にもBatch Normalizationの層だけevaluationモードを適用するという変則的な措置を取ります。


以下のDiscussionでこの問題について述べられていたので今回参考にさせていただきました。
discuss.pytorch.org


Batch Normalizationの計算を完全に固定してしまうと、本来の役割が果たせないため学習が収束しない可能性があり、使い時は限られると思いますが...
日本語での解説は見当たらなかったので書き残しておこうと思います。

KerasでもBatch Normalizationの学習と推論での挙動についてまとめられている記事がありました。
KerasでBatchNormalization層を転移学習をする際の注意点 - Qiita



詳細

そもそもBatch Normalizationがどのような役割でどのような計算を行なっているかについて理解しておく必要があります。詳しくは以下の記事が参考になりました。
Batch Normalizationを理解する | 楽しみながら理解するAI・機械学習入門

中間層においても特徴量の正規化を行うことで、モデル全体の学習を安定化させる効果があります。具体的な計算は以下。


 \mathit {y} = \dfrac {\mathit {x} - \mathrm {E}[x]} {\sqrt{ \mathrm {Var}[x]+ε}} \, \, * γ + β


Conv2d → BatchNorm2d → ReLU のような接続が一般的です。正規化後の値が適切な区間で活性化関数に渡され、活性化関数の非線形性が活かされるように、線形変換を行うのがβ,γの役割です。そのためこの2つはaffine parameter と呼ばれています。これらは学習パラメータなので、Requires_grad = Falseとするとtrainingモードでも固定されます。
一方で、E[x], Var[x]はモジュール内でrunning_mean, running_varとして記録されています。これらはパラメータとして扱われておらず、forwardメソッドで呼び出される時に更新されるbufferである、と上に記載したdiscussionでは述べられています。
したがって、requires_grad = False としても学習時に固定されません。
更新方法についてはBatchNorm2dのインスタンス生成時にtrack_running_statsで指定できます。track_running_stats = True がデフォルトで以下のように機能します。

trainingモード: momentum(デフォルトで0.1)の割合でrunning_mean,running_varを更新する。
evaluationモード: trainingモードで記録した値を使用して計算する。

track_running_stats = False ではtrainingモード, evaluationモードのどちらでも、都度入力されたバッチの平均,分散を計算に使用し、過去のデータを記憶しないため、evaluationモードではデータリークを招く可能性があります。




つまり、転移学習の前後でrunning_mean,running_varの値が変わらないようにするためには, track_runnning_stats = Trueの設定で学習時もBatchNorm2dのみevaluationモードを指定しておく必要があります。

model.train()
for layer in model.children():
    if isinstance(layer, nn.BatchNorm2d):
        layer.requires_grad = False
        layer.eval()

これでBatch Normalizationの計算も完全に元のモデルを再現できるはずです。ただし、繰り返しになりますが学習が収束しない可能性がありますので、その点は難ありです。