画像分類タスクにおける半教師有り学習 第2回

こんにちは、モルフォリサーチャーの芳賀です。

1回目の記事では半教師有り学習の概要に始まり、ラベル有りデータとラベル無しデータを一度に学習に組み込むone-stage学習に焦点を当て、基本コンセプトである「consistency regularization」「entropy minimization」について具体的な手法を交えて紹介しました。

techblog.morphoinc.com

2回目の記事では最近に至るまでの数々の手法について紹介します。

Temporal Ensembling

前回の記事ではconsistency regularizationの例としてΠ-modelを紹介しました。 Π-modelでは同じ入力画像を2回ネットワークに流して得られた出力同士のMSE(平均二乗誤差)をunsupervised lossとして学習に組み込んでいました。

Temporal Ensembling[1]はΠ-modelを処理速度とノイズ耐性に関して改良した手法です。 処理速度に関しては、2回ネットワークに流すうちの片方を1つ前のエポックで出力された結果を活用するということで約2倍高速化させています。 ノイズ耐性に関しては、上記の1つ前のエポックで出力された結果を指数移動平均で更新するというやり方を用いています。

手法の説明

以下のΠ-modelとTemporal Ensemblingのフローの比較図を踏まえて、もう少し具体的に見ていきたいと思います。

[1]Figure.1より引用。

インデックス iの入力 x_iに対しネットワークに流して得られた出力 z_iまではΠ-modelと同様ですが、MSEを取るもう片方の \tilde{z}_iは1つ前のエポックにおける同じ入力に対応する出力から以下のように更新されます。

 Zは中間変数で、 \alphaは指数移動平均の係数になります。 2行目の 1-\alpha^tで割っているものは Zの初期値のバイアス補正によるものです。  Zはゼロベクトルで初期化しますが、最初の更新( t = 1)では式を見ると初期値の影響を無視するように \tilde{z} = zとなっていることがわかります。学習が進むと 1-\alpha^tは1に近くなっていって指数移動平均の効果が支配的になっていきます。 この補正はOptimizerで有名なAdamでも使われています。

以上の変更によりΠ-modelと比較し高速かつノイズに強い学習ができますが、以下のようなデメリットも挙げられています。

  • 入力画像ごとにネットワークの出力を保持するメモリが必要
  • 指数移動平均のハイパーパラメータ \alphaが追加

1つ目に関して例えばデータ数10万、分類クラス100の設定の時、出力を64ビットfloatで保持するとして、100000 * 100 * 8 ~ O(100MB)程度のメモリが必要となることから、学習の規模が大きいと厳しい制約になることがわかります。

実験結果

実験ではCIFAR-10, CIFAR-100, SVHN*1のデータセットを用いています。

ここではCIFAR-10の結果を紹介します。

[1]Table.1より引用。CIFAR-10による誤分類率の結果。各列は50000個のデータのうちラベル有りデータの数を変えて学習させた結果。誤分類率のばらつきは乱数シードを変えて10回学習させたときの標準偏差で求めている。

CIFAR-10では50000個のデータのうち4000個のラベル有りデータを用いた学習に対し、今回の手法では誤分類率12.16%を達成しており、既存手法の結果と比較してもその有効性が確認できます。 また、データ拡張を入れたことで4%程度誤分類率が下がっており、改めてデータ拡張の重要性がわかります。 一番右列は全てのデータにラベルを入れて学習させた結果ですが、既存の教師有り学習の結果と比べて良い精度が出ており、教師有り学習においてもconsistency regularizationの考え方は有効であることがわかります。

Mean Teacher

Temporal Ensemblingは入力に対する出力を過去のエポックにわたり指数移動平均させたもの \tilde{z}_iをunsupervised lossに組み込む手法でした。 しかし \tilde{z}_iの更新間隔は1エポックであるため、大きいデータセットの学習では非効率であるという問題点が上げられています。

ここで紹介するMean Teacher[2]はこの問題点を克服すべく、ネットワークの重み自体を指数移動平均で1イテレーションごとに更新するという思い切った手法を提案しています。

手法の説明

以下のMean Teacherの概略図をもとに説明したいと思います。

[2]FIgure.2より引用。

Mean Teacherではstudent modelとteacher modelというものを用意しています。 以下の手順に沿ってイテレーションを回します。

  1. 入力データをstudent, teacher両方のモデルに流す
  2. 入力がラベル有りの場合は出力と教師ラベルとのロスを計算する
  3. student modelの出力とteacher modelの出力のconsistency cost(consistency regularizationと同義)を計算する
  4. 2と3のロスにより誤差逆伝搬でstudent modelの重みを更新する
  5. teacher modelの重みを指数移動平均を用いてstudent modelの重みで更新する

Temporal Ensembling同様かなりシンプルな手法であることがわかります。 最終的に推論に使うネットワークはteacher modelを使います。

これにより、既存手法と比較し以下の恩恵が得られます。

  • 1イテレーションごとに学習のフィードバックがかかるため*2、teacher modelの推論結果の精度が高くなる
  • 大きなデータセットに対してもオンラインで学習できる

実験結果

CIFAR-10のデータセットを用いた実験結果は以下になります。

[2]Table.2より引用。表の見方は前節の表と同様。上段4つは既存提案手法の論文からの参照値、下段は著者環境で再構築したモデルで学習させたときの結果を表している。

CIFAR-10において、ラベル有りデータが50000個中1000個というより少ないラベル数の設定での学習で、既存手法より6%程度低い誤分類率を達成しています。

ハイパーパラメータである指数移動平均の係数 \alpha(0~1)ですが、実験では学習初期は0.99に設定し、学習が進むにつれて0.999まで徐々に上げていく戦略がうまくいくと報告されています。 これは、学習初期はstudentの精度がどんどん改良していく段階でなるべく新しく学習した重みを重視するようteacherに更新するようにし、学習が成熟した段階で長いスパンでの平均的な重みをteacherに保持させるためという理由から納得できます。

Virtual Adversarial Training (VAT)

今回最後に紹介する手法は正則化に焦点を当てたVirtual Adversarial Trainingというものになります。 端的に言うと、入力画像に対し分類器としてのネットワークの出力がもっとも揺らぎやすい方向のノイズを付与しても分類結果が変わらないように制御する学習手法です。 つまり、半教師有り学習の設定でいうと上記のノイズを加えた入力と加えない生の入力との間でconsistency regularizationを考える手法になります。

この「出力がもっとも揺らぎやすい方向」をどう考えるかを含めて、まずはAdversarial Trainingから説明したいと思います。

Adversarial Training

入力画像にガウシアンノイズを付与するデータ拡張において分類ネットワークの汎化性能は向上しますが、ある特定の方向に対する弱いノイズに対しては予測結果が揺らぎやすいという性質が報告されています。[4] いわゆるadversarial attackというもので、以下のように汎化性能の高いモデルであっても人間の目には検知できない特定の方向(adversarial direction)の弱いノイズが付与された入力に対しては異なる推論結果を出してしまうという例で有名です。

[4]のFigure.1より引用。パンダの入力画像に対し57.7%の信頼度で当てるモデルでも、ロス関数 Jの(入力 xに関する)勾配方向の符号で決めたノイズを微小に同じパンダ画像に加えると、人間には見分けがつかなくても先ほどのモデルでは99.3%の信頼度でテナガザルだと誤って推論してしまう例。

Adversarial Trainingはこのadversarial attackに対する処方として考案されたもので、adversarial directionのノイズを付与させても出力が教師ラベルのそれと近くなるように学習させる手法になります。

adversarial trainingでのロスは以下で定義されます。

 x_lがラベル有りデータ、 yが教師ラベル、 \thetaがネットワークの重み、 Dは分布間の距離でここではcross entropyを想定、 \epsilonが考えうるノイズの絶対値、 r_{adv}がadversarial directionのノイズとなります。 また、 p(y|x)が入力 xに対するネットワークの出力(ベクトル)で、 q(y|x)が教師ラベルのベクトルになります。

 r_{adv}を求める必要がありますが、そのままでは解析的に書けないので以下の近似を施します。

  •  q(y|x_l)を教師ラベルのone-hotベクトル h(y;y_l)とみなす
  •  rを変数として Dの一次近似を施す

以上によりcross entropyの勾配を用いて r_{adv}を計算することができます。

Virtual Adversarial Training

virtual adversarial training[3]は先ほどのadversarial trainingを半教師有り学習に応用させた手法になります。

consistency regularizationの考えを踏まえて、(1)式(2)式に対し

  • ラベル有りデータ x_l → ラベルの有無に関係しない入力 x_*
  • 真のラベル分布 q(y|x) → 現在のネットワークの出力 p(y|x, \hat{\theta})

のように置き換えると以下のようになります。

 \hat{\theta}は現時点でのパラメータ \thetaをfixしたものを表します。 この r_{vadv}がvirtual adversarial perturbationの定義です。 virtualという名前はラベル無しデータの真のラベル分布を仮定するという意味合いで名づけられているそうです。 LDSと書いたロスを下げることは各データ点周辺のラベル分布を滑らかにする効果があり、consistency regularizationに対応するものとなっています。

そして、 r_{vadv}の求め方ですが様々な近似を適用することで以下の D r方向の勾配を用いて表すことができます。 ここでは近似の詳細は省略しますが、論文には数学的に議論されているので気になる方はそちらを参照してください。

 dはランダムな初期値ベクトル*3 \xiは微小な係数です。

学習ではLDSのネットワークの重み \thetaに関する勾配を知りたいので、アルゴリズムとしては以下のように求めることができます。

  • 入力 xを選ぶ
  • ランダムな単位ベクトル dをi.i.d.ガウシアンからサンプリング
  •  D rに関する r=\xi dにおける勾配 gを求める
  • 方向 gの大きさ \epsilonのベクトルとして r_{vadv}を計算する
  •  D r=r_{vadv}を代入しパラメータ \thetaに関する勾配を求める

最終的なロスとしては、上記をミニバッチで平均したものと教師有りロスを足したものになります。

以上がvirtual adversarial trainingの手法になります。

実験結果

ここでも概要のみになりますが、CIFAR-10とSHVNでの結果を見てみます。

[3]Table.5より引用。SVHNとCIFAR-10においてデータ拡張を入れて学習させた誤分類率の結果。 N_lはラベル有りデータの数を表す。下から二行目が今回の手法VATの結果、最下行はVATにEntMinという別の手法を組み合わせた結果。

既存手法(Temporal Ensembling)と比較して精度が改善もしくは同程度であることがわかります。

以下は実際に入力から計算されたノイズ r_{vadv}を入力に施したサンプルになります。

[3]Fig.5の一部より引用。

パラメータである \epsilonを増やすとノイズの強度や範囲が画像の重要な部分を中心に大きくなっていくことがわかります。 大きなノイズをかけすぎると不自然な入力になってしまうため、論文ではそれぞれのデータセットで真ん中の列の程度の強度を採用しているようです。

まとめ

今回はTemporal Ensembling, Mean teacher, VATという半教師有り学習に関する手法を紹介しました。指数移動平均をうまく使ったりadversarial trainingの知見を生かしたりと、個人的に調べていて興味深かったです。

次回はMixMatch, ReMixMatch, FixMatchという2019年以降に発表された一連の手法について紹介したいと思います。

参考文献

[1] Laine, Samuli, and Timo Aila. "Temporal Ensembling for semi-supervised learning." arXiv preprint arXiv:1610.02242 (2016).

[2] Tarvainen, Antti, and Harri Valpola. "Mean teachers are better role models: Weight-averaged consistency targets improve semi-supervised deep learning results." arXiv preprint arXiv:1703.01780 (2017).

[3] Miyato, Takeru, et al. "Virtual Adversarial Training: A Regularization Method for Supervised and Semi-Supervised Learning." arXiv preprint arXiv:1704.03976 (2017).

[4] Goodfellow, Ian J., Jonathon Shlens, and Christian Szegedy. "Explaining and harnessing adversarial examples." arXiv preprint arXiv:1412.6572 (2014).

*1:street view house numbers。住居看板の数字データセット。画像中央の数字を当てるタスク等に使われる。

*2:Temporal Ensemblingも重みの更新は1イテレーションごとでしたが、1エポック前の出力でロスを計算していたため遅いフィードバックとなっていました。

*3:ランダムな単位ベクトルで問題ないかと疑問を持たれる方がいると思います。近似の中で本来は d \leftarrow \nabla_r D(r, x, \hat{\theta})|_{r = \xi d}で反復的に d r_{vadv}方向に調整していくのですが、実験では1回の反復で十分な精度が得られるためこの形式になっています。