<acronym id="s8ci2"><small id="s8ci2"></small></acronym>
<rt id="s8ci2"></rt><rt id="s8ci2"><optgroup id="s8ci2"></optgroup></rt>
<acronym id="s8ci2"></acronym>
<acronym id="s8ci2"><center id="s8ci2"></center></acronym>
0
  • 聊天消息
  • 系統消息
  • 評論與回復
登錄后你可以
  • 下載海量資料
  • 學習在線課程
  • 觀看技術視頻
  • 寫文章/發帖/加入社區
會員中心
創作中心

完善資料讓更多小伙伴認識你,還能領取20積分哦,立即完善>

3天內不再提示

如何計算transformer模型的參數量

jf_pmFSk4VX ? 來源:GiantPandaCV ? 2023-07-10 09:13 ? 次閱讀

1. 前言

最近,OpenAI推出的ChatGPT展現出了卓越的性能,引發了大規模語言模型(Large Language Model,LLM)的研究熱潮。大規模語言模型的“大”體現在兩個方面:模型參數規模大,訓練數據規模大。以GPT3為例,GPT3的參數量為1750億,訓練數據量達到了570GB。進而,訓練大規模語言模型面臨兩個主要挑戰:顯存效率和計算效率。

現在業界的大語言模型都是基于transformer模型的,模型結構主要有兩大類:encoder-decoder(代表模型是T5)和decoder-only,具體的,decoder-only結構又可以分為Causal LM(代表模型是GPT系列)和PrefixLM(代表模型是GLM)。歸因于GPT系列取得的巨大成功,大多數的主流大語言模型都采用Causal LM結構。因此,針對decoder-only框架,為了更好地理解訓練訓練大語言模型的顯存效率和計算效率,本文分析采用decoder-only框架transformer模型的模型參數量、計算量、中間激活值、KV cache。

853e25d8-1d86-11ee-962d-dac502259ad0.jpg

為了方便分析,先定義好一些數學符號。記transformer模型的層數為8568a3ee-1d86-11ee-962d-dac502259ad0.png?,隱藏層維度為85808590-1d86-11ee-962d-dac502259ad0.png?,注意力頭數為8597e2c6-1d86-11ee-962d-dac502259ad0.png?。詞表大小為85af3034-1d86-11ee-962d-dac502259ad0.png?,訓練數據的批次大小為85c04e64-1d86-11ee-962d-dac502259ad0.png?,序列長度為85cf59e0-1d86-11ee-962d-dac502259ad0.png?。

2. 模型參數量

transformer模型由8568a3ee-1d86-11ee-962d-dac502259ad0.png個相同的層組成,每個層分為兩部分:self-attention塊和MLP塊。

self-attention塊的模型參數有85f2dcbc-1d86-11ee-962d-dac502259ad0.png?的權重矩陣860b6cb4-1d86-11ee-962d-dac502259ad0.png和偏置,輸出權重矩陣?861f975c-1d86-11ee-962d-dac502259ad0.png?和偏置,4個權重矩陣的形狀為86324bae-1d86-11ee-962d-dac502259ad0.png?,4個偏置的形狀為8645eff6-1d86-11ee-962d-dac502259ad0.png?。self- attention塊的參數量為8657ee68-1d86-11ee-962d-dac502259ad0.png?。

MLP塊由2個線性層組成,一般地,第一個線性層是先將維度從85808590-1d86-11ee-962d-dac502259ad0.png?映射到867afe44-1d86-11ee-962d-dac502259ad0.png,第二個線性層再將維度從867afe44-1d86-11ee-962d-dac502259ad0.png映射到85808590-1d86-11ee-962d-dac502259ad0.png。第一個線性層的權重矩陣86ac99f4-1d86-11ee-962d-dac502259ad0.png?的形狀為86c0c348-1d86-11ee-962d-dac502259ad0.png?,偏置的形狀為86d6bf18-1d86-11ee-962d-dac502259ad0.png?。第二個線性層權重矩陣86e7af30-1d86-11ee-962d-dac502259ad0.png?的形狀為86fae0c8-1d86-11ee-962d-dac502259ad0.png?,偏置形狀為8645eff6-1d86-11ee-962d-dac502259ad0.png?。MLP塊的參數量為872bf654-1d86-11ee-962d-dac502259ad0.png?。

self-attention塊和MLP塊各有一個layer normalization,包含了2個可訓練模型參數:縮放參數873cc984-1d86-11ee-962d-dac502259ad0.png?和平移參數8753cc42-1d86-11ee-962d-dac502259ad0.png?,形狀都是8645eff6-1d86-11ee-962d-dac502259ad0.png?。2個layernormalization的參數量為?867afe44-1d86-11ee-962d-dac502259ad0.png?。

87817ce6-1d86-11ee-962d-dac502259ad0.jpg

總的,每個transformer層的參數量879c4148-1d86-11ee-962d-dac502259ad0.png?。

除此之外,詞嵌入矩陣的參數量也較多,詞向量維度通常等于隱藏層維度85808590-1d86-11ee-962d-dac502259ad0.png,詞嵌入矩陣的參數量為?87c03350-1d86-11ee-962d-dac502259ad0.png。最后的輸出層的權重矩陣通常與詞嵌入矩陣是參數共享的。

關于位置編碼,如果采用可訓練式的位置編碼,會有一些可訓練模型參數,數量比較少。如果采用相對位置編碼,例如RoPE和ALiBi,則不包含可訓練的模型參數。我們忽略這部分參數。

綜上,8568a3ee-1d86-11ee-962d-dac502259ad0.png層transformer模型的可訓練模型參數量為87d9da76-1d86-11ee-962d-dac502259ad0.png。當隱藏維度?85808590-1d86-11ee-962d-dac502259ad0.png?較大時,可以忽略一次項,?模型參數量近似為8803fd6a-1d86-11ee-962d-dac502259ad0.png?。

接下來,我們估計不同版本LLaMA模型的參數量。

實際參數量 隱藏維度h 層數l 12lh^2
6.7B 4096 32 6,442,450,944
13.0B 5120 40 12,582,912,000
32.5B 6656 60 31,897,681,920
65.2B 8192 80 64,424,509,440

2.1 訓練過程中的顯存占用分析

在訓練神經網絡的過程中,占用顯存的大頭主要分為四部分:模型參數、前向計算過程中產生的中間激活、后向傳遞計算得到的梯度、優化器狀態。這里著重分析參數、梯度和優化器狀態的顯存占用,中間激活的顯存占用后面會詳細介紹。訓練大模型時通常會采用AdamW優化器,并用混合精度訓練來加速訓練,基于這個前提分析顯存占用。

在一次訓練迭代中,每個可訓練模型參數都會對應1個梯度,并對應2個優化器狀態(Adam優化器梯度的一階動量和二階動量)。設模型參數量為881c3236-1d86-11ee-962d-dac502259ad0.png?,那么梯度的元素數量為881c3236-1d86-11ee-962d-dac502259ad0.png?,AdamW優化器的元素數量為8841de5a-1d86-11ee-962d-dac502259ad0.png。float16數據類型的元素占2個bytes,float32數據類型的元素占4個bytes。在混合精度訓練中,會使用float16的模型參數進行前向傳遞和后向傳遞,計算得到float16的梯度;在優化器更新模型參數時,會使用float32的優化器狀態、float32的梯度、float32的模型參數來更新模型參數。因此,對于每個可訓練模型參數,占用了88581f9e-1d86-11ee-962d-dac502259ad0.png。使用AdamW優化器和混合精度訓練來訓練參數量為?881c3236-1d86-11ee-962d-dac502259ad0.png的大模型,?模型參數、梯度和優化器狀態占用的顯存大小為887f5154-1d86-11ee-962d-dac502259ad0.png?。

8892c3c4-1d86-11ee-962d-dac502259ad0.jpg

2.2 推理過程中的顯存占用分析

在神經網絡的推理階段,沒有優化器狀態和梯度,也不需要保存中間激活。少了梯度、優化器狀態、中間激活,模型推理階段占用的顯存要遠小于訓練階段。模型推理階段,占用顯存的大頭主要是模型參數,如果使用float16來進行推理,推理階段模型參數占用的顯存大概是88b124fe-1d86-11ee-962d-dac502259ad0.png?。如果使用KVcache來加速推理過程,?KV cache也需要占用顯存,KVcache占用的顯存下文會詳細介紹。此外,輸入數據也需要放到GPU上,還有一些中間結果(推理過程中的中間結果用完會盡快釋放掉),不過這部分占用的顯存是很小的,可以忽略。

3. 計算量FLOPs估計

FLOPs,floating point operations,表示浮點數運算次數,衡量了計算量的大小。

如何計算矩陣乘法的FLOPs呢?

對于88c2f724-1d86-11ee-962d-dac502259ad0.png?,計算?88d9729c-1d86-11ee-962d-dac502259ad0.png?需要進行?88f02cc6-1d86-11ee-962d-dac502259ad0.png?次乘法運算和?88f02cc6-1d86-11ee-962d-dac502259ad0.png?次加法運算,共計?8913769a-1d86-11ee-962d-dac502259ad0.png?次浮點數運算,需要?8913769a-1d86-11ee-962d-dac502259ad0.png?的FLOPs。對于?893c3c60-1d86-11ee-962d-dac502259ad0.png?,計算?88d9729c-1d86-11ee-962d-dac502259ad0.png?需要的浮點數運算次數為?8962e2fc-1d86-11ee-962d-dac502259ad0.png?。

在一次訓練迭代中,假設輸入數據的形狀為897c4b84-1d86-11ee-962d-dac502259ad0.png?。我們?先分析self-attention塊的計算,計算公式如下:

89962dd8-1d86-11ee-962d-dac502259ad0.png89a87cb8-1d86-11ee-962d-dac502259ad0.png

1. 計算89bccee8-1d86-11ee-962d-dac502259ad0.png?:矩陣乘法的輸入和輸出形狀為89d2b0f0-1d86-11ee-962d-dac502259ad0.png。計算量為89e69084-1d86-11ee-962d-dac502259ad0.png。

2.89fdc10a-1d86-11ee-962d-dac502259ad0.png?矩陣乘法的輸入和輸出形狀為

8a0cb43a-1d86-11ee-962d-dac502259ad0.png。計算量為?8a280b4a-1d86-11ee-962d-dac502259ad0.png?。

3. 計算在85af3034-1d86-11ee-962d-dac502259ad0.png?上的加權?8a4c4500-1d86-11ee-962d-dac502259ad0.png?,矩陣乘法的輸入和輸出形狀為8a619a9a-1d86-11ee-962d-dac502259ad0.png。計算量為?8a280b4a-1d86-11ee-962d-dac502259ad0.png?。

4. attention后的線性映射,矩陣乘法的輸入和輸出形狀為89d2b0f0-1d86-11ee-962d-dac502259ad0.png。計算量為?8a93189a-1d86-11ee-962d-dac502259ad0.png?。

接下來分析MLP塊的計算,計算公式如下

8aaa25ee-1d86-11ee-962d-dac502259ad0.png

1. 第一個線性層,矩陣乘法的輸入和輸出形狀為8ac3882c-1d86-11ee-962d-dac502259ad0.png。計算量為?8adb754a-1d86-11ee-962d-dac502259ad0.png?。

2. 第二個線性層,矩陣乘法的輸入和輸出形狀為8af27f60-1d86-11ee-962d-dac502259ad0.png。計算量為?8adb754a-1d86-11ee-962d-dac502259ad0.png?。

將上述計算量相加,得到每個transformer層的計算量大約為8b1dfc44-1d86-11ee-962d-dac502259ad0.png?。

此外,另一個計算量的大頭是logits的計算,將隱藏向量映射為詞表大小。矩陣乘法的輸入和輸出形狀為8b35ff06-1d86-11ee-962d-dac502259ad0.png,計算量為?8b4c2218-1d86-11ee-962d-dac502259ad0.png?。

因此,對于一個8568a3ee-1d86-11ee-962d-dac502259ad0.png?層的transformer模型,輸入數據形狀為897c4b84-1d86-11ee-962d-dac502259ad0.png?的情況下,一次訓練迭代的計算量為8b7f76ea-1d86-11ee-962d-dac502259ad0.png。

3.1 計算量與參數量的關聯

當隱藏維度85808590-1d86-11ee-962d-dac502259ad0.png?比較大,且遠大于序列長度85cf59e0-1d86-11ee-962d-dac502259ad0.png?時,我們可以忽略一次項,計算量可以近似為8bb25fce-1d86-11ee-962d-dac502259ad0.png?。前面提到當模型參數量為8803fd6a-1d86-11ee-962d-dac502259ad0.png?,輸入的tokens數為8bd8c614-1d86-11ee-962d-dac502259ad0.png?,存在等式8bef6874-1d86-11ee-962d-dac502259ad0.png。我們可以近似認為:?在一次前向傳遞中,對于每個token,每個模型參數,需要進行2次浮點數運算,即一次乘法法運算和一次加法運算。

一次訓練迭代包含了前向傳遞和后向傳遞,后向傳遞的計算量是前向傳遞的2倍。因此,前向傳遞 + 后向傳遞的系數8c064c1a-1d86-11ee-962d-dac502259ad0.png。一次訓練迭代中,對于每個token,每個模型參數,需要進行8c185e50-1d86-11ee-962d-dac502259ad0.png?次浮點數運算。

接下來,我們可以估計訓練GPT3-175B所需要的計算量。對于GPT3,每個token,每個參數進行了6次浮點數運算,再乘以參數量和總tokens數就得到了總的計算量。GPT3的模型參數量為8c29c1b8-1d86-11ee-962d-dac502259ad0.png?,訓練數據量為?8c3c5f3a-1d86-11ee-962d-dac502259ad0.png?tokens。

8c4efc26-1d86-11ee-962d-dac502259ad0.png

8c661cb2-1d86-11ee-962d-dac502259ad0.jpg

3.2 訓練時間估計

模型參數量和訓練總tokens數決定了訓練transformer模型需要的計算量。給定硬件GPU類型的情況下,可以估計所需要的訓練時間。給定計算量,訓練時間(也就是GPU算完這么多flops的計算時間)不僅跟GPU類型有關,還與GPU利用率有關。計算端到端訓練的GPU利用率時,不僅要考慮前向傳遞和后向傳遞的計算時間,還要**考慮CPU加載數據、優化器更新、多卡通信和記錄日志的時間。一般來講,GPU利用率一般在8c8a6fd6-1d86-11ee-962d-dac502259ad0.png之間。

上文講到一次前向傳遞中,對于每個token,每個模型參數,進行2次浮點數計算。使用激活重計算技術來減少中間激活顯存(下文會詳細介紹)需要進行一次額外的前向傳遞,因此前向傳遞+ 后向傳遞 + 激活重計算的系數=1+2+1=4。使用激活重計算的一次訓練迭代中,對于每個token,每個模型參數,需要進行8c9e60b8-1d86-11ee-962d-dac502259ad0.png?次浮點數運算。在給定訓練tokens數、硬件環境配置的情況下,訓練transformer模型的計算時間為

8cb12194-1d86-11ee-962d-dac502259ad0.png

8cc7905a-1d86-11ee-962d-dac502259ad0.jpg

以GPT3-175B為例,在1024張40GB顯存的A100上,在300Btokens的數據上訓練175B參數量的GPT3。40GB顯存A100的峰值性能為312TFLOPS,設GPU利用率為0.45,則所需要的訓練時間為34天,這與[7]中的訓練時間是對得上的。

8cee6784-1d86-11ee-962d-dac502259ad0.png

以LLaMA-65B為例,在2048張80GB顯存的A100上,在1.4TBtokens的數據上訓練了65B參數量的模型。80GB顯存A100的峰值性能為624TFLOPS,設GPU利用率為0.3,則所需要的訓練時間為21天,這與[4]中的實際訓練時間是對得上的。

8d05f390-1d86-11ee-962d-dac502259ad0.png

4. 中間激活值分析

除了模型參數、梯度、優化器狀態外,占用顯存的大頭就是前向傳遞過程中計算得到的中間激活值了,需要保存中間激活以便在后向傳遞計算梯度時使用。這里的激活(activations)指的是:前向傳遞過程中計算得到的,并在后向傳遞過程中需要用到的所有張量。這里的激活不包含模型參數和優化器狀態,但包含了dropout操作需要用到的mask矩陣。

在分析中間激活的顯存占用時,只考慮激活占用顯存的大頭,忽略掉一些小的buffers。比如,對于layernormalization,計算梯度時需要用到層的輸入、輸入的均值8d33dd50-1d86-11ee-962d-dac502259ad0.png?和方差8d45df46-1d86-11ee-962d-dac502259ad0.png?。輸入包含了8d5b7ad6-1d86-11ee-962d-dac502259ad0.png?個元素,而輸入的均值和方差分別包含了8bd8c614-1d86-11ee-962d-dac502259ad0.png?個元素。由于85808590-1d86-11ee-962d-dac502259ad0.png?通常是比較大的(千數量級),有?8d91f19c-1d86-11ee-962d-dac502259ad0.png?。因此,對于layernormalization,中間激活近似估計為?8d5b7ad6-1d86-11ee-962d-dac502259ad0.png?,而不是8db7052c-1d86-11ee-962d-dac502259ad0.png?。

大模型在訓練過程中通常采用混合精度訓練,中間激活值一般是float16或者bfloat16數據類型的。在分析中間激活的顯存占用時,假設中間激活值是以float16或bfloat16數據格式來保存的,每個元素占了2個bytes。唯一例外的是,dropout操作的mask矩陣,每個元素只占1個bytes。在下面的分析中,單位是bytes,而不是元素個數。

每個transformer層包含了一個self-attention塊和MLP塊,并分別對應了一個layer normalization連接。

先分析self-attention塊的中間激活。self-attention塊的計算公式如下:

89962dd8-1d86-11ee-962d-dac502259ad0.png

89a87cb8-1d86-11ee-962d-dac502259ad0.png

1. 對于89bccee8-1d86-11ee-962d-dac502259ad0.png?,需要保存它們共同的輸入8df04ff8-1d86-11ee-962d-dac502259ad0.png?,這就是中間激活。輸入8df04ff8-1d86-11ee-962d-dac502259ad0.png?的形狀為8e13c38e-1d86-11ee-962d-dac502259ad0.png?,元素個數為8d5b7ad6-1d86-11ee-962d-dac502259ad0.png?,占用顯存大小為8e323ab2-1d86-11ee-962d-dac502259ad0.png?。

2. 對于89fdc10a-1d86-11ee-962d-dac502259ad0.png?矩陣乘法,需要保存中間激活8e548176-1d86-11ee-962d-dac502259ad0.png?,兩個張量的形狀都是8e13c38e-1d86-11ee-962d-dac502259ad0.png?,占用顯存大小合計為8e754262-1d86-11ee-962d-dac502259ad0.png?。

3. 對于8e8d8cc8-1d86-11ee-962d-dac502259ad0.png函數,需要保存函數的輸入?89fdc10a-1d86-11ee-962d-dac502259ad0.png?,占用顯存大小為8eaf0420-1d86-11ee-962d-dac502259ad0.png?,這里的8597e2c6-1d86-11ee-962d-dac502259ad0.png?表示注意力頭數。

8ed2bb04-1d86-11ee-962d-dac502259ad0.png

8ee428ee-1d86-11ee-962d-dac502259ad0.png?的形狀為:?8ef662e8-1d86-11ee-962d-dac502259ad0.png

8f0ab716-1d86-11ee-962d-dac502259ad0.png?的形狀為:8f2396a0-1d86-11ee-962d-dac502259ad0.png

89fdc10a-1d86-11ee-962d-dac502259ad0.png?的形狀為:8f445c96-1d86-11ee-962d-dac502259ad0.png,元素個數為?8f5a3d36-1d86-11ee-962d-dac502259ad0.png?,占用顯存大小為8eaf0420-1d86-11ee-962d-dac502259ad0.png?。

4. 計算完8e8d8cc8-1d86-11ee-962d-dac502259ad0.png函數后,會進行dropout操作。需要保存一個mask矩陣,mask矩陣的形狀與89fdc10a-1d86-11ee-962d-dac502259ad0.png?相同,占用顯存大小為8f5a3d36-1d86-11ee-962d-dac502259ad0.png?。

5. 計算在85af3034-1d86-11ee-962d-dac502259ad0.png?上的attention,即?8a4c4500-1d86-11ee-962d-dac502259ad0.png?,需要保存8fed2d6c-1d86-11ee-962d-dac502259ad0.png?,大小為8eaf0420-1d86-11ee-962d-dac502259ad0.png?;以及85af3034-1d86-11ee-962d-dac502259ad0.png?,大小為90275a96-1d86-11ee-962d-dac502259ad0.png?。二者占用顯存大小合計為90482d70-1d86-11ee-962d-dac502259ad0.png?。

6. 計算輸出映射以及一個dropout操作。輸入映射需要保存其輸入,大小為90275a96-1d86-11ee-962d-dac502259ad0.png?;dropout需要保存mask矩陣,大小為8d5b7ad6-1d86-11ee-962d-dac502259ad0.png?。二者占用顯存大小合計為90a0671a-1d86-11ee-962d-dac502259ad0.png?。

因此,將上述中間激活相加得到,self-attention塊的中間激活占用顯存大小為90b20fce-1d86-11ee-962d-dac502259ad0.png?。

接下來看MLP塊的中間激活。MLP塊的計算公式如下

8aaa25ee-1d86-11ee-962d-dac502259ad0.png

1. 第一個線性層需要保存其輸入,占用顯存大小為90275a96-1d86-11ee-962d-dac502259ad0.png?。

2. 激活函數需要保存其輸入,占用顯存大小為90dd5c56-1d86-11ee-962d-dac502259ad0.png?。

3. 第二個線性層需要保存其輸入,占用顯存大小為90dd5c56-1d86-11ee-962d-dac502259ad0.png?。

4. 最后有一個dropout操作,需要保存mask矩陣,占用顯存大小為8d5b7ad6-1d86-11ee-962d-dac502259ad0.png?。

對于MLP塊,需要保存的中間激活值為910fbe12-1d86-11ee-962d-dac502259ad0.png?。

另外,self-attention塊和MLP塊分別對應了一個layer normalization。每個layer norm需要保存其輸入,大小為90275a96-1d86-11ee-962d-dac502259ad0.png?。2個layer norm需要保存的中間激活為912fd2e2-1d86-11ee-962d-dac502259ad0.png?。

綜上,每個transformer層需要保存的中間激活占用顯存大小為91429bf2-1d86-11ee-962d-dac502259ad0.png?。對于8568a3ee-1d86-11ee-962d-dac502259ad0.png層transformer模型,還有embedding層、最后的輸出層。embedding層不需要中間激活??偟亩?,當隱藏維度85808590-1d86-11ee-962d-dac502259ad0.png?比較大,層數8568a3ee-1d86-11ee-962d-dac502259ad0.png?較深時,這部分的中間激活是很少的,可以忽略。因此,對于8568a3ee-1d86-11ee-962d-dac502259ad0.png?層transformer模型,中間激活占用的顯存大小可以近似為918e2cd4-1d86-11ee-962d-dac502259ad0.png。

4.1 對比中間激活與模型參數的顯存大小

在一次訓練迭代中,模型參數(或梯度)占用的顯存大小只與模型參數量和參數數據類型有關,與輸入數據的大小是沒有關系的。優化器狀態占用的顯存大小也是一樣,與優化器類型有關,與模型參數量有關,但與輸入數據的大小無關。而中間激活值與輸入數據的大?。ㄅ未笮?/strong>85c04e64-1d86-11ee-962d-dac502259ad0.png?和序列長度85cf59e0-1d86-11ee-962d-dac502259ad0.png?)是成正相關的,隨著批次大小85c04e64-1d86-11ee-962d-dac502259ad0.png?和序列長度85cf59e0-1d86-11ee-962d-dac502259ad0.png的增大,中間激活占用的顯存會同步增大。當我們訓練神經網絡遇到顯存不足OOM(Out OfMemory)問題時,通常會嘗試減小批次大小來避免顯存不足的問題,這種方式減少的其實是中間激活占用的顯存,而不是模型參數、梯度和優化器的顯存。

以GPT3-175B為例,我們來直觀地對比下模型參數與中間激活的顯存大小。GPT3的模型配置如下。我們假設采用混合精度訓練,模型參數和中間激活都采用float16數據類型,每個元素占2個bytes。

模型名 參數量 層數 隱藏維度 注意力頭數
GPT3 175B 96 12288 96

GPT3的模型參數量為175B,占用的顯存大小為91e94efc-1d86-11ee-962d-dac502259ad0.png。GPT3模型需要占用350GB的顯存。

GPT3的序列長度85cf59e0-1d86-11ee-962d-dac502259ad0.png?為920b2784-1d86-11ee-962d-dac502259ad0.png?。對比不同的批次大小85c04e64-1d86-11ee-962d-dac502259ad0.png?占用的中間激活:

922b29ee-1d86-11ee-962d-dac502259ad0.png?時,中間激活占用顯存為92448182-1d86-11ee-962d-dac502259ad0.png,大約是模型參數顯存的0.79倍。

925af32c-1d86-11ee-962d-dac502259ad0.png?時,中間激活占用顯存為9271fd88-1d86-11ee-962d-dac502259ad0.png,大約是模型參數顯存的50倍。

928c62ea-1d86-11ee-962d-dac502259ad0.png?時,中間激活占用顯存為

929f7ba0-1d86-11ee-962d-dac502259ad0.png,大約是模型參數顯存的101倍。

可以看到隨著批次大小85c04e64-1d86-11ee-962d-dac502259ad0.png的增大,中間激活占用的顯存遠遠超過了模型參數顯存。通常會采用?激活重計算技術來減少中間激活,理論上可以將中間激活顯存從92c0fb7c-1d86-11ee-962d-dac502259ad0.png?減少到92d78c98-1d86-11ee-962d-dac502259ad0.png,代價是增加了一次額外前向計算的時間,本質上是“時間換空間”。

5. KV cache

在推斷階段,transformer模型加速推斷的一個常用策略就是使用 KV cache。一個典型的大模型生成式推斷包含了兩個階段:

1.預填充階段:輸入一個prompt序列,為每個transformer層生成 key cache和value cache(KV cache)。

2.解碼階段:使用并更新KV cache,一個接一個地生成詞,當前生成的詞依賴于之前已經生成的詞。

92ed9d6c-1d86-11ee-962d-dac502259ad0.png?個transformer層的權重矩陣為9304c852-1d86-11ee-962d-dac502259ad0.png。其中,self-attention塊的4個權重矩陣?9319b14a-1d86-11ee-962d-dac502259ad0.png,并且MLP塊的2個權重矩陣?93302484-1d86-11ee-962d-dac502259ad0.png。

預填充階段

假設第92ed9d6c-1d86-11ee-962d-dac502259ad0.png?個transformer層的輸入為93515a78-1d86-11ee-962d-dac502259ad0.png?,self-attention塊的key、value、query和output表示為93684648-1d86-11ee-962d-dac502259ad0.png,其中,?93822d4c-1d86-11ee-962d-dac502259ad0.png。

key cache和value cache的計算過程為:

9398dd94-1d86-11ee-962d-dac502259ad0.png93af854e-1d86-11ee-962d-dac502259ad0.png

92ed9d6c-1d86-11ee-962d-dac502259ad0.png?個transformer層剩余的計算過程為:

93d233fa-1d86-11ee-962d-dac502259ad0.png93e27e72-1d86-11ee-962d-dac502259ad0.png93f8cbaa-1d86-11ee-962d-dac502259ad0.png

解碼階段

給定當前生成詞在第92ed9d6c-1d86-11ee-962d-dac502259ad0.png?個transformer層的向量表示為9418a024-1d86-11ee-962d-dac502259ad0.png。推斷計算分兩部分:更新KV cache和計算第?92ed9d6c-1d86-11ee-962d-dac502259ad0.png個transformer層的輸出。

更新key cache和value cache的計算過程如下:

943ab0a6-1d86-11ee-962d-dac502259ad0.png

94514e74-1d86-11ee-962d-dac502259ad0.png

92ed9d6c-1d86-11ee-962d-dac502259ad0.png?個transformer層剩余的計算過程為:

946e96e6-1d86-11ee-962d-dac502259ad0.png

9480832e-1d86-11ee-962d-dac502259ad0.png9492f5e0-1d86-11ee-962d-dac502259ad0.png

5.1 KV cache的顯存占用分析

假設輸入序列的長度為85cf59e0-1d86-11ee-962d-dac502259ad0.png?,輸出序列的長度為88f02cc6-1d86-11ee-962d-dac502259ad0.png?,以float16來保存KV cache,那么?KVcache的峰值顯存占用大小為94c74c82-1d86-11ee-962d-dac502259ad0.png。這里第一個2表示K/V cache,第二個2表示float16占2個bytes。

以GPT3為例,對比KV cache與模型參數占用顯存的大小。GPT3模型占用顯存大小為350GB。假設批次大小925af32c-1d86-11ee-962d-dac502259ad0.png?,輸入序列長度94ee1858-1d86-11ee-962d-dac502259ad0.png?,輸出序列長度95050914-1d86-11ee-962d-dac502259ad0.png?,則KV cache占用顯存為951b99e0-1d86-11ee-962d-dac502259ad0.png,大約是模型參數顯存的0.5倍。

6. 總結

本文首先介紹了如何計算transformer模型的參數量,基于參數量可以進一步估計模型參數、梯度和優化器狀態占用的顯存大小。接著,本文估計了訓練迭代中,在給定訓練tokens數的情況下transformer模型的計算量,給予計算量和顯卡性能可以進一步估計訓練迭代的計算耗時。然后,本文分析了transformer模型前向計算過程中產生的中間激活值的顯存大小,中間激活的顯存大小與輸入數據大小正相關,甚至會遠超過模型參數占用的顯存。最后,本文介紹了transformer模型推理過程常用的加速策略:使用KVcache??偟膩碚f,分析transformer模型的參數量、計算量、中間激活和KV cache,有助于理解大模型訓練和推斷過程中的顯存效率和計算效率。

聲明:本文內容及配圖由入駐作者撰寫或者入駐合作網站授權轉載。文章觀點僅代表作者本人,不代表電子發燒友網立場。文章及其配圖僅供工程師學習之用,如有內容侵權或者其他違規問題,請聯系本站處理。 舉報投訴
  • 模型
    +關注

    關注

    1

    文章

    2797

    瀏覽量

    47972
  • Transformer
    +關注

    關注

    0

    文章

    130

    瀏覽量

    5915
  • ChatGPT
    +關注

    關注

    28

    文章

    1485

    瀏覽量

    5654

原文標題:分析transformer模型的參數量、計算量、中間激活、KV cache

文章出處:【微信號:GiantPandaCV,微信公眾號:GiantPandaCV】歡迎添加關注!文章轉載請注明出處。

收藏 人收藏

    評論

    相關推薦

    基于卷積的基礎模型InternImage網絡技術分析

    近年來大規模視覺 Transformer 的蓬勃發展推動了計算機視覺領域的性能邊界。視覺 Transformer 模型通過擴大模型
    發表于 11-18 10:49 ?547次閱讀
    基于卷積的基礎<b class='flag-5'>模型</b>InternImage網絡技術分析

    大語言模型背后的Transformer,與CNN和RNN有何不同

    ? 電子發燒友網報道(文/李彎彎)近年來,隨著大語言模型的不斷出圈,Transformer這一概念也走進了大眾視野。Transformer是一種非常流行的深度學習模型,最早于2017年
    的頭像 發表于 12-25 08:36 ?2012次閱讀
    大語言<b class='flag-5'>模型</b>背后的<b class='flag-5'>Transformer</b>,與CNN和RNN有何不同

    【大語言模型:原理與工程實踐】大語言模型的基礎技術

    特定任務對模型進行微調。這種方法的成功不僅是自然語言處理發展的一個轉折點,還為許多現實世界的應用場帶來了前所未有的性能提升。從廣為人知的GPT到BERT,預訓練的模型參數量越來越大預訓練數據越來越多
    發表于 05-05 12:17

    你了解在單GPU上就可以運行的Transformer模型

    上一步也跑不了,因為它們的內存需求太大了。例如,完整的GPT-2模型大約包含1.5B參數。最大配置的參數數量超過每層0.5B,而層數有64 層。圖2:標準Transformer
    發表于 11-02 15:19

    Google科學家設計簡化稀疏架構Switch Transformer,語言模型參數量可擴展至 1.6 萬億

    剛剛,Google Brain 高級研究科學家 Barret Zoph 發帖表示,他們設計了一個名叫「Switch Transformer」的簡化稀疏架構,可以將語言模型參數量擴展至 1.6 萬億
    的頭像 發表于 01-13 16:50 ?2754次閱讀

    一個GPU訓練一個130億參數模型

    現在的模型動輒數百、數千億參數,普通人訓不動怎么辦? 前不久,谷歌發布了參數量為 1.6 萬億的語言模型Swith Transformer,
    的頭像 發表于 02-11 09:04 ?2275次閱讀
    一個GPU訓練一個130億<b class='flag-5'>參數</b>的<b class='flag-5'>模型</b>

    Transformer模型的多模態學習應用

    隨著Transformer在視覺中的崛起,Transformer在多模態中應用也是合情合理的事情,甚至以后可能會有更多的類似的paper。
    的頭像 發表于 03-25 09:29 ?1w次閱讀
    <b class='flag-5'>Transformer</b><b class='flag-5'>模型</b>的多模態學習應用

    使用跨界模型Transformer來做物體檢測!

    用了Transformer 架構開發的一個目標檢測模型。在這篇文章中,我將通過分析DETR架構的內部工作方式來幫助提供一些關于它的直覺。 下面,我將解釋一些結構,但是如果你只是想了解如何使用模型,可以直接跳到代碼部分
    的頭像 發表于 06-10 16:04 ?2005次閱讀
    使用跨界<b class='flag-5'>模型</b><b class='flag-5'>Transformer</b>來做物體檢測!

    超大Transformer語言模型的分布式訓練框架

    模型的預訓練計算。 大模型是大勢所趨 近年來,NLP 模型的發展十分迅速,模型的大小每年以1-2個數量
    的頭像 發表于 10-11 16:46 ?2371次閱讀
    超大<b class='flag-5'>Transformer</b>語言<b class='flag-5'>模型</b>的分布式訓練框架

    一種顯著降低Transformer計算量的輕量化方法

    然而,transformer的原始公式在輸入令牌(token)數量方面具有二次計算復雜度。鑒于這個數字通常從圖像分類的14^2到圖像去噪的128^2 = 16K不等,內存和計算的這一限
    的頭像 發表于 01-10 14:12 ?969次閱讀

    在X3派上玩轉一億參數量超大Transformer,DIY專屬你的離線語音識別

    Transformer模型在自然語言領域被提出后,目前已經擴展到了計算機視覺、語音等諸多領域。然而,雖然Transformer模型在語音識別
    的頭像 發表于 02-21 16:08 ?581次閱讀
    在X3派上玩轉一億<b class='flag-5'>參數量</b>超大<b class='flag-5'>Transformer</b>,DIY專屬你的離線語音識別

    transformer模型詳解:Transformer 模型的壓縮方法

    ?動機&背景 Transformer 模型在各種自然語言任務中取得了顯著的成果,但內存和計算資源的瓶頸阻礙了其實用化部署。低秩近似和結構化剪枝是緩解這一瓶頸的主流方法。然而,作者通過分析發現,結構化
    的頭像 發表于 07-17 10:50 ?1488次閱讀
    <b class='flag-5'>transformer</b><b class='flag-5'>模型</b>詳解:<b class='flag-5'>Transformer</b> <b class='flag-5'>模型</b>的壓縮方法

    盤古大模型參數量有多少

    盤古大模型參數量有多少 盤古大模型(PanGu-α)是由中國科學院計算技術研究所提供的一種語言生成預訓練模型。該
    的頭像 發表于 08-17 11:28 ?2273次閱讀

    盤古大模型與ChatGPT的模型基礎架構

    華為盤古大模型Transformer模型架構為基礎,利用深層學習技術進行訓練。模型的每個數量達到2.6億個,是目前世界上最大的漢語預備訓練
    的頭像 發表于 09-05 09:55 ?1669次閱讀

    基于Transformer模型的壓縮方法

    基于Transformer架構的大型模型在人工智能領域中發揮著日益重要的作用,特別是在自然語言處理(NLP)和計算機視覺(CV)領域。
    的頭像 發表于 02-22 16:27 ?344次閱讀
    基于<b class='flag-5'>Transformer</b><b class='flag-5'>模型</b>的壓縮方法
    亚洲欧美日韩精品久久_久久精品AⅤ无码中文_日本中文字幕有码在线播放_亚洲视频高清不卡在线观看
    <acronym id="s8ci2"><small id="s8ci2"></small></acronym>
    <rt id="s8ci2"></rt><rt id="s8ci2"><optgroup id="s8ci2"></optgroup></rt>
    <acronym id="s8ci2"></acronym>
    <acronym id="s8ci2"><center id="s8ci2"></center></acronym>