【論文解説】非IIDデータを用いたFederated Learningで学習精度が悪化する問題への対応策

この記事では以下の論文の解説をします
arxiv.org

Federated Learningの分野において、頻繁に引用されている論文ですが、
日本語での解説記事があまり見当たらなかったので簡単に紹介しようと思います
不正確な内容がありましたら是非ご指摘ください

そもそもFederated Learning(FL)とは何か?

Federated Learningは主に機械学習においてプライバシーとデータセキュリティの確保を目的とした手法の1つに位置付けられており、医療や金融, IoTの領域での応用が期待されています。
学習データセットが複数に分散しており、生データには互いにアクセスできない状態を維持しつつ、各データセットを用いて学習することによる恩恵だけを集めて優れた機械学習モデルを得ることを目指します。
具体的には、最もシンプルなFederated Learningの手法として、FedAvg と呼ばれる方法があります。FedAvgでは、まずデータセットを所有している各機関(クライアント)でアーキテクチャと重みの初期値が同一の機械学習モデルをそれぞれ別個に学習させます。次に、一定のエポックごとに入力された元のデータではなくモデルの重みだけを共有し、加重平均をとって中央(サーバー)の機械学習モデルの重みとして採用します。中央サーバーのモデルの重みを初期値としてクライアントへ再配布し、このプロセスを繰り返すことで学習を進めます。

論文中で前提とされているFederated Learningの手法も,このFedAvgになります。

本論文の要旨

  • 高度に歪んだ非IIDデータを用いてFederated Learningを行ったニューラルネットワークは最大55%までaccuracyが低下する
  • この精度低下は各クライアントのクラス分布と母集団分布の間のEMD(Earth Mover's Distance)により定量化できる
  • 解決策として、全てのクライアントで共有されるグローバルな小規模データセットを導入することが有効であり、CIFAR10の場合、5%の共有データで30%の精度改善がみられた

解説

この論文では、クライアントが持つデータセットが非IID(独立同一分布)である場合にサンプルの確率分布の差が大きいほど、FedAvg法によるニューラルネットワークの学習精度が一般的な学習手法と比較して低下することを示し、その解決策として、クライアント全体で共有できるデータセットを用意することが有効だと主張しています。このData-sharing Strategyでは、重みの初期値として共有データセット学習済みモデルを使用し、クライアント側の学習ではクライアント固有のデータセットと共有データセットを混ぜて学習させます。

イメージとしては以下の図がわかりやすいかと思います

重みの初期値が同一であっても、non-IIDのデータセットを用いて学習するとクライアントのモデルごとの重みにばらつきが生じます。イテレーションごとにFedAvg法により重みをマージしたとしても、データセット全体を用いて学習した場合の重みとは異なるため、この差がイテレーションごとに広がっていき、モデルの精度の大きな低下につながります。

論文中でこの性質について数学的証明がなされていますが、ここでは割愛します。

非IID性と精度悪化の関係性

実験ではデータセットの分割方法を変えて、中央サーバーのCNNモデルの学習精度を比較しました。
用いられたデータセットはMNIST, CIFAR10の画像データとSpeech commands データセットと呼ばれる音声データです。
比較される学習手法は以下のように設定されました。

IID
10クラスの一様分布のデータセットが各クライアントへ提供される

1class non-IID
1つのクラスのデータだけを各クライアントは受け取り、固定の正解クラスを選択するよう学習する

2class non-IID
1クラスを2分割し、各クライアントはランダムに二種のクラスの分割データを受け取り、正解クラスを選択するように学習する

SGD
一般的な学習手法でデータセット全体を用いて学習する

pre-trained
全体のデータセットである程度事前学習済みのモデルに対し追加の学習を行う


これらの学習手法を比較した結果が以下の図です。


IIDのデータセットを用いたFLではSGDとほとんど同じ精度が得られています。
一方、non-IIDデータは特に1classの場合で顕著にその精度を悪化させています。
pre-trainedモデルを用いてもその後ほとんど精度が向上しない、あるいは精度が低下することがわかります。
また、クライアントごとの学習におけるローカルエポック数(E)の値は精度に大きな影響は与えていません。

この図は分割前のデータセットとクライアントのデータセットのクラス分布の差が大きくなるほど、モデルの重みのばらつきも大きくなることを示しています。横軸のEMDは以下から算出されます。

p(y=i)はデータセットにおける正解ラベルがiであるデータの割合を示します。左項はクライアントkの, 右項は元のデータセットの場合です。
この割合の差に対しなんらかの距離尺度を適用し、全クラスの総和を求めます。
この値は確率分布間の距離 EMDに相当すると論文中で述べられています。

EMDについては以下のサイトが分かりやすかったです。
Earth Mover's Distance (EMD) - 人工知能に関する断創録

実際にEMDが上昇するとある段階から急激に test accuracyが減少することが確認されています。

non-IIDデータセットによる精度低下への対応策

論文の後半では、データセットの非IID性によるtest accuracyの低下への対応策が示されます。
方法は単純です。
・図のように、全てのクライアントで共有可能なグローバルデータセットGを用意する
・Gで中央サーバーのモデルを学習させる
・学習済みモデルの重みを初期値としてクライアントへ配布
・Gのうち割合αのデータをランダムに抽出しクライアントへ配布
・クライアントはクライアントのデータセットDとα×Gを混合してFedAvg法を行う

data-sharing strategyによるtest accuracyへの影響を確認しました。
クライアントのデータセットDのサイズに対するGのデータサイズの割合をβとして、
直感に違わずβが大きくなるにしたがい、EMDは減少しtest accuracyは改善されます。
αについては、一定の大きさでなければ改善効果が低いことがわかります。

結論と感想

data-sharing strategyは、学習精度を向上・安定化させたい事を考えると真っ先に思いつく手法だと思いますが、
この論文では実際に、これが有効である事を簡潔に確認しています。
今回は割愛しましたが、non-IIDデータにおける精度低下を数学的に証明している点でこの論文は評価されているのだと思います。
本論文ではMNIST, CIFAR10, Speech commands datasetのみ検証されています。
医用画像のデータセットの場合などについても気になるので今後調べていきたいと思います。