【論文読み】Semi-Supervised Learning with Generative Adversarial Networks
新年早々だがいいスタートをきりたいという三日坊主感があるものの、とりあえず論文を読んだ。
今回読んだ論文はSemi-Supervised Learning with Generative Adversarial Networksである。
GANを利用した半教師あり学習で高い精度がでているらしいので興味を持った。
Abstract
- Generated Adversarial Networks(GANs)で利用されるDiscriminatorの出力をクラス分類にすることによって半教師あり学習に拡張
- N個のクラスに属しているデータを利用してGeneratorとDiscriminatorを学習
- 学習時のDiscriminatorはGeneratorで生成したデータをN+1番目のクラスに分類できるようにクラスを拡張して予測
- MNISTデータセットを利用して本手法がより効率のいい分類器の生成と従来のGANよりも質の高いデータを生成できることを示す
Introduction
- 生成モデルの一つとしてGANが提案され、一つのニューラルネットワークでより良いデータを生成できるように
- GANを利用して半教師あり学習での分類器と生成モデルを同時に学習するようにすることを考える
- 深層生成モデルを利用した半教師あり学習の研究:Semi-Supervised Learning with Deep Generative Models
- GANを利用した半教師あり学習の研究:Unsupervised and Semi-supervised Learning with Categorical Generative Adversarial Networks
なぜGANを利用するといいのか?
生成データと訓練データを見極めるモデルをD、分類モデルをCとする。
この2つのモデルの関係はDの精度が向上するとCの精度が向上するということは妥当らしい。
また逆も同様である。
今回利用するGANのなかのDiscriminatorはCとDのモデルを併用したものを利用してお互いの学習を助けるようにする。
Semi-Supervised GAN Model
従来のGANのDiscriminatorはデータを入力しそのデータが生成データである確率を出力するような構造となっている。
従来GANのDiscriminator
$$
\left[ Real, Fake \right]
$$
出力関数はシグモイド関数やソフトマックス関数が利用されることが多い。
今回のSGANはクラス分類がN個だとするとN+1個のユニットが出力されるような構造をとる。
SGANのDiscriminator
$$
\left[ Class1, Class2, Class3, \dots, Fake \right]
$$
このDiscriminatorが分類器の役割を持つように学習を行う。
具体的な学習フローは以下の通り。
- ノイズから生成データをドロー
- 訓練データをバッジ数ドロー
- 分類器のためのDの教師あり学習
- 生成データか判別するDの学習
- 新たなノイズから生成データをドロー
- Gを学習する
Results
MNISTを利用して検証を行った。
生成データの検証
従来のGANと生成された画像を比較する。
従来のGANよりもくっきりとした画像を生成することができる。
分類モデルの検証
SGANの分類器が優れているかどうかの検証。
検証を行う時にGは更新しなかったらしい。
ここは正直よくわからなかった。
結果を見るとサンプル数が少なくても高い精度を保っていることがわかる。
感想
まあ生成データと訓練データの判別とラベルの分類の学習は関係があるからGANで半教師あり学習をしているという雰囲気はわかった。
今後の検証としては生成データと訓練データでラベルを分けていたがそれをクラス毎に分けるという検証も必要だとfuture workには書いてあった。
つまりラベルが2N個になるということかな?
ともあれ一回実装して試してみたい。
GANを利用した半教師あり学習の論文はこちらのほうが有名かもしれないので今度はこれを詳しく読んでいきたい。