與我們大多數(shù)從頭開始的實施一樣, 第 9.5 節(jié)旨在深入了解每個組件的工作原理。但是,當(dāng)您每天使用 RNN 或編寫生產(chǎn)代碼時,您會希望更多地依賴于減少實現(xiàn)時間(通過為通用模型和函數(shù)提供庫代碼)和計算時間(通過優(yōu)化這些庫實現(xiàn))。本節(jié)將向您展示如何使用深度學(xué)習(xí)框架提供的高級 API 更有效地實現(xiàn)相同的語言模型。和以前一樣,我們首先加載時間機(jī)器數(shù)據(jù)集。
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l
import tensorflow as tf
from d2l import tensorflow as d2l
9.6.1. 定義模型
我們使用由高級 API 實現(xiàn)的 RNN 定義以下類。
Specifically, to initialize the hidden state, we invoke the member method begin_state
. This returns a list that contains an initial hidden state for each example in the minibatch, whose shape is (number of hidden layers, batch size, number of hidden units). For some models to be introduced later (e.g., long short-term memory), this list will also contain other information.
class RNN(d2l.Module): #@save
"""The RNN model implemented with high-level APIs."""
def __init__(self, num_hiddens):
super().__init__()
self.save_hyperparameters()
self.rnn = rnn.RNN(num_hiddens)
def forward(self, inputs, H=None):
if H is None:
H, = self.rnn.begin_state(inputs.shape[1], ctx=inputs.ctx)
outputs, (H, ) = self.rnn(inputs, (H, ))
return outputs, H
Flax does not provide an RNNCell for concise implementation of Vanilla RNNs as of today. There are more advanced variants of RNNs like LSTMs and GRUs which are available in the Flax linen
API.
class RNN(d2l.Module): #@save
"""The RNN model implemented with high-level APIs."""
def __init__(self, num_hiddens):
super().__init__()
self.save_hyperparameters()
self.rnn = tf.keras.layers.SimpleRNN(
num_hiddens, return_sequences=True, return_state=True,
time_major=True)
def forward(self, inputs, H=None):
outputs, H = self.rnn(inputs, H)
return outputs, H
繼承自9.5 節(jié)RNNLMScratch
中的類 ,下面的類定義了一個完整的基于 RNN 的語言模型。請注意,我們需要創(chuàng)建一個單獨的全連接輸出層。RNNLM
class RNNLM(d2l.RNNLMScratch): #@save
"""The RNN-based language model implemented with high-level APIs."""
def init_params(self):
self.linear = nn.LazyLinear(self.vocab_size)
def output_layer(self, hiddens):
return self.linear(hiddens).swapaxes(0, 1)
class RNNLM(d2l.RNNLMScratch): #@save
"""The RNN-based language model implemented with high-level APIs."""
training: bool = True
def setup(self):
self.linear = nn.Dense(self.vocab_size)
def output_layer(self, hiddens):
return self.linear(hiddens).swapaxes(0, 1)
def forward(self, X, state=None):
embs = self.one_hot(X)
rnn_outputs, _ = self.rnn(embs, state, self.training)
return self.output_layer(rnn_outputs)
9.6.2. 訓(xùn)練和預(yù)測
在訓(xùn)練模型之前,讓我們使用隨機(jī)權(quán)重初始化的模型進(jìn)行預(yù)測。鑒于我們還沒有訓(xùn)練網(wǎng)絡(luò),它會產(chǎn)生無意義的預(yù)測。
'it hasgggggggggggggggggggg'
'it hasxlxlxlxlxlxlxlxlxlxl'
接下來,我們利用高級 API 訓(xùn)練我們的模型。
與第 9.5 節(jié)相比,該模型實現(xiàn)了相當(dāng)?shù)睦Щ蠖?,但由于實現(xiàn)優(yōu)化,運行速度更快。和以前一樣,我們可以在指定的前綴字符串之后生成預(yù)測標(biāo)記。
評論
查看更多