Counterfactualを知りたい

about Counterfactual ML, Recommender Systems, Learning-to-Rank, Causal Inference, and Bandit Algorithms

統計的学習理論(有限仮説集合の場合の予測損失の上界)

はじめに

最近自身の研究で使うため統計的学習理論の勉強をしています. 2回に渡って基本的な内容をまとめてみます.

目次

定式化

まず, 統計的学習理論のモチベーションを述べます. 入力空間 \mathcal{X}に値をとる確率変数を X, 入力空間から出力空間 \mathcal{Y}への写像 f: \mathcal{X} \rightarrow \mathcal{Y}としてlabeling functionと呼ぶことにします. また, 入力が従う確率分布を Dとします ( X \sim D). 損失関数として \ell: \mathcal{Y} \times \mathcal{Y} \rightarrow \mathbb{R}_+を用いるとします. このとき, ある仮説 h: \mathcal{X} \rightarrow \mathcal{Y}の予測損失は次のように定義されます.


\begin{aligned}
R (h) = \mathbb{E}_{ X \sim D } \left[  \ell \left( h(X), f(X) \right) \right]
\end{aligned}


つまり, 分布 Dにおける予測値 h(X)の損失の期待値です.  R(h)をできるだけ小さくするような仮説 hを見つけ出すことが目標です.

ここで, 予測損失 R(h)がわかっていれば, 問題は随分簡単になるのですが, 残念ながらデータの真の分布 Dは未知です. よって, 有限の観測データから, 予測損失をできるだけ小さくする仮説を導く必要があります. 予測損失を観測データから評価する上で最も大きな手がかりは経験損失です. データ  \left\{ X_i \right\}_{i=1}^n が観測されたときの経験損失は次のように定義されます.


\begin{aligned}
\hat{R} (h) = \mathbb{E}_{ X \sim \hat{D} } \left[  \ell \left( h(X), f(X) \right) \right] = \frac{1}{n} \sum_{i=1}^n \ell  \left( h(X_i), f(X_i) \right)
\end{aligned}


ここで,  \hat{D}をデータの経験分布としました. これは, データ数が nのときに, 確率 1 / nで各観測データ X_iの値をとるような分布のことです.

 \mathcal{H} = \{h_1, h_2, ..., h_{} \}をある有限な仮説集合とします. このとき, この仮説集合に含まれる仮説の中で予測損失と経験損失を最小化するような仮説をそれぞれ h^*, \hat{h}と表しておきます.


\begin{aligned}
h^* \in \arg \min_{h \in \mathcal{H}} R (h), \quad \hat{h} & \in \arg \min_{h \in \mathcal{H}} \hat{R} (h)
\end{aligned}


 \hat{h}は, Empirical Risk Minimizerと呼んだりもします.  h^*, \hat{h}について, 定義より次の2つの不等式が成り立ちます.


\begin{aligned}
R (h^* )  \leqq R( \hat{h} ), \quad \hat{R} (\hat{h})  \leqq  \hat{R} (h^*)
\end{aligned}


ここで私たちが導きたいのは, 観測データから計算できる経験損失を最小化する基準で得られる仮説 \hat{h}の期待損失 R ( \hat{h}) と 仮説集合 \mathcal{H}を用いたときに達成され得る最小の期待損失 R(h ^*)の差を次のように評価することです.


少なくとも 1 - \deltaの確率で次の不等式が成り立つ.


\begin{aligned}
R (\hat{h} ) \leqq   R (h^* ) + [extra \: term]
\end{aligned}



次章では, 仮説集合 \mathcal{H} \deltaに依存する [extra term]を具体的に求めます.

有限仮説集合の場合の予測損失の上界の導出

さて, 本記事で考えている仮説集合 \mathcal{H}は有限でした. またここでは , 損失関数 \ellがある定数 Mで上からboundできるとします (例えば, 01-lossの場合は,  M=1). この場合, こちらの記事で紹介したHoeffding's ineqを用いて R (\hat{h} )  R (h^* )の差を評価することができます.

まず,


\begin{aligned}
R (\hat{h} ) -   R (h^* ) 
& =  R (\hat{h} ) -   \hat{R} (\hat{h} ) + \underbrace{ \hat{R} (\hat{h} )  - \hat{R} (h^* ) }_{ \leqq 0 } + \hat{R} (h^*) -  R (h^* )  \\
& \leqq R (\hat{h} ) -   \hat{R} (\hat{h} ) +  \hat{R} (h^*) -  R (h^* )  \\
& \leqq 2 \max_{h \in \mathcal{H}} \left| \hat{R} (h) -  R (h) \right|
\end{aligned}


ここで, 右辺の裾確率をunion boundとHoeffding's ineqを用いて次のように評価します. 途中で係数2が登場しているのは, 誤差の絶対値を評価するため両側の裾確率を考慮しているからです.


\begin{aligned}
& \mathbb{P} \left( 2 \max_{h \in \mathcal{H}} \left| \hat{R} (h) -  R (h) \right|  \geqq \epsilon  \right) \\
& \leqq \sum_{h \in \mathcal{H}} \mathbb{P} \left( \left| \hat{R} (h) -  R (h) \right| \geqq \epsilon / 2 \right) \\
&  \leqq  \sum_{h \in \mathcal{H}} 2 \exp \left( - \frac{n \epsilon^2} {2M^2} \right) \quad \because Hoeffding's \,  ineq \\
& = 2 | \mathcal{H} | \exp \left( - \frac{n \epsilon^2} {2M^2} \right)
\end{aligned}


右辺を \deltaと置いて \epsilonについて解けば,


\begin{aligned}
2 | \mathcal{H} | \exp \left( - \frac{n \epsilon^2} {2M^2} \right) = \delta \Leftrightarrow \epsilon = M \sqrt {  \frac{2}{n} \log \frac{2 | \mathcal{H}| }{ \delta}  }
\end{aligned}


よって,


\begin{aligned}
\mathbb{P} \left( 2 \max_{h \in \mathcal{H}} \left| \hat{R} (h) -  R (h) \right|  \geqq M \sqrt {  \frac{2}{n} \log \frac{2 | \mathcal{H}| }{ \delta}  }  \right)  \leqq  \delta
\end{aligned}


なので, 少なくとも 1 - \deltaの確率で次の不等式が成り立ちます.


\begin{aligned}
2 \max_{h \in \mathcal{H}} \left| \hat{R} (h) -  R (h) \right|  \leqq M \sqrt {  \frac{2}{n} \log \frac{2 | \mathcal{H}| }{ \delta}  }
\end{aligned}


以上の結果を統合すると, 最終的な目標であった次の不等式が少なくとも 1 - \deltaの確率で成り立つという結果を得ます.


\begin{aligned}
R (\hat{h} ) \leqq   R (h^* )  + M \sqrt {  \frac{2}{n} \log \frac{2 | \mathcal{H}| }{ \delta}  }
\end{aligned}


したがって, 仮説集合 \mathcal{H}が有限である場合,  R ( \hat{h} )  R (h^* )に 少なくとも \mathcal{O}_p \left( \sqrt{ \log | \mathcal{H} |  / n } \right) で収束することがわかります.

さいごに

今回は, 仮説集合が有限な場合の \hat{h}の予測損失の上界をHoeffding's ineqを用いて導出しました. しかし, 今回導いた上界は仮説集合が無限である場合 (i.e.,  |\mathcal{H}| = \infty) 実用的ではありません. (上界の第2項を見れば一目瞭然.) よって, 次はRademacher Complexityという仮説集合の複雑さの指標を用いて仮説集合が無限の場合も意味を成す期待損失の上界を導いてみます.

参考

[金森 (2015)] 金森敬文. 2015. 統計的学習理論. 講談社 機械学習プロフェッショナルシリーズ.
[Mohri et al. (2012)] Mohri, M.; Rostamizadeh, A.; and Talwalkar, A. 2012. Foundations of Machine Learning. MIT Press.