大模型(LLMs)現(xiàn)在是 NLP 領(lǐng)域的最主流方法之一了。
這個(gè)趨勢(shì)帶來的主要問題之一,就是大模型的訓(xùn)練/微調(diào)/推理需要的內(nèi)存也越來越多。
舉例來說,即使 RTX 3090 有著 24GB 的 RAM,是除了 A100 之外顯存最大的顯卡。但使用一塊 RTX 3090 依然無法 fp32 精度訓(xùn)練最小號(hào)的 LLaMA-6B。
本文總結(jié)一些 Memory-Efficient 的 LLMs 的訓(xùn)練/微調(diào)/推理方法,包括:
● fp16
●int8
●LoRA
●Gradient checkpointing
●Torch FSDP
估算模型所需的RAM
首先,我們需要了解如何根據(jù)參數(shù)量估計(jì)模型大致所需的 RAM,這在實(shí)踐中有很重要的參考意義。我們需要通過估算設(shè)置 batch_size,設(shè)置模型精度,選擇微調(diào)方法和參數(shù)分布方法等。
接下來,我們用LLaMA-6B模型為例估算其大致需要的內(nèi)存。
首先考慮精度對(duì)所需內(nèi)存的影響:
●fp32 精度,一個(gè)參數(shù)需要 32 bits, 4 bytes. ●fp16 精度,一個(gè)參數(shù)需要 16 bits, 2 bytes. ●int8 精度,一個(gè)參數(shù)需要 8 bits, 1 byte.
其次,考慮模型需要的 RAM 大致分三個(gè)部分:
●模型參數(shù) ●梯度 ●優(yōu)化器參數(shù)
模型參數(shù):等于參數(shù)量*每個(gè)參數(shù)所需內(nèi)存。
對(duì)于 fp32,LLaMA-6B 需要 6B*4 bytes = 24GB內(nèi)存
對(duì)于 int8,LLaMA-6B 需要 6B*1 byte = 6GB
梯度:同上,等于參數(shù)量*每個(gè)梯度參數(shù)所需內(nèi)存。
優(yōu)化器參數(shù):不同的優(yōu)化器所儲(chǔ)存的參數(shù)量不同。
對(duì)于常用的 AdamW 來說,需要儲(chǔ)存兩倍的模型參數(shù)(用來儲(chǔ)存一階和二階momentum)。
fp32 的 LLaMA-6B,AdamW 需要 6B*8 bytes = 48 GB
int8 的 LLaMA-6B,AdamW 需要 6B*2 bytes = 12 GB
除此之外,CUDA kernel也會(huì)占據(jù)一些 RAM,大概 1.3GB 左右,查看方式如下。
綜上,int8 精度的 LLaMA-6B 模型部分大致需要 6GB+6GB+12GB+1.3GB = 25.3GB 左右。
再根據(jù)LLaMA的架構(gòu)(hidden_size = 4096, intermediate_size =11008, num_hidden_layers = 32, context_length = 2048)計(jì)算中間變量?jī)?nèi)存。
每個(gè) instance 需要:
所以一張 A100(80GB RAM)大概可以在 int8 精度;batch_size = 50 的設(shè)定下進(jìn)行全參數(shù)訓(xùn)練。
查看消費(fèi)級(jí)顯卡的內(nèi)存和算力:
2023 GPU Benchmark and Graphics Card Comparison Chart
https://www.gpucheck.com/gpu-benchmark-graphics-card-comparison-chart
Fp16-mixed precision
混合精度訓(xùn)練的大致思路是在 forward pass 和 gradient computation 的時(shí)候使用 fp16 來加速,但是在更新參數(shù)時(shí)使用 fp32。
用 torch 實(shí)現(xiàn):
CUDA Automatic Mixed Precision examples
https://pytorch.org/docs/stable/notes/amp_examples.html
torch fp16 推理:直接使用 model.half() 將模型轉(zhuǎn)換為fp16.
使用 Huggingface Transformers:在 TrainingArguments 里聲明 fp16=True
https://huggingface.co/docs/transformers/perf_train_gpu_one#fp16-training
Int8-bitsandbytes
Int8 是個(gè)很極端的數(shù)據(jù)類型,它最多只能表示 - 128~127 的數(shù)字,并且完全沒有精度。
為了在訓(xùn)練和 inference 中使用這個(gè)數(shù)據(jù)類型,bitsandbytes 使用了兩個(gè)方法最大程度地降低了其帶來的誤差:
1. vector-wise quantization
2. mixed precision decompasition
Huggingface 在這篇文章中用動(dòng)圖解釋了 quantization 的實(shí)現(xiàn):
https://huggingface.co/blog/hf-bitsandbytes-integration
論文:
LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scalehttps://arxiv.org/abs/2208.07339
借助 Huggingface PEFT,使用 int8 訓(xùn)練 opt-6.5B 的完整流程:
https://github.com/huggingface/peft/blob/main/examples/int8_training/Finetune_opt_bnb_peft.ipynb
LoRA
Low-Rank Adaptation 是微調(diào) LLMs 最常用的省內(nèi)存方法之一。
LoRA 發(fā)現(xiàn)再微調(diào) LLMs 時(shí),更新矩陣(update matrix)往往特別 sparse,也就是說 update matrix 是低秩矩陣。LoRA 的作者根據(jù)這一特點(diǎn)將 update matrix reparametrize 為兩個(gè)低秩矩陣的積積 。 其中,,A 和 B 的秩為 r,且 。 如此一來,A+B 的參數(shù)量將大大小于 . LoRA 的論文: https://arxiv.org/pdf/2106.09685.pdf
借助 Huggingface PEFT 框架,使用 LoRA 微調(diào) mt0: https://github.com/huggingface/peft/blob/main/examples/conditional_generation/peft_lora_seq2seq.ipynb
Gradient Checkpointing
在 torch 中使用 - 把 model 用一個(gè) customize 的 function 包裝一下即可,詳見:
Explore Gradient-Checkpointing in PyTorch
https://qywu.github.io/2019/05/22/explore-gradient-checkpointing.html 在 Huggingface Transformers 中使用: https://huggingface.co/docs/transformers/v4.27.2/en/perf_train_gpu_one#gradient-checkpointing
Torch FSDP+CPU offload
Fully Sharded Data Paralle(FSDP)和 DeepSpeed 類似,均通過 ZeRO 等分布優(yōu)化算法,減少內(nèi)存的占用量。其將模型參數(shù),梯度和優(yōu)化器狀態(tài)分布至多個(gè) GPU 上,而非像 DDP 一樣,在每個(gè) GPU 上保留完整副本。 CPU offload 則允許在一個(gè) back propagation 中,將參數(shù)動(dòng)態(tài)地從 GPU -> CPU, CPU -> GPU 進(jìn)行轉(zhuǎn)移,從而節(jié)省 GPU 內(nèi)存。 Huggingface 這篇博文解釋了 ZeRO 的大致實(shí)現(xiàn)方法: https://huggingface.co/blog/zero-deepspeed-fairscale
借助 torch 實(shí)現(xiàn) FSDP,只需要將 model 用 FSDPwarp 一下;同樣,cpu_offload 也只需要一行代碼: https://pytorch.org/blog/introducing-pytorch-fully-sharded-data-parallel-api/
在這個(gè)可以查看 FSDP 支持的模型: https://pytorch.org/docs/stable/fsdp.html
在 Huggingface Transformers 中使用 Torch FSDP: https://huggingface.co/docs/transformers/v4.27.2/en/main_classes/trainer#transformers.Trainin
根據(jù)某些 issue,shard_grad_op(只分布保存 optimizer states 和 gradients)模式可能比 fully_shard 更穩(wěn)定: https://github.com/tatsu-lab/stanford_alpaca/issues/32
審核編輯 :李倩
-
RAM
+關(guān)注
關(guān)注
8文章
1351瀏覽量
114372 -
參數(shù)
+關(guān)注
關(guān)注
11文章
1733瀏覽量
31982 -
語言模型
+關(guān)注
關(guān)注
0文章
491瀏覽量
10225
原文標(biāo)題:有哪些省內(nèi)存的大語言模型訓(xùn)練/微調(diào)/推理方法?
文章出處:【微信號(hào):tyutcsplab,微信公眾號(hào):智能感知與物聯(lián)網(wǎng)技術(shù)研究所】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。
發(fā)布評(píng)論請(qǐng)先 登錄
相關(guān)推薦
評(píng)論