back to index

Meta Learning – MAML (7/9)


link |
00:00.000
我們來看一下,接下來就會用到一些數學,那不想聽的人就可以休息一下。
link |
00:05.600
那其實這一段呢,我其實是今天早上剛抵達臺灣啦,剛從美國回來。
link |
00:11.720
這一段是在飛機上面在亂流中做的啦,所以一定是會有很多的錯誤啊,所以你可以找找看,看能不能找到錯誤的符號。
link |
00:19.840
好,這個是我們的training的過程,就我們的MDNL怎麼train呢?我們說你就是用歸電decent train,你要用你的Phi去對大L算它的歸電,大L長什麼樣子?
link |
00:35.840
大L呢,就是所有你手上的task小L,所有你手上的task,它們用theta hat算出來的loss function。
link |
00:45.120
theta hat是什麼?theta hat是我們把這個Phi做一次update以後得到的model參數,叫做theta hat。
link |
00:53.480
那我們現在就實際上來算一下,這一項歸電應該長什麼樣子?長什麼樣子?
link |
01:00.800
好,那這項歸電長什麼樣子呢?我們先把大L用這個summation over所有小L這一項替代,把大L用summation over小L替代。
link |
01:10.440
然後這個歸電這一項本來是放在summation外面,但是你其實可以把它拿到summation裡面。
link |
01:16.800
所以接下來我們就只要計算這個Phi這個參數,對小寫的loss,某一個task的loss function算出來的歸電的值。
link |
01:30.280
然後這個loss function它現在是用什麼樣的參數算出來這個loss function的值呢?是用theta hat算出這個loss function的值。
link |
01:39.720
好,那這個歸電是什麼意思呢?大家都知道說所謂的歸電就是它的每一個dimension代表了某一個參數對你的loss function的偏微分的結果。
link |
01:51.840
所以今天把Phi對小L做歸電得到的是一個vector,這個vector的第一位就是拿Phi1,Phi1就是Phi這個model它的第一個參數。
link |
02:04.400
Phi1其實是一個network的初始參數的值,network的初始參數的值就是一大堆的weight跟bias的值。
link |
02:11.560
所以把它通通拿出來拼成一個vector,然後它的第一個參數的值對loss的偏微分,第二個參數的值對loss的偏微分,然後到第i一個參數的值對loss的偏微分。
link |
02:22.840
那我們就拿其中一項來算算看說它應該長什麼樣,那也會算一項,那也就會算其他項。
link |
02:29.880
好,那我們來算一下拿Phi1對loss的偏微分的結果。
link |
02:37.440
那我們來看一下這個Phi1對loss的關係是什麼,那我知道所謂偏微分的意思就是,如果Phi1做一下小小的變化,到底loss會產生什麼樣的變化。
link |
02:50.800
那Phi1它是初始參數,它是訓練的時候的初始參數,那這個初始的參數會影響你最終訓練出來的結果,也就是set ahead。
link |
03:02.640
那最終訓練出來的model會影響你最終算出來的loss,也就是loss.
link |
03:08.960
所以Phi1是透過了set ahead裡面的每一個參數去影響了最終的loss function的值。
link |
03:17.160
那根據Chain rule,你就可以把Phi1對loss的偏微分寫成這樣子,就每一條Path上面,每一條Path的偏微分通通把它加起來。
link |
03:29.200
所以Phi1對loss的偏微分就是submission over set ahead裡面所有的參數,然後把set ahead裡面的這個參數拿來對loss的偏微分,然後再把Phi1拿來對set ahead,就是把Phi1的第i的參數對set ahead的第j的參數拿來做偏微分,把它們乘起來。
link |
03:56.720
然後submission over set ahead裡面所有的參數,就得到了Phi1對loss的偏微分,根據Chain rule搞出來的就是這個樣子。
link |
04:06.760
接下來就真的去算一下這一項應該是長什麼樣子,這一項應該長什麼樣子呢?前面這一項沒有問題,
link |
04:15.400
set ahead對loss的偏微分,set ahead裡面的某一個參數對loss的偏微分,沒有問題,這個就depend on你的loss長什麼樣子嘛,比如說它是cross entropy還是regression,這depend on你的loss長什麼樣子。
link |
04:30.240
其實也depend on你的testing的資料長什麼樣子,根據你的測試資料,根據你的訓練任務裡面的測試資料,跟你的這些訓練任務的loss function,你就可以算出set ahead的這一個參數對loss的偏微分。
link |
04:52.480
我們真正需要算一下的是Phi1的第i一個參數對set ahead的第j一個參數做偏微分以後的結果,這個就可以分成兩個case來考慮,我們先把這個式子只考慮其中一維就好。
link |
05:09.120
這個set ahead是一個項量,Phi1是一個項量,這個歸點是一個項量,先是把項量減項量會得到項量,我們這邊取那個項量的第j一維出來看。如果取那個項量的第j一維出來看,就是set ahead的第j一維等於Phi1的第j一維減掉f0,f0其實也是一個learning rate,
link |
05:29.600
其實在這個meta learning裡面,你會有兩個learning rate,一個是每一個參數在訓練的時候自己的learning rate,另外一個是你在訓練初始化參數的learning rate,這兩個learning rate當然不需要是一樣的,也沒有什麼理由是一樣的,然後其實也是需要調的。
link |
05:47.360
好,這個set ahead的第j一個參數等於Phi1的第j一個參數減掉learning rate乘上Phij對L of Phi的偏微分的結果,如果今天i不等於j的時候會發生什麼事呢?
link |
06:06.720
如果i不等於j,那你拿Phii去對這個式子做偏微分會發生什麼事呢?第一項就直接消失掉了,對不對?拿Phii對Phij做偏微分,那Phij跟Phii沒有什麼關係,所以它就消失掉了,只剩下後面這一項,後面這一項變成你把本來的Phij對L of Phi的偏微分再加上Phii的偏微分。
link |
06:31.520
好,那如果是i等於j呢?如果是i等於j的話,那前面就多了一項1,因為如果i等於j的話,Phij對Phij做偏微分等於1,所以前面就多了一項1。
link |
06:44.800
好,那你就把這些東西通通都帶進去,你就可以把這個微分的值算出來,你就可以做NNL。但是這樣顯然是你還要做二次微分,顯然是很花時間的。所以在NNL的原始paper裡面,它提出了一個想法就是,可不可以不要算二次微分,就假裝沒看到那個二次微分的項這樣。
link |
07:09.120
如果沒有看到那個二次微分的項,你就會發現說這個式子變得非常的簡單,當i不等於j的時候,這一項就是0,i等於j的時候,這一項就是1。
link |
07:19.120
那這樣會得到什麼結果?這樣就會得到說,我們在做summation的時候,我們不需要對所有的j做summation,我們只需要考慮i等於j那一項就好。
link |
07:28.080
因為當i不等於j的時候,紅框框裡面那一項一定會是0,i不等於j的時候,紅框框裡面那一項假裝是0,所以我們只需要考慮i等於j的配合。所以算出來,我們就可以得到這一個approximation的式子。
link |
07:41.840
我們就可以說,y i對L of zeta hat的偏微分,可以用zeta hat對L of zeta hat的偏微分來解釋。所以在Paper裡面,他是微言大意,他就提了一句說,他用了first order的approximation,他其實沒有解釋說這句話是什麼意思,但是根據後人的推測,其實就是這個意思。
link |
08:08.480
我們現在已經知道說,這邊的每一項都可以做approximation,本來我們是算y i對L of zeta hat的偏微分,現在可以改成zeta hat對L of zeta hat的偏微分。
link |
08:24.160
所以在上面這個式子裡面的phi,我們就可以把它用zeta hat取代掉。那麼實際上在運算的時候,你就不是拿phi對L of zeta hat做偏微分,這項運算起來有點麻煩,你其實是拿zeta hat去對L of zeta hat直接做偏微分,那這個運算起來就簡單很多。
link |
08:50.240
如果這邊講的你都沒有聽懂的話,就很麻煩了。