0
  • 聊天消息
  • 系統(tǒng)消息
  • 評論與回復(fù)
登錄后你可以
  • 下載海量資料
  • 學(xué)習(xí)在線課程
  • 觀看技術(shù)視頻
  • 寫文章/發(fā)帖/加入社區(qū)
會(huì)員中心
創(chuàng)作中心

完善資料讓更多小伙伴認(rèn)識你,還能領(lǐng)取20積分哦,立即完善>

3天內(nèi)不再提示

改動(dòng)一行代碼,PyTorch訓(xùn)練三倍提速!這些技術(shù)是關(guān)鍵!

CVer ? 來源:機(jī)器之心 ? 2023-08-14 13:07 ? 次閱讀

用對了方法,加速 PyTorch 訓(xùn)練,有時(shí)也不是那么復(fù)雜。

近日,深度學(xué)習(xí)領(lǐng)域知名研究者、Lightning AI 的首席人工智能教育者 Sebastian Raschka 在 CVPR 2023 上發(fā)表了主題演講「Scaling PyTorch Model Training With Minimal Code Changes」。

為了能與更多人分享研究成果,Sebastian Raschka 將演講整理成一篇文章。文章探討了如何在最小代碼更改的情況下擴(kuò)展 PyTorch 模型訓(xùn)練,并表明重點(diǎn)是利用混合精度(mixed-precision)方法和多 GPU 訓(xùn)練模式,而不是低級機(jī)器優(yōu)化。

文章使用視覺 Transformer(ViT)作為基礎(chǔ)模型,ViT 模型在一個(gè)基本數(shù)據(jù)集上從頭開始,經(jīng)過約 60 分鐘的訓(xùn)練,在測試集上取得了 62% 的準(zhǔn)確率。

809a7692-39f3-11ee-9e74-dac502259ad0.png

GitHub 地址:https://github.com/rasbt/cvpr2023

以下是文章原文:

構(gòu)建基準(zhǔn)

在接下來的部分中,Sebastian 將探討如何在不進(jìn)行大量代碼重構(gòu)的情況下改善訓(xùn)練時(shí)間和準(zhǔn)確率。

想要注意的是,模型和數(shù)據(jù)集的詳細(xì)信息并不是這里的主要關(guān)注點(diǎn)(它們只是為了盡可能簡單,以便讀者可以在自己的機(jī)器上復(fù)現(xiàn),而不需要下載和安裝太多的依賴)。所有在這里分享的示例都可以在 GitHub 找到,讀者可以探索和重用完整的代碼。

80a66e5c-39f3-11ee-9e74-dac502259ad0.png

腳本 00_pytorch-vit-random-init.py 的輸出。

不要從頭開始訓(xùn)練

現(xiàn)如今,從頭開始訓(xùn)練文本或圖像的深度學(xué)習(xí)模型通常是低效的。我們通常會(huì)利用預(yù)訓(xùn)練模型,并對模型進(jìn)行微調(diào),以節(jié)省時(shí)間和計(jì)算資源,同時(shí)獲得更好的建模效果。

如果考慮上面使用的相同 ViT 架構(gòu),在另一個(gè)數(shù)據(jù)集(ImageNet)上進(jìn)行預(yù)訓(xùn)練,并對其進(jìn)行微調(diào),就可以在更短的時(shí)間內(nèi)實(shí)現(xiàn)更好的預(yù)測性能:20 分鐘(3 個(gè)訓(xùn)練 epoch)內(nèi)達(dá)到 95% 的測試準(zhǔn)確率。

80ae2476-39f3-11ee-9e74-dac502259ad0.png

00_pytorch-vit-random-init.py 和 01_pytorch-vit.py 的對比。

提升計(jì)算性能

我們可以看到,相對于從零開始訓(xùn)練,微調(diào)可以大大提升模型性能。下面的柱狀圖總結(jié)了這一點(diǎn)。

80c25bee-39f3-11ee-9e74-dac502259ad0.png

00_pytorch-vit-random-init.py 和 01_pytorch-vit.py 的對比柱狀圖。

當(dāng)然,模型效果可能因數(shù)據(jù)集或任務(wù)的不同而有所差異。但對于許多文本和圖像任務(wù)來說,從一個(gè)在通用公共數(shù)據(jù)集上預(yù)訓(xùn)練的模型開始是值得的。

接下來的部分將探索各種技巧,以加快訓(xùn)練時(shí)間,同時(shí)又不犧牲預(yù)測準(zhǔn)確性。

開源庫 Fabric

在 PyTorch 中以最小代碼更改來高效擴(kuò)展訓(xùn)練的一種方法是使用開源 Fabric 庫,它可以看作是 PyTorch 的一個(gè)輕量級包裝庫 / 接口。通過 pip 安裝。

pip install lightning

下面探索的所有技術(shù)也可以在純 PyTorch 中實(shí)現(xiàn)。Fabric 的目標(biāo)是使這一過程更加便利。

在探索「加速代碼的高級技術(shù)」之前,先介紹一下將 Fabric 集成到 PyTorch 代碼中需要進(jìn)行的小改動(dòng)。一旦完成這些改動(dòng),只需要改變一行代碼,就可以輕松地使用高級 PyTorch 功能。

PyTorch 代碼和修改后使用 Fabric 的代碼之間的區(qū)別是微小的,只涉及到一些細(xì)微的修改,如下面的代碼所示:

80cfc2b6-39f3-11ee-9e74-dac502259ad0.png

普通 PyTorch 代碼(左)和使用 Fabric 的 PyTorch 代碼

總結(jié)一下上圖,就可以得到普通的 PyTorch 代碼轉(zhuǎn)換為 PyTorch+Fabric 的三個(gè)步驟:

導(dǎo)入 Fabric 并實(shí)例化一個(gè) Fabric 對象。

使用 Fabric 設(shè)置模型、優(yōu)化器和 data loader。

損失函數(shù)使用 fabric.backward (),而不是 loss.backward ()。

80dbd06a-39f3-11ee-9e74-dac502259ad0.png

這些微小的改動(dòng)提供了一種利用 PyTorch 高級特性的途徑,而無需對現(xiàn)有代碼進(jìn)行進(jìn)一步重構(gòu)。

深入探討下面的「高級特性」之前,要確保模型的訓(xùn)練運(yùn)行時(shí)間、預(yù)測性能與之前相同。

813137bc-39f3-11ee-9e74-dac502259ad0.png

01_pytorch-vit.py 和 03_fabric-vit.py 的比較結(jié)果。

正如前面柱狀圖中所看到的,訓(xùn)練運(yùn)行時(shí)間、準(zhǔn)確率與之前完全相同,正如預(yù)期的那樣。其中,任何波動(dòng)都可以歸因于隨機(jī)性。

在前面的部分中,我們使用 Fabric 修改了 PyTorch 代碼。為什么要費(fèi)這么大的勁呢?接下來將嘗試高級技術(shù),比如混合精度和分布式訓(xùn)練,只需更改一行代碼,把下面的代碼

fabric = Fabric(accelerator="cuda")

改為

fabric=Fabric(accelerator="cuda",precision="bf16-mixed")

814ef98c-39f3-11ee-9e74-dac502259ad0.png

04_fabric-vit-mixed-precision.py 腳本的比較結(jié)果。腳本地址:https://github.com/rasbt/cvpr2023/blob/main/04_fabric-vit-mixed-precision.py

通過混合精度訓(xùn)練,我們將訓(xùn)練時(shí)間從 18 分鐘左右縮短到 6 分鐘,同時(shí)保持相同的預(yù)測性能。這種訓(xùn)練時(shí)間的縮短只需在實(shí)例化 Fabric 對象時(shí)添加參數(shù)「precision="bf16-mixed"」即可實(shí)現(xiàn)。

理解混合精度機(jī)制

混合精度訓(xùn)練實(shí)質(zhì)上使用了 16 位和 32 位精度,以確保不會(huì)損失準(zhǔn)確性。16 位表示中的計(jì)算梯度比 32 位格式快得多,并且還節(jié)省了大量內(nèi)存。這種策略在內(nèi)存或計(jì)算受限的情況下非常有益。

之所以稱為「混合」而不是「低」精度訓(xùn)練,是因?yàn)椴皇菍⑺袇?shù)和操作轉(zhuǎn)換為 16 位浮點(diǎn)數(shù)。相反,在訓(xùn)練過程中 32 位和 16 位操作之間切換,因此稱為「混合」精度。

如下圖所示,混合精度訓(xùn)練涉及步驟如下:

將權(quán)重轉(zhuǎn)換為較低精度(FP16)以加快計(jì)算速度;

計(jì)算梯度;

將梯度轉(zhuǎn)換回較高精度(FP32)以保持?jǐn)?shù)值穩(wěn)定性;

使用縮放后的梯度更新原始權(quán)重。

這種方法在保持神經(jīng)網(wǎng)絡(luò)準(zhǔn)確性和穩(wěn)定性的同時(shí),實(shí)現(xiàn)了高效的訓(xùn)練。

815c189c-39f3-11ee-9e74-dac502259ad0.png

更詳細(xì)的步驟如下:

將權(quán)重轉(zhuǎn)換為 FP16:在這一步中,神經(jīng)網(wǎng)絡(luò)的權(quán)重(或參數(shù))初始時(shí)用 FP32 格式表示,將其轉(zhuǎn)換為較低精度的 FP16 格式。這樣可以減少內(nèi)存占用,并且由于 FP16 操作所需的內(nèi)存較少,可以更快地被硬件處理。

計(jì)算梯度:使用較低精度的 FP16 權(quán)重進(jìn)行神經(jīng)網(wǎng)絡(luò)的前向傳播和反向傳播。這一步計(jì)算損失函數(shù)相對于網(wǎng)絡(luò)權(quán)重的梯度(偏導(dǎo)數(shù)),這些梯度用于在優(yōu)化過程中更新權(quán)重。

將梯度轉(zhuǎn)換回 FP32:在計(jì)算得到 FP16 格式的梯度后,將其轉(zhuǎn)換回較高精度的 FP32 格式。這種轉(zhuǎn)換對于保持?jǐn)?shù)值穩(wěn)定性非常重要,避免使用較低精度算術(shù)時(shí)可能出現(xiàn)的梯度消失或梯度爆炸等問題。

乘學(xué)習(xí)率并更新權(quán)重:以 FP32 格式表示的梯度乘以學(xué)習(xí)率將用于更新權(quán)重(標(biāo)量值,用于確定優(yōu)化過程中的步長)。

步驟 4 中的乘積用于更新原始的 FP32 神經(jīng)網(wǎng)絡(luò)權(quán)重。學(xué)習(xí)率有助于控制優(yōu)化過程的收斂性,對于實(shí)現(xiàn)良好的性能非常重要。

Brain Float 16

前面談到了「float 16-bit」精度訓(xùn)練。需要注意的是,在之前的代碼中,指定了 precision="bf16-mixed",而不是 precision="16-mixed"。這兩個(gè)都是有效的選項(xiàng)。

在這里,"bf16-mixed" 中的「bf16」表示 Brain Floating Point(bfloat16)。谷歌開發(fā)了這種格式,用于機(jī)器學(xué)習(xí)和深度學(xué)習(xí)應(yīng)用,尤其是在張量處理單元(TPU)中。Bfloat16 相比傳統(tǒng)的 float16 格式擴(kuò)展了動(dòng)態(tài)范圍,但犧牲了一定的精度。

81751608-39f3-11ee-9e74-dac502259ad0.png

擴(kuò)展的動(dòng)態(tài)范圍使得 bfloat16 能夠表示非常大和非常小的數(shù)字,使其更適用于深度學(xué)習(xí)應(yīng)用中可能遇到的數(shù)值范圍。然而,較低的精度可能會(huì)影響某些計(jì)算的準(zhǔn)確性,或在某些情況下導(dǎo)致舍入誤差。但在大多數(shù)深度學(xué)習(xí)應(yīng)用中,這種降低的精度對建模性能的影響很小。

雖然 bfloat16 最初是為 TPU 開發(fā)的,但從 NVIDIA Ampere 架構(gòu)的 A100 Tensor Core GPU 開始,已經(jīng)有幾種 NVIDIA GPU 開始支持 bfloat16。

我們可以使用下面的代碼檢查 GPU 是否支持 bfloat16:

>>> torch.cuda.is_bf16_supported() True

如果你的 GPU 不支持 bfloat16,可以將 precision="bf16-mixed" 更改為 precision="16-mixed"。

多 GPU 訓(xùn)練和完全分片數(shù)據(jù)并行

接下來要嘗試修改多 GPU 訓(xùn)練。如果我們有多個(gè) GPU 可供使用,這會(huì)帶來好處,因?yàn)樗梢宰屛覀兊哪P陀?xùn)練速度更快。

這里介紹一種更先進(jìn)的技術(shù) — 完全分片數(shù)據(jù)并行(Fully Sharded Data Parallelism (FSDP)),它同時(shí)利用了數(shù)據(jù)并行性和張量并行性。

817fd156-39f3-11ee-9e74-dac502259ad0.png

在 Fabric 中,我們可以通過下面的方式利用 FSDP 添加設(shè)備數(shù)量和多 GPU 訓(xùn)練策略:

fabric = Fabric( accelerator="cuda", precision="bf16-mixed", devices=4, strategy="FSDP" # new! )

8191ed6e-39f3-11ee-9e74-dac502259ad0.png

06_fabric-vit-mixed-fsdp.py 腳本的輸出。

現(xiàn)在使用 4 個(gè) GPU,我們的代碼運(yùn)行時(shí)間大約為 2 分鐘,是之前僅使用混合精度訓(xùn)練時(shí)的近 3 倍。

理解數(shù)據(jù)并行和張量并行

在數(shù)據(jù)并行中,小批量數(shù)據(jù)被分割,并且每個(gè) GPU 上都有模型的副本。這個(gè)過程通過多個(gè) GPU 的并行工作來加速模型的訓(xùn)練速度。

81bb8854-39f3-11ee-9e74-dac502259ad0.png

如下簡要概述了數(shù)據(jù)并行的工作原理

同一個(gè)模型被復(fù)制到所有的 GPU 上。

每個(gè) GPU 分別接收不同的輸入數(shù)據(jù)子集(不同的小批量數(shù)據(jù))。

所有的 GPU 獨(dú)立地對模型進(jìn)行前向傳播和反向傳播,計(jì)算各自的局部梯度。

收集并對所有 GPU 的梯度求平均值。

平均梯度被用于更新模型的參數(shù)。

每個(gè) GPU 都在并行地處理不同的數(shù)據(jù)子集,通過梯度的平均化和參數(shù)的更新,整個(gè)模型的訓(xùn)練過程得以加速。

這種方法的主要優(yōu)勢是速度。由于每個(gè) GPU 同時(shí)處理不同的小批量數(shù)據(jù),模型可以在更短的時(shí)間內(nèi)處理更多的數(shù)據(jù)。這可以顯著減少訓(xùn)練模型所需的時(shí)間,特別是在處理大型數(shù)據(jù)集時(shí)。

然而,數(shù)據(jù)并行也有一些限制。最重要的是,每個(gè) GPU 必須具有完整的模型和參數(shù)副本。這限制了可以訓(xùn)練的模型大小,因?yàn)槟P捅仨氝m應(yīng)單個(gè) GPU 的內(nèi)存。這對于現(xiàn)代的 ViTs 或 LLMs 來說這是不可行的。

與數(shù)據(jù)并行不同,張量并行將模型本身劃分到多個(gè) GPU 上。并且在數(shù)據(jù)并行中,每個(gè) GPU 都需要適 應(yīng)整個(gè)模型,這在訓(xùn)練較大的模型時(shí)可能成為一個(gè)限制。而張量并行允許訓(xùn)練那些對單個(gè) GPU 而言可能過大的模型,通過將模型分解并分布到多個(gè)設(shè)備上進(jìn)行訓(xùn)練。

81cc648a-39f3-11ee-9e74-dac502259ad0.png

張量并行是如何工作的呢?想象一下矩陣乘法,有兩種方式可以進(jìn)行分布計(jì)算 —— 按行或按列。為了簡單起見,考慮按列進(jìn)行分布計(jì)算。例如,我們可以將一個(gè)大型矩陣乘法操作分解為多個(gè)獨(dú)立的計(jì)算,每個(gè)計(jì)算可以在不同的 GPU 上進(jìn)行,如下圖所示。然后將結(jié)果連接起來以獲取結(jié)果,這有效地分?jǐn)偭擞?jì)算負(fù)載。

81dc51ce-39f3-11ee-9e74-dac502259ad0.png

聲明:本文內(nèi)容及配圖由入駐作者撰寫或者入駐合作網(wǎng)站授權(quán)轉(zhuǎn)載。文章觀點(diǎn)僅代表作者本人,不代表電子發(fā)燒友網(wǎng)立場。文章及其配圖僅供工程師學(xué)習(xí)之用,如有內(nèi)容侵權(quán)或者其他違規(guī)問題,請聯(lián)系本站處理。 舉報(bào)投訴
  • 數(shù)據(jù)集
    +關(guān)注

    關(guān)注

    4

    文章

    1197

    瀏覽量

    24592
  • 深度學(xué)習(xí)
    +關(guān)注

    關(guān)注

    73

    文章

    5437

    瀏覽量

    120794
  • pytorch
    +關(guān)注

    關(guān)注

    2

    文章

    795

    瀏覽量

    13080

原文標(biāo)題:CVPR 2023 大牛演講:改動(dòng)一行代碼,PyTorch訓(xùn)練三倍提速!這些技術(shù)是關(guān)鍵!

文章出處:【微信號:CVer,微信公眾號:CVer】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。

收藏 人收藏

    評論

    相關(guān)推薦

    分享50條經(jīng)典的Python一行代碼

    今天浩道跟大家分享python學(xué)習(xí)過程中非常經(jīng)典的50條一行代碼,讓大家體驗(yàn)它簡潔而功能強(qiáng)大的特點(diǎn)。同時(shí)給大家分享號主收集到的所有關(guān)于python的電子書籍,所有電子書以網(wǎng)盤打包,免費(fèi)分享給大家學(xué)習(xí)!福利在文末喔~
    發(fā)表于 08-16 15:00 ?957次閱讀

    三倍壓電路的原理圖(方波-帶示波器)

    本帖最后由 fuzhaoguo 于 2012-4-9 09:16 編輯 三倍壓電路的原理圖(方波-帶示波器)
    發(fā)表于 04-04 16:33

    個(gè)多行的字符串如何一行一行的執(zhí)行然后一行一行的顯示出來啊

    要做個(gè)將hex文件轉(zhuǎn)化成bin 文件的labview,結(jié)果發(fā)現(xiàn)不少按一行一行處理的,而是將所有字符串當(dāng)成一行來處理的,就是假如有5二十個(gè)
    發(fā)表于 06-30 14:24

    最簡單同相放大器。求解決不是應(yīng)該放大三倍嗎?怎么是2

    求解決不是應(yīng)該放大三倍嗎?怎么是2
    發(fā)表于 12-04 16:32

    Pytorch模型訓(xùn)練實(shí)用PDF教程【中文】

    ?模型部分?還是優(yōu)化器?只有這樣不斷的通過可視化診斷你的模型,不斷的對癥下藥,才能訓(xùn)練個(gè)較滿意的模型。本教程內(nèi)容及結(jié)構(gòu):本教程內(nèi)容主要為在 PyTorch訓(xùn)練
    發(fā)表于 12-21 09:18

    一行代碼Android第2版-郭霖

    一行代碼Android第2版-郭霖
    發(fā)表于 04-03 12:08

    直流三倍壓電路

    直流三倍壓電路
    發(fā)表于 07-31 08:22 ?2067次閱讀
    直流<b class='flag-5'>三倍</b>壓電路

    低功率直流三倍壓電路

    低功率直流三倍壓電路 這個(gè)采用555
    發(fā)表于 10-10 16:55 ?1972次閱讀
    低功率直流<b class='flag-5'>三倍</b>壓電路

    CES:博通將發(fā)新WiFi芯片 比現(xiàn)有技術(shù)三倍

    Broadcom(博通)公司日前表示,將在下周舉行的國際消費(fèi)電子展上發(fā)布新代Wi-Fi芯片,將比現(xiàn)有技術(shù)三倍,并符合IEEE 802.11ac標(biāo)準(zhǔn)。
    發(fā)表于 01-07 11:39 ?831次閱讀

    一行代碼——Android

    android開發(fā)。第一行代碼開發(fā)入門 。
    發(fā)表于 03-21 11:40 ?0次下載

    一行代碼——Android

    一行代碼——Android
    發(fā)表于 03-19 11:24 ?0次下載

    如何讓PyTorch模型訓(xùn)練變得飛快?

    ),使用這個(gè)清單,步確保你能榨干你模型的所有性能。 本指南從最簡單的結(jié)構(gòu)到最復(fù)雜的改動(dòng)都有,可以使你的網(wǎng)絡(luò)得到最大的好處。我會(huì)給你展示示例Pytorch
    的頭像 發(fā)表于 11-27 10:43 ?1678次閱讀

    pytorch實(shí)現(xiàn)斷電繼續(xù)訓(xùn)練時(shí)需要注意的要點(diǎn)

    本文整理了pytorch實(shí)現(xiàn)斷電繼續(xù)訓(xùn)練時(shí)需要注意的要點(diǎn),附有代碼詳解。
    的頭像 發(fā)表于 08-22 09:50 ?1356次閱讀

    解讀PyTorch模型訓(xùn)練過程

    PyTorch作為個(gè)開源的機(jī)器學(xué)習(xí)庫,以其動(dòng)態(tài)計(jì)算圖、易于使用的API和強(qiáng)大的靈活性,在深度學(xué)習(xí)領(lǐng)域得到了廣泛的應(yīng)用。本文將深入解讀PyTorch模型訓(xùn)練的全過程,包括數(shù)據(jù)準(zhǔn)備、模型
    的頭像 發(fā)表于 07-03 16:07 ?732次閱讀

    pytorch如何訓(xùn)練自己的數(shù)據(jù)

    本文將詳細(xì)介紹如何使用PyTorch框架來訓(xùn)練自己的數(shù)據(jù)。我們將從數(shù)據(jù)準(zhǔn)備、模型構(gòu)建、訓(xùn)練過程、評估和測試等方面進(jìn)行講解。 環(huán)境搭建 首先,我們需要安裝PyTorch??梢酝ㄟ^訪問
    的頭像 發(fā)表于 07-11 10:04 ?376次閱讀