0
  • 聊天消息
  • 系統(tǒng)消息
  • 評(píng)論與回復(fù)
登錄后你可以
  • 下載海量資料
  • 學(xué)習(xí)在線課程
  • 觀看技術(shù)視頻
  • 寫(xiě)文章/發(fā)帖/加入社區(qū)
會(huì)員中心
創(chuàng)作中心

完善資料讓更多小伙伴認(rèn)識(shí)你,還能領(lǐng)取20積分哦,立即完善>

3天內(nèi)不再提示

FlashAttenion-V3: Flash Decoding詳解

jf_pmFSk4VX ? 來(lái)源:GiantPandaCV ? 2023-10-31 16:18 ? 次閱讀

Flash Attention V1和V2的作者又推出了Flash Decoding,真是太強(qiáng)了!

Flash-Decoding借鑒了FlashAttention的優(yōu)點(diǎn),將并行化維度擴(kuò)展到keys/values序列長(zhǎng)度。這種方法幾乎不收序列長(zhǎng)度影響(這對(duì)LLM模型能力很重要),可以充分利用GPU,即使在batch size較小時(shí)(inference特點(diǎn)),也可以極大提高了encoding速度。

相關(guān)背景知識(shí)先推薦閱讀:

FlashAttention圖解(如何加速Attention)

FlashAttention2詳解(性能比FlashAttention提升200%)

Motivation

最近,像ChatGPT或Llama這樣的LLM模型受到了空前的關(guān)注。然而,它們的運(yùn)行成本卻非常高昂。雖然單次回復(fù)的成本約為0.01美元(例如在AWS 8塊A100上運(yùn)行幾秒鐘),但是當(dāng)擴(kuò)展到數(shù)十億用戶的多次交互時(shí),成本會(huì)迅速上升。而且一些場(chǎng)景的成本更高,例如代碼自動(dòng)補(bǔ)全,因?yàn)橹灰脩糨斎胍粋€(gè)新字符就會(huì)執(zhí)行。由于LLM應(yīng)用非常廣泛且還在迅速增長(zhǎng),即使稍微提升其運(yùn)行效率也會(huì)產(chǎn)生巨大的收益。

LLM inference(或稱(chēng)為decoding)是一個(gè)迭代的過(guò)程:預(yù)測(cè)的tokens是逐個(gè)生成的。如果生成的句子有N個(gè)單詞,那么模型需要進(jìn)行N次forward。一個(gè)常用的優(yōu)化技巧是KV Cache,該方法緩存了之前forward的一些中間結(jié)果,節(jié)約了大部分運(yùn)算(如MatMul),但是attention操作是個(gè)例外。隨著輸出tokens長(zhǎng)度增加,attention操作的復(fù)雜度也極具上升。

然而我們希望LLM能處理長(zhǎng)上下文。增加了上下文長(zhǎng)度,LLM可以輸出更長(zhǎng)的文檔、跟蹤更長(zhǎng)的對(duì)話,甚至在編寫(xiě)代碼之前處理整個(gè)代碼庫(kù)。例如,2022年大多數(shù)LLM的上下文長(zhǎng)度最多為2k(如GPT-3),但現(xiàn)在LLM上下文長(zhǎng)度可以擴(kuò)展到32k(Llama-2-32k),甚至最近達(dá)到了100k(CodeLlama)。在這種情況下,attention操作在推理過(guò)程中占據(jù)了相當(dāng)大的時(shí)間比例。此外,當(dāng)batch size增加時(shí),即使在相對(duì)較小的上下文中,attention操作也可能成為瓶頸。這是因?yàn)樵摬僮餍枰獙?duì)內(nèi)存的訪問(wèn)會(huì)隨著batch size增加而增加,而模型中其他操作只和模型大小相關(guān)。

因此,本文提出了Flash-Decoding,可以推理過(guò)程中顯著加速attention操作(例如長(zhǎng)序列生成速度提高8倍)。其主要思想是最大化并行加載keys和values的效率,通過(guò)重新縮放組合得到正確結(jié)果。

Multi-head attention for decoding

在decoding過(guò)程中,每個(gè)生成的新token需要與先前的tokens合并后,才能繼續(xù)執(zhí)行attention操作,即936fb5aa-77c1-11ee-939d-92fbcf53809c.png。Attention操作在訓(xùn)練過(guò)程的瓶頸主要卡在訪問(wèn)內(nèi)存讀寫(xiě)中間結(jié)果(例如93895640-77c1-11ee-939d-92fbcf53809c.png)的帶寬,相關(guān)加速方案可以參考FlashAttention和FlashAttention2。

然而,上述優(yōu)化不適合直接應(yīng)用于推理過(guò)程。因?yàn)樵谟?xùn)練過(guò)程中,F(xiàn)lashAttention對(duì)batch size和query length進(jìn)行了并行化加速。而在推理過(guò)程中,query length通常為1,這意味著如果batch size小于GPU上的SM數(shù)量(例如A100上有108個(gè)SMs),那么整個(gè)計(jì)算過(guò)程只使用了GPU的一小部分!特別是當(dāng)上下文較長(zhǎng)時(shí),通常會(huì)減小batch size來(lái)適應(yīng)GPU內(nèi)存。例如batch size = 1時(shí),F(xiàn)lashAttention對(duì)GPU利用率小于1%!

下面展示了FlashAttention的計(jì)算示意圖,該示例將keys和values分為了2個(gè)block:

93a173e2-77c1-11ee-939d-92fbcf53809c.png

FlashAttention示意圖

對(duì)應(yīng)的計(jì)算公式:

93b5acae-77c1-11ee-939d-92fbcf53809c.png

FlashAttention示意圖對(duì)應(yīng)的計(jì)算公式

注意93bdf760-77c1-11ee-939d-92fbcf53809c.png的計(jì)算過(guò)程依賴(lài)93c63aba-77c1-11ee-939d-92fbcf53809c.png,從下圖也可以看出,F(xiàn)lashAttention是按順序更新output的,其實(shí)當(dāng)時(shí)我在看FlashAttention這篇文章時(shí)就覺(jué)得這個(gè)順序操作可以優(yōu)化的,因?yàn)榉凑家猺escale,不如最后統(tǒng)一rescale,沒(méi)必要等之前block計(jì)算完(為了獲取上一個(gè)block的max值)

93d525ac-77c1-11ee-939d-92fbcf53809c.jpg

flashattention計(jì)算過(guò)程

A faster attention for decoding: Flash-Decoding

上面提到FlashAttention對(duì)batch size和query length進(jìn)行了并行化加速,F(xiàn)lash-Decoding在此基礎(chǔ)上增加了一個(gè)新的并行化維度:keys/values的序列長(zhǎng)度。即使batch size很小,但只要上下文足夠長(zhǎng),它就可以充分利用GPU。與FlashAttention類(lèi)似,F(xiàn)lash-Decoding幾乎不用額外存儲(chǔ)大量數(shù)據(jù)到全局內(nèi)存中,從而減少了內(nèi)存開(kāi)銷(xiāo)。

93e66074-77c1-11ee-939d-92fbcf53809c.gif

flashdecoding計(jì)算過(guò)程

Flash Decoding主要包含以下三個(gè)步驟(可以結(jié)合上圖來(lái)看):

將keys和values分成較小的block

使用FlashAttention并行計(jì)算query與每個(gè)block的注意力(這是和FlashAttention最大的區(qū)別)。對(duì)于每個(gè)block的每行(因?yàn)橐恍惺且粋€(gè)特征維度),F(xiàn)lash Decoding會(huì)額外記錄attention values的log-sum-exp(標(biāo)量值,用于第3步進(jìn)行rescale)

對(duì)所有output blocks進(jìn)行reduction得到最終的output,需要用log-sum-exp值來(lái)重新調(diào)整每個(gè)塊的貢獻(xiàn)

實(shí)際應(yīng)用中,第1步中的數(shù)據(jù)分塊不涉及GPU操作(因?yàn)椴恍枰谖锢砩戏珠_(kāi)),只需要對(duì)第2步和第3步執(zhí)行單獨(dú)的kernels。雖然最終的reduction操作會(huì)引入一些額外的計(jì)算,但在總體上,F(xiàn)lash-Decoding通過(guò)增加并行化的方式取得了更高的效率。

Benchmarks on CodeLlama 34B

作者對(duì)CodeLLaMa-34b的decoding throughput進(jìn)行了基準(zhǔn)測(cè)試。該模型與Llama 2具有相同的架構(gòu)。作者在各種序列長(zhǎng)度(從512到64k)上測(cè)試了decoding速度,并比較了多種attention計(jì)算方法:

PyTorch:使用純PyTorch primitives運(yùn)行注意力計(jì)算(不使用FlashAttention)。

FlashAttention v2(v2.2之前的版本)。

FasterTransformer:使用FasterTransformer attention kernel

Flash-Decoding

將從內(nèi)存中讀取整個(gè)模型和KV Cache所需的時(shí)間作為上限

940efbf6-77c1-11ee-939d-92fbcf53809c.png

Untitled

從上圖可以看出,F(xiàn)lash-Decoding在處理非常大的序列時(shí)速度可以提高8倍,并且比其他方法具有更好的可擴(kuò)展性。所有方法在處理small prompts時(shí)表現(xiàn)相似,但隨著序列長(zhǎng)度從512增加到64k,其他方法的性能都變差了,而Flash-Decoding對(duì)序列長(zhǎng)度的增加并不敏感(下圖也是很好的證明)

9422813a-77c1-11ee-939d-92fbcf53809c.png

micro-benchmark on A100

Using Flash-Decoding

作者還通了Flash-Decoding使用方式:

基于FlashAttention package ,從版本2.2開(kāi)始。

xFormers,在版本0.0.22中提供了xformers.ops.memory_efficient_attention模塊

作者也提供了LLaMa v2/CodeLLaMa的repo1和xFormers repo2。此外,作者還提供了一個(gè)針對(duì)LLaMa v1/v2的最小示例。

個(gè)人總結(jié)

Flash-Decoding對(duì)LLM在GPU上inference進(jìn)行了顯著加速(尤其是batch size較小時(shí)),并且在處理長(zhǎng)序列時(shí)具有更好的可擴(kuò)展性。

聲明:本文內(nèi)容及配圖由入駐作者撰寫(xiě)或者入駐合作網(wǎng)站授權(quán)轉(zhuǎn)載。文章觀點(diǎn)僅代表作者本人,不代表電子發(fā)燒友網(wǎng)立場(chǎng)。文章及其配圖僅供工程師學(xué)習(xí)之用,如有內(nèi)容侵權(quán)或者其他違規(guī)問(wèn)題,請(qǐng)聯(lián)系本站處理。 舉報(bào)投訴
  • gpu
    gpu
    +關(guān)注

    關(guān)注

    27

    文章

    4650

    瀏覽量

    128494
  • 模型
    +關(guān)注

    關(guān)注

    1

    文章

    3081

    瀏覽量

    48595
  • LLM
    LLM
    +關(guān)注

    關(guān)注

    0

    文章

    254

    瀏覽量

    289

原文標(biāo)題:FlashAttenion-V3: Flash Decoding詳解

文章出處:【微信號(hào):GiantPandaCV,微信公眾號(hào):GiantPandaCV】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。

收藏 人收藏

    評(píng)論

    相關(guān)推薦

    Flash基本操作——Flash基礎(chǔ)(1)#多媒體技術(shù)

    FlaSh
    未來(lái)加油dz
    發(fā)布于 :2023年05月24日 10:43:53

    Flash基本操作——Flash工具1(3)#多媒體技術(shù)

    FlaSh
    未來(lái)加油dz
    發(fā)布于 :2023年05月24日 10:46:17

    Flash基本操作——Flash工具2(3)#多媒體技術(shù)

    FlaSh
    未來(lái)加油dz
    發(fā)布于 :2023年05月24日 10:48:11

    Flash基本操作——Flash工具3(1)#多媒體技術(shù)

    FlaSh
    未來(lái)加油dz
    發(fā)布于 :2023年05月24日 10:49:01

    Flash基本操作——Flash工具3(2)#多媒體技術(shù)

    FlaSh
    未來(lái)加油dz
    發(fā)布于 :2023年05月24日 10:49:44

    Flash基本操作——Flash工具3(3)#多媒體技術(shù)

    FlaSh
    未來(lái)加油dz
    發(fā)布于 :2023年05月24日 10:50:22

    Necessary to disable "Above 4G Decoding" for View with vGPU?

    /grid-vgpu-deployment-guide.pdf 在第17頁(yè),它為幾個(gè)服務(wù)器制造商提供了BIOS建議。 它建議禁用SuperMicro的“Above 4G Decoding”。 對(duì)于Dom0為32位
    發(fā)表于 09-04 15:36

    3~25V與10安3~15V電壓可調(diào)電壓電路原理圖詳解

    3~25V與10安3~15V電壓可調(diào)電壓電路原理圖詳解
    發(fā)表于 04-16 20:47

    模電Flash動(dòng)畫(huà)詳解

    模電Flash動(dòng)畫(huà)詳解,一共有161個(gè)!
    發(fā)表于 09-27 08:15

    Flash Magic V2.45

    Flash Magic V2.45 Flash Magic V2.45軟件
    發(fā)表于 05-10 11:24 ?8次下載

    基于MSP430功能模塊詳解系列之——FLASH存儲(chǔ)器

    基于MSP430功能模塊詳解系列之——FLASH存儲(chǔ)器
    發(fā)表于 10-12 15:27 ?11次下載
    基于MSP430功能模塊<b class='flag-5'>詳解</b>系列之——<b class='flag-5'>FLASH</b>存儲(chǔ)器

    MP3-FLASH-16P 使用說(shuō)明書(shū) V1.0

    藍(lán)板MP3-FLASH-16P使用說(shuō)明書(shū) V1.0 MP3-FLASH-16P 是一個(gè)提供串口的語(yǔ)音模塊,完美的集成了 MP3、WAV 的硬解碼。同時(shí)軟件支持工業(yè)級(jí)別的串口通信協(xié)議,以
    發(fā)表于 11-28 14:08 ?24次下載

    【轉(zhuǎn)載】keil將程序裝入外部FLASH詳解

    【轉(zhuǎn)載】keil將程序裝入外部FLASH詳解
    發(fā)表于 12-01 20:21 ?14次下載
    【轉(zhuǎn)載】keil將程序裝入外部<b class='flag-5'>FLASH</b><b class='flag-5'>詳解</b>

    開(kāi)源軟件-Morse_Encoding_Decoding摩斯密碼工具

    ./oschina_soft/Morse_Encoding_Decoding.zip
    發(fā)表于 06-28 11:52 ?1次下載
    開(kāi)源軟件-Morse_Encoding_<b class='flag-5'>Decoding</b>摩斯密碼工具

    瑞薩Flash程序員V3 發(fā)布說(shuō)明

    電子發(fā)燒友網(wǎng)站提供《瑞薩Flash程序員V3 發(fā)布說(shuō)明.pdf》資料免費(fèi)下載
    發(fā)表于 02-19 09:37 ?1次下載
    瑞薩<b class='flag-5'>Flash</b>程序員<b class='flag-5'>V3</b> 發(fā)布說(shuō)明