譯者 | Sambodhi??
生成對(duì)抗網(wǎng)絡(luò)(Generative Adversarial Network,GAN)由 Goodfellow 等人在 2014 年提出,它徹底改變了計(jì)算機(jī)視覺(jué)中的圖像生成領(lǐng)域:沒(méi)有人能夠相信這些令人驚嘆而生動(dòng)的圖像實(shí)際上是純粹由機(jī)器生成的。
事實(shí)上,人們?cè)?jīng)認(rèn)為生成的任務(wù)是不可能的,并且被 GAN 的力量所震驚,因?yàn)閭鹘y(tǒng)上,根本沒(méi)有任何事實(shí)可以比較我們生成的圖像。
本文介紹了創(chuàng)建 GAN 背后的簡(jiǎn)單直覺(jué),然后介紹了通過(guò) PyTorch 實(shí)現(xiàn)的卷積 GAN 及其訓(xùn)練過(guò)程。
GAN 背后的直覺(jué)
不同于傳統(tǒng)分類方法,我們的網(wǎng)絡(luò)預(yù)測(cè)可以直接與事實(shí)的正確答案相比較,而生成圖像的“正確性”是很難定義和衡量的。Goodfellow 等人在他們的原創(chuàng)論文《生成對(duì)抗網(wǎng)絡(luò)》(Generative Adversarial Network)中提出了一個(gè)有趣的想法:使用經(jīng)過(guò)訓(xùn)練的分類器來(lái)區(qū)分生成的圖像和實(shí)際圖像。如果存在這樣的分類器,我們可以創(chuàng)建并訓(xùn)練一個(gè)生成器網(wǎng)絡(luò),直到它輸出的圖像能完全騙過(guò)分類器。
圖 1 GAN 管道
GAN 是這一過(guò)程的產(chǎn)物:它包含一個(gè)根據(jù)給定的數(shù)據(jù)集生成圖像的生成器,以及一個(gè)區(qū)分圖像是真實(shí)的還是生成的判別器(分類器)。GAN 的詳細(xì)管道見(jiàn)圖 1。
損失函數(shù)
對(duì)生成器和判別器進(jìn)行優(yōu)化都很困難,因?yàn)檎缒闼胂蟮哪菢?,這兩個(gè)網(wǎng)絡(luò)的目標(biāo)完全相反:生成器希望盡可能地創(chuàng)造出真實(shí)的東西,但判別器希望區(qū)分生成的材料。
為了說(shuō)明這一點(diǎn),我們讓 D(x) 是判別器的輸出,也就是 x 是真實(shí)圖像的概率,而 G(z) 是我們的生成器的輸出。判別器類似于一個(gè)二元分類器,因此判別器的目標(biāo)是使函數(shù)最大化:
本質(zhì)上是二元交叉熵?fù)p失,沒(méi)有開(kāi)頭的負(fù)號(hào)。另一方面,生成器的目標(biāo)是使判別器做出正確判斷的機(jī)會(huì)最小化,因此它的目標(biāo)是最小化函數(shù)。所以,最終的損失函數(shù)將是兩個(gè)分類器之間的一個(gè)極小極大博弈(minimax game),具體如下:
從理論上講,這將收斂到判別器,預(yù)測(cè)所有事件的概率為 0.5。
但在實(shí)踐中,極小極大博弈往往會(huì)導(dǎo)致網(wǎng)絡(luò)無(wú)法收斂,因此仔細(xì)調(diào)整訓(xùn)練過(guò)程非常重要。像學(xué)習(xí)率這樣的超參數(shù)對(duì)于訓(xùn)練 GAN 時(shí)顯然更為重要:一個(gè)微小的變化會(huì)導(dǎo)致 GAN 產(chǎn)生一個(gè)輸出,而與輸入噪聲無(wú)關(guān)。
運(yùn)算環(huán)境
庫(kù)
我們通過(guò) PyTorch 庫(kù)(包括 torchvision)來(lái)構(gòu)建整個(gè)程序。GAN 的生成結(jié)果的可視化是通過(guò) Matplotlib 庫(kù)繪制的。下面的代碼導(dǎo)入了所有的庫(kù):
importGAN.py
""" Import necessary libraries to create a generative adversarial network The code is mainly developed using the PyTorch library """ import time import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision import datasets from torchvision.transforms import transforms from model import discriminator, generator import numpy as np import matplotlib.pyplot as plt
數(shù)據(jù)集
在 GAN 訓(xùn)練中,數(shù)據(jù)集是一個(gè)重要方面。圖像的非結(jié)構(gòu)化性質(zhì)意味著任何給定的類別(如狗、貓或手寫(xiě)的數(shù)字)都可以有一個(gè)可能的數(shù)據(jù)分布,而這種分布最終是 GAN 生成內(nèi)容的基礎(chǔ)。
為了演示,本文將使用最簡(jiǎn)單的 MNIST 數(shù)據(jù)集,其中包含 60000 張從 0 到 9 的手寫(xiě)數(shù)字圖像。事實(shí)上,像 MNIST 這樣的非結(jié)構(gòu)化數(shù)據(jù)集可以在 Graviti 上找到。這是一家年輕的創(chuàng)業(yè)公司,他們希望通過(guò)非結(jié)構(gòu)化數(shù)據(jù)集為社區(qū)提供幫助,在他們的 平臺(tái) 上有一些最好的公共非結(jié)構(gòu)化數(shù)據(jù)集,包括 MNIST。
硬件要求
最好的方法是用 GPU 訓(xùn)練神經(jīng)網(wǎng)絡(luò),它可以顯著地提高訓(xùn)練速度。但是,如果只有 CPU 可用,你仍然可以測(cè)試程序。要使你的程序能夠自行確定硬件,你可以使用以下方法:
torchDevice.py
""" Determine if any GPUs are available """ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
實(shí)施
網(wǎng)絡(luò)架構(gòu)
由于數(shù)字的簡(jiǎn)單性,這兩種架構(gòu)——判別器和生成器,都是由全連接層構(gòu)建的。請(qǐng)注意,在某些情況下,全連接的 GAN 也比 DCGAN 略微容易收斂。
以下是兩種架構(gòu)的 PyTorch 實(shí)現(xiàn):
GANArchitecture.py
""" Network Architectures The following are the discriminator and generator architectures """ class discriminator(nn.Module): def __init__(self): super(discriminator, self).__init__() self.fc1 = nn.Linear(784, 512) self.fc2 = nn.Linear(512, 1) self.activation = nn.LeakyReLU(0.1) def forward(self, x): x = x.view(-1, 784) x = self.activation(self.fc1(x)) x = self.fc2(x) return nn.Sigmoid()(x) class generator(nn.Module): def __init__(self): super(generator, self).__init__() self.fc1 = nn.Linear(128, 1024) self.fc2 = nn.Linear(1024, 2048) self.fc3 = nn.Linear(2048, 784) self.activation = nn.ReLU() def forward(self, x): x = self.activation(self.fc1(x)) x = self.activation(self.fc2(x)) x = self.fc3(x) x = x.view(-1, 1, 28, 28) return nn.Tanh()(x)訓(xùn)練
在訓(xùn)練 GAN 時(shí),我們優(yōu)化了判別器的結(jié)果,同時(shí)也改進(jìn)了我們的生成器。這樣,在每次迭代過(guò)程中會(huì)有兩個(gè)相互矛盾的損失來(lái)同時(shí)優(yōu)化它們。我們送入生成器的是隨機(jī)噪聲,而生成器理應(yīng)根據(jù)給定噪聲的微小差異來(lái)生成圖像:
trainGAN.py
""" Network training procedure Every step both the loss for disciminator and generator is updated Discriminator aims to classify reals and fakes Generator aims to generate images as realistic as possible """ for epoch in range(epochs): for idx, (imgs, _) in enumerate(train_loader): idx += 1 # Training the discriminator # Real inputs are actual images of the MNIST dataset # Fake inputs are from the generator # Real inputs should be classified as 1 and fake as 0 real_inputs = imgs.to(device) real_outputs = D(real_inputs) real_label = torch.ones(real_inputs.shape[0], 1).to(device) noise = (torch.rand(real_inputs.shape[0], 128) - 0.5) / 0.5 noise = noise.to(device) fake_inputs = G(noise) fake_outputs = D(fake_inputs) fake_label = torch.zeros(fake_inputs.shape[0], 1).to(device) outputs = torch.cat((real_outputs, fake_outputs), 0) targets = torch.cat((real_label, fake_label), 0) D_loss = loss(outputs, targets) D_optimizer.zero_grad() D_loss.backward() D_optimizer.step() # Training the generator # For generator, goal is to make the discriminator believe everything is 1 noise = (torch.rand(real_inputs.shape[0], 128)-0.5)/0.5 noise = noise.to(device) fake_inputs = G(noise) fake_outputs = D(fake_inputs) fake_targets = torch.ones([fake_inputs.shape[0], 1]).to(device) G_loss = loss(fake_outputs, fake_targets) G_optimizer.zero_grad() G_loss.backward() G_optimizer.step() if idx % 100 == 0 or idx == len(train_loader): print('Epoch {} Iteration {}: discriminator_loss {:.3f} generator_loss {:.3f}'.format(epoch, idx, D_loss.item(), G_loss.item())) if (epoch+1) % 10 == 0: torch.save(G, 'Generator_epoch_{}.pth'.format(epoch)) print('Model saved.')?結(jié)? 果
當(dāng) 100 個(gè)輪數(shù)(epoch)之后,我們可以繪制數(shù)據(jù)集,并看到從隨機(jī)噪音中生成的數(shù)字的結(jié)果:
圖 2:GAN 生成的結(jié)
如上圖所示,生成的結(jié)果看起來(lái)確實(shí)相當(dāng)像真實(shí)的結(jié)果。鑒于網(wǎng)絡(luò)非常簡(jiǎn)單,所以結(jié)果看起來(lái)確實(shí)很有希望!
超越單純的內(nèi)容創(chuàng)作
GAN 的創(chuàng)造與計(jì)算機(jī)視覺(jué)領(lǐng)域的先前工作如此不同。隨后的眾多應(yīng)用使學(xué)術(shù)界對(duì)深度網(wǎng)絡(luò)的能力感到驚訝。下面將介紹一些令人驚訝的工作。
CycleGAN
Zhu 等人的 CycleGAN 引入了一種概念,它無(wú)需配對(duì)樣本就可以將圖像從 X 域翻譯成 Y 域。馬被轉(zhuǎn)化為斑馬,夏日的陽(yáng)光被轉(zhuǎn)化為暴風(fēng)雪,CycleGAN 的結(jié)果令人驚訝且準(zhǔn)確。
GauGAN
Nvidia 利用 GAN 的力量,把簡(jiǎn)單的繪畫(huà),根據(jù)畫(huà)筆的語(yǔ)義,轉(zhuǎn)換成優(yōu)雅而逼真的照片。盡管訓(xùn)練資源的計(jì)算成本很高,但它創(chuàng)造了一個(gè)全新的研究和應(yīng)用領(lǐng)域。
AdvGAN
GAN 還擴(kuò)展到清理對(duì)抗性圖像,并將其轉(zhuǎn)化為不會(huì)欺騙分類器的干凈樣本。關(guān)于對(duì)抗性攻擊和防御的更多信息可以在 這里 到。
結(jié)? 語(yǔ)
所以,你已經(jīng)擁有了它!希望這篇文章對(duì)如何構(gòu)建 GAN 提供了一個(gè)概覽。
作者簡(jiǎn)介:
Ta-ying Cheng,中國(guó)香港人,牛津大學(xué)哲學(xué)博士新生,愛(ài)好 3D 視覺(jué)、深度學(xué)習(xí)。
編輯:黃飛
?
評(píng)論
查看更多