0
  • 聊天消息
  • 系統(tǒng)消息
  • 評論與回復(fù)
登錄后你可以
  • 下載海量資料
  • 學(xué)習(xí)在線課程
  • 觀看技術(shù)視頻
  • 寫文章/發(fā)帖/加入社區(qū)
會員中心
創(chuàng)作中心

完善資料讓更多小伙伴認(rèn)識你,還能領(lǐng)取20積分哦,立即完善>

3天內(nèi)不再提示

利用PyTorch實現(xiàn)NeRF代碼詳解

3D視覺工坊 ? 來源:3DCV ? 2023-10-21 09:46 ? 次閱讀

作者:大森林| 來源:3DCV

1. NeRF定義

神經(jīng)輻射場(NeRF)是一種利用神經(jīng)網(wǎng)絡(luò)來表示和渲染復(fù)雜的三維場景的方法。它可以從一組二維圖片中學(xué)習(xí)出一個連續(xù)的三維函數(shù),這個函數(shù)可以給出空間中任意位置和方向上的顏色和密度。通過體積渲染的技術(shù),NeRF可以從任意視角合成出逼真的圖像,包括透明和半透明物體,以及復(fù)雜的光線傳播效果。

2. NeRF優(yōu)勢

NeRF模型相比于其他新的視圖合成和場景表示方法有以下幾個優(yōu)勢:

1)NeRF不需要離散化的三維表示,如網(wǎng)格或體素,因此可以避免模型精度和細(xì)節(jié)程度受到限制。NeRF也可以自適應(yīng)地處理不同形狀和大小的場景,而不需要人工調(diào)整參數(shù)。

2)NeRF使用位置編碼的方式將位置和角度信息映射到高頻域,使得網(wǎng)絡(luò)能夠更好地捕捉場景的細(xì)微結(jié)構(gòu)和變化。NeRF還使用視角相關(guān)的顏色預(yù)測,能夠生成不同視角下不同的光照效果。

3)NeRF使用分段隨機采樣的方式來近似體積渲染的積分,這樣可以保證采樣位置的連續(xù)性,同時避免網(wǎng)絡(luò)過擬合于離散點的信息。NeRF還使用多層級體素采樣的技巧,以提高渲染效率和質(zhì)量。

3. NeRF實現(xiàn)步驟

1)定義一個全連接的神經(jīng)網(wǎng)絡(luò),它的輸入是空間位置和視角方向,輸出是顏色和密度。

2)使用位置編碼的方式將輸入映射到高頻域,以便網(wǎng)絡(luò)能夠捕捉細(xì)微的結(jié)構(gòu)和變化。

3)使用分段隨機采樣的方式從每條光線上采樣一些點,然后用神經(jīng)網(wǎng)絡(luò)預(yù)測這些點的顏色和密度。

4)使用體積渲染的公式計算每條光線上的顏色和透明度,作為最終的圖像輸出。

5)使用渲染損失函數(shù)來優(yōu)化神經(jīng)網(wǎng)絡(luò)的參數(shù),使得渲染的圖像與輸入的圖像盡可能接近。

importtorch
importtorch.nnasnn
importtorch.nn.functionalasF

#定義一個全連接的神經(jīng)網(wǎng)絡(luò),它的輸入是空間位置和視角方向,輸出是顏色和密度。
classNeRF(nn.Module):
def__init__(self,D=8,W=256,input_ch=3,input_ch_views=3,output_ch=4,skips=[4]):
super().__init__()
#定義位置編碼后的位置信息的線性層,如果層數(shù)在skips列表中,則將原始位置信息與隱藏層拼接
self.pts_linears=nn.ModuleList(
[nn.Linear(input_ch,W)]+[nn.Linear(W,W)ifinotinskipselsenn.Linear(W+input_ch,W)foriinrange(D-1)])
#定義位置編碼后的視角方向信息的線性層
self.views_linears=nn.ModuleList([nn.Linear(W+input_ch_views,W//2)]+[nn.Linear(W//2,W//2)foriinrange(1)])
#定義特征向量的線性層
self.feature_linear=nn.Linear(W//2,W)
#定義透明度(alpha)值的線性層
self.alpha_linear=nn.Linear(W,1)
#定義RGB顏色的線性層
self.rgb_linear=nn.Linear(W+input_ch_views,3)

defforward(self,x):
#x:(B,input_ch+input_ch_views)
#提取位置和視角方向信息
p=x[:,:3]#(B,3)
d=x[:,3:]#(B,3)

#對輸入進(jìn)行位置編碼,將低頻信號映射到高頻域
p=positional_encoding(p)#(B,input_ch)
d=positional_encoding(d)#(B,input_ch_views)

#將位置信息輸入網(wǎng)絡(luò)
h=p
fori,linenumerate(self.pts_linears):
h=l(h)
h=F.relu(h)
ifiinskips:
h=torch.cat([h,p],-1)#如果層數(shù)在skips列表中,則將原始位置信息與隱藏層拼接

#將視角方向信息與隱藏層拼接,并輸入網(wǎng)絡(luò)
h=torch.cat([h,d],-1)
fori,linenumerate(self.views_linears):
h=l(h)
h=F.relu(h)

#預(yù)測特征向量和透明度(alpha)值
feature=self.feature_linear(h)#(B,W)
alpha=self.alpha_linear(feature)#(B,1)

#使用特征向量和視角方向信息預(yù)測RGB顏色
rgb=torch.cat([feature,d],-1)
rgb=self.rgb_linear(rgb)#(B,3)

returntorch.cat([rgb,alpha],-1)#(B,4)

#定義位置編碼函數(shù)
defpositional_encoding(x):
#x:(B,C)
B,C=x.shape
L=int(C//2)#計算位置編碼的長度
freqs=torch.logspace(0.,L-1,steps=L).to(x.device)*math.pi#計算頻率系數(shù),呈指數(shù)增長
freqs=freqs[None].repeat(B,1)#(B,L)
x_pos_enc_low=torch.sin(x[:,:L]*freqs)#對前一半的輸入進(jìn)行正弦變換,得到低頻部分(B,L)
x_pos_enc_high=torch.cos(x[:,:L]*freqs)#對前一半的輸入進(jìn)行余弦變換,得到高頻部分(B,L)
x_pos_enc=torch.cat([x_pos_enc_low,x_pos_enc_high],dim=-1)#將低頻和高頻部分拼接,得到位置編碼后的輸入(B,C)
returnx_pos_enc

#定義體積渲染函數(shù)
defvolume_rendering(rays_o,rays_d,model):
#rays_o:(B,3),每條光線的起點
#rays_d:(B,3),每條光線的方向
B=rays_o.shape[0]

#在每條光線上采樣一些點
near,far=0.,1.#近平面和遠(yuǎn)平面
N_samples=64#每條光線的采樣數(shù)
t_vals=torch.linspace(near,far,N_samples).to(rays_o.device)#(N_samples,)
t_vals=t_vals.expand(B,N_samples)#(B,N_samples)
z_vals=near*(1.-t_vals)+far*t_vals#計算每個采樣點的深度值(B,N_samples)
z_vals=z_vals.unsqueeze(-1)#(B,N_samples,1)
pts=rays_o.unsqueeze(1)+rays_d.unsqueeze(1)*z_vals#計算每個采樣點的空間位置(B,N_samples,3)

#將采樣點和視角方向輸入網(wǎng)絡(luò)
pts_flat=pts.reshape(-1,3)#(B*N_samples,3)
rays_d_flat=rays_d.unsqueeze(1).expand(-1,N_samples,-1).reshape(-1,3)#(B*N_samples,3)
x_flat=torch.cat([pts_flat,rays_d_flat],-1)#(B*N_samples,6)
y_flat=model(x_flat)#(B*N_samples,4)
y=y_flat.reshape(B,N_samples,4)#(B,N_samples,4)

#提取RGB顏色和透明度(alpha)值
rgb=y[...,:3]#(B,N_samples,3)
alpha=y[...,3]#(B,N_samples)

#計算每個采樣點的權(quán)重
dists=torch.cat([z_vals[...,1:]-z_vals[...,:-1],torch.tensor([1e10]).to(z_vals.device).expand(B,1)],-1)#計算相鄰采樣點之間的距離,最后一個距離設(shè)為很大的值(B,N_samples)
alpha=1.-torch.exp(-alpha*dists)#計算每個采樣點的不透明度,即1減去透明度的指數(shù)衰減(B,N_samples)
weights=alpha*torch.cumprod(torch.cat([torch.ones((B,1)).to(alpha.device),1.-alpha+1e-10],-1),-1)[:,:-1]#計算每個采樣點的權(quán)重,即不透明度乘以之前所有采樣點的透明度累積積,最后一個權(quán)重設(shè)為0(B,N_samples)

#計算每條光線的最終顏色和透明度
rgb_map=torch.sum(weights.unsqueeze(-1)*rgb,-2)#加權(quán)平均每個采樣點的RGB顏色,得到每條光線的顏色(B,3)
depth_map=torch.sum(weights*z_vals.squeeze(-1),-1)#加權(quán)平均每個采樣點的深度值,得到每條光線的深度(B,)
acc_map=torch.sum(weights,-1)#累加每個采樣點的權(quán)重,得到每條光線的不透明度(B,)

returnrgb_map,depth_map,acc_map

#定義渲染損失函數(shù)
defrendering_loss(rgb_map_pred,rgb_map_gt):
return((rgb_map_pred-rgb_map_gt)**2).mean()#計算預(yù)測的顏色與真實顏色之間的均方誤差

綜上所述,本代碼實現(xiàn)了NeRF的核心結(jié)構(gòu),具體實現(xiàn)內(nèi)容包括以下四個部分。

1)定義了NeRF網(wǎng)絡(luò)結(jié)構(gòu),包含位置編碼和多層全連接網(wǎng)絡(luò),輸入是位置和視角,輸出是顏色和密度。

2)實現(xiàn)了位置編碼函數(shù),通過正弦和余弦變換引入高頻信息。

3)實現(xiàn)了體積渲染函數(shù),在光線上采樣點,查詢NeRF網(wǎng)絡(luò)預(yù)測顏色和密度,然后通過加權(quán)平均實現(xiàn)整體渲染。

4)定義了渲染損失函數(shù),計算預(yù)測顏色和真實顏色的均方誤差。

當(dāng)然,本方案只是實現(xiàn)NeRF的一個基礎(chǔ)方案,更多的細(xì)節(jié)還需要進(jìn)行優(yōu)化。

當(dāng)然,為了方便下載,我們已經(jīng)將上述兩個源代碼打包好了。

審核編輯:湯梓紅

聲明:本文內(nèi)容及配圖由入駐作者撰寫或者入駐合作網(wǎng)站授權(quán)轉(zhuǎn)載。文章觀點僅代表作者本人,不代表電子發(fā)燒友網(wǎng)立場。文章及其配圖僅供工程師學(xué)習(xí)之用,如有內(nèi)容侵權(quán)或者其他違規(guī)問題,請聯(lián)系本站處理。 舉報投訴
  • 神經(jīng)網(wǎng)絡(luò)

    關(guān)注

    42

    文章

    4724

    瀏覽量

    100311
  • 函數(shù)
    +關(guān)注

    關(guān)注

    3

    文章

    4258

    瀏覽量

    62227
  • 代碼
    +關(guān)注

    關(guān)注

    30

    文章

    4695

    瀏覽量

    68080
  • pytorch
    +關(guān)注

    關(guān)注

    2

    文章

    795

    瀏覽量

    13079

原文標(biāo)題:一文帶你入門NeRF:利用PyTorch實現(xiàn)NeRF代碼詳解(附代碼)

文章出處:【微信號:3D視覺工坊,微信公眾號:3D視覺工坊】歡迎添加關(guān)注!文章轉(zhuǎn)載請注明出處。

收藏 人收藏

    評論

    相關(guān)推薦

    Image Style Transfer pytorch方式實現(xiàn)的主要思路

    深度學(xué)總結(jié):Image Style Transfer pytorch方式實現(xiàn),這個是非基于autoencoder和domain adversrial方式
    發(fā)表于 06-20 10:58

    PyTorch如何入門

    PyTorch 入門實戰(zhàn)(一)——Tensor
    發(fā)表于 06-01 09:58

    Pytorch代碼移植嵌入式開發(fā)筆記,錯過絕對后悔

    @[TOC]Pytorch 代碼移植嵌入式開發(fā)筆記目前在做開發(fā)完成后的AI模型移植到前端的工作。 由于硬件設(shè)施簡陋,需要把代碼和算法翻譯成基礎(chǔ)加乘算法并輸出每個環(huán)節(jié)參數(shù)。記錄幾點實用技巧以及項目
    發(fā)表于 11-08 08:24

    單片機點燈的基本語法代碼詳解

    【單片機】點燈基本語法代碼詳解代碼詳解#include #include //功能:實現(xiàn)P1口左移#define uchar unsigne
    發(fā)表于 02-16 06:34

    PyTorch官網(wǎng)教程PyTorch深度學(xué)習(xí):60分鐘快速入門中文翻譯版

    PyTorch 深度學(xué)習(xí):60分鐘快速入門”為 PyTorch 官網(wǎng)教程,網(wǎng)上已經(jīng)有部分翻譯作品,隨著PyTorch1.0 版本的公布,這個教程有較大的代碼改動,本人對教程進(jìn)行重新翻
    的頭像 發(fā)表于 01-13 11:53 ?1w次閱讀

    Pytorch 1.1.0,來了!

    許多用戶已經(jīng)轉(zhuǎn)向使用標(biāo)準(zhǔn)PyTorch運算符編寫自定義實現(xiàn),但是這樣的代碼遭受高開銷:大多數(shù)PyTorch操作在GPU上啟動至少一個內(nèi)核,并且RNN由于其重復(fù)性質(zhì)通常運行許多操作。但是
    的頭像 發(fā)表于 05-05 10:02 ?5852次閱讀
    <b class='flag-5'>Pytorch</b> 1.1.0,來了!

    詳解Tutorial代碼的學(xué)習(xí)過程與準(zhǔn)備

    導(dǎo)讀:本文主要解析Pytorch Tutorial中BiLSTM_CRF代碼,幾乎注釋了每行代碼,希望本文能夠幫助大家理解這個tutorial,除此之外借助代碼和圖解也對理解條件隨機場
    的頭像 發(fā)表于 04-03 16:50 ?1811次閱讀
    <b class='flag-5'>詳解</b>Tutorial<b class='flag-5'>代碼</b>的學(xué)習(xí)過程與準(zhǔn)備

    Pytorch實現(xiàn)MNIST手寫數(shù)字識別

    Pytorch 實現(xiàn)MNIST手寫數(shù)字識別
    發(fā)表于 06-16 14:47 ?7次下載

    pytorch實現(xiàn)斷電繼續(xù)訓(xùn)練時需要注意的要點

    本文整理了pytorch實現(xiàn)斷電繼續(xù)訓(xùn)練時需要注意的要點,附有代碼詳解。
    的頭像 發(fā)表于 08-22 09:50 ?1354次閱讀

    PyTorch教程3.2之面向?qū)ο蟮脑O(shè)計實現(xiàn)

    電子發(fā)燒友網(wǎng)站提供《PyTorch教程3.2之面向?qū)ο蟮脑O(shè)計實現(xiàn).pdf》資料免費下載
    發(fā)表于 06-05 15:48 ?0次下載
    <b class='flag-5'>PyTorch</b>教程3.2之面向?qū)ο蟮脑O(shè)計<b class='flag-5'>實現(xiàn)</b>

    PyTorch教程3.5之線性回歸的簡潔實現(xiàn)

    電子發(fā)燒友網(wǎng)站提供《PyTorch教程3.5之線性回歸的簡潔實現(xiàn).pdf》資料免費下載
    發(fā)表于 06-05 11:28 ?0次下載
    <b class='flag-5'>PyTorch</b>教程3.5之線性回歸的簡潔<b class='flag-5'>實現(xiàn)</b>

    PyTorch教程13.6之多個GPU的簡潔實現(xiàn)

    電子發(fā)燒友網(wǎng)站提供《PyTorch教程13.6之多個GPU的簡潔實現(xiàn).pdf》資料免費下載
    發(fā)表于 06-05 14:21 ?0次下載
    <b class='flag-5'>PyTorch</b>教程13.6之多個GPU的簡潔<b class='flag-5'>實現(xiàn)</b>

    [源代碼]Python算法詳解

    [源代碼]Python算法詳解[源代碼]Python算法詳解
    發(fā)表于 06-06 17:50 ?0次下載

    TorchFix:基于PyTorch代碼靜態(tài)分析

    TorchFix是我們最近開發(fā)的一個新工具,旨在幫助PyTorch用戶維護(hù)健康的代碼庫并遵循PyTorch的最佳實踐。首先,我想要展示一些我們努力解決的問題的示例。
    的頭像 發(fā)表于 12-18 15:20 ?999次閱讀

    pytorch怎么在pycharm中運行

    第一部分:PyTorch和PyCharm的安裝 1.1 安裝PyTorch PyTorch是一個開源的機器學(xué)習(xí)庫,用于構(gòu)建和訓(xùn)練神經(jīng)網(wǎng)絡(luò)。要在PyCharm中使用PyTorch,首先需
    的頭像 發(fā)表于 08-01 16:22 ?872次閱讀