【論文解説】非IIDデータを用いたFederated Learningで学習精度が悪化する問題への対応策

この記事では以下の論文の解説をします
arxiv.org

Federated Learningの分野において、頻繁に引用されている論文ですが、
日本語での解説記事があまり見当たらなかったので簡単に紹介しようと思います
不正確な内容がありましたら是非ご指摘ください

そもそもFederated Learning(FL)とは何か?

Federated Learningは主に機械学習においてプライバシーとデータセキュリティの確保を目的とした手法の1つに位置付けられており、医療や金融, IoTの領域での応用が期待されています。
学習データセットが複数に分散しており、生データには互いにアクセスできない状態を維持しつつ、各データセットを用いて学習することによる恩恵だけを集めて優れた機械学習モデルを得ることを目指します。
具体的には、最もシンプルなFederated Learningの手法として、FedAvg と呼ばれる方法があります。FedAvgでは、まずデータセットを所有している各機関(クライアント)でアーキテクチャと重みの初期値が同一の機械学習モデルをそれぞれ別個に学習させます。次に、一定のエポックごとに入力された元のデータではなくモデルの重みだけを共有し、加重平均をとって中央(サーバー)の機械学習モデルの重みとして採用します。中央サーバーのモデルの重みを初期値としてクライアントへ再配布し、このプロセスを繰り返すことで学習を進めます。

論文中で前提とされているFederated Learningの手法も,このFedAvgになります。

本論文の要旨

  • 高度に歪んだ非IIDデータを用いてFederated Learningを行ったニューラルネットワークは最大55%までaccuracyが低下する
  • この精度低下は各クライアントのクラス分布と母集団分布の間のEMD(Earth Mover's Distance)により定量化できる
  • 解決策として、全てのクライアントで共有されるグローバルな小規模データセットを導入することが有効であり、CIFAR10の場合、5%の共有データで30%の精度改善がみられた

解説

この論文では、クライアントが持つデータセットが非IID(独立同一分布)である場合にサンプルの確率分布の差が大きいほど、FedAvg法によるニューラルネットワークの学習精度が一般的な学習手法と比較して低下することを示し、その解決策として、クライアント全体で共有できるデータセットを用意することが有効だと主張しています。このData-sharing Strategyでは、重みの初期値として共有データセット学習済みモデルを使用し、クライアント側の学習ではクライアント固有のデータセットと共有データセットを混ぜて学習させます。

イメージとしては以下の図がわかりやすいかと思います

重みの初期値が同一であっても、non-IIDのデータセットを用いて学習するとクライアントのモデルごとの重みにばらつきが生じます。イテレーションごとにFedAvg法により重みをマージしたとしても、データセット全体を用いて学習した場合の重みとは異なるため、この差がイテレーションごとに広がっていき、モデルの精度の大きな低下につながります。

論文中でこの性質について数学的証明がなされていますが、ここでは割愛します。

非IID性と精度悪化の関係性

実験ではデータセットの分割方法を変えて、中央サーバーのCNNモデルの学習精度を比較しました。
用いられたデータセットはMNIST, CIFAR10の画像データとSpeech commands データセットと呼ばれる音声データです。
比較される学習手法は以下のように設定されました。

IID
10クラスの一様分布のデータセットが各クライアントへ提供される

1class non-IID
1つのクラスのデータだけを各クライアントは受け取り、固定の正解クラスを選択するよう学習する

2class non-IID
1クラスを2分割し、各クライアントはランダムに二種のクラスの分割データを受け取り、正解クラスを選択するように学習する

SGD
一般的な学習手法でデータセット全体を用いて学習する

pre-trained
全体のデータセットである程度事前学習済みのモデルに対し追加の学習を行う


これらの学習手法を比較した結果が以下の図です。


IIDのデータセットを用いたFLではSGDとほとんど同じ精度が得られています。
一方、non-IIDデータは特に1classの場合で顕著にその精度を悪化させています。
pre-trainedモデルを用いてもその後ほとんど精度が向上しない、あるいは精度が低下することがわかります。
また、クライアントごとの学習におけるローカルエポック数(E)の値は精度に大きな影響は与えていません。

この図は分割前のデータセットとクライアントのデータセットのクラス分布の差が大きくなるほど、モデルの重みのばらつきも大きくなることを示しています。横軸のEMDは以下から算出されます。

p(y=i)はデータセットにおける正解ラベルがiであるデータの割合を示します。左項はクライアントkの, 右項は元のデータセットの場合です。
この割合の差に対しなんらかの距離尺度を適用し、全クラスの総和を求めます。
この値は確率分布間の距離 EMDに相当すると論文中で述べられています。

EMDについては以下のサイトが分かりやすかったです。
Earth Mover's Distance (EMD) - 人工知能に関する断創録

実際にEMDが上昇するとある段階から急激に test accuracyが減少することが確認されています。

non-IIDデータセットによる精度低下への対応策

論文の後半では、データセットの非IID性によるtest accuracyの低下への対応策が示されます。
方法は単純です。
・図のように、全てのクライアントで共有可能なグローバルデータセットGを用意する
・Gで中央サーバーのモデルを学習させる
・学習済みモデルの重みを初期値としてクライアントへ配布
・Gのうち割合αのデータをランダムに抽出しクライアントへ配布
・クライアントはクライアントのデータセットDとα×Gを混合してFedAvg法を行う

data-sharing strategyによるtest accuracyへの影響を確認しました。
クライアントのデータセットDのサイズに対するGのデータサイズの割合をβとして、
直感に違わずβが大きくなるにしたがい、EMDは減少しtest accuracyは改善されます。
αについては、一定の大きさでなければ改善効果が低いことがわかります。

結論と感想

data-sharing strategyは、学習精度を向上・安定化させたい事を考えると真っ先に思いつく手法だと思いますが、
この論文では実際に、これが有効である事を簡潔に確認しています。
今回は割愛しましたが、non-IIDデータにおける精度低下を数学的に証明している点でこの論文は評価されているのだと思います。
本論文ではMNIST, CIFAR10, Speech commands datasetのみ検証されています。
医用画像のデータセットの場合などについても気になるので今後調べていきたいと思います。

PyTorchのBatch Normalizationがパラメータを固定しても変動する問題の解決法

起こった問題
大規模データセットで学習済みのResNetを用いて転移学習を行うと、ResNetの全てのパラメータに対してrequires_grad = Falseとしているにも関わらず、転移学習の前後でResNetの出力が変化する問題が生じました。


原因
調べると、PyTorchのBatch Normalizationの層の仕様によるものでした。これは転移学習などが収束するために必要な正常な動作のようです。
Batch Normalizationの内部計算ではデータの平均,分散, アフィンパラメータβ,γを用いています。β,γは学習パラメータですが平均と分散は学習パラメータではなくtrainingモードでforwardを呼び出すと無条件に更新される値であるため、requires_grad = Falseとしてもこの値は転移学習で用いたデータにより更新されてしまいます。


解決方法
簡単な対応として、evaluationモードではtrainingモードで記憶した平均と分散を利用するため、学習時にもBatch Normalizationの層だけevaluationモードを適用するという変則的な措置を取ります。


以下のDiscussionでこの問題について述べられていたので今回参考にさせていただきました。
discuss.pytorch.org


Batch Normalizationの計算を完全に固定してしまうと、本来の役割が果たせないため学習が収束しない可能性があり、使い時は限られると思いますが...
日本語での解説は見当たらなかったので書き残しておこうと思います。

KerasでもBatch Normalizationの学習と推論での挙動についてまとめられている記事がありました。
KerasでBatchNormalization層を転移学習をする際の注意点 - Qiita



詳細

そもそもBatch Normalizationがどのような役割でどのような計算を行なっているかについて理解しておく必要があります。詳しくは以下の記事が参考になりました。
Batch Normalizationを理解する | 楽しみながら理解するAI・機械学習入門

中間層においても特徴量の正規化を行うことで、モデル全体の学習を安定化させる効果があります。具体的な計算は以下。


 \mathit {y} = \dfrac {\mathit {x} - \mathrm {E}[x]} {\sqrt{ \mathrm {Var}[x]+ε}} \, \, * γ + β


Conv2d → BatchNorm2d → ReLU のような接続が一般的です。正規化後の値が適切な区間で活性化関数に渡され、活性化関数の非線形性が活かされるように、線形変換を行うのがβ,γの役割です。そのためこの2つはaffine parameter と呼ばれています。これらは学習パラメータなので、Requires_grad = Falseとするとtrainingモードでも固定されます。
一方で、E[x], Var[x]はモジュール内でrunning_mean, running_varとして記録されています。これらはパラメータとして扱われておらず、forwardメソッドで呼び出される時に更新されるbufferである、と上に記載したdiscussionでは述べられています。
したがって、requires_grad = False としても学習時に固定されません。
更新方法についてはBatchNorm2dのインスタンス生成時にtrack_running_statsで指定できます。track_running_stats = True がデフォルトで以下のように機能します。

trainingモード: momentum(デフォルトで0.1)の割合でrunning_mean,running_varを更新する。
evaluationモード: trainingモードで記録した値を使用して計算する。

track_running_stats = False ではtrainingモード, evaluationモードのどちらでも、都度入力されたバッチの平均,分散を計算に使用し、過去のデータを記憶しないため、evaluationモードではデータリークを招く可能性があります。




つまり、転移学習の前後でrunning_mean,running_varの値が変わらないようにするためには, track_runnning_stats = Trueの設定で学習時もBatchNorm2dのみevaluationモードを指定しておく必要があります。

model.train()
for layer in model.children():
    if isinstance(layer, nn.BatchNorm2d):
        layer.requires_grad = False
        layer.eval()

これでBatch Normalizationの計算も完全に元のモデルを再現できるはずです。ただし、繰り返しになりますが学習が収束しない可能性がありますので、その点は難ありです。

物体検出のDeticを導入する際にトラブった件

はてなブログ開設の1記事目としては全く唐突ですが、物体検出器 Deticを導入するにあたって、起こったトラブルについて調べたことを備忘録がてら残しておこうと思います。

Deticの元論文は以下になります。

arxiv.org

今回起こった問題

githubからDeticをクローンしてローカル環境で実行しようとした所、パッケージの衝突が起こりました。そのうちの一つとして、実行時に以下のエラーが表示されました。

module 'torch.jit' has no attribute '_script_if_tracing'

これは最新のPyTorchを使用している場合に生じるエラーのようです。

github.com

解決方法

  • Python 3.8
  • PyTorch 1.8
  • torchvision 0.9.0

仮想環境下で上記のバージョンのパッケージを揃えると正常に動作しました。

 

背景

あるあるなトラブルですが、解決にかなり苦労したのでその過程で調べたことをメモしておきます。

githubのインストールDemoではインターフェースとしてdetectron2を用いています。また、requirements.txtに記載されているようにCLIPを内部でgithubよりクローンしています。したがって、Detic, detectron2, CLIP全てが機能するように依存パッケージを揃える必要があります。

その際に重要なのはPython, PyTorch, torchvision 3つの依存関係でした。

torchvision のドキュメントに依存関係の対応表が記載されています。これを元にパッケージのバージョンの確認を行います。

pypi.org

最終的に上述のパッケージをインストールすると正常に機能しました。

短いですが今回は以上です。機会があればDetic導入時の他のトラブルについてもまとめようと思います。