back to index
【機器學習2021】生成式對抗網路 (Generative Adversarial Network, GAN) (二) – 理論介紹與WGAN

link |
剛才就是用了一堆比喻,告訴你 GAN 是怎麼運作的,也實際上告訴你 GAN 的操作是怎麼操作的。
link |
接下來我們要講一點理論的部分,講一點農場文不會講到的部分,告訴你說實際上為什麼 GAN 的這一番操作,為什麼這個 Generator 跟 Discriminator 的互動可以讓我們的 Generator 產生像是真正的人臉的圖片,這背後的互動在做的到底是什麼樣的事情。
link |
那我們先來弄清楚我們今天訓練的目標到底是什麼。
link |
你知道我們在訓練 Network 的時候,你就是要訂一個 Loss Function 嘛,訂完以後用 Gradient Descent 去調你的參數,去 Minimize 那個 Loss Function 就結束了。
link |
那在這個 Generation 的問題裡面,到底我們要 Minimize 的或者是我們要 Maximize 的到底是什麼樣的東西呢?
link |
我們要把這件事弄清楚,才能夠做接下來的事情。
link |
那在 Generator 裡面,我們到底想要 Minimize 或者是 Maximize 什麼樣的東西呢?
link |
我們想要 Minimize 的東西是這個樣子的。
link |
我們有一個 Generator,給他一大堆的 Vector,給他從 Normal Distribution Sample 出來的東西,丟進這個 Generator 以後,會產生一個比較複雜的 Distribution。
link |
這個複雜的 Distribution,我們叫他 PG,然後我們有一堆的 Data,這個是真正的 Data,真正的 Data 也形成了另外一個 Distribution 叫做 P-Data,我們期待 PG 跟 P-Data 越接近越好。
link |
如果你一下子沒有辦法想像這個 PG、P-Data 是怎麼一回事的話,那我們用一維的狀況來跟大家說明。
link |
我們假設 Generator 的 Input 是一個一維的向量,Generator 的 Output 也是一維的向量,我們的真正的 Data 也是一維的向量。
link |
那我們的 Normal Distribution 就長這個樣子,沒有問題。
link |
那丟到 Generator 以後,這邊這每一個點,假設你輸入五個點,這邊這每一個點,他的位置會改變,你就產生一個新的 Distribution。
link |
那可能本來大家都集中在中間,通過這個 Generator,通過一個 Network,裡面很複雜,不知道做了什麼事情以後,這些點就分成兩邊,所以你的 Distribution 就變成這個樣子。
link |
而 P-Data 是指真正的資料的分布,真正的資料的分布可能長這個樣子,他分兩邊的狀況是更極端的,左邊的東西比較多,右邊的東西比較少,那你期待左邊這個分布跟右邊這個分布越接近越好。
link |
如果寫成式子的話,你可以寫成這個樣子,你這邊這個 DIV of PG 跟 P-Data,他指的意思就是 PG 跟 P-Data 這兩個 Distribution 之間的 Divergence。
link |
那 Divergence 這邊指的是什麼意思呢? Divergence 這邊指的意思就是,你可以想成是這兩個 Distribution 之間的某種距離。
link |
如果這個 Divergence 越大,就代表這兩個 Distribution 越不像,Divergence 越小,就代表這兩個 Distribution 越相近。Divergence 就是衡量兩個 Distribution 相似度的一個 measure。
link |
然後呢,我們現在的目標是要去找一個 generator,所謂的找一個 generator,實際上骨子裡做的事情是找一個 generator 裡面的參數,找一組 generator 裡面的參數,generator 也是一個 network,裡面有一大堆的 weight 跟 bias,找一組 generator 的參數,他可以讓我們產生出來的 PG 跟 P-Data 之間的 Divergence 越小越好。
link |
我們要找的就是這樣子的 generator,我們這邊把它寫作 G-SPOT。
link |
所以我們這邊要做的事情跟一般的 train network 其實非常的像,我們第一堂課就告訴你說,我們定義了 loss function,找一組參數 minimize loss function,我們現在其實也定義了我們的 loss function。
link |
在 generation 這個問題裡面,我們的 loss function 就是 PG 跟 P-Data 的 divergence,就是他們兩個之間的距離,他們兩個越近,就代表產生出來的 PG 跟 P-Data 越像。
link |
所以 PG 跟 P-Data,我們希望他們越相像越好,所以我們希望 PG 跟 P-Data 的 divergence 越小越好,我們要做的事情就是找一個 G,讓 divergence 變得最小。
link |
但是我們這邊遇到一個困難的問題,怎麼樣困難的問題呢?這個 loss 我們是可以算的,但是這個 divergence 要怎麼樣算呢?
link |
你可能知道一些 divergence 的式子,比如說 KL divergence,比如說 JS divergence,這些 divergence 用在這種 continuous distribution 上面,你要做一個很複雜的,在實作上你幾乎不知道要怎麼算的積分。
link |
那我們根本就無法把這個 divergence 算出來,我們算不出這個 divergence,我們又要如何去找一個 G,去 minimize 這個 divergence 呢?這個就是 GAN 所遇到的問題。
link |
這就是我們在 train 這種 generator 的時候會遇到的問題。而 GAN 是一個很神奇的做法,它可以突破我們不知道怎麼計算 divergence 的限制。
link |
所以我們現在遇到的問題就是,不知道怎麼計算 divergence。而 GAN 告訴我們的就是,只要你知道怎麼從 pg 和 pdata 這兩個 distribution sample 東西出來,就有辦法算 divergence。
link |
你不需要知道 pg 和 pdata 實際上的 formulation 長什麼樣子,你只要能夠 sample 就能夠算 divergence。而 pg 和 pdata 是可以 sample 的嗎?是可以 sample 的。
link |
怎麼從真正的 data 裡面 sample 出東西來呢?你就把你的圖庫拿出來,從圖庫裡面隨機產生,隨機 sample 一些圖片出來,你就得到 pdata 了。
link |
那怎麼從 generator 裡面產生一些東西出來呢?那你就把你的 generator 輸入從 normal distribution sample 出來的 vector,丟到 generator 裡面。
link |
我們剛才也說過說,你這邊的 distribution,你拿來 sample 那個 distribution,要是簡單的,要是你有辦法 sample 的,所以我們選 normal distribution,
link |
我們是知道的,是有辦法 sample 的,我們從 normal distribution 裡面 sample 一堆 vector 出來,丟給 generator,讓 generator 產生一堆圖片出來。
link |
這些圖片就是從 pg sample 出來的結果。所以我們有辦法從 pg 做 sample,我們有辦法從 pdata 做 sample。
link |
接下來,GAN 這一整個系列的 work 就是要告訴你說,怎麼在只有做 sample 的前提之下,我根本不知道 pg 跟 pdata 實際上完整的 formulation 長什麼樣子,
link |
在只能做 sample 的前提之下,居然就算出了,居然就估測出了 divergence。
link |
那這個就是要靠 discriminator 的力量。我們剛才講過說,discriminator 是怎麼訓練出來的呢?我們有一大堆的 real data,這個 real data 就是從 pdata sample 出來的結果。
link |
我們有一大堆 generated data,generated data 就可以看出是從 pg sample 出來的結果。
link |
根據 real data 跟 generated data,我們會去訓練一個 discriminator。訓練的目標是看到 real data 就給他比較高的分數,看到 generated data 就給他比較低的分數。
link |
我們剛才就說,discriminator 訓練的目標就是要分辨好的圖跟不好的圖,分辨真的圖跟生成的圖,所以看到真的圖給他高分,看到生成的圖給他低分。
link |
那實際上剛才講的,你也可以把它寫成四字,把它當作是一個 optimization 的問題。
link |
這個 optimization 的問題是這樣子的,我們要訓練一個 discriminator。
link |
這個 discriminator 可以去 maximize 某一個 function,這邊叫做 objective function,就是我們要 maximize 的東西,我們會叫 objective function。
link |
如果要 minimize,我們就會叫它 loss function。我們現在要找一個 d,它可以 maximize 這個 objective function。
link |
這個 objective function 長什麼樣子呢?這個 objective function 長這個樣子,我們有一堆 y,它是從 p data 裡面 sample 出來的,也就是它們是真正的 image。
link |
而我們把這個真正的 image 丟到 d 裡面得到一個分數,再去 log。那另外一方面,我們有一堆 y,它是從 pg,從 generator 所產生出來的。
link |
把這些圖片也丟到 discriminator 裡面得到一個分數,再去 log 1-b of y。
link |
那我們希望這個 objective function b 越大越好。我們希望 b 越大越好,意味著我們希望這邊的 d of y 越大越好。我們希望 y 如果是從 p data sample 出來的,它就要越大越好。
link |
我們希望說如果 y 是從 pg sample 出來的,它就要越小越好。
link |
那我們就去最大化 b 這個式子,找一個 d 可以 maximize 這個 objective function,我們其實就是讓 d of y 越大越好。
link |
讓這邊的 d of y,也就是 generator 生成的圖片的值越小越好, discriminator output 的值越小越好。
link |
那這件事情其實又等同於你可能覺得沒事突然寫出這個式子有點奇怪,那你不一定要把這個 objective function 寫成這個樣子,它完全可以有其他的寫法。
link |
那最早年之所以寫成這個樣子是有一個很明確的理由,有一個很明確的動機,是為了要把 discriminator 跟 binary 的 classification 跟分類,跟二元的分類扯上關係。
link |
怎麼說呢?事實上這個 objective function 它就是 cross entropy 成一個負號。
link |
我們知道我們在訓練一個 classifier 的時候,我們就是要 minimize cross entropy,所以當我們 maximize 這個 objective function, maximize cross entropy 成一個負號的時候,其實等同於 minimize cross entropy,也就等同於是在訓練一個 classifier。
link |
所以這個 discriminator 做的事情,如果 discriminator 做的事情是去 maximize 這個 objective function,那這個 discriminator 其實可以當作是一個 classifier。
link |
它做的事情就是把藍色這些點,從 p data sample 出來的真實的 image 當作 class 1,把從 pg sample 出來的這些假的 image 當作 class 2。
link |
由兩個 class 的 data 訓練一個 binary 的 classifier,訓練完就等同於是解了這個 optimization 的問題。
link |
那這邊最神奇的地方是以下這句話,這一個式子,這個紅框框裡面的數值,它跟 JS divergence 有關。事實上有趣的事情是,我覺得最原始的 Gantt paper,它的發想可能真的是從 binary classifier 來的。
link |
一開始是把 discriminator 寫成 binary classifier,然後有了這樣的 objective function,然後再經過一番推導以後,這個 objective function 它的 maximum,就是你找到一個 d,可以讓這個 objective function 它的值最大的時候,這個最大的值跟 JS divergence 是有關的。
link |
它們沒有完全一模一樣,所以顯然一開始並不是針對 JS divergence 設計的,而是經過一番推導以後發現它們是非常有關聯的。
link |
至於實際上的推導過程,你可以參見原始 Ian Goodfellow 寫的文章,其實裡面的推導過程我覺得寫得算是蠻清楚的。
link |
真正神奇的地方就是,這一個 objective function 的最大值,它跟 divergence 是有關的。所以我們剛才說,我們不知道怎麼算 divergence,沒關係,train 你的 discriminator,train 完以後,看看它的 objective function 可以到多大,那個值就跟 divergence 有關。
link |
這邊我們並沒有把證明拿出來跟大家講,但是我們還是可以從直觀上來理解一下,為什麼這個 objective function 的值會跟 divergence 有關呢?
link |
這個直觀的理解並沒有很困難,因為你可以想想看,假設 pg 跟 pdata 它的 divergence 很小,也就是 pg 跟 pdata 很像,它們差距沒有很大,它們很像 pg 跟 pdata sample 出來的藍色的星星跟紅色的星星,它們是混在一起的。
link |
這個時候對 discriminator 來說,discriminator 就是在 train 一個 binary 的 classifier,對 discriminator 來說,既然這兩堆資料是混在一起的,那就很難分開,這個問題很難。
link |
既然這個問題很難,你在解這個 optimization problem 的時候,你就沒有辦法讓這個 objective 的值非常的大,所以這個 objective,這個 v 的 maximum 的值就比較小,所以小的 divergence 對應到小的 objective function 的 maximum 的值。
link |
所以不是 objective function 的值本身,是 objective function 在窮取所有 discriminator 以後可以得到的最大的值。如果今天你的兩組 data 很不像,它們的 divergence 很大,那對 discriminator 而言就可以輕易地把它分開。
link |
當 discriminator 可以輕易地把它分開的時候,這個 objective function 就可以衝得很大,所以當你有大的 divergence 的時候,這個 objective function 的 maximum 的值就可以很大。當然這邊是用直觀的方法來跟你講的。
link |
詳細的證明請參見 GANN 原始的 paper,裡面再做了一些假設,比如說 discriminator 的 capacity 是無窮大等等的假設以後,可以做出這個 maximum 的值跟 JS divergence 一些相關的推導。
link |
所以我們說,我們本來的目標是要找一個 generator 去 minimize pg 跟 p data 的 divergence 的值,那我們卡在不知道怎麼計算 divergence。那我們現在要知道,我們只要訓練一個 discriminator,訓練完以後,這個 objective function 的最大值就是這個 divergence,就跟這個 divergence 有關。
link |
那我們何不就把紅框框裡面這一項跟 divergence 做替換呢?我們何不就把 divergence 替換成紅框框裡面這一項呢?所以我們就有了這樣一個 objective function。
link |
這個 objective function 乍看之下有點複雜,它有一個 minimum 又有一個 maximum,所以你不小心就會腦筋轉不過來。我們是要找一個 generator 去 minimize 紅色框框裡面這件事,但是紅框框裡面這件事又是另外一個 optimization problem,它是在給定 generator 的情況下去找一個 discriminator,這個 discriminator 可以讓 v 這個 objective function 越大越好。
link |
我們要找一個 g 讓紅框框裡面的值最小,這個 g 就是我們要的 generator。
link |
而剛才我們講的 generator 跟 discriminator 互動、互相欺騙這個過程,其實就是想解這個有 minimize 又有 maximize 這個 min-max 的問題,就是透過下面這個我們剛才講的 GAN 的 algorithm 來解的。至於實際上為什麼下面這個 algorithm 可以解這個問題,你也可以參見原始的 GAN 的 paper。
link |
講到這邊,也許你就會問說,為什麼是 JS Divergence,而且還不是真的 JS Divergence,是跟 JS Divergence 相關而已,怎麼不用真正的 JS Divergence 或不用別的,比如說 AL Divergence,你完全可以這麼做。你只要改了那個 objective function,你就可以量各式各樣的 divergence。
link |
至於怎麼樣設計 objective function 得到不同的 divergence,有一篇叫做 FGAN 的 paper 裡面有非常詳細的證明,它有很多的 table 告訴你說不同的 divergence 要怎麼設計它的 objective function,你設計什麼樣的 objective function 去找它的 maximum value 就會變成什麼樣的 divergence,在這篇文章裡面都有詳細的記載。
link |
所以一開始有人會覺得說,GAN 之所以沒有很好 train,也許是因為我們沒有在真的 minimize JS Divergence,但是有了 FGAN 這篇 paper 以後,它就告訴你說,我們有辦法 minimize JS Divergence,但就算你真的可以 minimize JS Divergence,結果也還是沒有很好,GAN 還是沒有很好 train。
link |
所以 GAN 是以不好 train 而聞名的,所以俗話就說,no pain, no GAN,我們就要講一些 GAN 訓練的小技巧。
link |
GAN 有什麼樣訓練的小技巧呢?其實 GAN 訓練的小技巧非常多,我想了很久以後,我們只挑一個最知名的來講。
link |
這個最知名的就是很多人都聽過的 WGAN。這個 WGAN 是什麼呢?在講 WGAN 之前,我們先講 JS Divergence 有什麼樣的問題。在最早的 GAN 說,我們要 minimize 的是 JS Divergence。
link |
選擇 JS Divergence 的時候會有什麼問題呢?在講 JS Divergence 的問題之前,我們先看一下 pg 跟 pdata 有什麼樣的特性。
link |
pg 跟 pdata 有一個非常關鍵的特性是,pg 跟 pdata 它們重疊的部分往往非常少。
link |
這邊有兩個理由,第一個理由是來自於 data 本身的特性。pg 跟 pdata 它們都是要產生圖片,圖片其實是高維空間裡面的一個低維的 metaphor。
link |
怎麼知道圖片是高維空間裡面低維的 metaphor 呢?因為你想想看,你在高維空間裡面隨便 sample 一個點,它通常都沒有辦法構成一個二次元人物的頭像。所以二次元人物的頭像的分布在高維空間中其實是非常狹窄的。
link |
所以二次元頭像的分布,這個圖片的分布,其實是高維空間中的低維的 metaphor,或者是如果是以二維空間來想的話,那圖片的分布可能就是二維空間的一條線。
link |
二維空間中多數的點都不是圖片,就高維空間中隨便 sample 一個點都不是圖片,只有非常小的範圍,sample 出來它會是圖片。
link |
所以從這個角度來看,從資料本身的特性來看,Pg 跟 Pdata 它們都是 low-dimensional 的 metaphor。
link |
用二維空間來講,Pg 跟 Pdata 都是二維空間中的兩條直線。而二維空間中的兩條直線,除非它剛好重合,不然它們相交的範圍幾乎是可以忽略的。這是第一個理由。
link |
也許有人說,圖片根本就不是 low-dimensional 的 metaphor,那會不會第一個理由就不成立了呢?那我給你第二個理由。
link |
第二個理由是,我們從來都不知道 Pg 跟 Pdata 長什麼樣子。我們對 Pg 跟 Pdata 分布的理解其實來自於 sample。
link |
所以也許 Pg 跟 Pdata 它們是有非常大的 overlap 的範圍,但是我們實際上在了解 Pg 跟 Pdata,在計算它們的 divergence 的時候,
link |
我們是從 Pdata 裡面 sample 一些點出來,從 Pg 裡面 sample 一些點出來。
link |
如果你 sample 的點不夠多,你 sample 的點不夠密,那就算是這兩個 divergence 實際上,這兩個 distribution 實際上有重疊,
link |
但是假設你 sample 的點不夠多,對 discriminator 來說,它也是沒有重疊的。
link |
這個藍色的分布跟紅色的分布明明是有重疊的,但如果你從藍色分布 sample 一些點、紅色分布 sample 一些點,
link |
這些點你又 sample 的不夠多,你根本就可以畫一條楚河漢界,把紅色的點跟藍色的點完全地分開來,
link |
然後說紅色的點的分布就是在楚河漢界的右邊,藍色的點就是在左邊,它們完全是沒有任何重疊的。
link |
以上給你兩個理由,試圖說服你說 Pg 跟 Pdata 這兩個分布,它們重疊的範圍是非常小的。
link |
而幾乎沒有重疊這件事情,對於 JS divergence 會造成什麼問題呢?
link |
JS divergence 有一個特性,是兩個沒有重疊的分布,JS divergence 算出來就永遠都是 log2,
link |
不管這兩個分布長什麼樣子。所以兩個分布只要沒有重疊,算出來就一定是 log2,
link |
不管它們長什麼樣子,算出來都是 log2。
link |
所以舉例來說,假設這是你的 Pdata,這是你的 Pg,假設它們都是一條直線,然後中間有很長的距離,
link |
你算它們的 JS divergence 是 log2。假設你的 Pg 跟 Pdata 其實蠻接近的,
link |
那中間的間隔其實是比較小的,算出來結果還是 log2。
link |
除非你的 Pg 跟 Pdata 有重合,不然這個 Pg 跟 Pdata 只要它們是兩條直線,它們這兩條直線沒有相交,
link |
那算出來就是 log2。這個 case 算出來是 log2,這個 case 算出來也是 log2。
link |
但是明明這個 case 就比這個 case 好啊,中間這個 case,中間這個 generator 明明就比左邊這個 generator 好啊,但是你不知道。
link |
明明藍色的線就跟紅色的線比較近啊,但是從 JS divergence 上面看不出這樣子的現象。
link |
那既然從 JS divergence 上看不出這樣子的現象,你在 training 的時候,你根本就沒有辦法把這樣子的 generator,update 參數變成這樣子的 generator。
link |
因為對你的 loss 來說,對你的目標來說,這兩個 generator 是一樣好或者是一樣糟的。
link |
那以上是從比較理論的方向來說明,如果我們從更直觀的實際操作的角度來說明,
link |
你會發現當你是用 JS divergence 的時候,也就是假設你今天在 train 一個 binary classifier 去分辨 real 的 image 跟 generated image,
link |
你會發現實際上你通常 train 完以後,正確率幾乎都是百分之百。為什麼?因為你 sample 的圖片根本就沒幾張啊。
link |
對你的 discriminator 來說,你 sample 256 張 real 的圖片、256 張 fake 的圖片,他直接用硬背的都可以把這兩組圖片分開,知道說誰是 real 的、誰是 fake 的。
link |
所以實際上如果你有自己 train 過 game 的話,你會發現如果你用 binary classifier train 下去,
link |
你會發現你幾乎每次 train 完你的 classifier 以後,也就是你 train 完你的 discriminator 以後,正確率都是 100%。
link |
我們本來會期待說這個 discriminator 的 loss 也許代表了某些事情,
link |
這個 binary classifier loss 也許代表某些事情,這個 loss 越來越大代表問題越來越難,代表我們的 generated data 跟 real 的 data 越來越接近。
link |
但實際上,你在實際操作的時候你根本觀察不到這個現象,這個 binary classifier 訓練完的 loss 根本沒有什麼意義,
link |
因為他總是可以讓他的正確率變到 100%,兩組 image 都是 sample 出來的,他硬背都可以得到 100% 的正確率,
link |
你根本就沒有辦法看出你的 generated 有沒有越來越好。
link |
所以過去,尤其是在還沒有 WGAN 這樣的技術,在我們還用 binary classifier 當作 discriminator 的時候,
link |
train game 真的就很像巫術、黑魔法,你根本就不知道你 train 的時候有沒有越來越好。
link |
所以怎麼辦呢?那時候的做法就是,你每次 update 幾次 generator 以後,你就要把你的圖片 print 出來看,
link |
然後你就要一邊吃飯一邊看圖片生成的結果,然後跑一跑就發現,哇,壞掉了,然後 cut 掉重做這樣子。
link |
所以以前你根本就沒有,不像我們在 train 一般的 network 的時候,你有一個 loss function,
link |
然後那個 loss 隨著訓練的過程會慢慢慢慢變小,那你就會看說,
link |
ok, loss 慢慢變小,你就放心知道說你的 network 有在 train。
link |
那會不會 overfitting 是另外一件事?至少他的 training data 上有越來越好。
link |
但是對 game 而言,本來我們期待 classifier 的 loss 可以提供一些資訊,
link |
但是當你的 classifier 是一個簡單一般的 binary classifier 的時候,他訓練的結果就沒有任何資訊,
link |
你每次訓練出來,正確率都是 100%,你根本不知道你的 generator 有沒有越來越好,
link |
變成你只能夠用人眼看,用人眼守在電腦前面看,發現結果不好, cut 掉,
link |
重新用一組 hyperparameter,重新調一下 network 架構重做。
link |
所以過去訓練 game 是有點辛苦的。
link |
那既然是 JS divergence 的問題,於是有人就想說,
link |
那會不會換一個衡量兩個 distribution 的相似程度的方式,
link |
換一種 divergence 就可以解決這個問題了呢?
link |
於是就有了使用 Wasserstein distance 的想法。
link |
那 Wasserstein distance 這邊有一個冷知識,
link |
就是這個 w,其實這邊是唸 v,不是發 w 的音,不是唸 Wasserstein,而是唸 Wasserstein。
link |
那這個 Wasserstein distance 是怎麼計算的呢?
link |
它的想法是這個樣子,假設你有兩個 distribution,
link |
一個 distribution 我們叫它 p,另外一個 distribution 我們叫它 q,
link |
Wasserstein distance 它計算的方法就是想像你在開一台推土機,
link |
推土機的英文叫做 earth mover,
link |
想像你在開一台推土機,那你把 p 想成是一堆土,
link |
把 q 想成是你要把土堆放的目的地。
link |
那這個推土機把 p 這邊的土挪到 q 所移動的平均距離就是 Wasserstein distance。
link |
在這個例子裡面,我們假設 p 都集中在這個點,q 都集中在這個點,
link |
對推土機而言,假設它要把 p 這邊的土挪到 q 這邊,
link |
所以在這個例子裡面,假設 p 集中在一個點,
link |
q 集中在一個點,這兩個點之間的距離是 d 的話,
link |
那 p 跟 q 的 Wasserstein distance 就是 d。
link |
那因為在講這個 Wasserstein distance 的時候,
link |
你要想像有一個 earth mover,有一個推土機在推土,
link |
所以其實 Wasserstein distance 就叫 earth mover distance。
link |
那但是呢,如果是更複雜的 distribution,
link |
你要算 Wasserstein distance 就有點困難了。
link |
怎麼說呢?假設這是你的 p,假設這是你的 q,
link |
假設你開了一個推土機,想要把 p 把它重新塑造一下形狀,
link |
讓 p 的形狀跟 q 比較接近一點。
link |
你可能的 moving plan,就是你把 q 呢,
link |
把 p 重新塑造成 q 的方法,有無窮多種。
link |
我把這邊的土搬到這裡來,把 p 變成 q。
link |
那你也可以捨近求遠說,我把這裡的土搬到這裡來,
link |
把這裡的土搬到這裡來,捨近求遠一樣還是可以把 p 變成 q。
link |
所以當我們考慮比較複雜的 distribution 的時候,
link |
把 p 變成 q 的方法是有非常非常多不同的方法的。
link |
你有各式各樣不同的 moving plan。
link |
用不同的 moving plan,你算出來的距離,
link |
在左邊這個例子裡面,推土機平均走的距離比較小。
link |
那難道這個 p 跟 q 他們之間的 vessel stand distance
link |
會根據你的不同的方法,不同的推土機行進的方法,
link |
為了讓 vessel stand distance 只有一個值,
link |
所以這邊 vessel stand 的定義是窮取所有的 moving plan,
link |
然後看哪一個推土的方法,哪一個 moving 的計劃,
link |
哪一個推土的計劃可以讓平均的距離最小。
link |
那個最小的值才是 vessel stand distance。
link |
所以會窮取所有把 p 變成 q 的方法,
link |
那選最短的那個距離當作是 vessel stand distance。
link |
所以其實要計算 vessel stand distance 是挺麻煩的。
link |
你會發現說,你光我只是要計算一個 distance,
link |
我居然還要解一個 optimization 的問題,
link |
解說這個 optimization 的問題才能算 vessel stand distance。
link |
好,那我們先不講怎麼計算 vessel stand distance 這件事,
link |
我們先來講假設我們能夠計算 vessel stand distance 的話,
link |
那假設 pg 跟 pdata 它們的距離是 d0,
link |
那在這個例子裡面, vessel stand distance 算出來就是 d0。
link |
在這個例子裡面, pg 跟 pdata 它們之間的距離是 d1,
link |
那 vessel stand 算出來的距離就是 d1。
link |
那假設 d1 比較小, d0 比較大,
link |
那算 vessel stand 的時候,
link |
這個 case 的 vessel stand 就比較小,
link |
這個 case 的 vessel stand 就比較大。
link |
由左向右的時候, vessel stand 是越來越小的。
link |
所以如果你觀察 vessel stand 的話,
link |
會知道說從左到右,我們的 generator 越來越進步。
link |
但是如果你觀察 discriminator,
link |
對 discriminator 而言,
link |
這邊每一個 case 算出來的 divergence 都是一樣的,
link |
但是如果換成 vessel stand distance,
link |
我們會知道說,我們的 generator 做得越來越好。
link |
所以我們換一個計算 divergence 的方式,
link |
我們就可以解決 this divergence 有可能帶來的問題。
link |
在皮膚上經過突變產生一些感光的細胞,
link |
但是天擇突變,怎麼可能產生這麼複雜的器官呢?
link |
從感光細胞到眼睛,中間其實是有連續的步驟的。
link |
舉例來說,感光的細胞可能會出現在一個比較凹陷的地方。
link |
皮膚凹陷下去,這樣感光細胞就可以接受來自不同方向的光源。
link |
後來覺得說,乾脆把凹陷的地方蓋起來。
link |
後來覺得蓋起來的地方裡面可以放一些液體,
link |
但是這邊每一小步都可以讓一個生命存活的機率變大,
link |
現在這邊每一個步驟都可以讓生命繁衍的機率變高,
link |
當你使用Vessel Stance Distance來衡量你的Divergence的時候,
link |
你要它一步從這裡跑到這裡,一步從這裡跑到這裡,
link |
讓Pg0跟Pdata直接align在一起,是不可能的,
link |
對Jace Divergence而言,它需要做到直接從這一步跳到這一步,
link |
它的Jace Divergence Loss才會有差異。
link |
你只要每次有稍微把Pg往Pdata挪近一點,
link |
W Distance有變化,你才有辦法train你的Generator去minimize W Distance。
link |
所以這就是為什麼當我們從Jace Divergence
link |
換成Vessel Stance Distance的時候,可以帶來的好處。
link |
好,那W Gain實際上就是當你用Vessel Stance Distance
link |
來取代Jace Divergence的時候,
link |
Vessel Stance Distance要怎麼算呢?
link |
Pdata跟Pg它的Vessel Stance Distance要怎麼計算呢?
link |
Vessel Stance Distance是一個非常複雜的東西,
link |
我光要算個Vessel Stance Distance還要解一個optimization的問題,
link |
解下面這個optimization的problem,解出來以後,
link |
你得到的值就是Vessel Stance Distance,
link |
就是Pdata跟Pg的Vessel Stance Distance。
link |
所以這邊的X在前面的投影片裡面其實都是Y,
link |
他們指的都是一張圖片,他們指的都是Network的輸出,
link |
我們就觀察一下這個式子,這個式子裡面有說,
link |
X如果是從Pdata來的,那我們要計算它的Dx的期望值,
link |
X如果是從Pg來的,我們計算它的Dx的期望值,
link |
所以如果你要Maximize這個Objective Function,
link |
你會達成,如果X是從Pdata Sample出來的,
link |
Dx,就Discriminator的Output,要越大越好。
link |
如果X是從Pg,從Generator Sample出來的,
link |
那Dx,也就是Discriminator的Output,應該要越小越好。
link |
它不是光叫你把這裡大括號裡面的值變大就好,
link |
還有一個限制是,D不能夠是一個隨便的Function,
link |
D必須要是一個One Distance的Function。
link |
可能會問說,One Distance是什麼東西呢?
link |
如果你不知道是什麼的話,也沒有關係,
link |
我們這邊你就想成,D必須要是一個足夠平滑的Function,
link |
它不可以是變動很劇烈的Function,
link |
它必須要是一個足夠平滑的Function。
link |
那為什麼足夠平滑這件事情是非常重要的呢?
link |
這是Generated的資料的分布。
link |
如果我們沒有這個限制,只看大括號裡面的值的話,
link |
那要讓Generated的值,它的Dx越小越好。
link |
只單純要這邊的值越大越好,這邊的值越小越好,
link |
也就是真正的Image跟Generated的Image沒有任何重疊的情況下,
link |
你的Discriminator會做什麼?
link |
它會給Real的Image無限大的正值,
link |
給Generated的Image無限大的負值。
link |
所以你這個Training根本就沒有辦法收斂,
link |
而且你會發現說,只要這兩堆Data沒有重疊,
link |
你算出來的這個Maximum值都是無限大。
link |
這顯然不是我們要的,這不就跟JS Divergence的問題一模一樣嗎?
link |
你這個Maximum的值才會是Versus Distance,
link |
那為什麼加上這個限制就可以解決剛才的問題呢?
link |
因為這個限制是要求Discriminator不可以太變化劇烈,
link |
那如果你要求你的Discriminator夠平滑的時候,
link |
假設Real跟Generated的Data距離比較近,
link |
那你就沒有辦法讓Real的Data值非常大,
link |
那這個Discriminator變化就很劇烈了,
link |
它就不平滑了,它就不是One Decision了。
link |
那為了要是One Decision,
link |
這邊的值沒辦法很大,這邊的值沒辦法很小。
link |
如果Real跟Generated的Data差距離很遠,
link |
你的Best of Same Distance就比較大。
link |
如果Real跟Generated很近,
link |
有了這個寫在Max下面的One Decision Function的限制,
link |
因為有One Decision Function的限制,
link |
所以Real Data的值跟Generated Data的值就沒有辦法差很多。
link |
所以算出來的Best of Same Distance就會比較小。
link |
怎麼確保Discriminator一定符合One Decision Function的限制呢?
link |
所以最早的一篇WGAN的paper,
link |
最早使用Best of Same的那篇paper,
link |
它做了一個比較rough,比較粗糙的處理方法。
link |
它是說,我就train a network,
link |
那train network的時候呢,
link |
如果超過C,就Gradient Descent Update以後超過C就設為C,
link |
Gradient Descent Update以後小於-C就直接設為-C。
link |
並不一定真的能夠讓Discriminator變成One Decision Function。
link |
也許真的可以讓我們的Discriminator比較平滑,
link |
但它並沒有真的去解這個optimization的problem,
link |
它並沒有真的讓Discriminator符合這個限制。
link |
有一個想法叫做Gradient Penalty。
link |
Gradient Penalty是出自ImproveWGAN這篇paper。
link |
ImproveWGAN這篇paper是說,
link |
假設這個是你的real data的分布,
link |
這個是你的fake data的分布,
link |
那我在real data這邊取一個sample,
link |
fake data這邊取一個sample,
link |
我要求這個點,它的gradient要接近1。
link |
那就詳盡ImproveWGAN的paper。
link |
其實我覺得你現在也不一定要真的非常較真說,
link |
這件事情Gradient Penalty跟這個限制之間的關係,
link |
那其實後來ImproveWGAN之後,
link |
還有什麼Improve的ImproveWGAN,
link |
又有另外一篇paper就叫ImproveWGAN。
link |
真的把第一限制讓它是one list function,
link |
這個叫做Spectral Normalization,
link |
那如果你要train真的非常好的GAN,
link |
你可能會需要用到Spectral Normalization。
link |
理論是要量現在的generator的output,
link |
你只要看現在的generator的output就好了。
link |
我現在generator只generate一部分的資料,
link |
我資料比較多,也許我訓練出來的結果會比較robust。
link |
這是一個Trend Game的Tip。
link |
因為我們說我們Trend那個discriminator,
link |
maximum的值就代表了我們的divergence,
link |
那其實照理說我們每次discriminator,
link |
重新initialize一個discriminator,
link |
因為實際上我們在updatediscriminator的時候,
link |
我們並沒有辦法真的跑很多iteration。
link |
跑到discriminator真的收斂為止。
link |
因為你每次如果discriminator都要跑到收斂,
link |
才update一次generator,
link |
discriminator只update幾次,
link |
就要輪到generatorupdate。
link |
所以discriminator其實update的次數,
link |
我們並沒有辦法真的去maximize,
link |
那個objective function。
link |
所以變成說我們discriminator,
link |
它的參數都是從前一個iteration,
link |
我們在訓練discriminator的時候,
link |
能不能夠用之前generator sample出來的資料。
link |
divergence的function其實是已經定好的。
link |
在這個固定好的divergence的
link |
還是說它是我們要算的divergence,
link |
就是VesselStandDistance。
link |
這個value就是VesselStandDistance。