Counterfactualを知りたい

about Counterfactual ML, Recommender Systems, Learning-to-Rank, Causal Inference, and Bandit Algorithms

統計的学習理論(Rademacher Complexityを用いた期待損失の導出)

はじめに

前回の記事では, 仮説集合 \mathcal{H}が有限である場合の, 仮説 h \in \mathcal{H}の予測損失の上界をHoeffding's ineqを用いて導きました. しかし, 無限仮説集合に対しては同様の方法で実用的な上界を得ることは不可能でした. したがって, 今回は無限仮説集合に対応する方法の一つであるRademacher Complexityを用いて予測損失の上界を導いてみようと思います.

目次

定式化のおさらい

統計的学習理論のモチベーションをおさらいします. 前回記事をお読みいただいている方は読み飛ばしていただいても結構です.

入力空間 \mathcal{X}に値をとる確率変数を X, 入力空間から出力空間 \mathcal{Y}への写像 f: \mathcal{X} \rightarrow \mathcal{Y}としてlabeling functionと呼びます. また, 入力が従う確率分布を Dとします ( X \sim D). 損失関数として \ell: \mathcal{Y} \times \mathcal{Y} \rightarrow \mathbb{R}_+を用いるとします. このとき, ある仮説 h: \mathcal{X} \rightarrow \mathcal{Y}の予測損失は次のように定義されます.


\begin{aligned}
R (h) = \mathbb{E}_{ X \sim D } \left[  \ell \left( h(X), f(X) \right) \right]
\end{aligned}


つまり, 分布 Dにおける予測値 h(X)の損失の期待値です.  R(h)をできるだけ小さくするような仮説 hを見つけ出すことが機械学習の目標です.

ここで, データの真の分布 Dは未知であることが問題であり, 有限の観測データから, 予測損失をできるだけ小さくする仮説を導く必要がありました. 予測損失を観測データから評価する上で最も大きな手がかりは経験損失です. データ  \left\{  X_i  \right\}_{i=1}^n が観測されたときの経験損失は次のように定義されます.


\begin{aligned}
\hat{R} (h) = \mathbb{E}_{ X \sim \hat{D} } \left[  \ell \left( h(X), f(X) \right) \right] = \frac{1}{n} \sum_{i=1}^n \ell  \left( h(X_i), f(X_i) \right)
\end{aligned}


ここで,  \hat{D}はデータの経験分布です.

また, 仮説集合 \mathcal{H}に含まれる仮説の中で予測損失と経験損失を最小化するような仮説をそれぞれ  h^* \in \arg \min_{h \in \mathcal{H}} R (h),  \hat{h} \in \arg \min_{h \in \mathcal{H}} \hat{R} (h)と表していました.

ここで私たちが導きたいのは, 観測データから計算できる経験損失を最小化する基準で得られる仮説 \hat{h}の期待損失 R ( \hat{h}) と 仮説集合 \mathcal{H}を用いたときに達成され得る最小の期待損失 R(h ^*)の差を次のように評価することです.


少なくとも 1 - \deltaの確率で次の不等式が成り立つ.


\begin{aligned}
R (\hat{h} ) \leqq   R (h^* ) + complexity \left( \mathcal{H} \right) + confidence ( \delta )
\end{aligned}



以降の目標は, 仮説集合 \mathcal{H}のcomplexity項と \deltaに依存するconfidence項を具体的に求めることです.

Rademacher Complexity

Rademacher Complexityの導入

何度か名前が出てきていますが, 実数値関数の集合の複雑さの指標として解釈されるEmpirical Rademacher ComplexityRademacher Complexityを定義します.

Empirical Rademacher Complexity: 空間 \mathcal{Z}上の実数値関数からなる集合を \mathcal{G} \subset \{ g: \mathcal{Z} \rightarrow \mathbb{R} \}とする. また, 入力点の集合を \mathcal{S} = \{ z_1, ..., z_n \}としする. さらに, 同数の+1と-1を等確率でとる独立な確率変数 (Rademacher variables)を \boldsymbol {\sigma} = \{ \sigma_1, ..., \sigma_n \}とする. このとき, 集合 \mathcal{G}のEmpirical Rademacher Complexityは次のように定義される.


\begin{aligned}
\hat{\mathfrak{R}}_{ \mathcal{S} } (\mathcal{G}) =  \mathbb{E}_{ \boldsymbol {\sigma} } \left[ \sup_{g \in \mathcal{G}} \frac{1}{n} \sum_{i=1}^n \sigma_i g(z_i)  \right]
\end{aligned}


さらに, それぞれのサンプル z_iはある分布 Dに独立に従う( \mathcal{S} \sim D^{n}) とき, 次のRademacher Complexityを定義します.

Rademacher Complexity: 入力点 z_iがある分布 Dに従う確率変数であるとする. また, ある入力点のサンプル集合 \mathcal{S} = \{ z_1, ..., z_n \}が与えられたとする. このとき, 集合 \mathcal{G}のRademacher Complexityは次のように定義される.


\begin{aligned}
\mathfrak{R}_{n } (\mathcal{G}) 
& =  \mathbb{E}_{ S \sim D^n } \mathbb{E}_{ \boldsymbol {\sigma} } \left[ \sup_{g \in \mathcal{G}} \frac{1}{n} \sum_{i=1}^n \sigma_i g(z_i)  \right] \\
& =  \mathbb{E}_{ S \sim D^n }  \left [ \hat{\mathfrak{R}}_{ \mathcal{S} } (\mathcal{G})   \right]
\end{aligned}


Rademacher variables  \boldsymbol{ \sigma } \mathcal{S}とは独立にサンプルされます. よって, 入力点とそれに対するランダムなラベル付け (z_1, \sigma_1), ..., (z_n, \sigma_n)を 最大でどれくらい予測できてしまう関数を含むか?を測っていると解釈できます. 入力集合 \mathcal{S}とはなんら無関係なランダムノイズの組みに対してよくfittingできてしまうほど,  \mathcal{G}は複雑である, ということです.

Rademacher Complexityの推定

さて, 定義よりRademacher ComplexityはEmpirical Rademacher Complaxityをデータ集合について期待値をとったものでした. しかし, ある入力点のサンプル集合 \mathcal{S}が与えられたとき, Empirical Rademacher Complexityは必ずしもRademacher Complexityに一致するとは限りません. よって, Rademacher Complexityの推定誤差の裾確率を評価することで, 誤差の大きさを見積もってみようと思います.

Rademacher Complexity vs Empirical Rademacher Complexity: 入力空間 \mathcal{Z}上の実数値関数の集合を \mathcal{G} \subset \{ g: \mathcal{Z} \rightarrow [0, M ] \} とする. このとき, 少なくとも 1 - \delta / 2 ( \delta > 0)の確率で次の不等式が成り立つ.


\begin{aligned}
\mathfrak{R}_{n } (\mathcal{G})  \leqq  \hat{\mathfrak{R}}_{ \mathcal{S} } (\mathcal{G}) + M \sqrt {  \frac{ \log \frac{2}{ \delta } }{ 2n }  }
\end{aligned}


導出
 g(z_1, ..., z_n) = \hat{\mathfrak{R}}_{ \mathcal{S} } (\mathcal{G}) =  \mathbb{E}_{ \boldsymbol {\sigma} } \left[ \sup_{g \in \mathcal{G}} \frac{1}{n} \sum_{i=1}^n \sigma_i g(z_i)  \right]とおきます. このとき,


\begin{aligned}
& | g(z_1, ..., z_i, .., z_n) - g(z_1, ..., z'_i, .., z_n) | \\
& =\mathbb{E}_{ \boldsymbol {\sigma} } \left[ \sup_{g \in \mathcal{G}} \left( \frac{1}{n} \sum_{j=1}^n \sigma_j g(z_j) \right) 
-  \sup_{g \in \mathcal{G}} \left( \frac{1}{n} \sum_{j=1}^n \sigma_j g(z_j) - \frac{1}{n} \sigma_i g(z_i) + \frac{1}{n} \sigma_{i} g(z_i') \right) \right]  \\
& \leqq  \mathbb{E}_{ \boldsymbol {\sigma} } \left[ \sup_{g \in \mathcal{G}} \left( \frac{1}{n} \sum_{j=1}^n \sigma_j g(z_j) \right) 
-  \sup_{g \in \mathcal{G}} \left( \frac{1}{n} \sum_{j=1}^n \sigma_j g(z_j) \right)  + \sup_{g \in \mathcal{G}} \frac{1}{n} |  \sigma_i g(z_i) - \sigma_{i} g(z_i') |   \right]  \\
& =  \mathbb{E}_{ \boldsymbol {\sigma} } \left[ \sup_{g \in \mathcal{G}} \frac{1}{n} |  \sigma_i g(z_i) - \sigma_{i} g(z_i') |   \right]
\leqq \frac{M}{n}
\end{aligned}


ここでは, 関数 gの値域が [0, M]であることとと,  \sigma \in \{-1, +1\}であることを用いました. これにて, 関数 gこちらの記事で導出したMacDiamid's ineqのboundedness conditonを満たすことがわかります. したがって, MacDiamid's ineqで c_i = M / nとおくと, 任意の \epsilon > 0に対して, 次の不等式が成り立つことがわかります.


\begin{aligned}
\mathbb{P} \left(  \mathbb{E} \left [ g ( \mathcal{S} ) \right] -   g ( \mathcal{S} ) \geqq \epsilon  \right) \Leftrightarrow
\mathbb{P} \left(  \mathfrak{R}_{n } (\mathcal{G}) -   \hat{\mathfrak{R}}_{ \mathcal{S} } (\mathcal{G}) \geqq \epsilon  \right) \leqq \exp \left( - \frac{ 2 n \epsilon^2 }{M^2} \right)
\end{aligned}


ここで, 左辺を \delta / 2とおくと,


\begin{aligned}
\exp \left( - \frac{ 2 n \epsilon^2 }{M^2} \right) = \frac{\delta}{2} \:  \Leftrightarrow  \: M \sqrt {  \frac{ \log \frac{2}{ \delta } }{ 2n }  }
\end{aligned}


です. よって,


\begin{aligned}
\mathbb{P} \left(  \mathfrak{R}_{n } (\mathcal{G}) -   \hat{\mathfrak{R}}_{ \mathcal{S} } (\mathcal{G}) \leqq  M \sqrt {  \frac{ \log \frac{2}{ \delta } }{ 2n }  } \right) 
>  1 - \frac{\delta}{2}
\end{aligned}


ですので, 所望の不等式を得ます.

Uniform law of large numbers

さて, 前章でRademacher Complexityを導入しました. 本章では, これを用いて次のUniform law of large numbersを導きます.

Uniform law of large numbers: 入力空間 \mathcal{Z}上の実数値関数の集合を \mathcal{G} \subset \{ g: \mathcal{Z} \rightarrow [0, M ] \} とする. また, ある入力点のサンプル集合を \mathcal{S} = \{ z_1, ..., z_n \}とし, 一つ一つのサンプルは独立に zと同じ分布 Dに従うとする. このとき, 少なくとも 1 - \delta ( \delta > 0)の確率で次の不等式が成り立つ.


\begin{aligned}
\mathbb{E} [g (z) ]  \leqq \frac{1}{n} \sum_{i=1}^n g (z_i) + 2  \hat{\mathfrak{R}}_{ \mathcal{S} } (\mathcal{G}) + 3 M \sqrt {  \frac{ \log \frac{2}{ \delta } }{ 2n }  }
\end{aligned}



導出
まず,


\begin{aligned}
\varphi( z_1, ..., z_n)  = \sup_{g \in \mathcal{G}}  \left( \mathbb{E} [g (z) ]  - \frac{1}{n} \sum_{i=1}^n g (z_i)  \right)
\end{aligned}


と置きます. ここで先ほど同様に \varphi( z_1, ..., z_n)  に対してMacDiamid's ineqを適用します.


\begin{aligned}
\varphi( z_1, ..., z_i, ..., z_n)  - \varphi( z_1, ..., z'_i, ..., z_n)  \leqq   \sup_{g \in \mathcal{G}} \frac{g(z_i) - g(z_i') }{n}  \leqq \frac{M}{n}
\end{aligned}


より,  \varphi( z_1, ..., z_n)はboundedness conditionを満たすので, 少なくとも 1 - \delta / 2の確率で, 次の不等式が成り立ちます.


\begin{aligned}
\varphi( \mathcal{S} )  \leqq \mathbb{E}_{S \sim D^n} \left[  \varphi( \mathcal{S} ) \right] + M \sqrt {  \frac{ \log \frac{2}{ \delta } }{ 2n }  }
\end{aligned}


ここで,  \varphi( z_1, ..., z_n) = \varphi( \mathcal{S} ) としました.

次に, 今導いた不等式の右辺の第1項をRademacher Complexityを用いて評価します.


\begin{aligned}
 \mathbb{E}_{S} \left[  \varphi( \mathcal{S} ) \right]  
& =  \mathbb{E}_{S} \left[ \sup_{g \in \mathcal{G}}  \left( \mathbb{E} [g (z) ]  - \frac{1}{n} \sum_{i=1}^n g (z_i)  \right)  \right]   \\
& =  \mathbb{E}_{S} \left[ \sup_{g \in \mathcal{G}}  \left( \mathbb{E}_{S'} \left [ \frac{1}{n} \sum_{i=1}^n g (z'_i)  \right ]  - \frac{1}{n} \sum_{i=1}^n g (z_i)  \right)  \right] \\
& =  \mathbb{E}_{S} \left[ \sup_{g \in \mathcal{G}}  \left( \mathbb{E}_{S'} \left[ \frac{1}{n} \sum_{i=1}^n g (z'_i)   - \frac{1}{n} \sum_{i=1}^n g (z_i)  \right ] \right)   \right] \\
& \leqq \mathbb{E}_{S, S'} \left[ \sup_{g \in \mathcal{G}}  \frac{1}{n} \sum_{i=1}^n (g (z'_i)   -  g (z_i) )    \right] 
\end{aligned}


ここで,  z_i z'_iは同一分布 Dに従います. また,  \sigma_iは+1と-1を等確率でとるRademacher variableです. このとき,  (g (z'_i)   -  g (z_i) )   \sigma_i (g (z'_i)   -  g (z_i) )  は同一分布に従います. よって,


\begin{aligned}
& \mathbb{E}_{S, S'} \left[ \sup_{g \in \mathcal{G}}  \frac{1}{n} \sum_{i=1}^n (g (z'_i)   -  g (z_i) )    \right]  \\
&= \mathbb{E}_{ \boldsymbol{\sigma}, S, S'} \left[ \sup_{g \in \mathcal{G}}  \frac{1}{n} \sum_{i=1}^n \sigma_i (g (z'_i)   -  g (z_i) )    \right] \\
& \leqq \mathbb{E}_{ \boldsymbol{\sigma}, S, S'} \left[ \sup_{g \in \mathcal{G}}  \frac{1}{n} \sum_{i=1}^n \sigma_i g (z'_i)    \right] 
+ \mathbb{E}_{ \boldsymbol{\sigma}, S, S'} \left[ \sup_{g \in \mathcal{G}}  \frac{1}{n} \sum_{i=1}^n - \sigma_i g (z_i)    \right] \\
& = \mathbb{E}_{ \boldsymbol{\sigma}, S} \left[ \sup_{g \in \mathcal{G}}  \frac{1}{n} \sum_{i=1}^n \sigma_i g (z_i)    \right] 
+ \mathbb{E}_{ \boldsymbol{\sigma}, S} \left[ \sup_{g \in \mathcal{G}}  \frac{1}{n} \sum_{i=1}^n  \sigma_i g (z_i)    \right] \\
& = 2 \mathfrak{R}_{n } (\mathcal{G})
\end{aligned}


これらの結果を合わせると, 少なくとも 1 - \delta / 2の確率で次の不等式が成り立ちます.


\begin{aligned}
\sup_{g \in \mathcal{G}}  \left( \mathbb{E} [g (z) ]  - \frac{1}{n} \sum_{i=1}^n g (z_i)  \right) =  \varphi( z_1, ..., z_n)  \leqq 2 \mathfrak{R}_{n } (\mathcal{G}) + M \sqrt {  \frac{ \log \frac{2}{ \delta } }{ 2n }  }
\end{aligned}


さらに, 前章の (Rademacher Complexity vs Empirical Rademacher Complexity) の結果を用いると, 少なくとも 1 - \delta / 2の確率で次の不等式が成り立ちます.


\begin{aligned}
 2 \mathfrak{R}_{n } (\mathcal{G}) \leqq 2 \hat{\mathfrak{R}}_{ \mathcal{S} } (\mathcal{G})  + 2M \sqrt {  \frac{ \log \frac{2}{ \delta } }{ 2n }  }
\end{aligned}


以上の結果を統合し, union boundを用いることで, 少なくとも 1 - \deltaの確率で次の不等式が成り立つことを得ます.


\begin{aligned}
\mathbb{E} [g (z) ]  \leqq \frac{1}{n} \sum_{i=1}^n g (z_i) + 2  \hat{\mathfrak{R}}_{ \mathcal{S} } (\mathcal{G}) + 3 M \sqrt {  \frac{ \log \frac{2}{ \delta } }{ 2n }  }
\end{aligned}


 L_p損失を用いた時の予測損失の上界

さてようやく本題です. ここでは, 次にように表される L_p損失を用いたときの, 予測損失の確率的上界を導出します.  p \geqq 1で, また損失は定数 M^{p}でboundされるとします.


\begin{aligned}
R (h) = \mathbb{E}_{ X \sim D } \left[  \ell_{p} \left( h(X), f(X) \right) \right] = \mathbb{E}_{ X \sim D } \left[ \left| h(X) -  f(X) \right|^p \right] 
\end{aligned}


ここで,  L_p損失と仮説の合成関数の集合 \mathcal{H}_p = \left\{ x \rightarrow | h(x) - f(x) |^p : h \in \mathcal{H} \right\}のEmpirical Rademacher Complexityを仮説集合 \mathcal{H}のEmpirical Rademacher Complexityで評価します.

まず, 先ほどの集合 \mathcal{H}_pを関数 \phi_p: x \rightarrow |x|^pと集合 \mathcal{H}' = \{x \rightarrow h(x) - f(x): h \in \mathcal{H} \}を用いて,  \mathcal{H}_p = \{ \phi_p \circ  \mathcal{H}' \}と表しておきます. 関数 \phi_pは,  pM^{p-1}-Lipschitzなので, Talagrand's lemma*1を用いると,


\begin{aligned}
 \hat{\mathfrak{R}}_{ \mathcal{S} } (\mathcal{H}_p) \leqq  pM^{p-1}  \hat{\mathfrak{R}}_{ \mathcal{S} } (\mathcal{H}')
\end{aligned}


が成り立ちます. また,


\begin{aligned}
 \hat{\mathfrak{R}}_{ \mathcal{S} } (\mathcal{H}') 
& = \mathbb{E}_{ \boldsymbol {\sigma} } \left[ \sup_{h \in \mathcal{H}} \frac{1}{n} \sum_{i=1}^n \sigma_i (h(x_i) - f(x_i))  \right] \\
& \leqq \mathbb{E}_{ \boldsymbol {\sigma} } \left[ \sup_{h \in \mathcal{H}} \frac{1}{n} \sum_{i=1}^n \sigma_i h(x_i)   \right] 
+ \underbrace{ \mathbb{E}_{ \boldsymbol {\sigma} } \left[ \sup_{h \in \mathcal{H}} \frac{1}{n} \sum_{i=1}^n \sigma_i f(x_i)   \right]}_{= 0} \\
& =  \hat{\mathfrak{R}}_{ \mathcal{S} } (\mathcal{H}) 
\end{aligned}


したがって, 結局のところ


\begin{aligned}
 \hat{\mathfrak{R}}_{ \mathcal{S} } (\mathcal{H}_p) \leqq  pM^{p-1}  \hat{\mathfrak{R}}_{ \mathcal{S} } (\mathcal{H})
\end{aligned}


なので,  L_p損失と仮説集合の合成関数の集合のEmpirical Rademacher Complexityは, 仮説集合のEmpirical Rademacher Complexityで上から評価できます. この事実と, Uniform law of large numbersにおいて,  \mathcal{G} \mathcal{H}_pとすれば, 任意の仮説 h \in \mathcal{H}に対して, 少なくとも 1 - \deltaの確率で次の不等式が成り立ちます.


\begin{aligned}
R (h) -  \hat{R} (h) = & \mathbb{E}_{X \sim D} [ \ell_{p} (h(X),  f(X ))  ] - \frac{1}{n} \sum_{i=1}^n \ell_{p} (h(X_i),  f(X_i)) \\
&  \leqq   2pM^{p-1}  \hat{\mathfrak{R}}_{ \mathcal{S} } (\mathcal{G}) + 3 M^p \sqrt {  \frac{ \log \frac{2}{ \delta } }{ 2n }  }
\end{aligned}


この不等式を用いれば, 少なくとも 1 - \deltaの確率で次の不等式が成り立つことを導くことができます.


\begin{aligned}
R (\hat{h} ) -   R (h^* ) 
& =  R (\hat{h} ) -   \hat{R} (\hat{h} ) + \underbrace{ \hat{R} (\hat{h} )  - \hat{R} (h^* ) }_{ \leqq 0 } + \hat{R} (h^*) -  R (h^* )  \\
& \leqq R (\hat{h} ) -   \hat{R} (\hat{h} ) +  \hat{R} (h^*) -  R (h^* )  \\
& \leqq 2 \sup_{h \in \mathcal{H}} \left| \hat{R} (h) -  R (h) \right| \\
& \leqq 2 \left(  2pM^{p-1}  \hat{\mathfrak{R}}_{ \mathcal{S} } (\mathcal{H}) + 3 M^p \sqrt {  \frac{ \log \frac{2}{ \delta } }{ 2n }  }  \right)
\end{aligned}


したがって, 所望の確率的上界を次のように得ることができました.


\begin{aligned}
R (\hat{h} )  \leqq  R (h^* )  + \underbrace{ 4 pM^{p-1}  \hat{\mathfrak{R}}_{ \mathcal{S} } (\mathcal{H})}_{complexity}
+   \underbrace{ 6 M^p \sqrt {  \frac{ \log \frac{2}{ \delta } }{ 2n }  } }_{confidence}
\end{aligned}


さいごに

本記事では,  L_p損失を用いた場合の予測損失の確率的上界をRademacher Complexityを用いて導出してみました. 相変わらず, 私の誤解で誤った記述をしている可能性が大いにありますので, 見つけた場合はご指摘いただけたら幸いです.

参考

[金森 (2015)] 金森敬文. 2015. 統計的学習理論. 講談社 機械学習プロフェッショナルシリーズ.
[Mohri et al. (2012)] Mohri, M.; Rostamizadeh, A.; and Talwalkar, A. 2012. Foundations of Machine Learning. MIT Press.

*1:Mohri et al. (2012)のlemma 4.2