本文会对NeRF整体的训练代码进行解析;(如果不训练,只是渲染的话,与训练的区别就是有无最后一步的计算loss并进行反向传播)
整体的代码流程如下:
- 参数设置------------------confi_parser
- 数据加载------------------load_llff_data…
- NeRF网络构建------------create_nerf
- 仅渲染--------------------render
- 构建raybatch tensor-----get_rays_np
- 渲染的核心过程-----------render
- 计算loss------------------img2mse
可以结合上一篇博客学习,代码和理论结合使用;
1. 参数设置confi_parser()
数据的处理、网络训练或者渲染中所用到的参数基本都在函数config_parser()里面进行了介绍,但是很多参数都是不太用得到的,每次只需要改的就是官方config中的各个txt文件中提及的那几个;
2. 数据加载load_llff_data …
(1)数据读取_load_data()
输出:pose[3,5,N], bsd[2,N], images[h,w,c,N]
过程:读取要进行渲染的所有图片,进行pose的维度变换,并根据是否要进行下采样进行操作。下采样函数为_minify()
(2)数据后处理recenter_poses()
拿到上面的数据后,根据要求变换位姿,并将表示图像维度的N放到第0维;
紧接着进行边界和平移向量t的缩放;
recenter_poses():计算所有pose的均值,将所有pose做均值逆转换,简单来说就是重新定义世界坐标系,原点期望放在被测物体的中心
(3)render_path_spiral():生成用来渲染的螺旋路径的位姿
(4)函数load_llff_data最后的输出:
pose:[N,3,5],N为图像个数,3x5中,前3为旋转,第4列为t,第5列为[h,w,f]
images: [N,h,w,c]
bds:[2,N]采样深度范围
render_poses:螺旋路径的位姿
i_test:距离最小的id,作为测试
(5)网络构建前的数据预处理
上面得到的i_val作为验证,其余作为训练集;紧接着计算内参K;然后创建log路径,保存训练用的所有参数到args,复制config参数并保存
3. NeRF网络构建create_nerf
(1)位置编码get_embedder
对xyz以及view方向的都进行位置编码
输入:xyz三维或者view
输出:input_ch=63高维的特征或者对应view的27维
实现:对应公式
接下来输入网络深度和每层宽度,且输入宽度不是5d的5,而是位置编码后的通道数63
(2)模型初始化NeRF,实例化
输入:8层网络,每层设置的通道数为256,xyz对应的输入通道数63,view对应的输入27,再多一次输入的层序号skips
输出:feature_linear输出特征256维,不透明度alpha_linear256维,Rgb128为3通道rgb_linear
上面是coarse网络,接下来refine网络也是类似,唯一的区别就是网络深度和每层的通道数不同(但好多时候其实也是一样的)
(3)模型批量处理数据函数
位置编码:embedded = embed_fn(inputs_flat)
以更小的patch-netchunk去跑网络前向:outputs_flat = batchify(fn, netchunk)(embedded)
将output进行reshape:[1024,64,4](这里面的4表示rgb+alpha)
(4)定义优化器Adam
optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))
加载已有模型参数,传递给优化器;加载已有模型
训练需要的参数:
create_nerf最后的输出:
render_kwargs_train = {
'network_query_fn' : network_query_fn,#网络
'perturb' : args.perturb,#扰动
'N_importance' : args.N_importance,#refine采样点
'network_fine' : model_fine,#refine网络
'N_samples' : args.N_samples,
'network_fn' : model,
'use_viewdirs' : args.use_viewdirs,
'white_bkgd' : args.white_bkgd,
'raw_noise_std' : args.raw_noise_std,
}
# NDC only good for LLFF-style forward facing data
if args.dataset_type != 'llff' or args.no_ndc:
print('Not ndc!')
render_kwargs_train['ndc'] = False
render_kwargs_train['lindisp'] = args.lindisp
render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train}
render_kwargs_test['perturb'] = False
render_kwargs_test['raw_noise_std'] = 0.
# start:下一次迭代的起始点,最开始为0,从ckpt里读取
# render_kwargs_train:训练参数
# render_kwargs_test:测试参数
# optimizer:优化器
# grad_vars:未使用
return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer
4. 渲染 if args.render_only
如果只是进行渲染,那运行到这里就调用render函数,完成后就return;(在后面的训练中,还是会执行这个函数)
5. 构建raybatch tensor
(1)use_batching为true的话:就会将所有图里的ray都算出来,然后每次随机取N_rand去训练;如果为false,就从一张图里去选;
(2)get_rays_np
def get_rays_np(H, W, K, c2w):
# 像素的ux全部放到i中,uy全放j中
i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy')
# 2D点到3D点的计算,也就是从像素坐标到相机归一化坐标的过程:[x,y,z] = [(u-cx)/fx,-(v-cy)/fy,-1]
# 这里存在不同的就是,在y和z都取了反,原因是呢人防中使用的坐标系是x向右,y向上,z朝我们(正常的右手系)
dirs = np.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -np.ones_like(i)], -1)
# Rotate ray directions from camera frame to the world frame
# 这里再用pose把相机坐标系下的光线向量转到世界坐标系
rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1) # dot product, equals to: [c2w.dot(dir) for dir in dirs]
# Translate camera frame's origin to the world frame. It is the origin of all rays.
# 相机原点在世界坐标系中的坐标,也是同一个相机所有ray的起点
rays_o = np.broadcast_to(c2w[:3,-1], np.shape(rays_d))
# rays_o:ray的起点;rays_d:ray的方向
return rays_o, rays_d
然后把对应是属于训练图像的rays取出来为rays_rgb,并打乱顺序,使得随机取去训练的时候更鲁棒。接下来就是训练步骤中的核心,也就是渲染:
6. 渲染render
开始训练,进行迭代:
start = start + 1
for i in trange(start, N_iters):
time0 = time.time()
# Sample random ray batch
if use_batching:
# Random over all images,每次从所有图像的ray中取N_rand个ray,每遍历一次都要打乱顺序
batch = rays_rgb[i_batch:i_batch+N_rand] # [B, 2+1, 3*?]
batch = torch.transpose(batch, 0, 1)
# batch_rays:ray的起点和方向;target_s:rgb值,拿它当真值
batch_rays, target_s = batch[:2], batch[2]
i_batch += N_rand
if i_batch >= rays_rgb.shape[0]:
print("Shuffle data after an epoch!")
# 每个epoch后都会重新打乱ray的分布
rand_idx = torch.randperm(rays_rgb.shape[0])
rays_rgb = rays_rgb[rand_idx]
i_batch = 0
else:
# Random from one image
# 每次随机抽取一张图像,抽取一个batch的ray
在经过上述步骤拿到ray的信息后,开始调用render
##### Core optimization loop #####
rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays,
verbose=i < 10, retraw=True,
**render_kwargs_train)
def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True,
near=0., far=1.,
use_viewdirs=False, c2w_staticcam=None,
**kwargs):
"""Render rays
Args:
H: int. Height of image in pixels.高度
W: int. Width of image in pixels.宽度
focal: float. Focal length of pinhole camera.焦距
chunk: int. Maximum number of rays to process simultaneously. Used to
control maximum memory usage. Does not affect final results.并行处理的最大光线数
rays: array of shape [2, batch_size, 3]. Ray origin and direction for
each example in batch.每个batch的ray的原点和方向
c2w: array of shape [3, 4]. Camera-to-world transformation matrix.相机到世界的转换矩阵
ndc: bool. If True, represent ray origin, direction in NDC coordinates.NDC坐标
near: float or array of shape [batch_size]. Nearest distance for a ray.光线最近距离
far: float or array of shape [batch_size]. Farthest distance for a ray.光线最远距离
use_viewdirs: bool. If True, use viewing direction of a point in space in model.是否使用view方向
c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for
camera while using other c2w argument for viewing directions.相机参数
Returns:
rgb_map: [batch_size, 3]. Predicted RGB values for rays.预测的rgb图
disp_map: [batch_size]. Disparity map. Inverse of depth.视差图
acc_map: [batch_size]. Accumulated opacity (alpha) along a ray.深度图
extras: dict with everything returned by render_rays().
"""
# 1. 确定有ray和view处理
if c2w is not None:
# special case to render full image
rays_o, rays_d = get_rays(H, W, K, c2w)
else:
# use provided ray batch
rays_o, rays_d = rays
if use_viewdirs:
# provide ray directions as input
viewdirs = rays_d # 方向赋值
if c2w_staticcam is not None:
# special case to visualize effect of viewdirs,用来分析的,例如给一个corner case
rays_o, rays_d = get_rays(H, W, K, c2w_staticcam)
viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)#归一化
viewdirs = torch.reshape(viewdirs, [-1,3]).float()
sh = rays_d.shape # [..., 3]
if ndc:#未使用
# for forward facing scenes
rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d)
# 2. Create ray batch构建rays batch
rays_o = torch.reshape(rays_o, [-1,3]).float()
rays_d = torch.reshape(rays_d, [-1,3]).float()
near, far = near * torch.ones_like(rays_d[...,:1]), far * torch.ones_like(rays_d[...,:1])
# 这里维度变成[1024x8];8 = 3 + 3 + 1 + 1
rays = torch.cat([rays_o, rays_d, near, far], -1)
if use_viewdirs:
rays = torch.cat([rays, viewdirs], -1)
# 3. Render and reshape 渲染
all_ret = batchify_rays(rays, chunk, **kwargs)
for k in all_ret:
k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
all_ret[k] = torch.reshape(all_ret[k], k_sh)
k_extract = ['rgb_map', 'disp_map', 'acc_map']
ret_list = [all_ret[k] for k in k_extract]
ret_dict = {k : all_ret[k] for k in all_ret if k not in k_extract}
return ret_list + [ret_dict]
以上的就是渲染的整体流程,其中为了在更小的batch上渲染,避免超出内存,所以会调用batchify_rays函数;
而为了更进一步细小化的batch上渲染,所以最终每步的渲染是在render_rays中实现:
def render_rays(ray_batch,# 用来view ray的所有数据:ray原点、方向,最大最小距离,单位方向
network_fn,#NeRF网络,用来预测空间中每个点rgb和透明度
network_query_fn,# 将查询传递给network_fn的函数
N_samples,#coarse采样点数
retraw=False,#如果为真,返回数据无压缩
lindisp=False,#如果为真,在深度图的逆上面进行线性采样
perturb=0.,#扰动值
N_importance=0,#fine网络增加的采样点数
network_fine=None,#fine网络
white_bkgd=False,#若为true,则认为是白色背景
raw_noise_std=0.,#噪声
verbose=False,#打印debug信息
pytest=False):
"""Volumetric rendering.
Args:
ray_batch: array of shape [batch_size, ...]. All information necessary
for sampling along a ray, including: ray origin, ray direction, min
dist, max dist, and unit-magnitude viewing direction.
network_fn: function. Model for predicting RGB and density at each point
in space.
network_query_fn: function used for passing queries to network_fn.
N_samples: int. Number of different times to sample along each ray.
retraw: bool. If True, include model's raw, unprocessed predictions.
lindisp: bool. If True, sample linearly in inverse depth rather than in depth.
perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified
random points in time.
N_importance: int. Number of additional times to sample along each ray.
These samples are only passed to network_fine.
network_fine: "fine" network with same spec as network_fn.
white_bkgd: bool. If True, assume a white background.
raw_noise_std: ...
verbose: bool. If True, print more debugging info.
Returns:
rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model.
disp_map: [num_rays]. Disparity map. 1 / depth.
acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model.
raw: [num_rays, num_samples, 4]. Raw predictions from model.体素渲染之前的数据
# coarse网络得到的值,上面几个是fine网络得到的值
rgb0: See rgb_map. Output for coarse model.
disp0: See disp_map. Output for coarse model.
acc0: See acc_map. Output for coarse model.
#标准差
z_std: [num_rays]. Standard deviation of distances along ray for each
sample.
"""
# 0. 接受数据
N_rays = ray_batch.shape[0]
rays_o, rays_d = ray_batch[:,0:3], ray_batch[:,3:6] # [N_rays, 3] each
viewdirs = ray_batch[:,-3:] if ray_batch.shape[-1] > 8 else None
bounds = torch.reshape(ray_batch[...,6:8], [-1,1,2])
near, far = bounds[...,0], bounds[...,1] # [-1,1]
# 0.1 0-1线性采样N_samples个点,取的都是z值
t_vals = torch.linspace(0., 1., steps=N_samples)
if not lindisp:
z_vals = near * (1.-t_vals) + far * (t_vals)
else:
z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals))
z_vals = z_vals.expand([N_rays, N_samples])
if perturb > 0.:#如果加了扰动
# get intervals between samples
mids = .5 * (z_vals[...,1:] + z_vals[...,:-1])
upper = torch.cat([mids, z_vals[...,-1:]], -1)
lower = torch.cat([z_vals[...,:1], mids], -1)
# stratified samples in those intervals
t_rand = torch.rand(z_vals.shape)
# Pytest, overwrite u with numpy's fixed random numbers
if pytest:
np.random.seed(0)
t_rand = np.random.rand(*list(z_vals.shape))
t_rand = torch.Tensor(t_rand)
z_vals = lower + (upper - lower) * t_rand
# 0.2 获取每个采样点的3D坐标
pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples, 3]
# raw = run_network(pts)
# 1. coarse网络推断 network_query_fn得到前向结果:
raw = network_query_fn(pts, viewdirs, network_fn)#[ray,batch,N_samples,4][1024,64,4]其中4=rgb+alpha
# 2. coarse raw数据转换成rgb alpha等(公式3等)
# rgb,视差图,权重加和(也就是不透明度值),权重,深度图
rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)
if N_importance > 0:
rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map
z_vals_mid = .5 * (z_vals[...,1:] + z_vals[...,:-1])
# 3. finew网络需要的点采样
z_samples = sample_pdf(z_vals_mid, weights[...,1:-1], N_importance, det=(perturb==0.), pytest=pytest)
z_samples = z_samples.detach()
z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1)
pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples + N_importance, 3]
run_fn = network_fn if network_fine is None else network_fine
# raw = run_network(pts, fn=run_fn)
# 4. fine网络推断
raw = network_query_fn(pts, viewdirs, run_fn)
# 5. 数据转换
rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)
ret = {'rgb_map' : rgb_map, 'disp_map' : disp_map, 'acc_map' : acc_map}
if retraw:
ret['raw'] = raw
if N_importance > 0:
ret['rgb0'] = rgb_map_0
ret['disp0'] = disp_map_0
ret['acc0'] = acc_map_0
ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False) # [N_rays]
for k in ret:
if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG:
print(f"! [Numerical Error] {k} contains nan or inf.")
return ret
其中,在完成上面代码中注释的第2步以后,要开始进行fine网络需要的点采样,此时要遵循分层体积采样,可由上一篇博客中所述,需要评估这些采样点位置的coarse网络,计算每个采样点的权重,并进行归一化处理,得到概率密度函数;最后才能沿着每条射线对权重更大的点进行更精确的采样,具体过程由如下代码实现:
# Hierarchical sampling (section 5.2),逆变换采样
# 根据概率密度分布PDF计算累积分布函数CDF,在[0,1]内,对CDF值用均匀分布进行采样
# 将采样到的CDF值映射回坐标值
def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
# Get pdf,计算概率密度(公式5)
weights = weights + 1e-5 # prevent nans,防止做除法时产生nan值,[1024,62]
pdf = weights / torch.sum(weights, -1, keepdim=True)# 归一化
# 返回输入元素的总和sum,[1024,62]
cdf = torch.cumsum(pdf, -1)
cdf = torch.cat([torch.zeros_like(cdf[...,:1]), cdf], -1) # (batch, len(bins))
# Take uniform samples
if det:#未加扰动,申城线性间隔的均匀样本点集合
u = torch.linspace(0., 1., steps=N_samples)
u = u.expand(list(cdf.shape[:-1]) + [N_samples])
else:# 随机均匀分布的样本点集合
u = torch.rand(list(cdf.shape[:-1]) + [N_samples])
# Pytest, overwrite u with numpy's fixed random numbers
if pytest:
np.random.seed(0)
new_shape = list(cdf.shape[:-1]) + [N_samples]
if det:
u = np.linspace(0., 1., N_samples)
u = np.broadcast_to(u, new_shape)
else:
u = np.random.rand(*new_shape)
u = torch.Tensor(u)
# Invert CDF,逆CDF变换,把tensor变成内存连续分布形式[1024,64]
u = u.contiguous()
# 用高维的searchsorted算子去寻找坐标值的位置索引
inds = torch.searchsorted(cdf, u, right=True)#找到每个样本点在CDF中的位置索引
# 对于每个样本点,找到其对应的CDF值区间
below = torch.max(torch.zeros_like(inds-1), inds-1)
above = torch.min((cdf.shape[-1]-1) * torch.ones_like(inds), inds)
inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
# cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
# bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
denom = (cdf_g[...,1]-cdf_g[...,0])
denom = torch.where(denom<1e-5, torch.ones_like(denom), denom)
t = (u-cdf_g[...,0])/denom
samples = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0])
# 根据给定的PDF生成的样本点集合
return samples
7. loss计算并进行反向传播
optimizer.zero_grad()
# 计算损失函数,只使用rgb值
img_loss = img2mse(rgb, target_s)
trans = extras['raw'][...,-1]
loss = img_loss
psnr = mse2psnr(img_loss)# 计算psnr
if 'rgb0' in extras:
img_loss0 = img2mse(extras['rgb0'], target_s)
loss = loss + img_loss0
psnr0 = mse2psnr(img_loss0)
# 反向传播
loss.backward()
optimizer.step()
# NOTE: IMPORTANT!
### update learning rate调整学习率 ###
decay_rate = 0.1
decay_steps = args.lrate_decay * 1000
new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))
for param_group in optimizer.param_groups:
param_group['lr'] = new_lrate
################################
dt = time.time()-time0
# print(f"Step: {global_step}, Loss: {loss}, Time: {dt}")
##### end #####
# Rest is logging,保存log、
if i%args.i_weights==0:
path = os.path.join(basedir, expname, '{:06d}.tar'.format(i))
torch.save({
'global_step': global_step,# 迭代次数
'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(),# coarse网络参数
'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(),# fine网络参数
'optimizer_state_dict': optimizer.state_dict(),#优化器参数
}, path)
print('Saved checkpoints at', path)
if i%args.i_video==0 and i > 0:# 视频渲染
# Turn on testing mode
with torch.no_grad():
rgbs, disps = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test)
print('Done, saving', rgbs.shape, disps.shape)