0
  • 聊天消息
  • 系統(tǒng)消息
  • 評(píng)論與回復(fù)
登錄后你可以
  • 下載海量資料
  • 學(xué)習(xí)在線課程
  • 觀看技術(shù)視頻
  • 寫(xiě)文章/發(fā)帖/加入社區(qū)
會(huì)員中心
創(chuàng)作中心

完善資料讓更多小伙伴認(rèn)識(shí)你,還能領(lǐng)取20積分哦,立即完善>

3天內(nèi)不再提示

【BBuf的CUDA筆記】OpenAI Triton入門(mén)筆記一

jf_pmFSk4VX ? 來(lái)源:GiantPandaCV ? 2024-01-23 10:00 ? 次閱讀

0x1. OpenAI Triton介紹閱讀

這里來(lái)看官方的介紹:https://openai.com/research/triton ,從官方的介紹中我們可以看到OpenAI Triton的產(chǎn)生動(dòng)機(jī)以及它的目標(biāo)是什么,還可以看到一些經(jīng)典算法的實(shí)現(xiàn)例子展示。

這里的標(biāo)題是 Introducing Triton: Open-source GPU programming for neural networks ,翻譯就是《介紹 Triton:用于神經(jīng)網(wǎng)絡(luò)的開(kāi)源 GPU 編程語(yǔ)言》。然后下面的一句話翻譯過(guò)來(lái)是:我們發(fā)布了 Triton 1.0,這是一種開(kāi)源的類 Python 編程語(yǔ)言,它使得沒(méi)有 CUDA 經(jīng)驗(yàn)的研究人員能夠編寫(xiě)高效的 GPU 代碼——大多數(shù)情況下,其效能與專家所能編寫(xiě)的代碼相當(dāng)。這里指出了triton的目的,就是讓編寫(xiě)cuda kernrl變得更簡(jiǎn)單。接下來(lái)就逐步看一下介紹里的具體內(nèi)容,為了更加準(zhǔn)確這里會(huì)截圖對(duì)應(yīng)的原文然后放上我的翻譯或者理解。

1240ae7a-b926-11ee-8b88-92fbcf53809c.png

這里的意思是Triton可以使得用戶用較少的努力就寫(xiě)出一個(gè)達(dá)到硬件峰值性能的kernel,比如使用 Triton 可以編寫(xiě) FP16 矩陣乘法的核函數(shù),其性能能夠匹配 cuBLAS,并且這個(gè)代碼不超過(guò)25行。然后研究者已經(jīng)用Triton開(kāi)發(fā)了一些高效的實(shí)現(xiàn),和功能相同的Torch實(shí)現(xiàn)相比,性能可以達(dá)到兩倍提升。后面一段就是強(qiáng)調(diào)了使用CUDA來(lái)把一些原始的PyTorch實(shí)現(xiàn)寫(xiě)一個(gè)算子一般會(huì)更加高效,但是這個(gè)難度不小,并且目前已有工作也不能很好覆蓋這種情況,所以O(shè)penAI Triton誕生。

12582dde-b926-11ee-8b88-92fbcf53809c.png

這里講的是GPU編程的挑戰(zhàn),現(xiàn)代 GPU 的架構(gòu)大致可以分為三個(gè)主要部分——DRAM、SRAM 和 ALU。在優(yōu)化 CUDA 代碼時(shí),必須考慮到這些組件:

從 DRAM 的內(nèi)存?zhèn)鬏敱仨毢喜⒊纱笮褪聞?wù),以利用現(xiàn)代內(nèi)存接口的大總線寬度(內(nèi)存合并訪問(wèn))。

數(shù)據(jù)必須在重復(fù)使用前手動(dòng)存儲(chǔ)到 SRAM 中,并進(jìn)行管理來(lái)最小化bank conflict。

計(jì)算必須仔細(xì)地進(jìn)行劃分和調(diào)度,不僅是在流式多處理器(SMs)之間,還包括在其內(nèi)部,以促進(jìn)指令/線程級(jí)并行性,并利用專用的 ALU(例如,Tensor Cores)。

1280cc12-b926-11ee-8b88-92fbcf53809c.png1293ca2e-b926-11ee-8b88-92fbcf53809c.png

考慮所有這些因素可能對(duì)于擁有多年經(jīng)驗(yàn)的資深 CUDA 程序員來(lái)說(shuō)都是一個(gè)挑戰(zhàn)。Triton 的目的是完全自動(dòng)化這些優(yōu)化,以便開(kāi)發(fā)者能夠更好地專注于他們并行代碼的高層邏輯。Triton 旨在廣泛適用,因此不會(huì)自動(dòng)在流式多處理器(SMs)之間調(diào)度工作——留下一些重要的算法考慮(例如,tiling,跨 SM 同步)由開(kāi)發(fā)者自行決定。

然后給了一個(gè)表格展示cuda的編譯器和triton的區(qū)別。

12aadc5a-b926-11ee-8b88-92fbcf53809c.png12c93416-b926-11ee-8b88-92fbcf53809c.png

在所有可用的領(lǐng)域特定語(yǔ)言和即時(shí)編譯器中,Triton可能和Numba最相似:kernel被定義為一個(gè)裝飾過(guò)的函數(shù),并以不同的 program_id 并行啟動(dòng)在所謂的網(wǎng)格實(shí)例上。然而,正如下面的代碼片段所示,相似之處僅此而已:Triton 通過(guò)對(duì)塊上的操作來(lái)暴露實(shí)例內(nèi)部的并行性——這些小數(shù)組的尺寸是二的冪次方——而不是單指令多線程(SIMT)執(zhí)行模型。這樣做,Triton 有效地抽象出了所有與 CUDA 線程塊內(nèi)部并發(fā)相關(guān)的問(wèn)題(例如,內(nèi)存合并、共享內(nèi)存同步/沖突、Tensor Cores調(diào)度)。

12e5d9b8-b926-11ee-8b88-92fbcf53809c.png13058c90-b926-11ee-8b88-92fbcf53809c.png

wKgZomWvHkeAEQTkAADFqjaD4zM784.jpg

131b4224-b926-11ee-8b88-92fbcf53809c.png

注意,Triton 的即時(shí)編譯器將 X 和 Y 視為指針而不是張量;我們認(rèn)為保留對(duì)內(nèi)存訪問(wèn)的低級(jí)控制對(duì)于處理更復(fù)雜的數(shù)據(jù)結(jié)構(gòu)(例如,塊稀疏張量)是重要的。重要的是,這種特定的 softmax 實(shí)現(xiàn)在整個(gè)標(biāo)準(zhǔn)化過(guò)程中將 X 的行保留在 SRAM 中,這在適用時(shí)最大化了數(shù)據(jù)重用(約 <32K 列)。這與 PyTorch 的內(nèi)部 CUDA 代碼不同,后者使用臨時(shí)內(nèi)存使其更具通用性,但顯著更慢(如下所示)。這里的關(guān)鍵不是 Triton 本質(zhì)上更好,而是它簡(jiǎn)化了專用kernel的開(kāi)發(fā),這些內(nèi)核可能比在通用庫(kù)中找到的內(nèi)核快得多。

1335ab3c-b926-11ee-8b88-92fbcf53809c.png

Torch(v1.9)JIT編譯器的較低性能凸顯了從高級(jí)張量操作序列自動(dòng)生成 CUDA 代碼的難度。

1347b2dc-b926-11ee-8b88-92fbcf53809c.png13620100-b926-11ee-8b88-92fbcf53809c.png

這里是說(shuō)Triton大概只需要25行Python代碼就可以實(shí)現(xiàn)一個(gè)接近峰值的矩陣乘法。(后面有專門(mén)的一大節(jié)講這個(gè)代碼的原理)代碼如下:

@triton.jit
defmatmul(A,B,C,M,N,K,stride_am,stride_ak,
stride_bk,stride_bn,stride_cm,stride_cn,
**META):
#extractmetaparameters
BLOCK_M,GROUP_M=META['BLOCK_M'],META['GROUP_M']
BLOCK_N=META['BLOCK_N']
BLOCK_K=META['BLOCK_K']
#programsaregroupedtogethertoimproveL2hitrate
_pid_m=tl.program_id(0)
_pid_n=tl.program_id(1)
pid_m=_pid_m//GROUP_M
pid_n=(_pid_n*GROUP_M)+(_pid_m%GROUP_M)
#rm(resp.rn)denotesarangeofindices
#forrows(resp.col)ofC
rm=pid_m*BLOCK_M+tl.arange(0,BLOCK_M)
rn=pid_n*BLOCK_N+tl.arange(0,BLOCK_N)
#rkdenotesarangeofindicesforcolumns
#(resp.rows)ofA(resp.B)
rk=tl.arange(0,BLOCK_K)
#thememoryaddressesofelementsinthefirstblockof
#AandBcanbecomputedusingnumpy-stylebroadcasting
A=A+(rm[:,None]*stride_am+rk[None,:]*stride_ak)
B=B+(rk[:,None]*stride_bk+rn[None,:]*stride_bn)
#initializeanditerativelyupdateaccumulator
acc=tl.zeros((BLOCK_M,BLOCK_N),dtype=tl.float32)
forkinrange(K,0,-BLOCK_K):
a=tl.load(A)
b=tl.load(B)
#blocklevelmatrixmultiplication
acc+=tl.dot(a,b)
#incrementpointerssothatthenextblocksofAandB
#areloadedduringthenextiteration
A+=BLOCK_K*stride_ak
B+=BLOCK_K*stride_bk
#fuseleakyReLUifdesired
#acc=tl.where(acc>=0,acc,alpha*acc)
#writebackresult
C=C+(rm[:,None]*stride_cm+rn[None,:]*stride_cn)
mask=(rm[:,None]

手寫(xiě)矩陣乘法kernel的一個(gè)重要優(yōu)勢(shì)是,它們可以根據(jù)需要定制,以適應(yīng)輸入(例如,切片)和輸出(例如,LeakyReLU)的融合轉(zhuǎn)換。如果沒(méi)有像 Triton 這樣的系統(tǒng),沒(méi)有出色的 GPU 編程專長(zhǎng)的開(kāi)發(fā)者將無(wú)法進(jìn)行矩陣乘法內(nèi)核的定制修改。

1385810c-b926-11ee-8b88-92fbcf53809c.png1397ec0c-b926-11ee-8b88-92fbcf53809c.png

這里是說(shuō)Triton 的良好性能源于一個(gè)以 Triton-IR 為中心的模塊化系統(tǒng)架構(gòu),Triton-IR 是一個(gè)基于 LLVM 的中間表示,在這個(gè)系統(tǒng)中,多維值塊(這個(gè)是MLIR的概念)是一等公民。GPT

@triton.jit 裝飾器的工作原理是遍歷提供的 Python 函數(shù)的抽象語(yǔ)法樹(shù)(AST),以便使用常見(jiàn)的 SSA 構(gòu)建算法即時(shí)生成 Triton-IR。然后,編譯器后端會(huì)簡(jiǎn)化、優(yōu)化并自動(dòng)并行化所產(chǎn)生的 IR 代碼,再將其轉(zhuǎn)換為高質(zhì)量的 LLVM-IR —— 最終生成 PTX —— 以在近期的 NVIDIA GPU 上執(zhí)行。目前不支持 CPUAMD GPU,但我們歡迎社區(qū)貢獻(xiàn),旨在解決這一限制。

13c1ebd8-b926-11ee-8b88-92fbcf53809c.png

我們發(fā)現(xiàn),通過(guò) Triton-IR 使用塊級(jí)別程序表示,使我們的編譯器能夠自動(dòng)執(zhí)行各種重要的程序優(yōu)化。例如,可以通過(guò)觀察計(jì)算密集型塊級(jí)操作(例如,tl.dot)的操作數(shù),自動(dòng)將數(shù)據(jù)暫存到共享內(nèi)存中,并使用標(biāo)準(zhǔn)的活性分析技術(shù)進(jìn)行分配和同步。

另一方面,如下所示,Triton 程序可以高效且自動(dòng)地并行化,既可以(1)通過(guò)并發(fā)執(zhí)行不同的kernel實(shí)例在流式多處理器(SMs)間并行,也可以(2)通過(guò)分析每個(gè)塊級(jí)操作的迭代空間,并在不同的 SIMD 單元間適當(dāng)分配,從而在 SMs 內(nèi)部并行。

13d33b9a-b926-11ee-8b88-92fbcf53809c.png

0x2. 教程1 Vector Addition閱讀

13ec7ab0-b926-11ee-8b88-92fbcf53809c.png

意思是這一節(jié)教程會(huì)介紹Triton編程模型定義kernel的基本寫(xiě)法,此外也會(huì)介紹一下怎么實(shí)現(xiàn)一個(gè)良好的benchmark測(cè)試。下面來(lái)看計(jì)算kernel實(shí)現(xiàn),我把注釋改成中文了:

importtorch

importtriton
importtriton.languageastl

@triton.jit
defadd_kernel(x_ptr,#*指針*,指向第一個(gè)輸入向量。
y_ptr,#*指針*,指向第二個(gè)輸入向量。
output_ptr,#*指針*,指向輸出向量。
n_elements,#向量的大小。
BLOCK_SIZE:tl.constexpr,#每個(gè)程序應(yīng)處理的元素?cái)?shù)量。
#注意:`constexpr`這樣可以被用作形狀值。
):
#這里有多個(gè)“程序”處理不同的數(shù)據(jù)。我們?cè)谶@里識(shí)別我們是哪一個(gè)程序:
pid=tl.program_id(axis=0)#我們使用一維啟動(dòng)網(wǎng)格,所以軸是0。
#該程序?qū)⑻幚韽某跏紨?shù)據(jù)偏移的輸入。
#例如,如果你有一個(gè)長(zhǎng)度為256的向量和塊大小為64,那么程序
#將分別訪問(wèn)元素[0:64,64:128,128:192,192:256]。
#注意偏移量是一個(gè)指針列表:
block_start=pid*BLOCK_SIZE
offsets=block_start+tl.arange(0,BLOCK_SIZE)
#創(chuàng)建一個(gè)掩碼以防止內(nèi)存操作越界訪問(wèn)。
mask=offsets

這里還聲明了一個(gè)輔助函數(shù)來(lái)(1)分配z張量,(2)使用適當(dāng)?shù)木W(wǎng)格/塊大小排隊(duì)上面的kernel:

defadd(x:torch.Tensor,y:torch.Tensor):
#我們需要預(yù)分配輸出。
output=torch.empty_like(x)
assertx.is_cudaandy.is_cudaandoutput.is_cuda
n_elements=output.numel()
#SPMD啟動(dòng)網(wǎng)格表示并行運(yùn)行的kernel實(shí)例的數(shù)量。
#它類似于CUDA啟動(dòng)網(wǎng)格。它可以是Tuple[int],也可以是Callable(metaparameters)->Tuple[int]。
#在這種情況下,我們使用一個(gè)1D網(wǎng)格,其大小是塊的數(shù)量:
grid=lambdameta:(triton.cdiv(n_elements,meta['BLOCK_SIZE']),)
#注意:
#-每個(gè)torch.tensor對(duì)象都隱式地轉(zhuǎn)換為指向其第一個(gè)元素的指針。
#-使用`triton.jit`裝飾的函數(shù)可以用一個(gè)啟動(dòng)網(wǎng)格索引來(lái)獲得可調(diào)用的GPU內(nèi)核。
#-不要忘記將元參數(shù)作為關(guān)鍵字參數(shù)傳遞。
add_kernel[grid](x,y,output,n_elements,BLOCK_SIZE=1024)
#我們返回一個(gè)指向z的句柄,但是因?yàn)閌torch.cuda.synchronize()`還沒(méi)有被調(diào)用,所以這時(shí)kernel仍然
#在異步運(yùn)行。
returnoutput

我們現(xiàn)在可以使用上面定義的函數(shù)來(lái)計(jì)算兩個(gè)torch.tensor對(duì)象的逐元素求和,并測(cè)試其正確性:

torch.manual_seed(0)
size=98432
x=torch.rand(size,device='cuda')
y=torch.rand(size,device='cuda')
output_torch=x+y
output_triton=add(x,y)
print(output_torch)
print(output_triton)
print(f'Themaximumdifferencebetweentorchandtritonis'
f'{torch.max(torch.abs(output_torch-output_triton))}')

輸出:

tensor([1.3713,1.3076,0.4940,...,0.6724,1.2141,0.9733],device='cuda:0')
tensor([1.3713,1.3076,0.4940,...,0.6724,1.2141,0.9733],device='cuda:0')
Themaximumdifferencebetweentorchandtritonis0.0
13fa6076-b926-11ee-8b88-92fbcf53809c.png

我們可以對(duì)不同大小的向量進(jìn)行自定義操作的性能基準(zhǔn)測(cè)試,以了解它相對(duì)于PyTorch的表現(xiàn)如何。為了簡(jiǎn)化操作,Triton提供了一系列內(nèi)置工具,使我們能夠簡(jiǎn)潔地繪制出自定義操作在不同問(wèn)題規(guī)模下的性能圖表。

@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['size'],#用作繪圖x軸的參數(shù)名。
x_vals=[2**iforiinrange(12,28,1)],#`x_name`的不同可能值。
x_log=True,#x軸是對(duì)數(shù)的。
line_arg='provider',#其值對(duì)應(yīng)于圖中不同線條的參數(shù)名。
line_vals=['triton','torch'],#`line_arg`的可能值。
line_names=['Triton','Torch'],#線條的標(biāo)簽名稱。
styles=[('blue','-'),('green','-')],#線條樣式。
ylabel='GB/s',#y軸的標(biāo)簽名稱。
plot_name='vector-add-performance',#繪圖的名稱。也用作保存繪圖的文件名。
args={},#不在`x_names`和`y_name`中的函數(shù)參數(shù)的值。
))
defbenchmark(size,provider):
x=torch.rand(size,device='cuda',dtype=torch.float32)
y=torch.rand(size,device='cuda',dtype=torch.float32)
quantiles=[0.5,0.2,0.8]
ifprovider=='torch':
ms,min_ms,max_ms=triton.testing.do_bench(lambda:x+y,quantiles=quantiles)
ifprovider=='triton':
ms,min_ms,max_ms=triton.testing.do_bench(lambda:add(x,y),quantiles=quantiles)
gbps=lambdams:12*size/ms*1e-6
returngbps(ms),gbps(max_ms),gbps(min_ms)

gbps = lambda ms: 12 * size / ms * 1e-6這里的12表示的是數(shù)據(jù)讀寫(xiě)的bit,因?yàn)橛衳和y以及z的存在,所以是3*4=12bit。現(xiàn)在可以運(yùn)行上面的裝飾函數(shù)了。傳遞 print_data=True 參數(shù)來(lái)查看性能數(shù)據(jù),傳遞 show_plots=True 參數(shù)來(lái)繪制圖表,和/或傳遞 save_path='/path/to/results/' 參數(shù)來(lái)將它們連同原始CSV數(shù)據(jù)一起保存到磁盤(pán)上:

benchmark.run(print_data=True,show_plots=True)

140b253c-b926-11ee-8b88-92fbcf53809c.png

可以看到,對(duì)于elementwise任務(wù),Triton的性能幾乎和PyTorch持平,但是Triton寫(xiě)起來(lái)很簡(jiǎn)單。

0x3. 教程2 Fused Softmax閱讀

在這個(gè)教程中,我們將編寫(xiě)一個(gè)融合的softmax操作,這個(gè)操作對(duì)于特定類型的矩陣來(lái)說(shuō)比PyTorch的原生操作要快得多:那些行的大小可以放入GPU的SRAM中的矩陣。

通過(guò)這樣做,我們將學(xué)習(xí)到:

kernel融合對(duì)于帶寬受限操作的好處。

Triton中的reduce操作符。

動(dòng)機(jī)

自定義GPU kernel用于逐元素加法在教育上是有價(jià)值的,但在實(shí)際應(yīng)用中可能作用有限。讓我們考慮一個(gè)簡(jiǎn)單的(數(shù)值穩(wěn)定的)softmax操作的情況:

importtorch

importtriton
importtriton.languageastl

@torch.jit.script
defnaive_softmax(x):
"""使用原生pytorch計(jì)算X的逐行softmax

我們減去最大元素是為了避免溢出。Softmax對(duì)這種偏移是不變的。
"""
#讀取MN個(gè)元素;寫(xiě)入M個(gè)元素
x_max=x.max(dim=1)[0]
#讀取MN+M個(gè)元素;寫(xiě)入MN個(gè)元素
z=x-x_max[:,None]
#讀取MN個(gè)元素;寫(xiě)入MN個(gè)元素
numerator=torch.exp(z)
#讀取MN個(gè)元素;寫(xiě)入M個(gè)元素
denominator=numerator.sum(dim=1)
#讀取MN+M個(gè)元素;寫(xiě)入MN個(gè)元素
ret=numerator/denominator[:,None]
#總計(jì):讀取5MN+2M個(gè)元素;寫(xiě)入3MN+2M個(gè)元素
returnret

1421fea6-b926-11ee-8b88-92fbcf53809c.png

wKgaomWvHqyAFI-mAACokcxGs7Y255.jpg

計(jì)算kernel

我們的softmax kernel的工作方式如下:每個(gè)程序加載輸入矩陣X的一行,對(duì)其進(jìn)行歸一化處理,然后將結(jié)果寫(xiě)回到輸出Y中。需要注意的是,Triton的一個(gè)重要限制是每個(gè)塊必須包含2的冪次方個(gè)元素,因此如果我們想處理任何可能的輸入形狀,我們需要在內(nèi)部對(duì)每行進(jìn)行“pad”以及對(duì)內(nèi)存訪問(wèn)操作進(jìn)行保護(hù)(也就是防止越界):

@triton.jit
defsoftmax_kernel(output_ptr,input_ptr,input_row_stride,output_row_stride,n_cols,BLOCK_SIZE:tl.constexpr):
#softmax的各行是獨(dú)立的,所以我們?cè)谶@些行上進(jìn)行并行處理
row_idx=tl.program_id(0)
#步長(zhǎng)代表我們需要增加多少指針來(lái)前進(jìn)1行
row_start_ptr=input_ptr+row_idx*input_row_stride
#塊大小是大于n_cols的下一個(gè)2的冪次,因此我們可以將每一行放入單個(gè)塊中
col_offsets=tl.arange(0,BLOCK_SIZE)
input_ptrs=row_start_ptr+col_offsets
#將行加載到SRAM中,使用掩碼因?yàn)锽LOCK_SIZE可能大于n_cols
row=tl.load(input_ptrs,mask=col_offsets

解析來(lái)創(chuàng)建一個(gè)輔助函數(shù),該函數(shù)為任何給定的輸入張量排隊(duì)執(zhí)行kernel并且設(shè)置了啟動(dòng)參數(shù)。

defsoftmax(x):
n_rows,n_cols=x.shape
#塊大小是大于`x`中列數(shù)的最小2的冪
BLOCK_SIZE=triton.next_power_of_2(n_cols)
#我們可以使用的另一個(gè)技巧是要求編譯器通過(guò)增加每行分布的warp數(shù)(`num_warps`)來(lái)使用更多的線程。
#在下一個(gè)教程中,你將看到如何以更自然的方式自動(dòng)調(diào)整這個(gè)值,這樣你就不必自己想出手動(dòng)啟發(fā)式方法。
num_warps=4
ifBLOCK_SIZE>=2048:
num_warps=8
ifBLOCK_SIZE>=4096:
num_warps=16
#分配輸出
y=torch.empty_like(x)
#排隊(duì)執(zhí)行內(nèi)核。一維啟動(dòng)網(wǎng)格很簡(jiǎn)單:我們有每行一個(gè)內(nèi)核實(shí)例
#輸入矩陣
softmax_kernel[(n_rows,)](
y,
x,
x.stride(0),
y.stride(0),
n_cols,
num_warps=num_warps,
BLOCK_SIZE=BLOCK_SIZE,
)
returny

14385156-b926-11ee-8b88-92fbcf53809c.png

這里是驗(yàn)證Triton實(shí)現(xiàn)的fuse softmax和PyTorch的naive實(shí)現(xiàn)等價(jià),顯然他們是等價(jià)的。

BenchMark

1449f802-b926-11ee-8b88-92fbcf53809c.png

這里設(shè)定矩陣的行數(shù)為固定的4096來(lái)做benchmark。

@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['N'],#用作繪圖x軸的參數(shù)名
x_vals=[128*iforiinrange(2,100)],#`x_name`的不同可能值
line_arg='provider',#其值對(duì)應(yīng)于圖中不同線條的參數(shù)名
line_vals=[
'triton',
'torch-native',
'torch-jit',
],#`line_arg`的可能值
line_names=[
"Triton",
"Torch(原生)",
"Torch(jit)",
],#線條的標(biāo)簽名稱
styles=[('blue','-'),('green','-'),('green','--')],#線條樣式
ylabel="GB/s",#y軸的標(biāo)簽名稱
plot_name="softmax-performance",#繪圖的名稱。也用作保存繪圖的文件名。
args={'M':4096},#不在`x_names`和`y_name`中的函數(shù)參數(shù)的值
))
defbenchmark(M,N,provider):
x=torch.randn(M,N,device='cuda',dtype=torch.float32)
quantiles=[0.5,0.2,0.8]
ifprovider=='torch-native':
ms,min_ms,max_ms=triton.testing.do_bench(lambda:torch.softmax(x,axis=-1),quantiles=quantiles)
ifprovider=='triton':
ms,min_ms,max_ms=triton.testing.do_bench(lambda:softmax(x),quantiles=quantiles)
ifprovider=='torch-jit':
ms,min_ms,max_ms=triton.testing.do_bench(lambda:naive_softmax(x),quantiles=quantiles)
gbps=lambdams:2*x.nelement()*x.element_size()*1e-9/(ms*1e-3)
returngbps(ms),gbps(max_ms),gbps(min_ms)

benchmark.run(show_plots=True,print_data=True)

14cc8aa6-b926-11ee-8b88-92fbcf53809c.png14de2f68-b926-11ee-8b88-92fbcf53809c.png

這里提到雖然Triton實(shí)現(xiàn)的softmax性能更好并且易于理解和維護(hù),但PyTorch的torch.softmax則更加通用。

0x4. 教程3 Matrix Multiply閱讀

14fc667c-b926-11ee-8b88-92fbcf53809c.png

首先教程指出這里就是要寫(xiě)一個(gè)Block級(jí)別的矩陣乘法,然后這里會(huì)涉及到多維度的指針操作,程序重排以更好的命中l(wèi)2 cache以及自動(dòng)調(diào)優(yōu)。

動(dòng)機(jī)

矩陣乘法是大多數(shù)現(xiàn)代高性能計(jì)算系統(tǒng)的關(guān)鍵構(gòu)建塊。它們眾所周知難以優(yōu)化,因此它們的實(shí)現(xiàn)通常由硬件供應(yīng)商自己作為所謂的“內(nèi)核庫(kù)”(例如,cuBLAS)的一部分來(lái)完成。不幸的是,這些庫(kù)通常是專有的,無(wú)法輕易地定制以適應(yīng)現(xiàn)代深度學(xué)習(xí)工作負(fù)載的需求(例如,融合激活函數(shù))。在這個(gè)教程中,你將學(xué)習(xí)如何使用Triton自己實(shí)現(xiàn)高效的矩陣乘法,這種方法易于定制和擴(kuò)展。

大致來(lái)說(shuō),我們將要編寫(xiě)的內(nèi)核將實(shí)現(xiàn)以下塊級(jí)算法來(lái)乘以一個(gè) (M, K) 矩陣和一個(gè) (K, N) 矩陣:

#Doinparallel
forminrange(0,M,BLOCK_SIZE_M):
#Doinparallel
forninrange(0,N,BLOCK_SIZE_N):
acc=zeros((BLOCK_SIZE_M,BLOCK_SIZE_N),dtype=float32)
forkinrange(0,K,BLOCK_SIZE_K):
a=A[m:m+BLOCK_SIZE_M,k:k+BLOCK_SIZE_K]
b=B[k:k+BLOCK_SIZE_K,n:n+BLOCK_SIZE_N]
acc+=dot(a,b)
C[m:m+BLOCK_SIZE_M,n:n+BLOCK_SIZE_N]=acc

其中,雙重嵌套的for循環(huán)的每次迭代都由一個(gè)專用的Triton program實(shí)例執(zhí)行。

計(jì)算kernel

上述算法實(shí)際上在Triton中相當(dāng)容易實(shí)現(xiàn)。主要的難點(diǎn)來(lái)自于在內(nèi)循環(huán)中計(jì)算必須讀取A和B塊的內(nèi)存位置。為此,我們需要多維指針運(yùn)算。

指針運(yùn)算

對(duì)于一個(gè)2D Tensor X,X[i, j]的內(nèi)存位置為&X[i, j] = X + i*stride_xi + j*stride_xj。因此,對(duì)于A[m : m+BLOCK_SIZE_M, k:k+BLOCK_SIZE_K]和B[k : k+BLOCK_SIZE_K, n : n+BLOCK_SIZE_N]的塊指針可以用下面的偽代碼定義:

&A[m:m+BLOCK_SIZE_M,k:k+BLOCK_SIZE_K]=a_ptr+(m:m+BLOCK_SIZE_M)[:,None]*A.stride(0)+(k:k+BLOCK_SIZE_K)[None,:]*A.stride(1);
&B[k:k+BLOCK_SIZE_K,n:n+BLOCK_SIZE_N]=b_ptr+(k:k+BLOCK_SIZE_K)[:,None]*B.stride(0)+(n:n+BLOCK_SIZE_N)[None,:]*B.stride(1);

這意味著A和B塊的指針可以在Triton中初始化,比如 k=0 如下代碼所示。另外注意,我們需要一個(gè)額外的模運(yùn)算來(lái)處理M不是BLOCK_SIZE_M的倍數(shù)或N不是BLOCK_SIZE_N的倍數(shù)的情況,在這種情況下,我們可以用一些無(wú)用的值填充數(shù)據(jù),這些值不會(huì)對(duì)結(jié)果產(chǎn)生影響。對(duì)于K維度,我們稍后將使用掩碼加載語(yǔ)義來(lái)處理。

offs_am=(pid_m*BLOCK_SIZE_M+tl.arange(0,BLOCK_SIZE_M))%M
offs_bn=(pid_n*BLOCK_SIZE_N+tl.arange(0,BLOCK_SIZE_N))%N
offs_k=tl.arange(0,BLOCK_SIZE_K)
a_ptrs=a_ptr+(offs_am[:,None]*stride_am+offs_k[None,:]*stride_ak)
b_ptrs=b_ptr+(offs_k[:,None]*stride_bk+offs_bn[None,:]*stride_bn)

然后在內(nèi)循環(huán)中按如下方式更新:

a_ptrs+=BLOCK_SIZE_K*stride_ak;
b_ptrs+=BLOCK_SIZE_K*stride_bk;

如上所述,每個(gè)program實(shí)例計(jì)算一個(gè) [BLOCK_SIZE_M, BLOCK_SIZE_N] 大小的C矩陣塊。重要的是要記住,這些塊的計(jì)算順序是很重要的,因?yàn)樗鼤?huì)影響我們程序的L2緩存命中率,不幸的是,一個(gè)簡(jiǎn)單的行優(yōu)先順序是不夠的。

pid=triton.program_id(0);
grid_m=(M+BLOCK_SIZE_M-1)//BLOCK_SIZE_M;
grid_n=(N+BLOCK_SIZE_N-1)//BLOCK_SIZE_N;
pid_m=pid/grid_n;
pid_n=pid%grid_n;

L2 Cache優(yōu)化

如上所述,每個(gè)程序?qū)嵗?jì)算一個(gè) [BLOCK_SIZE_M, BLOCK_SIZE_N] 大小的C矩陣塊。重要的是要記住,這些塊的計(jì)算順序很重要,因?yàn)樗鼤?huì)影響我們程序的L2緩存命中率,不幸的是,一個(gè)簡(jiǎn)單的行主序排序是不夠的。

一個(gè)可能的解決方案是以一種促進(jìn)數(shù)據(jù)重用的順序啟動(dòng)塊。這可以通過(guò)在切換到下一列之前將塊在GROUP_M行的super group中分組來(lái)實(shí)現(xiàn):

#程序ID
pid=tl.program_id(axis=0)
#沿M軸的程序ID數(shù)量
num_pid_m=tl.cdiv(M,BLOCK_SIZE_M)
#沿N軸的程序ID數(shù)量
num_pid_n=tl.cdiv(N,BLOCK_SIZE_N)
#組中的程序數(shù)量
num_pid_in_group=GROUP_SIZE_M*num_pid_n
#該程序所在組的ID
group_id=pid//num_pid_in_group
#組中第一個(gè)程序的行ID
first_pid_m=group_id*GROUP_SIZE_M
#如果`num_pid_m`不能被`GROUP_SIZE_M`整除,最后一個(gè)組更小
group_size_m=min(num_pid_m-first_pid_m,GROUP_SIZE_M)
#*在組內(nèi)*,程序按列主序排列
#程序在*啟動(dòng)網(wǎng)格*中的行ID
pid_m=first_pid_m+(pid%group_size_m)
#程序在*啟動(dòng)網(wǎng)格*中的列ID
pid_n=(pid%num_pid_in_group)//group_size_m

例如,在下面的矩陣乘法中,每個(gè)矩陣由9個(gè)塊乘以9個(gè)塊組成,我們可以看到,如果我們按行主序計(jì)算輸出,我們需要將90個(gè)塊加載到SRAM中以計(jì)算前9個(gè)輸出塊,但如果我們按grouped ordering進(jìn)行計(jì)算,我們只需要加載54個(gè)塊。

15242360-b926-11ee-8b88-92fbcf53809c.png

在實(shí)際應(yīng)用中,這可以在某些硬件架構(gòu)上提高我們矩陣乘法內(nèi)核的性能超過(guò)10%(例如,在A100上從220提升到245 TFLOPS)。

L2 Cache優(yōu)化原理補(bǔ)充講解

上面的group oredering的訪問(wèn)代碼比較難理解,這里來(lái)更詳細(xì)的解析一下。

wKgaomWvHvSAfbS3AAB0PF-ZBkw397.jpg

#程序ID
pid=tl.program_id(axis=0)
#沿M軸的程序ID數(shù)量
num_pid_m=tl.cdiv(M,BLOCK_SIZE_M)
#沿N軸的程序ID數(shù)量
num_pid_n=tl.cdiv(N,BLOCK_SIZE_N)

這里的num_pid_m和num_pid_n就是求分別要在M和N方向循環(huán)多少次。

然后上面圖中的黑色數(shù)字其實(shí)就可以理解為program id,我們可以看到program id增加的方向其實(shí)就代表了遍歷的ordering,對(duì)于row major來(lái)說(shuō)就是在行方向上順序遍歷,而對(duì)于group ordering來(lái)說(shuō)就是按照一個(gè)BLOCK_SIZE_M*BLOCK_SIZE_N這么大的一個(gè)小組來(lái)遍歷。其實(shí)這段代碼就是完成group ordering的遍歷:

num_pid_in_group=GROUP_SIZE_M*num_pid_n
group_id=pid//num_pid_in_group
first_pid_m=group_id*GROUP_SIZE_M
group_size_m=min(num_pid_m-first_pid_m,GROUP_SIZE_M)
pid_m=first_pid_m+(pid%group_size_m)
pid_n=(pid%num_pid_in_group)//group_size_m

以上面圖來(lái)看,num_pid_m=3,num_pid_n=3,num_pid_in_group=group_id * GROUP_SIZE_M=9*3=27,也就是下面的紅色框里面的program個(gè)數(shù),從名字也可以看出來(lái)這個(gè)紅色框劃分的區(qū)域也是一個(gè)group。

1539a208-b926-11ee-8b88-92fbcf53809c.png

group_id 就表示當(dāng)前的這次 "循環(huán)", 是在第幾個(gè)紅色框里,以program 0為例,這里為group_id = pid // num_pid_in_group=0//27=0。而first_pid_m 代表當(dāng)前 group 中的第一個(gè)黃色program在全局的M維度上是第幾個(gè)program ,這里為first_pid_m = group_id * GROUP_SIZE_M=0,group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)這里是考慮到最后一個(gè)group可能占不滿數(shù)據(jù)(存在padding),所以就做一個(gè)截?cái)嗵幚怼?/p>

pid_m=first_pid_m+(pid%group_size_m)
pid_n=(pid%num_pid_in_group)//group_size_m

這兩行代碼計(jì)算當(dāng)前的program處理的黃色小塊坐標(biāo)([pid_m, pid_n]),pid_m這行是在行方向上移動(dòng),pid_n這行則是保證在上面的紅色框里面一定是一列一列來(lái)訪問(wèn)的。

作為對(duì)比,在Row-major的方法中,訪問(wèn)方式應(yīng)該是這樣的:

pid_m=pid//num_pid_n
pid_n=pid%num_pid_n

計(jì)算最后的結(jié)果

有了上面的鋪墊,我們就可以計(jì)算最終的結(jié)果了,下面的代碼展示了完整的Triton 矩陣乘法kernel實(shí)現(xiàn)。

#使用`triton.jit`裝飾的函數(shù)可以通過(guò)`triton.autotune`裝飾器進(jìn)行自動(dòng)調(diào)優(yōu),該裝飾器包括:
#-一系列定義不同配置的`triton.Config`對(duì)象,
#這些配置涉及元參數(shù)(例如`BLOCK_SIZE_M`)和編譯選項(xiàng)(例如`num_warps`)的不同設(shè)置
#-一個(gè)自動(dòng)調(diào)優(yōu)*關(guān)鍵字*,其值的變化將觸發(fā)對(duì)所有
#提供的配置的評(píng)估
@triton.autotune(
configs=[
#每個(gè)Config定義了一組特定的配置參數(shù)和編譯選項(xiàng)
triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':256,'BLOCK_SIZE_K':64,'GROUP_SIZE_M':8},num_stages=3,
num_warps=8),
triton.Config({'BLOCK_SIZE_M':64,'BLOCK_SIZE_N':256,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':128,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':64,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M':64,'BLOCK_SIZE_N':128,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M':128,'BLOCK_SIZE_N':32,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=4,
num_warps=4),
triton.Config({'BLOCK_SIZE_M':64,'BLOCK_SIZE_N':32,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=5,
num_warps=2),
triton.Config({'BLOCK_SIZE_M':32,'BLOCK_SIZE_N':64,'BLOCK_SIZE_K':32,'GROUP_SIZE_M':8},num_stages=5,
num_warps=2),
],
key=['M','N','K'],#自動(dòng)調(diào)優(yōu)關(guān)鍵字
)
@triton.jit
defmatmul_kernel(
#指向矩陣的指針
a_ptr,b_ptr,c_ptr,
#矩陣維度
M,N,K,
#步長(zhǎng)變量表示在特定維度上移動(dòng)1個(gè)元素時(shí)指針增加的量。
#例如`stride_am`是將`a_ptr`增加多少以獲取下一行的元素(A有M行)。
stride_am,stride_ak,#A矩陣的步長(zhǎng)
stride_bk,stride_bn,#B矩陣的步長(zhǎng)
stride_cm,stride_cn,#C矩陣的步長(zhǎng)
#元參數(shù)
BLOCK_SIZE_M:tl.constexpr,BLOCK_SIZE_N:tl.constexpr,BLOCK_SIZE_K:tl.constexpr,#
GROUP_SIZE_M:tl.constexpr,#
ACTIVATION:tl.constexpr#激活函數(shù)
):
"""用于計(jì)算矩陣乘法C=AxB的內(nèi)核。
A的形狀為(M,K),B的形狀為(K,N),C的形狀為(M,N)。
"""
#-----------------------------------------------------------
#將程序ID`pid`映射到它應(yīng)該計(jì)算的C矩陣的塊。
#這是以groupedordering完成的,以促進(jìn)L2數(shù)據(jù)重用。
#詳細(xì)解釋看一節(jié)
pid=tl.program_id(axis=0)
num_pid_m=tl.cdiv(M,BLOCK_SIZE_M)
num_pid_n=tl.cdiv(N,BLOCK_SIZE_N)
num_pid_in_group=GROUP_SIZE_M*num_pid_n
group_id=pid//num_pid_in_group
first_pid_m=group_id*GROUP_SIZE_M
group_size_m=min(num_pid_m-first_pid_m,GROUP_SIZE_M)
pid_m=first_pid_m+(pid%group_size_m)
pid_n=(pid%num_pid_in_group)//group_size_m

#----------------------------------------------------------
#為A和B的第一個(gè)塊創(chuàng)建指針。
#我們將在K方向移動(dòng)時(shí)推進(jìn)這個(gè)指針并累加
#`a_ptrs`是[BLOCK_SIZE_M,BLOCK_SIZE_K]塊的指針
#`b_ptrs`是[BLOCK_SIZE_K,BLOCK_SIZE_N]塊的指針
#有關(guān)詳細(xì)信息,請(qǐng)參閱上方“指針?biāo)阈g(shù)”部分
offs_am=(pid_m*BLOCK_SIZE_M+tl.arange(0,BLOCK_SIZE_M))%M
offs_bn=(pid_n*BLOCK_SIZE_N+tl.arange(0,BLOCK_SIZE_N))%N
offs_k=tl.arange(0,BLOCK_SIZE_K)
a_ptrs=a_ptr+(offs_am[:,None]*stride_am+offs_k[None,:]*stride_ak)
b_ptrs=b_ptr+(offs_k[:,None]*stride_bk+offs_bn[None,:]*stride_bn)

#-----------------------------------------------------------
#迭代以計(jì)算C矩陣的一個(gè)塊。
#我們將累加到一個(gè)`[BLOCK_SIZE_M,BLOCK_SIZE_N]`塊
#的fp32值以獲得更高的精度。
#`accumulator`在循環(huán)后會(huì)轉(zhuǎn)換回fp16。
accumulator=tl.zeros((BLOCK_SIZE_M,BLOCK_SIZE_N),dtype=tl.float32)
forkinrange(0,tl.cdiv(K,BLOCK_SIZE_K)):
#LoadthenextblockofAandB,generateamaskbycheckingtheKdimension.
#Ifitisoutofbounds,setitto0.
a=tl.load(a_ptrs,mask=offs_k[None,:]=0,x,0.01*x)

我們現(xiàn)在可以創(chuàng)建一個(gè)方便的封裝函數(shù),它只需要兩個(gè)輸入張量,并且會(huì):(1)檢查任何形狀約束;(2)分配輸出;(3)啟動(dòng)上述kernel。

defmatmul(a,b,activation=""):
#Checkconstraints.
asserta.shape[1]==b.shape[0],"Incompatibledimensions"
asserta.is_contiguous(),"MatrixAmustbecontiguous"
assertb.is_contiguous(),"MatrixBmustbecontiguous"
M,K=a.shape
K,N=b.shape
#Allocatesoutput.
c=torch.empty((M,N),device=a.device,dtype=a.dtype)
#1Dlaunchkernelwhereeachblockgetsitsownprogram.
grid=lambdaMETA:(triton.cdiv(M,META['BLOCK_SIZE_M'])*triton.cdiv(N,META['BLOCK_SIZE_N']),)
matmul_kernel[grid](
a,b,c,#
M,N,K,#
a.stride(0),a.stride(1),#
b.stride(0),b.stride(1),#
c.stride(0),c.stride(1),#
ACTIVATION=activation#
)
returnc

計(jì)算過(guò)程的補(bǔ)充說(shuō)明

上面的《L2 Cache優(yōu)化原理補(bǔ)充講解》這一節(jié)明確了kernel的group ordering的訪問(wèn)方式以及實(shí)現(xiàn),現(xiàn)在來(lái)看對(duì)于當(dāng)前的program實(shí)例具體是怎么計(jì)算的?,F(xiàn)在以計(jì)算C中的第一個(gè)Block的(0, 0)為例子,它需要從A和B分別加載9個(gè)黃色的小塊數(shù)據(jù)相乘并累加最后得到C中的(0, 0)位置結(jié)果。如下圖所示:

154fc970-b926-11ee-8b88-92fbcf53809c.png

下面的代碼先把program實(shí)例當(dāng)前要處理A和B的第一個(gè)Block加載上來(lái):

#----------------------------------------------------------
#為A和B的第一個(gè)塊創(chuàng)建指針。
#我們將在K方向移動(dòng)時(shí)推進(jìn)這個(gè)指針并累加
#`a_ptrs`是[BLOCK_SIZE_M,BLOCK_SIZE_K]塊的指針
#`b_ptrs`是[BLOCK_SIZE_K,BLOCK_SIZE_N]塊的指針
#有關(guān)詳細(xì)信息,請(qǐng)參閱上方“指針?biāo)阈g(shù)”部分
offs_am=(pid_m*BLOCK_SIZE_M+tl.arange(0,BLOCK_SIZE_M))%M
offs_bn=(pid_n*BLOCK_SIZE_N+tl.arange(0,BLOCK_SIZE_N))%N
offs_k=tl.arange(0,BLOCK_SIZE_K)
a_ptrs=a_ptr+(offs_am[:,None]*stride_am+offs_k[None,:]*stride_ak)
b_ptrs=b_ptr+(offs_k[:,None]*stride_bk+offs_bn[None,:]*stride_bn)

這里的a_ptr 是整個(gè) A 矩陣第一個(gè)元素的地址,offs_am和offs_bn表示當(dāng)前的program id在M維度和K維度的坐標(biāo),這個(gè)坐標(biāo)是一個(gè)list,用tl.arange(0, BLOCK_SIZE_K)來(lái)獲取。

得到 M 維度 和 K 維度的坐標(biāo)后, 就可以讓它們各自和 M 維度 和 K 維度的 stride 相乘, 然后和 a_ptr 相加, 就可以得到 A 矩陣 9 個(gè) block 中第一個(gè) block 中每個(gè)元素的地址了。 b_ptr也是同理。

最后一部分就是累加了,這里會(huì)在K維度上進(jìn)行累加,每次計(jì)算輸出的一個(gè)塊。

#迭代以計(jì)算C矩陣的一個(gè)塊。
#我們將累加到一個(gè)`[BLOCK_SIZE_M,BLOCK_SIZE_N]`塊
#的fp32值以獲得更高的精度。
#`accumulator`在循環(huán)后會(huì)轉(zhuǎn)換回fp16。
accumulator=tl.zeros((BLOCK_SIZE_M,BLOCK_SIZE_N),dtype=tl.float32)
forkinrange(0,tl.cdiv(K,BLOCK_SIZE_K)):
#LoadthenextblockofAandB,generateamaskbycheckingtheKdimension.
#Ifitisoutofbounds,setitto0.
a=tl.load(a_ptrs,mask=offs_k[None,:]

這行代碼a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)考慮到 K 可能不能被 BLOCK_SIZE_K 整除, 到每一行最后一個(gè) block 的時(shí)候, 實(shí)際大小是不足 BLOCK_SIZE_K 的,所以需要把超出的那部分元素mask掉。

最后這部分代碼是把當(dāng)前的算子和LeakyReLU激活函數(shù)進(jìn)行融合:

#當(dāng)累加器仍然是FP32時(shí),可以融合任意激活函數(shù)
ifACTIVATION=="leaky_relu":
accumulator=leaky_relu(accumulator)
c=accumulator.to(tl.float16)

單元測(cè)試

155dcdb8-b926-11ee-8b88-92fbcf53809c.png

Benchmark

這里使用一個(gè)方陣來(lái)對(duì)比Triton實(shí)現(xiàn)的matmul kernel和cublas的matmul kernel的性能。

@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=['M','N','K'],#用作圖表x軸的參數(shù)名
x_vals=[128*iforiinrange(2,33)],#`x_name`的不同可能值
line_arg='provider',#其值對(duì)應(yīng)于圖表中不同線條的參數(shù)名
#`line_arg`的可能值
line_vals=['cublas','triton'],
#線條的標(biāo)簽名稱
line_names=["cuBLAS","Triton"],
#線條樣式
styles=[('green','-'),('blue','-')],
ylabel="TFLOPS",#y軸的標(biāo)簽名稱
plot_name="matmul-performance",#圖表的名稱,也用作保存圖表的文件名。
args={},#其他參數(shù)
))
defbenchmark(M,N,K,provider):
#初始化張量
a=torch.randn((M,K),device='cuda',dtype=torch.float16)
b=torch.randn((K,N),device='cuda',dtype=torch.float16)
quantiles=[0.5,0.2,0.8]#分位數(shù)
#如果提供者是cublas
ifprovider=='cublas':
ms,min_ms,max_ms=triton.testing.do_bench(lambda:torch.matmul(a,b),quantiles=quantiles)
#如果提供者是triton
ifprovider=='triton':
ms,min_ms,max_ms=triton.testing.do_bench(lambda:matmul(a,b),quantiles=quantiles)
#性能計(jì)算函數(shù)
perf=lambdams:2*M*N*K*1e-12/(ms*1e-3)
returnperf(ms),perf(max_ms),perf(min_ms)

#運(yùn)行基準(zhǔn)測(cè)試,展示圖表和打印數(shù)據(jù)
benchmark.run(show_plots=True,print_data=True)

15816296-b926-11ee-8b88-92fbcf53809c.png

可以看到基于Triton實(shí)現(xiàn)的矩陣乘kernel性能大體可以和高度優(yōu)化的cuBlas持平。





審核編輯:劉清

聲明:本文內(nèi)容及配圖由入駐作者撰寫(xiě)或者入駐合作網(wǎng)站授權(quán)轉(zhuǎn)載。文章觀點(diǎn)僅代表作者本人,不代表電子發(fā)燒友網(wǎng)立場(chǎng)。文章及其配圖僅供工程師學(xué)習(xí)之用,如有內(nèi)容侵權(quán)或者其他違規(guī)問(wèn)題,請(qǐng)聯(lián)系本站處理。 舉報(bào)投訴
  • sram
    +關(guān)注

    關(guān)注

    6

    文章

    760

    瀏覽量

    114555
  • 多處理器
    +關(guān)注

    關(guān)注

    0

    文章

    22

    瀏覽量

    8901
  • Cache
    +關(guān)注

    關(guān)注

    0

    文章

    129

    瀏覽量

    28231
  • python
    +關(guān)注

    關(guān)注

    54

    文章

    4758

    瀏覽量

    84293
  • OpenAI
    +關(guān)注

    關(guān)注

    9

    文章

    1014

    瀏覽量

    6347

原文標(biāo)題:【BBuf的CUDA筆記】十三,OpenAI Triton 入門(mén)筆記一

文章出處:【微信號(hào):GiantPandaCV,微信公眾號(hào):GiantPandaCV】歡迎添加關(guān)注!文章轉(zhuǎn)載請(qǐng)注明出處。

收藏 人收藏

    評(píng)論

    相關(guān)推薦

    嵌入式linux入門(mén)筆記

    嵌入式linux入門(mén)筆記
    發(fā)表于 08-13 16:06

    嵌入式linux入門(mén)筆記

    嵌入式linux入門(mén)筆記
    發(fā)表于 08-20 20:53

    嵌入式入門(mén)筆記

    嵌入式入門(mén)筆記。。。
    發(fā)表于 10-31 23:13

    嵌入式入門(mén)筆記

    嵌入式入門(mén)筆記 ,初學(xué)者可以學(xué)習(xí)下,很值得借鑒
    發(fā)表于 05-21 12:54

    圖書(shū)推薦:IAR-for-AVR入門(mén)學(xué)習(xí)筆記

    IAR-for-AVR入門(mén)學(xué)習(xí)筆記
    發(fā)表于 06-12 13:46

    求CSS入門(mén)的學(xué)習(xí)筆記

    CSS入門(mén) 學(xué)習(xí)筆記4
    發(fā)表于 06-04 15:15

    什么是CUDA

    的時(shí)間盡可能清晰的了解這個(gè)深度學(xué)習(xí)賴以實(shí)現(xiàn)的基礎(chǔ)概念。本文在以下資料的基礎(chǔ)上整理完成,感謝以下前輩提供的資料:CUDA——“從入門(mén)到放棄”我的CUDA學(xué)習(xí)之旅——啟程介紹篇不錯(cuò)的
    發(fā)表于 07-26 06:28

    筆記本如何與投影機(jī)鏈接入門(mén)應(yīng)用小技巧

    筆記本如何與投影機(jī)鏈接入門(mén)應(yīng)用小技巧 、投影機(jī)連接筆記本電腦,無(wú)輸出影像?   答:筆記本電腦外接
    發(fā)表于 01-18 09:50 ?554次閱讀

    嵌入式入門(mén)筆記

    本文提供了嵌入式入門(mén)筆記,希望對(duì)你的學(xué)習(xí)有所幫助!
    發(fā)表于 06-07 16:57 ?0次下載
    嵌入式<b class='flag-5'>入門(mén)</b><b class='flag-5'>筆記</b>

    英飛凌MCU新手入門(mén)應(yīng)用筆記(中文版)

    英飛凌MCU新手入門(mén)應(yīng)用筆記(中文版)
    發(fā)表于 06-25 12:04 ?0次下載
    英飛凌MCU新手<b class='flag-5'>入門(mén)</b>應(yīng)用<b class='flag-5'>筆記</b>(中文版)

    ARM入門(mén)調(diào)試筆記

    ARM入門(mén)調(diào)試筆記
    發(fā)表于 10-13 14:26 ?11次下載
    ARM<b class='flag-5'>入門(mén)</b>調(diào)試<b class='flag-5'>筆記</b>

    CUDA學(xué)習(xí)筆記篇:個(gè)基本的CUDA C程序

    1、CUDA的簡(jiǎn)介 2、GPU架構(gòu)和CUDA介紹3、CUDA架構(gòu)4、開(kāi)發(fā)環(huán)境說(shuō)明和配置5、開(kāi)始第個(gè)Hello CUDA程序????5.1、
    的頭像 發(fā)表于 12-14 23:40 ?840次閱讀

    Xilinx_Vivado_zynq7000入門(mén)筆記

    Xilinx_Vivado_zynq7000入門(mén)筆記說(shuō)明。
    發(fā)表于 04-08 11:48 ?71次下載

    RT-Thread Nano入門(mén)學(xué)習(xí)筆記

    RT-Thread Nano入門(mén)學(xué)習(xí)筆記
    發(fā)表于 11-26 12:36 ?20次下載
    RT-Thread Nano<b class='flag-5'>入門(mén)</b>學(xué)習(xí)<b class='flag-5'>筆記</b>

    入門(mén)級(jí)微波電路(MMIC)的筆記-S 參數(shù)

    入門(mén)級(jí)微波電路(MMIC)的上課筆記-S 參數(shù)
    的頭像 發(fā)表于 07-05 10:13 ?646次閱讀
    <b class='flag-5'>入門(mén)</b>級(jí)微波電路(MMIC)的<b class='flag-5'>筆記</b>-S 參數(shù)