OpenAI提出新的神經(jīng)網(wǎng)絡模型“稀疏Transformer”,能夠預測文本、圖像和聲音等序列的后續(xù)內(nèi)容,該模型是對注意力機制的一個改進,預測長度達到之前最佳水平的30倍。
目前人工智能研究的一大挑戰(zhàn)是對復雜數(shù)據(jù)(如圖像,視頻或聲音)中的大范圍微妙的相互依賴性進行建模。稀疏Transformer降低了傳統(tǒng)注意力機制模型的計算復雜度,將其直接應用于不同的數(shù)據(jù)類型中。以前,在這些數(shù)據(jù)上使用的模型是針對某個專門領(lǐng)域設計的,難以擴展到超過幾千個元素的序列規(guī)模上應用。
此次OpenAI提出的模型可以使用數(shù)百個層對數(shù)萬個元素的序列進行建模,在多個域中實現(xiàn)最先進的性能。稀疏Transformer能夠幫助我們構(gòu)建具有更強的理解世界能力的AI系統(tǒng)。
深度注意力機制
在稀疏Transformer中,每個輸出元素都與每個輸入元素相連,它們之間的權(quán)重是根據(jù)環(huán)境動態(tài)計算的,這個過程稱為注意力。雖然這樣會讓模型比固定連接模式的模型更加靈活,但在實踐中需要為每個層和注意力頭N×N注意力矩陣,面對元素數(shù)量眾多的數(shù)據(jù)類型時會消耗大量的內(nèi)存,比如圖像或原始音頻數(shù)據(jù)。
當矩陣存儲在內(nèi)存中或在后向傳遞期間重新計算時,深度Transformer的內(nèi)存消耗情況(64層、4個注意力頭)。作為參考,用于深度學習的標準GPU通常配備12-32GB的內(nèi)存
減少內(nèi)存消耗一種方法是在反向傳播期間從檢查點重新計算注意力矩陣,這是深度學習中的一種成熟技術(shù),以增加計算量為代價來減少內(nèi)存使用。在計算Transformer的注意力矩陣時,意味著最大的內(nèi)存成本與層數(shù)無關(guān),這使我們能夠以比以前更大的深度訓練神經(jīng)網(wǎng)絡。
實際上,我們發(fā)現(xiàn)深度達128層的Transformer在常用數(shù)據(jù)集基準任務(如CIFAR-10)上的表現(xiàn)優(yōu)于較淺層的網(wǎng)絡。
為了更深入地訓練這些模型,我們對Transformer中的操作順序進行了幾次調(diào)整,并修改了初始方案。
稀疏注意力機制:顯著降低計算復雜度
然而,即使是計算單個注意力矩陣,對于非常大的輸入也是不切實際。因此我們使用稀疏注意力模式,即每個輸出位置僅計算來自輸入位置子集的權(quán)重。當子集相對于整個輸入集較小時,即使對于非常長的序列,所得到的注意力計算也是容易處理的,算法復雜度為O(N *sqrt {N}),而不是O(N^2)。
為了評估該方法的可行性,我們首先將深度Transformer在圖像上的學習注意模式進行可視化,發(fā)現(xiàn)許多模型表現(xiàn)出可解釋和結(jié)構(gòu)化的稀疏模式。下面的每個圖像顯示給定的注意頭處理哪些輸入像素(以白色突出顯示)以便預測圖像中的下一個值。
當輸入部分聚焦在小的子集上并顯示出高度的規(guī)則性時,該層就是易于稀疏化的。下圖為CIFAR-10圖像上的128層模型示例。
左圖為19層,右圖為20層
學習后的128層CIFAR-10網(wǎng)絡的多個層的注意力模式(白色高亮部分)。這些層學會將注意力分散在兩個維度上。其中第19層總結(jié)了每一行的信息,第20層則按列聚合這些信息,從而能夠?qū)θ孀⒁饬Σ僮鬟M行有效分解。
左圖為第6層,右圖為第36層
一些層學會了訪問位置存儲器,無論輸入數(shù)據(jù)或時間步長如何,通常都會訪問類似的位置(第6層)。還有的層學習了高度依賴數(shù)據(jù)的訪問模式(第36層)。
雖然許多圖層顯示出了稀疏結(jié)構(gòu),某些層還清晰地顯示出在整個圖像上延伸的動態(tài)注意力。為了讓網(wǎng)絡保持學習這些模式的能力,我們進行了注意力矩陣的二維分解,網(wǎng)絡可以通過兩個稀疏注意力步驟來關(guān)注所有位置。
(左)普通transformer,(中)范圍注意力,(右)固定注意力
第一個版本,大范圍注意力,大致相當于參與其行和列的每個位置,并且類似于上面的網(wǎng)絡學習的注意力模式。(注意,列注意力可以等效地表示成轉(zhuǎn)置矩陣的行注意力)。第二個版本是固定注意力,注意固定列和最新列元素之后的元素,我們發(fā)現(xiàn)這種模式在數(shù)據(jù)不適合二維結(jié)構(gòu)(如文本)時很有用。
實驗結(jié)果:創(chuàng)造多個數(shù)據(jù)集上的新紀錄
稀疏Transformer在CIFAR-10,Enwik8和Imagenet 64上創(chuàng)造了密度估計的最新記錄。如下表所示:
CIFAR-10 | BITS PER DIM |
PixelCNN++ (Oord et al, 2016) | 2.92 |
Image Transformer (Parmar et. al, 2018) | 2.90 |
PixelSNAIL (Chen et al., 2017) | 2.85 |
Sparse Transformer 59M (256W, 128L, 2H) | 2.80 |
ENWIK8 | BITS PER BYTE |
Deeper Self-Attention (Al-Rfou et al, 2018) | 1.06 |
Transformer-XL 88M (Dai et al., 2018) | 1.03 |
Transformer-XL 277M (Dai et al., 2018) | 0.99 |
Sparse Transformer 95M (512W, 30L, 8H) | 0.99 |
IMAGENET 64X64 | BITS PER DIM |
PixelCNN++ (Oord et al, 2016) | 3.57 |
Parallel Multiscale (Reed et al, 2017) | 3.7 |
SPN 150M (Menick & Kalchbrenner, 2018) | 3.52 |
Sparse Transformer 152M (512W, 48L, 16H) | 3.44 |
在一系列數(shù)據(jù)集上的密度建模表現(xiàn),M為網(wǎng)絡中使用的參數(shù)數(shù)量(百萬),W為網(wǎng)絡寬度,L為層數(shù),H為注意力頭數(shù)量。
我們還發(fā)現(xiàn),除了速度明顯更快之外,稀疏注意力模型的損失也要低于完全注意力模型。這可能表明我們的稀疏模式存在有用的歸納偏差,或是密集關(guān)注的潛在優(yōu)化問題。
使用稀疏注意力的Transformer似乎有一個全局結(jié)構(gòu)的概念,可以通過查看圖像完成來定性評估。我們對64×64 ImageNet上訓練的模型進行了可視化,如下圖所示:
Prompt
Completions
Ground truth
我們還利用未調(diào)整的softmax temperature 1.0下生成了完全無條件的樣圖。這些模型使用最大似然目標進行訓練,眾所周知,這類訓練的目標是覆蓋所有數(shù)據(jù)模式(包括可能不存在的數(shù)據(jù)),而不是增加小部分數(shù)據(jù)的保真度。從這些具有未調(diào)整溫度的模型中生成樣圖,可以讓我們看到模型認為存在于真實世界中圖像的完整分布。結(jié)果,一些樣本看起來很奇怪。
模型采樣
真實數(shù)據(jù)
生成原始音頻波形
稀疏Transformer也可以通過簡單地改變位置嵌入,自適應地生成原始音頻。隨著深度學習擴展到新型數(shù)據(jù)類型,可以使用這類網(wǎng)絡作為確定歸納偏差的有用工具。
該模型在原始古典音樂剪輯上進行訓練,并使用稀疏注意力生成長度為65000的序列,相當于大約5秒的原始音頻,我們在每個片段中將幾個樣本連接在了一起。
關(guān)于代碼發(fā)布和開源
通常,實現(xiàn)稀疏注意力將涉及在數(shù)據(jù)塊中將查詢和關(guān)鍵矩陣單獨“切片”,因此為了簡化實驗,我們實現(xiàn)了一組塊稀疏內(nèi)核,這些內(nèi)核可以在GPU上高效執(zhí)行這些操作。我們開源了這些內(nèi)核,并在Github上提供示例稀疏注意函數(shù)。
未來方向和局限
我們提出的稀疏注意力模式只是長序列高效建模方向的初步模式。我們認為,探索稀疏性的不同模式和組合的用途不僅于此,學習稀疏模式對于下一代神經(jīng)網(wǎng)絡體系結(jié)構(gòu)來說是一個很有前途的方向。
即使經(jīng)過改進,自回歸序列生成對于非常高分辨率的圖像或視頻來說仍然是不切實際的。不過,我們提出的優(yōu)化注意力操作可能是一次有益的探索,可以和其他(如多尺度方法)方法相結(jié)合來對高維數(shù)據(jù)進行建模。
-
神經(jīng)網(wǎng)絡
+關(guān)注
關(guān)注
42文章
4726瀏覽量
100318 -
圖像
+關(guān)注
關(guān)注
2文章
1078瀏覽量
40346 -
模型
+關(guān)注
關(guān)注
1文章
3065瀏覽量
48579
原文標題:OpenAI提出Sparse Transformer,文本、圖像、聲音都能預測,序列長度提高30倍
文章出處:【微信號:AI_era,微信公眾號:新智元】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。
發(fā)布評論請先 登錄
相關(guān)推薦
評論