<acronym id="s8ci2"><small id="s8ci2"></small></acronym>
<rt id="s8ci2"></rt><rt id="s8ci2"><optgroup id="s8ci2"></optgroup></rt>
<acronym id="s8ci2"></acronym>
<acronym id="s8ci2"><center id="s8ci2"></center></acronym>
0
  • 聊天消息
  • 系統消息
  • 評論與回復
登錄后你可以
  • 下載海量資料
  • 學習在線課程
  • 觀看技術視頻
  • 寫文章/發帖/加入社區
會員中心
創作中心

完善資料讓更多小伙伴認識你,還能領取20積分哦,立即完善>

3天內不再提示

一文解析PPO算法原理

深度學習自然語言處理 ? 來源:深度學習自然語言處理 ? 2024-01-09 12:12 ? 次閱讀

作者:LLM-Finder,某廠研究大語言模型和多模態學習

寫這篇文章的動機

1. 在筆者看來RLHF是LLMs智能的關鍵之一;

2. 國內廠商在這方面投入比較少,目前看起來并沒有很重視;

3. 大家偏向于認為ChatGPT的RLHF做法最多的線索來源于InstructGPT,但是InstructGPT原文的描述也挺含糊的,很多東西只能靠猜和結合開源的實現來解讀;

4. 通常學習強化學習所依賴鏈路比較長,筆者希望以最直觀的方式幫助大家通關。

筆者會分兩篇文章來介紹,第一篇是理論篇,第二篇是實踐篇。讀者會在第一篇學習到PPO的原理和instrcutGPT中的RLHF做法;在第二篇中學習到目前影響比較大的開源RLHF實現。

e7082e78-ae34-11ee-8b88-92fbcf53809c.jpg

據公開可獲得的信息來看,ChatGPT需要有大致三個階段的訓練過程,如上圖所示:

1.Pretraining: 在大規?!盁o監督”的語料上訓練,訓練任務是預測下一個詞。

2.Supervised Fine-Tuning(SFT):在人類標注上進行微調,所謂人類標注就是人類寫Prompt,人類寫答案。然后語言模型學習模仿人類是如何作答的。這部分通常要求數據集多樣性很好,也因為標注成本很高,通常量級很小。

3.Reinforcement Learning with human feedback(RLHF):對于同一個Prompt把模型的多個輸出給人類排序,獲取人類偏好標注。用人類的偏好標注,訓練一個reward model。訓練得到的reward model會作為PPO算法中的reawrd function,來繼續優化SFT得到的模型。

通常來說,第一步最有資源門檻,第三步最有技術門檻(同時也需要大量的資源),第二步最簡單。所以目前很多廠商是直接拿了開源的第一步的模型,做SFT,或者continue-pretrain(比較小規模的無監督訓練)再SFT。他們PR的時候可能會嘴一句,無需復雜的RLHF,只需做細致的微調也能達到很好的效果。

后面兩個步驟,通常被視作是人類偏好對齊(alignment),讓模型更好地跟隨人類的指令作回復。而一些研究發現,對齊后的模型是會有對齊稅的現象的(alignment tax),即在通用能力上會有所下降。

因此,不少人是這樣認為的:第一步預訓練得到的模型就已經決定了后續模型的能力上限;后面兩步要做的事情僅僅是在盡可能減少對齊稅的情況下,對齊人類偏好。

這里可以分兩種情況分析:

? SFT過數據太多遍了,導致大模型出現遺忘;

?安全性對齊很多模型能回答的問題,強制不讓回答肯定會對模型能力有所牽制。

在筆者看來,某種意義下RL提供了對LLM的response的Global-level的監督,在一些需要答案非常精確的場景上,RL可能可以發揮出更大的威力。這個看法的依據也很樸素:

1. 比如在coding、數學推導等場景,只要response在關鍵的地方犯了一點點錯給人的感覺就是模型不會,但是SFT的loss可能區分不出來是犯錯了還是只是寫法風格的差異。

2. SFT給定了標準答案,LLM的上限可能會被標注者的水平所限制;RLHF則只給定了人類偏好,得到了一定(有可能是很大)程度的解放,模型有可能探索出更高程度的智能。這一點并不是無中生有的想法,在游戲AI領域有太多的驗證,即在模仿人類玩法(imitation learning)之后,再用RL訓練出來的模型,就是能獲得更高的智能。這里語言模型跟游戲又有多少本質的區別呢。

InstructGPT中的RLHF

這里簡要帶過具體數據構造和訓練細節,后面會專門有一篇對InstructGPT像素級的解讀。

如前文所述,InstructGPT也是包含3階段的訓練,同時我們應該注意到,RLHF這一步訓練,實則包含兩步訓練:

1. 訓練Reward Model(RM);

2. 用Reward Model和SFT Model構造Reward Function,基于PPO算法來訓練LLM。

數據集

SFT、RM和PPO用到的數據集數據量如下表所示:

e70eebc8-ae34-11ee-8b88-92fbcf53809c.jpg

注意,上表統計的是prompts數量,在RM數據中每個prompt,對應會有4~9個responses。

在構造RM數據的時候,作者采集了用戶的prompts,每個prompts包含4~9個模型的輸出,模型的輸出會給標注員進行排序。

訓練Reward Model(RM)

目標:給pormpt-response pair打分,擬合人類的偏好。

模型:這InstructGPT的paper中,雖然用了1.3B、6B和175B的GPT-3來做實驗,但是綜合考慮下,只用6B的模型來訓練Reward Model,因為作者發現用175B的模型會不穩定。把最后的unembedding層換成一個輸出為scalar的線性層。這里讀者可能會有點混亂,眾所周知,GPT的模型結構是sequence-in,sequence-out的,怎么變成scalar呢?這里文章似乎也沒提到,根據筆者的判斷和開源實現,推測是直接用最后一個token的輸出接一個linear。

Reward Model的初始化:6B的GPT-3模型在多個公開數據((ARC, BoolQ, CoQA, DROP, MultiNLI, OpenBookQA, QuAC, RACE, and Winogrande)上fintune。不過Paper中提到其實從預訓練模型或者SFT模型開始訓練結果也差不多。

訓練:以前的做法是,RM每次比較兩個模型輸出的好壞,做法很簡單類似對比學習,兩個樣本對應兩個類別,RM對這兩個樣本分別輸出兩個得分,拼成一個logits向量;人類標注比較好的那個輸出作為label,比如第一個比較好那么label為0,第二個比較好label為1;用cross entropy約束之。

但是作者發現這么做很容易過擬合;也不高效,因為每比較一次都要重新過一下reward model。

因此作者的做法是,在一個batch里面,把每個Prompt對應的所有的模型輸出,都過一遍Reward model,并把所有兩兩組合都比較一遍。比如一個Prompt有K個模型輸出,那么模型則只需要處理K個樣本就可以一氣兒做(2K)次比較。loss的設計如下:

e718f262-ae34-11ee-8b88-92fbcf53809c.png

很直觀,其中,x是prompt,yw和yl分別是較好和較差的模型response,rθ(x,y)是Reward Model的輸出。σ在文中似乎沒有解釋,不過根據公式推斷和開源實現,應該是sigmod函數。

這里要注意一個細節:在RM訓練完之后,會讓RM的輸出減去一個bias,使得reward score在人類寫的答案上(labeler demonstrations)的平均分為0。這里筆者沒找到具體在什么數據上統計的,猜測是在SFT數據上做的,如果有讀者知道是怎么做的歡迎指出。

Reinforcement Learning(RL)

直接看需要最大化的目標函數

e71d12c0-ae34-11ee-8b88-92fbcf53809c.png

其中,πΦRL和πSFT分別是正在用RL訓練的語言模型和SFT訓練得到的模型。

上式中,

第一項期望式是在最大化reward的同時,最小化和SFT模型的per-token KL penalty,可以理解為是一種正則手段,兩者組合成關于prompt-Responce pair最終的Reward:R(x,y)=rθ(x,y)?βlog(πΦRL(y∣x)/πSFT(y∣x))。per-token KL penalty的好處如下:

1. 充當熵紅利(Entropy bonus),鼓勵policy探索并阻止其坍塌為單一模式。

2. 確保策略模型產生的輸出 與 Reward Model在訓練期間看到的輸出 不會相差太大,保證Reward的可靠性。僅含這一項就是單純使用了PPO。這里也可以看出來,Reward model的能力可能會成為RLHF的瓶頸。

第二項期望式是可選項,注意到它其實是使用預訓練的數據來做跟預訓練同樣的任務(predict next word),因為這一項的數據不是模型生成的其實跟RL是并行的目標。包含這一項的算法稱之為PPO-ptx。

PPO算法

本小節以最小知識補充為前提,快速介紹PPO,不用犯怵,很簡單而直觀。

通常來說,對于一個強化學習模型,會有一個做動作的策略網絡π,它根據自己觀測的狀態(si)做出動作(ai)跟環境交互,然后會拿到一個即刻的reward(ri), 同時進入到下一個狀態(si+1);策略網絡再繼續觀測狀態si+1做下一個動作ai+1...直到達到最終狀態。這樣,策略網絡和環境的一系列互動后最終會得到一個軌跡(trajectory):τ=s1,a1,r1,s2,a2,r2,...,sT,aT,rT。

那么,在語言模型的場景下,策略網絡就是待微調的LLM,它所能做的動作就是預測下一個token,它觀測的轉狀態就是預測下一個token時所能觀測到的context(Prompt+這個token前所生成的所有tokens)。

reward除了最后一個rT等于上文提到的R(x,y)=rθ(x,y)?βlog(πRLΦ(yT∣x,y1,...,yT?1)/πSFT(yT∣x,y1,...,yT?1))

其他的ri=?βlog(πRLΦ(yi∣x,y1,...,yi?1)/πSFT(yi∣x,y1,...,yi?1))。

好,在LLM的場景中,現在可以統一一下符號:s1=x,ai=yi,si=cat(x,y1,y2,...,yi?1),其中x是prompt,yi是第i步蹦的token??吹竭@,了解PPO的同學基本上就清晰了RLHF具體是怎么做優化的了,可以直接跳過下面的科普部分。

因為PPO原文是基于Actor-Critic算法做的,Actor-Critic算法是進階版的Policy Gradient算法。下面我們從policy gradient到Actor-Critic,再到PPO,幫助RL背景比較弱的讀者串一遍。

Policy Gradient(PG)算法

核心要義:用“Reward”作為權重,最大化策略網絡所做出的動作的概率。

偽代碼核心部分一句話的事:

e72b6604-ae34-11ee-8b88-92fbcf53809c.jpg

用策略網絡πθ采樣出一個軌跡,然后根據即刻得到的rewardrt計算 discounted rewardRt=∑i=tTγi?tri;用Rt作為權重,最大化這個軌跡下所采取的動作的概率log(π(at∣st))?Rt,用梯度上升優化之。

雖然在強化學習算法中對每一步都有一個即時的“reward”,但是每一步對后面的可能狀態都是有影響的。

即,后面的動作獲取的即時“reward”都能累計到前面的動作的貢獻。但是直接加上去可能不好,畢竟不是前面的動作直接獲取的reward,但是可以打個折扣再加上去,即乘個小于1的γ。

這里面讀者可能會有個問題:可是不好的動作也要最大化概率嗎?

這里有必要稍微展開一下:

1.Rt也可以是負的,對負的Rt那就是最小化動作at的概率,這也是為什么前面提到要對RM的輸出做歸一化的其中一個原因之一。

2.即便Rt都是正的,但只要充分采樣,同一個狀態下相對的Rt較小的動作也是會被抑制的,因為同一個狀態下的動作概率求和等于1,此消彼長,只有權重最大的動作才會得到獎勵。

可是,比如同一個狀態下,有兩個動作的Rt是正的,但是因為動作采樣本來就很稀疏的,我們很可能不幸運采樣到了相對較小的Rt對應的動作,而沒有采樣到相對較大的。但因為它是正的,這時候當前的機制下,還是會鼓勵這個動作,這樣的話網絡很容易一直沿著不太好的策略去優化。為了解決這個問題,我們引入Actor-Critic算法。

Actor-Critic (AC)算法

核心要義:再增加一個Critic網絡來構造一個Reward baseline,只有獲得的reward比baseline要好才獎勵這個動作,否則抑制它。

e72f29e2-ae34-11ee-8b88-92fbcf53809c.jpg

Actor指的是策略網絡πθ;Criticb?目的就是給定一個策略網絡,預估每個狀態st,策略網絡所能拿到期望rewardb?(st)是多少。什么是期望reward,無非就是在狀態st,對πθ采樣不同的動作at所能獲取的Rt的平均值嘛。我們要選擇的動作當然是獲取的reward比平均reward要好的動作,不比baseline好的動作就得抑制它。

觀測上面算法2,其實對比PG算法就加了兩行:

1. 原來用Reward function來加權,現在用Advantage function來加權?,F在我們把b?(st)當作一個baseline方法所能拿到的reward, 用采樣出來的at所拿到的rewardRt減去b?(st)作為最大化當前動作概率的權重:At=Rt?b?(st)。其中 A_t 通常被稱作是Advantage function(或Advantage estimator),即優勢函數。

2.拉近b?(st)和Rt的距離,初學者對這個可能會費解。實則很好理解,記住b?在做什么,要預估當前策略下Rt的期望,我只要不管三七二十一,每來一個動作的Rt都拉近一下距離,其實就是在預估平均值。更一般地:

其實上面用到的b?,它無非是換了皮的Vπθ?(st)(簡寫成V?(st)),即RL中的重要概念V function:給定策略πθ在st上的期望reward。那么最后一步 T 到達的state sT通常來講是沒有隨機性的(比如下棋,最后一個state決定贏輸就是固定的reward;LLM,最后一個token生成完,response確定了,reward也就確定了),因此rT應該和V?(sT)相等。

所以我們可以重寫上面的優勢函數:

A^t=?V?(st)+rt+γrt+1+?+γT?tV?(sT)

寫成Generalized Advantage Estimation,當λ=1 下式等于上式:

A^t=δt+(γλ)δt+1+?+(γλ)T?t+1δT?1

其中,δt=rt+γV?(st+1)?V?(st)是時序差分式(TD error)。

記住這個結論:這樣我們可以用A^t優化πθ,現在我們可以用▽θlog(πθ(at∣st))?A^t來更新策略網絡了。

PPO Finetuning

上面提到的算法,有一個最嚴重的弊端是,一個軌跡只用一次就丟掉了??墒?,采樣軌跡通常是很耗時的,對應到在LLM場景則需要做推理,眾所周知LLM的推理是比訓練費勁很多的,它需要一個一個地蹦詞??墒侵苯佑弥暗牟呗圆蓸映鰜淼臉颖緛韮灮F在的策略網絡肯定不行,如何合理復用樣本則是PPO要做的事情。

做法巨簡單,大致可以用這個思想來更新:

定義 動作概率比rt(θ)=πθold(at∣st)πθ(at∣st),用▽θrt(θ)?A^t去梯度上升更新策略網絡,注意這里stat和A^t都是只之前的策略網絡πθold采樣得到的。這個公式,在筆者看來沒有直觀的解釋,需要一丟丟推導,因為是科普向這里讀者先承認就好了,后面筆者會單開一篇文章再重新梳理一遍。

本質上是最大化這個目標函數:

e739bf6a-ae34-11ee-8b88-92fbcf53809c.jpg

但是如果πθ和πθold如果差別太大,就不能用這個式子優化了,PPO給出的做法是給rt(θ)卡閾值,太大或太小就不用這一步的樣本更新了:

e7445ae2-ae34-11ee-8b88-92fbcf53809c.png

上面的目標函數可以分類討論進行分析,對優勢函數A^t大于0和小于0兩種情況分析,這個目標函數的圖像長這樣:

e751e306-ae34-11ee-8b88-92fbcf53809c.jpg

觀測圖像:

當A^t大于0,要提高動作的概率,但是如果概率比之前大比較多了(πθ是πθold的1+?倍),就不提高了

當A^t小于0,要減少動作的概率,但是如果概率比之前小比較多了(πθ是πθold的1??倍),就不減少了

偽代碼如下:

e75c7d8e-ae34-11ee-8b88-92fbcf53809c.jpg

科普到此結束,看到這讀者就可以看懂RLHF的代碼。值得注意的是為了減少讀者負擔做了大量的敘述上的簡化,方法上是比較完備的,但是說法上不夠嚴謹。Again,更詳細的強化學習科普會單開一篇文章。

大語言模型的PPO

稍微整理一下,符號和上面的科普部分不一致,不過應該不影響理解

1.現在我們的actor是SFT初始化的LLMπΦRL;

2.為了計算reward,我們需要兩個凍住參數網絡,一個RM,一個是凍住的SFT模型πSFT用來計算KL散度,參考下面兩式子:rT=R(x,y)=rθ(x,y)?βlog(πRLΦ(yT∣x,y1,...,yT?1)/πSFT(yT∣x,y1,...,yT?1))其他步的ri=?βlog(πRLΦ(yi∣x,y1,...,yi?1)/πSFT(yi∣x,y1,...,yi?1));

3.為了執行PPO算法,我們需要引入一個估計V值的網絡Vη,它初始化來自RM。所以統共,有4個網絡,兩個訓練的actorπΦRL和criticVη;兩個用來計算reward的SFT模型πSFT和RM模型。然后actor初始化來自SFT,critic初始化來自RM。

把這四個網絡,結合reward的構造,帶入到上面提到的PPO算法中,整個過程就比較清晰了。

盜一下DeepSpeed-Chat的圖,圖解如下:

e760d8c0-ae34-11ee-8b88-92fbcf53809c.jpg

看到這,相信讀者已經可以輕易看懂的DeepSpeed-Chat代碼了。??

審核編輯:黃飛

聲明:本文內容及配圖由入駐作者撰寫或者入駐合作網站授權轉載。文章觀點僅代表作者本人,不代表電子發燒友網立場。文章及其配圖僅供工程師學習之用,如有內容侵權或者其他違規問題,請聯系本站處理。 舉報投訴
  • ChatGPT
    +關注

    關注

    28

    文章

    1475

    瀏覽量

    5385

原文標題:PPO算法

文章出處:【微信號:zenRRan,微信公眾號:深度學習自然語言處理】歡迎添加關注!文章轉載請注明出處。

收藏 人收藏

    評論

    相關推薦

    拆解大語言模型RLHF中的PPO算法

    由于本文以大語言模型 RLHF 的 PPO 算法為主,所以希望你在閱讀前先弄明白大語言模型 RLHF 的前兩步,即 SFT Model 和 Reward Model 的訓練過程。另外因為本文不是純講強化學習的文章,所以我在敘述的時候不會假設你已經非常了解強化學習了。
    的頭像 發表于 12-11 18:30 ?1483次閱讀
    拆解大語言模型RLHF中的<b class='flag-5'>PPO</b><b class='flag-5'>算法</b>

    光耦PC817中解析

    光耦PC817中解析
    發表于 08-20 14:32

    看懂PID算法

    滯后的被控對象,比例+微分(PD)控制器能改善系統在 調節過程中的動態特性。 綜上所述得到條公式,這個就是模擬PID下面是關于應用,增量式PID算法。其實PID的算法可以做很深,
    發表于 07-19 16:54

    深度學習RCNN算法

    目標檢測算法圖解:看懂RCNN系列算法
    發表于 08-29 09:50

    解析傳感器的設計要點

    好的傳感器的設計是經驗加技術的結晶。般理解傳感器是將種物理量經過電路轉換成種能以另外種直觀的可表達的物理量的描述。而下文我們將對傳感器的概念、原理特性進行逐
    發表于 08-28 08:04

    基于二分圖構造LDPC碼的校驗矩陣算法及性能解析,不看肯定后悔

    依據二分圖構造LDPC碼的算法矩陣及性能解析,看不出必然
    發表于 06-22 06:52

    用PID算法調溫的經驗

    本文主要是分享資料,講解不會太多,因為分享的資料里面就有具體的詳細解析,而且百度上面也有詳細的資料,所以本次博主要是講解我用PID算法調溫的經驗。PID算法調整溫度最大的問題的溫度的
    發表于 11-23 08:27

    PID算法解析,絕對實用

    PID算法解析,絕對實用
    發表于 01-21 07:40

    C++的G代碼解析算法研究

    在數控技術發展過程中,G 代碼的解析優劣是促進數控技術的發展因素之一。但目前的解析算法,并不能更高效的進行解析處理。經過對G 代碼進行分析,提出與以往基于C 語言編寫的
    發表于 07-21 16:36 ?0次下載

    基于KMP算法的串口通訊協議解析鄒鐵

    基于KMP算法的串口通訊協議解析_鄒鐵
    發表于 03-17 08:00 ?2次下載

    基于java的負載均衡算法解析及源碼分享

    負載均衡的算法實際上就是解決跨系統調用的時候,在考慮后端機器承載情況的前提下,保證請求分配的平衡和合理。下面是基于java的負載均衡算法解析及源碼,以供參考。
    發表于 01-01 19:29 ?2085次閱讀

    用PyTorch實現了基本的RL算法

    近日,有開發人員用PyTorch實現了基本的RL算法,比如REINFORCE, vanilla actor-critic, DDPG, A3C, DQN 和PPO。這個帖子在Reddit論壇上獲得了195個贊并引發了熱議,一起來看一下吧。
    的頭像 發表于 06-07 15:36 ?6404次閱讀
    用PyTorch實現了基本的RL<b class='flag-5'>算法</b>

    基于PPO強化學習算法的AI應用案例

    Viet Nguyen就是其中一個。這位來自德國的程序員表示自己只玩到了第9個關卡。因此,他決定利用強化學習AI算法來幫他完成未通關的遺憾。
    發表于 07-29 09:30 ?2533次閱讀

    PPO物理改性及化學改性的方法

    PPO改性方法分為物理改性(共混、填充等)和化學改性(主鏈、端基改性等),物理改性主要是與其他高性能樹脂共混形成塑料合金,化學改性是在PPO分子鏈上引入活性基團改善相容性或與其他分子進行嵌段、接枝以克服自身缺陷。
    的頭像 發表于 09-06 15:12 ?3417次閱讀

    什么是材料PPO?超聲波能夠發揮什么作用?

    PPO材料一般應用于汽配行業,電子電訊行業,家電設備行業,工業機械行業,醫療行業,辦公室設備行業等。超聲波焊接機采用PPO材料可以焊機多種產品,比如在汽配行業中,超聲波焊接機可以焊接加工儀表板、汽車
    發表于 01-10 11:19 ?1103次閱讀
    什么是材料<b class='flag-5'>PPO</b>?超聲波能夠發揮什么作用?
    亚洲欧美日韩精品久久_久久精品AⅤ无码中文_日本中文字幕有码在线播放_亚洲视频高清不卡在线观看
    <acronym id="s8ci2"><small id="s8ci2"></small></acronym>
    <rt id="s8ci2"></rt><rt id="s8ci2"><optgroup id="s8ci2"></optgroup></rt>
    <acronym id="s8ci2"></acronym>
    <acronym id="s8ci2"><center id="s8ci2"></center></acronym>