RL論文閱讀7 - MAML2017
時間 2020-12-30
標籤
強化學習RL
人工智能
Tittle
source
標籤
總結
meta-learning的目標就是訓練一個模型,使這個模型能夠從很少的新任務的數據中快速學習一個新的任務。這個模型的訓練需要大量的不同任務作爲數據。
提出了一種meta-learning的框架,能夠用於使用梯度下降的算法,使其在應用於新的任務時,只需要很少步驟的訓練就能夠達到較好的效果。這個框架能夠用於分類任務(如圖像)和使用梯度下降來訓練策略的強化學習的任務。
其實簡單來說,就是訓練了適應一些列某類的任務的模型網絡,當有該類新任務時,只需要在這個模型上進行參數微調。
特點:
- 能夠從較少的examples中快速學習
- 隨着數據量的增多,能夠繼續增加算法的適應性
原理概述
一些標記:
- 模型 :
f
- 任務 :
T={L(x1,a1...xH,aH),q(x1),q(xt+1∣xt,at),H}
-
L損失函數
-
q(x1)初始狀態分佈
-
q(xt+1∣xt,at)狀態轉換概率分佈
- H: episode長度(多少步)
模型訓練
希望讓模型的參數處於對任務改變的敏感點,這樣任務微小的改變,都能引起很大的loss function改變,然後使用這個方向對特定任務進行更新。如下圖:
適應參數訓練
模型
fθ的參數爲
θ。當這個模型去適應一個新的任務KaTeX parse error: Undefined control sequence: \T at position 1: \̲T̲_i,那麼通過若干部梯度下降,就能夠得到針對這個任務的適應參數
θ′。
θ′使用下面這個更新公式計算(以一步gradient爲例,多步同理):
就是繼續利用
Ti的損失函數繼續優化。
α是學習率
模型參數訓練
採樣一些任務tasks,這些任務服從
p(T)分佈
然後先計算每個任務的適應參數
θ′和它的損失,然後最小化採樣任務的所有損失和來更新模型參數
θ
注意這裏計算的某個任務的損失,使用的是已經進行適應該任務的模型
fθ′,而不是通用模型
fθ
使用隨機梯度下降(SGD),那麼
θ的更新就表示爲:
β是另一個學習率
算法描述
應用到迴歸和分類問題
算法描述
注意事項:
- 定義模型的H=1,丟棄了時間步
xt,因此模型是一個輸入對應一個輸出,而不是序列輸出輸出
- 任務認爲獨立同分布
- 迴歸問題損失函數使用MSE
- 分類爲題使用交叉熵損失函數:
應用到RL問題
算法描述
注意事項:
- RL的對於任務
Ti的損失函數如下:
- 定義R爲非負, Loss之所以有負號是在RL中我們希望獎勵值最大,由於使用的是梯度下降算法,加一個負號相當於梯度上升了,向着最大的餓方向。
- 對於step8,由於策略梯度算法是on-policy算法,所以需要使用當前的適應過的策略
fθ′來採樣新的數據。