Soptq

Soptq

Probably a full-stack, mainly focusing on Distributed System / Consensus / Privacy-preserving Tech etc. Decentralization is a trend, privacy must be protected.
twitter
github
bilibili

如何計算深度學習優化方法的收斂率

收斂率是什麼?#

我們小學三年級學的數值分析告訴我們,如果函數 f(x)f(x) 是收斂的,即 limkxkx=0\lim_{k \rightarrow\infty}\vert\vert x_k - x^* \vert\vert = 0 ,其中 limkf(xk)=x\lim_{k \rightarrow\infty}f(x_k) = x^*,那麼有:

limkxk+1xxkx=a\lim_{k \rightarrow\infty} \frac{\vert\vert x_{k+1} - x^*\vert\vert}{\vert\vert x_k-x^*\vert\vert} = a

其中, aa 就是 f(x)f(x) 的收斂率。

基本理論#

SGD 基礎#

在深度學習的問題當中,我們一般是去解決這樣的問題:

minxf(x)=i=1nfi(x)\min_{x}f(x) = \sum_{i=1}^{n}f_i(x)

其中,f(x)f(x) 是模型, ii 是每一個樣本, xx 是我們要優化的參數, nn 是所有的樣本。

然後在利用 SGD 對模型進行更新的時候,一般是這樣更新的:

xt+1=xtμfixtx_{t+1} = x_{t} - \mu \nabla f_i{x_t}

這個應該大家都可以理解吧,就是一個梯度更新公式。

LL-Lipschitz 和 μ\mu-Strongly Convex#

在深度學習領域,我們看 10,000 篇跟優化沾邊的論文,9,900 篇都要在證明前加一句:

... Let ff be LL-smooth and μ\mu-strongly convex …

一切都是那麼的理所當然,就只有一個問題 —— ** 這兩個東西到底是什麼呢?** 還有很多辣雞養成了看到這兩個詞就跳過這段文字的條件反射(比如我)。是時候來直面恐懼了!

μ\mu-Strongly Convex#

首先,μ\mu-Strongly Convex 表示的是函數 f(x)f(x) 是強的,數學表達為:

f(x2)f(x1)+f(x1)T(x2x1)+μ2x2x12f(x_2) \ge f(x_1) + \nabla f(x_1)^{T}(x_2 - x_1) + \frac{\mu}{2}\vert\vert x_2 - x_1 \vert\vert ^2

其中 x1,x2Qx_1, x_2 \in QQQ f(x)f(x) 的定義域,我們知道,一個凸函數的定義如下:

f(x2)f(x1)+f(x1)T(x2x1)f(x_2) \ge f(x_1) + \nabla f(x_1)^{T}(x_2 - x_1)

這個式子的直觀表示就是對於任意 x1x_1f(x)f(x) 上的切線 f(x1)f^{'}(x_1) ,有 f(x)f(x1)f(x) \le f^{'}(x_1)

而我們的強凸的數學表達式其實比凸多了一個二項式 μ2x2x12\frac{\mu}{2}\vert\vert x_2 - x_1 \vert\vert ^2。因為在凸函數中,我們只限定了函數必須在切線以上,但沒有說以上多少。也就是說函數可以無限貼近切線,使得這樣的函數在優化中不可行。所以我們相當於給凸「度」限定了下界,使得優化可以被量化。

詳細證明的話可以參看這篇文章

LL-Lipschitz#

對於 LL-Lipschitz,有一篇知乎專欄我覺得講的很好,大家有時間可以去看一下。沒有時間的話我在下面也會把我的理解大概說一下。

Lipschitz Continus 是說,如果對於函數 f(x)f(x) 來說,就是指對於所有的 x1,x2Qx_1, x_2 \in QQQ f(x)f(x) 的定義域,滿足條件

f(x1)f(x2)Lx1x2\vert\vert f(x_1) - f(x_2)\vert\vert \le L\vert\vert x_1-x_2 \vert\vert

非常直觀的,上面這個公式表達的意思就是 f(x)L\vert\vert f^{'}(x) \vert\vert \le L ,所以 f(x)f(x) 的函數取值是被限定到一個範圍內的。

除了 Lipschitz Continus 以外,還有 Lipschitz Continus GradientLipschitz Continus HessianLipschitz Continus Gradient 是對於函數 f(x)f(x) 的梯度 / 導數來說的。換句話說,如果 f(x)f^{'}(x) 滿足 Lipschitz Continus,則 f(x)f(x) 滿足 Lipschitz Continus Gradient。 Lipschitz Continus Hessian 同理,若 f(x)f^{''}(x) 滿足 Lipschitz Continus,則 f(x)f(x) 滿足 Lipschitz Continus Hessian。

我們在深度學習中常用的是 Lipschitz Continus Gradient (LL-Smooth)。所以我們主要要理解它的數學表達:

f(x2)f(x1)f(x2),x2x1L2x2x12 \vert f(x_2) - f(x_1) - \langle f^{'}(x_2), x_2-x_1 \rangle \le \frac{L}{2}\vert\vert x_2-x_1 \vert\vert ^2

直觀上理解它和我們理解強凸非常相似,即我們對 f(x)f(x) 的變化趨勢做了一個限制。我們把絕對值打開這個限制會體現得更清楚一點:

{f(x2)f(x1)+f(x2),x2x1+L2x2x12f(x2)f(x1)+f(x2),x2x1+L2x2x12\left\{\begin{array}{lr} f(x_2) \le f(x_1) + \langle f^{'}(x_2), x_2-x_1 \rangle + \frac{L}{2}\vert\vert x_2-x_1 \vert\vert ^2&\\f(x_2) \ge f(x_1) + \langle f^{'}(x_2), x_2-x_1 \rangle + \frac{L}{2}\vert\vert x_2-x_1 \vert\vert ^2&\end{array}\right.

詳細證明可以參看這篇文章

SGD 的收斂率#

那我們來小試牛刀,計算一下 SGD 的收斂率吧。

首先 LL-Smooth 的條件先擺出來:

f(xt+σ)f(xt)+f(xt)Tσ+L2σ2f(x_{t}+\sigma) \le f(x_{t}) + \nabla f(x_{t})^{T}\sigma + \frac{L}{2}\vert\vert\sigma\vert\vert ^2

其中 σ=xt+1xt=μfixt\sigma = x_{t+1} - x_{t} = - \mu \nabla f_i{x_t} 。我們把 σ\sigma 帶入上式有:

f(xt+σ)=f(xtμfixt)=f(xt+1)f(xt)μf(xt)Tfi(xt)+μ2L2fi(xt)2\begin{aligned}f(x_t + \sigma) &= f(x_t - \mu \nabla f_i{x_t}) \\ &= f(x_{t+1})\\&\le f(x_{t}) - \mu \nabla f(x_t)^{T} \nabla f_{i}(x_t) + \frac{\mu ^2 L}{2} \vert\vert \nabla f_{i}(x_{t}) \vert\vert ^2\end{aligned}

然後我們對上式兩邊 ii 取期望有:

Ei(f(xt+1))f(xt)μf(xt)TEi(fi(xt))+μ2L2Eifi(xt)2\mathbb{E}_{i}(f(x_{t+1})) \le f(x_{t}) - \mu \nabla f(x_t)^{T} \mathbb{E}_{i}(\nabla f_{i}(x_t)) + \frac{\mu ^2 L}{2} \mathbb{E}_{i}\vert\vert \nabla f_{i}(x_{t}) \vert\vert ^2

可以看到我們有四個期望,第一个和第二个期望我們暫時不動,第三個期望 Ei(fi(xt))\mathbb{E}_{i}(\nabla f_{i}(x_t)) 有:

Ei(fi(xt))=1ni=1nfi(xt)=f(xi)\begin{aligned}\mathbb{E}_{i}(\nabla f_{i}(x_t)) &= \frac{1}{n}\sum_{i=1}^{n}\nabla f_i(x_t) \\ &= \nabla f(x_i)\end{aligned}

第四項期望 Eifi(xt)2\mathbb{E}_{i}\vert\vert \nabla f_{i}(x_{t}) \vert\vert ^2 稍微複雜一點,我們要通過定義方差來求。

我們假設梯度的方差 Var=1ni=1nfi(xt)f(xt)2Var = \frac{1}{n}\sum_{i=1}^{n}\vert\vert\nabla f_i(x_t) - \nabla f(x_t) \vert\vert ^2,我們把這個方差展開有:

Var=1ni=1nfi(xt)f(xt)2=1n[i=1nfi(xt)2]+1n[i=1nf(xt)2]2n[i=1nfi(xt),f(xt)]=f(xt)2+1n[i=1nf(xt)2]2f(xt)2=1ni=1nf(xt)2f(xt)2\begin{aligned}Var &= \frac{1}{n}\sum_{i=1}^{n}\vert\vert\nabla f_i(x_t) - \nabla f(x_t) \vert\vert ^2 \\ &= \frac{1}{n}[\sum_{i=1}^{n}\vert\vert \nabla f_i(x_t)\vert\vert ^2] + \frac{1}{n}[\sum_{i=1}^{n}\vert\vert\nabla f(x_t)\vert\vert ^2] - \frac{2}{n}[\sum_{i=1}^n\langle\nabla f_i(x_t), \nabla f(x_t)\rangle] \\ &=\vert\vert \nabla f(x_t)\vert\vert ^2 + \frac{1}{n}[\sum_{i=1}^{n}\vert\vert\nabla f(x_t)\vert\vert ^2] - 2\vert\vert \nabla f(x_t)\vert\vert ^2 \\ &= \frac{1}{n}\sum_{i=1}^{n}\vert\vert\nabla f(x_t)\vert\vert ^2 - \vert\vert \nabla f(x_t)\vert\vert ^2\end{aligned}

所以對於 Eifi(xt)2\mathbb{E}_{i}\vert\vert \nabla f_{i}(x_{t}) \vert\vert ^2 有:

Eifi(xt)2=1ni=1nfi(xt)2=Var+f(xt)2\mathbb{E}_{i}\vert\vert \nabla f_{i}(x_{t}) \vert\vert ^2 = \frac{1}{n}\sum_{i=1}^{n}\vert\vert\nabla f_{i}(x_{t})\vert\vert ^2 = Var + \vert\vert \nabla f(x_t)\vert\vert ^2

於是,我們把上面算出的兩個單項期望帶入整體有:

E(f(xt+1))E(f(xt))μf(xt)Tf(xi)+μ2L2(Var+f(xt)2)=E(f(xt))(μμ2L2)f(xt)2+μ2L2Var\begin{aligned}\mathbb{E}(f(x_{t+1})) &\le \mathbb{E}(f(x_{t})) - \mu \nabla f(x_t)^{T} \nabla f(x_i) + \frac{\mu ^2 L}{2}(Var + \vert\vert \nabla f(x_t)\vert\vert ^2)\\&=\mathbb{E}(f(x_t))-(\mu-\frac{\mu^2L}{2})\vert\vert \nabla f(x_t)\vert\vert ^2 + \frac{\mu^2L}{2}Var\end{aligned}

接下來,我們還有 λ\lambda-Strong Convex (這裡用 λ\lambda 是為了避免和例子中學習率搞混) 的數學表達如下:

f(x)f(x0)+f(x0)T(xx0)+λ2xx02f(x) \ge f(x_0) + \nabla f(x_0)^T (x-x_0) + \frac{\lambda}{2}\vert\vert x - x_0 \vert\vert ^2

我們令 x=x,x0=xx = x^*, x_0 = x 有,把加減組合成一個和的平方有:

f(x)f(x)+f(x)T(xx)+λ2xx2=f(x)+12λf(x)+λ(xx)212λf(x)2f(x)12λf(x)2\begin{aligned}f(x^*) &\ge f(x) + \nabla f(x)^T (x^*-x) + \frac{\lambda}{2}\vert\vert x^* - x \vert\vert ^2\\ &=f(x) + \frac{1}{2\lambda}\vert\vert\nabla f(x) + \lambda (x^*-x)\vert\vert ^2 - \frac{1}{2\lambda}\vert\vert\nabla f(x)\vert\vert ^2\\ &\ge f(x) - \frac{1}{2\lambda}\vert\vert\nabla f(x)\vert\vert ^2\end{aligned}

所以有:

f(x)22λ(f(x)f(x))\vert\vert\nabla f(x)\vert\vert ^2 \ge 2\lambda (f(x) - f(x^*))

我們把上式帶入期望的不等式有:

E(f(xt+1))E(f(xt))2λ(μμ2L2)(f(xt)f(x))+μ2L2Var\mathbb{E}(f(x_{t+1})) \le \mathbb{E}(f(x_t))-2\lambda(\mu-\frac{\mu^2L}{2})(f(x_t) - f(x^*)) + \frac{\mu^2L}{2}Var

最後,令 μ=1L\mu = \frac{1}{L},有:

E[f(xt+1)f(x)f(xt)f(x)]=(1σL)+H\mathbb{E}[\frac{f(x_{t+1}) - f(x^*)}{f(x_{t}) - f(x^*)}] = (1-\frac{\sigma}{L}) + H

其中 HH 是 SGD 相對於 GD 帶來的干擾。

載入中......
此文章數據所有權由區塊鏈加密技術和智能合約保障僅歸創作者所有。