こんにちは、モルフォリサーチャーの芳賀です。
1回目の記事では半教師有り学習の概要に始まり、ラベル有りデータとラベル無しデータを一度に学習に組み込むone-stage学習に焦点を当て、基本コンセプトである「consistency regularization」「entropy minimization」について具体的な手法を交えて紹介しました。
2回目の記事では最近に至るまでの数々の手法について紹介します。
Temporal Ensembling
前回の記事ではconsistency regularizationの例としてΠ-modelを紹介しました。 Π-modelでは同じ入力画像を2回ネットワークに流して得られた出力同士のMSE(平均二乗誤差)をunsupervised lossとして学習に組み込んでいました。
Temporal Ensembling[1]はΠ-modelを処理速度とノイズ耐性に関して改良した手法です。 処理速度に関しては、2回ネットワークに流すうちの片方を1つ前のエポックで出力された結果を活用するということで約2倍高速化させています。 ノイズ耐性に関しては、上記の1つ前のエポックで出力された結果を指数移動平均で更新するというやり方を用いています。
手法の説明
以下のΠ-modelとTemporal Ensemblingのフローの比較図を踏まえて、もう少し具体的に見ていきたいと思います。
インデックスの入力に対しネットワークに流して得られた出力まではΠ-modelと同様ですが、MSEを取るもう片方のは1つ前のエポックにおける同じ入力に対応する出力から以下のように更新されます。
は中間変数で、は指数移動平均の係数になります。 2行目ので割っているものはの初期値のバイアス補正によるものです。 はゼロベクトルで初期化しますが、最初の更新()では式を見ると初期値の影響を無視するようにとなっていることがわかります。学習が進むとは1に近くなっていって指数移動平均の効果が支配的になっていきます。 この補正はOptimizerで有名なAdamでも使われています。
以上の変更によりΠ-modelと比較し高速かつノイズに強い学習ができますが、以下のようなデメリットも挙げられています。
- 入力画像ごとにネットワークの出力を保持するメモリが必要
- 指数移動平均のハイパーパラメータが追加
1つ目に関して例えばデータ数10万、分類クラス100の設定の時、出力を64ビットfloatで保持するとして、100000 * 100 * 8 ~ O(100MB)程度のメモリが必要となることから、学習の規模が大きいと厳しい制約になることがわかります。
実験結果
実験ではCIFAR-10, CIFAR-100, SVHN*1のデータセットを用いています。
ここではCIFAR-10の結果を紹介します。
CIFAR-10では50000個のデータのうち4000個のラベル有りデータを用いた学習に対し、今回の手法では誤分類率12.16%を達成しており、既存手法の結果と比較してもその有効性が確認できます。 また、データ拡張を入れたことで4%程度誤分類率が下がっており、改めてデータ拡張の重要性がわかります。 一番右列は全てのデータにラベルを入れて学習させた結果ですが、既存の教師有り学習の結果と比べて良い精度が出ており、教師有り学習においてもconsistency regularizationの考え方は有効であることがわかります。
Mean Teacher
Temporal Ensemblingは入力に対する出力を過去のエポックにわたり指数移動平均させたものをunsupervised lossに組み込む手法でした。 しかしの更新間隔は1エポックであるため、大きいデータセットの学習では非効率であるという問題点が上げられています。
ここで紹介するMean Teacher[2]はこの問題点を克服すべく、ネットワークの重み自体を指数移動平均で1イテレーションごとに更新するという思い切った手法を提案しています。
手法の説明
以下のMean Teacherの概略図をもとに説明したいと思います。
Mean Teacherではstudent modelとteacher modelというものを用意しています。 以下の手順に沿ってイテレーションを回します。
- 入力データをstudent, teacher両方のモデルに流す
- 入力がラベル有りの場合は出力と教師ラベルとのロスを計算する
- student modelの出力とteacher modelの出力のconsistency cost(consistency regularizationと同義)を計算する
- 2と3のロスにより誤差逆伝搬でstudent modelの重みを更新する
- teacher modelの重みを指数移動平均を用いてstudent modelの重みで更新する
Temporal Ensembling同様かなりシンプルな手法であることがわかります。 最終的に推論に使うネットワークはteacher modelを使います。
これにより、既存手法と比較し以下の恩恵が得られます。
- 1イテレーションごとに学習のフィードバックがかかるため*2、teacher modelの推論結果の精度が高くなる
- 大きなデータセットに対してもオンラインで学習できる
実験結果
CIFAR-10のデータセットを用いた実験結果は以下になります。
CIFAR-10において、ラベル有りデータが50000個中1000個というより少ないラベル数の設定での学習で、既存手法より6%程度低い誤分類率を達成しています。
ハイパーパラメータである指数移動平均の係数(0~1)ですが、実験では学習初期は0.99に設定し、学習が進むにつれて0.999まで徐々に上げていく戦略がうまくいくと報告されています。 これは、学習初期はstudentの精度がどんどん改良していく段階でなるべく新しく学習した重みを重視するようteacherに更新するようにし、学習が成熟した段階で長いスパンでの平均的な重みをteacherに保持させるためという理由から納得できます。
Virtual Adversarial Training (VAT)
今回最後に紹介する手法は正則化に焦点を当てたVirtual Adversarial Trainingというものになります。 端的に言うと、入力画像に対し分類器としてのネットワークの出力がもっとも揺らぎやすい方向のノイズを付与しても分類結果が変わらないように制御する学習手法です。 つまり、半教師有り学習の設定でいうと上記のノイズを加えた入力と加えない生の入力との間でconsistency regularizationを考える手法になります。
この「出力がもっとも揺らぎやすい方向」をどう考えるかを含めて、まずはAdversarial Trainingから説明したいと思います。
Adversarial Training
入力画像にガウシアンノイズを付与するデータ拡張において分類ネットワークの汎化性能は向上しますが、ある特定の方向に対する弱いノイズに対しては予測結果が揺らぎやすいという性質が報告されています。[4] いわゆるadversarial attackというもので、以下のように汎化性能の高いモデルであっても人間の目には検知できない特定の方向(adversarial direction)の弱いノイズが付与された入力に対しては異なる推論結果を出してしまうという例で有名です。
Adversarial Trainingはこのadversarial attackに対する処方として考案されたもので、adversarial directionのノイズを付与させても出力が教師ラベルのそれと近くなるように学習させる手法になります。
adversarial trainingでのロスは以下で定義されます。
がラベル有りデータ、が教師ラベル、がネットワークの重み、は分布間の距離でここではcross entropyを想定、が考えうるノイズの絶対値、がadversarial directionのノイズとなります。 また、が入力に対するネットワークの出力(ベクトル)で、が教師ラベルのベクトルになります。
を求める必要がありますが、そのままでは解析的に書けないので以下の近似を施します。
- を教師ラベルのone-hotベクトルとみなす
- を変数としての一次近似を施す
以上によりcross entropyの勾配を用いてを計算することができます。
Virtual Adversarial Training
virtual adversarial training[3]は先ほどのadversarial trainingを半教師有り学習に応用させた手法になります。
consistency regularizationの考えを踏まえて、(1)式(2)式に対し
- ラベル有りデータ → ラベルの有無に関係しない入力
- 真のラベル分布 → 現在のネットワークの出力
のように置き換えると以下のようになります。
は現時点でのパラメータをfixしたものを表します。 このがvirtual adversarial perturbationの定義です。 virtualという名前はラベル無しデータの真のラベル分布を仮定するという意味合いで名づけられているそうです。 LDSと書いたロスを下げることは各データ点周辺のラベル分布を滑らかにする効果があり、consistency regularizationに対応するものとなっています。
そして、の求め方ですが様々な近似を適用することで以下のの方向の勾配を用いて表すことができます。 ここでは近似の詳細は省略しますが、論文には数学的に議論されているので気になる方はそちらを参照してください。
はランダムな初期値ベクトル*3、は微小な係数です。
学習ではLDSのネットワークの重みに関する勾配を知りたいので、アルゴリズムとしては以下のように求めることができます。
- 入力を選ぶ
- ランダムな単位ベクトルをi.i.d.ガウシアンからサンプリング
- のに関するにおける勾配を求める
- 方向の大きさのベクトルとしてを計算する
- にを代入しパラメータに関する勾配を求める
最終的なロスとしては、上記をミニバッチで平均したものと教師有りロスを足したものになります。
以上がvirtual adversarial trainingの手法になります。
実験結果
ここでも概要のみになりますが、CIFAR-10とSHVNでの結果を見てみます。
既存手法(Temporal Ensembling)と比較して精度が改善もしくは同程度であることがわかります。
以下は実際に入力から計算されたノイズを入力に施したサンプルになります。
パラメータであるを増やすとノイズの強度や範囲が画像の重要な部分を中心に大きくなっていくことがわかります。 大きなノイズをかけすぎると不自然な入力になってしまうため、論文ではそれぞれのデータセットで真ん中の列の程度の強度を採用しているようです。
まとめ
今回はTemporal Ensembling, Mean teacher, VATという半教師有り学習に関する手法を紹介しました。指数移動平均をうまく使ったりadversarial trainingの知見を生かしたりと、個人的に調べていて興味深かったです。
次回はMixMatch, ReMixMatch, FixMatchという2019年以降に発表された一連の手法について紹介したいと思います。