Counterfactualを知りたい

about Counterfactual ML, Recommender Systems, Learning-to-Rank, Causal Inference, and Bandit Algorithms

Domain Adversarial Neural Networksの解説

はじめに

最近自分の研究分野との親和性が高いこともあり, Unsupervised Domain Adaptation (教師なしドメイン適応)の理論を勉強しています. その理論を応用した手法に, Domain Adversarial Neural Networks (DANN) というものがあり自分でも動かしてみました.

目次

  • Unsupervised Domain Adaptationとは
  • H-divergenceを用いた汎化誤差上界
  • Domain Adversarial Neural Networks
  • 簡易実験
  • さいごに

Unsupervised Domain Adaptationとは

まずUnsupervised Domain Adaptation (UDA) を定式化します. 入力空間を \mathcal{X}, 出力空間を \mathcal{Y} = \{0, 1\}とします. ここで, あるdomain  Dとは, 入力の分布 P_Dとlabeling function f_D: \mathcal{X} \rightarrow \mathcal{Y}のpair  \left( P_D, f_D \right)のことを指します. UDAは, labelが得られているsource domainからのサンプル  \mathcal{S} = \{ x_i, y_i \}_{i=1}^{n_S} とlabelが得られていないtarget domainからのサンプル  \mathcal{T} = \{ x_j \}_{j=1}^{n_T} から, 次のtarget domainにおける期待判別誤差をできるだけ小さくするような仮説 hを得ることを目指します.


 \begin{aligned}
R_T^l (h, f_T) = \mathbb{E}_{X \sim P_T} \left[ l \left( h(X), f(X) \right)  \right]
\end{aligned}


以降は, 損失関数 lとして0-1損失のみを考えます. この時,


\begin{aligned}
R_T (h, f_T) = \mathbb{E}_{X \sim P_T} \left[ \left| h(X) - f(X) \right|  \right]
\end{aligned}


です. このように表されるtarget domainにおける期待判別誤差を(一様に)boundしたいというのが理論的なモチベーションになります. 学習データとテストデータのDomainが同一であるような通常の教師あり機械学習の場合, 次のような形で予測判別誤差をboundするのが一般的です.


任意の h \in \mathcal{H}について, 少なくとも 1 - \delta ( \forall \delta \in (0, 1)) の確率で次の不等式を満たす.


\begin{aligned}
R_D (h, f_D) \leqq  \hat{R}_D (h, f_D) + complexity \left( \mathcal{H} \right) + confidence\_level (\delta)
\end{aligned}



しかし, UDAではテストデータのlabelがサンプルとして得られていないので, target domainの経験判別誤差を用いることができません. よって, 次のようなboundを得ることを目指すこととします.


任意の h \in \mathcal{H}について, 少なくとも 1 - \delta ( \forall \delta \in (0, 1)) の確率で次の不等式を満たす.


\begin{aligned}
R_T (h, f_D) \leqq  & \underbrace{\hat{R}_S (h, f_S)}_{(1)} + \underbrace{complexity \left( \mathcal{H} \right)}_{(2)}  + \underbrace{confidence\_level (\delta)}_{(3)} \\
& + \underbrace{descrepancy (P_S, P_T)}_{(4)} + \underbrace{difference (f_S, f_D)}_{(5)}
\end{aligned}



(1)は, Source Domainにおける経験判別誤差. (2)は, 仮説集合の複雑さ. (3)は \deltaに依存する項. (4)はSourceとTargetの入力分布の乖離度. (5)はSourceとTargetのlabeling functionの乖離度です. 次節では, H-divergenceと呼ばれるdiscrepancyを用いたboundについて説明します.

H-divergenceを用いた汎化誤差上界

H-divergenceは次のように定義されるdiscrepancyの一種です.


\begin{aligned}
d_{\mathcal{H}} \left(  P_S, P_T \right) & = 2 \sup_{h \in \mathcal{H}} \left|  R_S (h,  1 ) -  R_T (h,  1 ) \right| \\
& = 2 \sup_{h \in \mathcal{H}} \left| \left( 1 -   R_S (h,  0 ) \right) -  R_T (h,  1 ) \right|\\
& = 2 \sup_{h \in \mathcal{H}} \left| 1 -  \left( R_S (h,  0 ) + R_T (h,  1 ) \right) \right| 
\end{aligned}

仮説集合が対称であるとき, empiricalには


\begin{aligned}
d_{\mathcal{H}} \left(  \hat{P}_S, \hat{P}_T \right) = 2 \left(  1 - \min_{h \in \mathcal{H}} \left( \frac{1}{n_S} \sum_{i = 1}^{n_S}\mathbb{I} \left[ h(x_i) = 0 \right]  + \frac{1}{n_T} \sum_{j = 1}^{n_T} \mathbb{I} \left[ h(x_j) = 1 \right]  \right) \right)
\end{aligned}


です. よって, H-divergenceは仮説集合 \mathcal{H}がsource domainとtarget domainのデータを入力から判別する性能に依存することがわかります. 入力からどちらのDomainからのサンプルかが判別できるほど, 2つのDomainのH-Divergenceが大きくなるというイメージです.

このH-divergenceを用いると次のようなboundが得られます. ([Ganin+ 2015]のTheorem 2, [Ben David+ 2010]のTheorem 2を参考にした)


任意の h \in \mathcal{H}について, 少なくとも 1 - \delta ( \forall \delta \in (0, 1)) の確率で次の不等式を満たす.


\begin{aligned}
R_T (h, f_D) \leqq  & R_S (h, f_S) + \frac{1}{2} d_{\mathcal{H}} \left(  P_S, P_T \right) + \beta
\end{aligned}


ただし,  \beta \geqq \inf_{h \in \mathcal{H}} \left( R_S(h, f_S) + R_T (h, f_T) \right)です. つまり,  \beta R_S(h, f_S) R_T(h, f_T)の和の下限の上界です.



さらに, 既存の統計的機械学習における結果を用いてempricalに推定可能なboundを次の通りに得ます.


 \mathcal{H}を有限のVC次元 dを持つ仮説集合とする. 任意の h \in \mathcal{H}について, 少なくとも 1 - \delta ( \forall \delta \in (0, 1)) の確率で次の不等式を満たす. ただし,  n_S = n_T = nとした.


\begin{aligned}
R_T (h, f_D) \leqq  & \hat{R}_S (h, f_S) + \sqrt{  \frac{4}{n} \left(  d \log \left( \frac{2en}{d} \right) + \log \left( \frac{4}{\delta} \right)  \right)       } \\
& + \frac{1}{2} d_{\mathcal{H}} \left(  \hat{P}_S, \hat{P}_T \right) + 4 \sqrt{  \frac{1}{n} \left(  2d \log \left( 2n \right) + \log \left( \frac{4}{\delta} \right)  \right)     } + \beta
\end{aligned}



VC次元やそれを用いた予測判別誤差と経験判別誤差の差の一様boundについては, MLPシリーズ『統計的機械学習』のChapter 2に説明があります.

Domain Adversarial Neural Networks (DANN)

ようやくDANNの説明に入ります. 前節で得た R_T (h, f_T)の上界のうち, 私たちがどうにかできるのはsource domainにおける経験判別誤差 \hat{R}_S (h, f_S) P_S, P_Tの経験H-Divergence  d_{\mathcal{H}} \left( \hat{P}_S, \hat{P}_T \right) です. これらの和を小さくするために3つのlayer  G_f, G_y, G_dを考えます.  G_f (\cdot ; \theta_f)representation layerと呼び, 入力空間 \mathcal{X}をある望ましい空間 \mathcal{R}写像する役割を持ちます.  G_y (\cdot ; \theta_y)prediction layerと呼び,  G_fで得た特徴表現 \mathcal{R}からlabelを予測します. 最後に,  G_d (\cdot ; \theta_d)domain layerと呼び, 特徴表現 \mathcal{R}からSource Domainから得られたサンプルなのかTarget Domainで得られたサンプルなのかを判別します.

これらを用いてDANNの損失関数は次のように定義されます.


\begin{aligned}
E \left( \theta_f, \theta_y, \theta_d \right) = \underbrace{ \frac{1}{n_S} \sum_{i = 1}^{n_S} \mathcal{L}^i_y \left( \theta_f, \theta_y \right)}_{(1)} - \underbrace{ \lambda \left(  \frac{1}{n_S} \sum_{i = 1}^{n_S}  \mathcal{L}^i_d \left( \theta_f, \theta_d \right)  + \frac{1}{n_T} \sum_{j = 1}^{n_T}  \mathcal{L}^j_d \left( \theta_f, \theta_d \right) \right) }_{(2)}
\end{aligned}


ここで,


\begin{aligned}
&  \mathcal{L}^i_y  \left( \theta_f, \theta_y \right) = \mathcal{L}_y \left( G_y \left( G_f (x_i, \theta_f), \theta_y \right), y_i  \right)  \\
&  \mathcal{L}^i_d \left( \theta_f, \theta_d \right) = \mathcal{L}_d \left( G_d \left( G_f (x_i, \theta_f), \theta_d \right), d_i  \right)
\end{aligned}


はそれぞれサンプル iに対するprediction lossdomain lossです. 損失関数 E \left( \theta_f, \theta_y, \theta_d \right) のうち, (1)はsource domainにおける経験判別誤差を表しており(2)は,  G_fによって生成される表現 \mathcal{R}上での経験H-Divergenceと読めます.  \lambdaはそのどちらをどれだけ重視するかを司るハイパーパラメータです. 要は E \left( \theta_f, \theta_y, \theta_d \right) は,  R_T(h, f_T)の上界のうち私たちがどうにかできる項と言えます. 3つのパラメータ \theta_f, \theta_y, \theta_gはそれぞれ次のように更新します.


\begin{aligned}
\theta_f & \leftarrow \theta_f - \mu \cdot \left(  \frac{\partial \mathcal{L}_y^i }{\partial \theta_f} - \lambda \frac{\partial \mathcal{L}_d^i }{\partial \theta_f}  \right) \\
\theta_y & \leftarrow \theta_y - \mu \cdot  \frac{\partial \mathcal{L}_y^i }{\partial \theta_y} \\
\theta_d & \leftarrow \theta_d - \mu \cdot  \frac{\partial \mathcal{L}_d^i }{\partial \theta_d}
\end{aligned}


 \muは学習率です. 3つの更新式の中で最も重要なのは,  \theta_fの更新式でしょう.  \theta_fは, prediction lossを小さくするような勾配とdomain lossを大きくするような勾配によって更新されていることがわかります. これにより,  G_fはLabelの予測には役立つ ( \hat{R}_S (h, f_S)を小さくする) がDomainの予測には役立たない ( d_{\mathcal{H}}を大きくする) ような入力表現 \mathcal{R}を得るための写像に近づいていくことが期待されます.

DANNのarchitectureは次の通りです. Domain Adversarialの名は, prediction layer  G_yとdomain layer  G_dが敵対的な関係にあることに由来すると思われます.

f:id:usaito:20190413045050p:plain
[Ganin+ 2015]のFigure 1

簡易実験

DANNのイメージをより鮮明に持つため, 人工データを用いた簡易実験を行ってみます. 本節は大いにこちらのrepositoryを参考にしました.

まず, scikit learnのmake_blobsを用いて人工データを生成します. sはsource, tはtargetを表しています.

Xs, ys = make_blobs(500, centers=[[0, 0], [0, 1]], cluster_std=0.2)
Xt, yt = make_blobs(500, centers=[[1, -1], [1, 0]], cluster_std=0.2)

描画すると次のような感じです.

f:id:usaito:20190413060233p:plain
入力の初期分布

このうち学習時にsource domainの入力とラベル, target domainの入力のみを用いて, テストデータにおけるラベルを精度よく予測したいというのがUDAの目標でした. いよいよDANNを学習します. 学習とテストは8:2で分け, OptimizerはMomentum (learning_rate=0.01, momentum=0.6), batchサイズは32, epoch数は5,000としました.

結果は次の通りです. 表の結果は, テストデータにおける最終epochの結果です. ちゃんとvalidationを用意して検証すればもう少し良い結果が出ると思いますが, target domainのラベルを学習時に全く用いていないのにも関わらず, 90%以上の精度を達成しています.

Source Target Domain
Cross Entropy 0.03359 0.20099 0.68040
Accuracy (%) 99.019 90.815 55.371

f:id:usaito:20190413060745p:plain
学習の様子

一方で, domainの判別はうまくいっていないことから, representation layerでdomainの判別が付かないような( d_{\mathcal{H}}が小さいような)入力表現を得ることができていそうです. 実際, representation layerでの表現を抜き出してPCAで2次元に圧縮して描画してみると次のようになりました.

f:id:usaito:20190413060320p:plain
representation layerにおける特徴表現

これを見ると, source domain (赤, 薄赤) と target domain (青, 緑) が上と下に分かれていそうですが, 初期表現と比べるとかなり判別しにくくなっていることがわかります. 一方で, class 0とclass 1は綺麗に左右に分かれており, source domain target domainに関わらず, labelの判別はうまくいきそうなことがわかります. もちろんかなりシンプルな人工データを使ったからうまくいっているのですが, DANNのイメージが湧きやすい結果が出たのではないでしょうか.

さいごに

今回は, [Ganin+ 2015]で提案されたDomain Adversarial Neural Networksのarchitectureを理論背景も含めて整理し, 人工データを用いた追試を行ってみました. 個人的なモチベーションとしてはDANNそのものではなく, その別の分野への応用です. その話題についても今後触れようと思います.

参考

[Ben David+ 2007] Ben-David, S.; Blitzer, J.; Crammer, K.; and Pereira, F. 2007. Analysis of representations for domain adaptation. In NIPS, 137–144.
[Ben David+ 2010] Ben-David, S.; Blitzer, J.; Crammer, K.; Kulesza, A.; Pereira, F.; and Vaughan, J. W. 2010. A theory of learning from different domains. Machine Learning 79(1-2):151– 175.
[Ganin+ 2015] Ganin, Y.; Ustinova, E.; Ajakan, H.; Germain, P.; Larochelle, H.; Laviolette, F.; Marchand, M.; and Lempitsky, V. 2016. Domain-adversarial training of neural networks. Journal of Machine Learning Research 17(1):2096–2030.
[Kota Matsui 2019] Recent Advances on Transfer Learning and Related Topics. (https://www.slideshare.net/KotaMatsui/recent-advances-on-transfer-learning-and-related-topics)