【成长day】NeRF学习记录2:NeRF源码解析

本文会对NeRF整体的训练代码进行解析;(如果不训练,只是渲染的话,与训练的区别就是有无最后一步的计算loss并进行反向传播)
整体的代码流程如下:

  1. 参数设置------------------confi_parser
  2. 数据加载------------------load_llff_data…
  3. NeRF网络构建------------create_nerf
  4. 仅渲染--------------------render
  5. 构建raybatch tensor-----get_rays_np
  6. 渲染的核心过程-----------render
  7. 计算loss------------------img2mse

可以结合上一篇博客学习,代码和理论结合使用;

1. 参数设置confi_parser()

数据的处理、网络训练或者渲染中所用到的参数基本都在函数config_parser()里面进行了介绍,但是很多参数都是不太用得到的,每次只需要改的就是官方config中的各个txt文件中提及的那几个;

2. 数据加载load_llff_data …

(1)数据读取_load_data()
输出:pose[3,5,N], bsd[2,N], images[h,w,c,N]
过程:读取要进行渲染的所有图片,进行pose的维度变换,并根据是否要进行下采样进行操作。下采样函数为_minify()
(2)数据后处理recenter_poses()
拿到上面的数据后,根据要求变换位姿,并将表示图像维度的N放到第0维;
紧接着进行边界和平移向量t的缩放;
recenter_poses():计算所有pose的均值,将所有pose做均值逆转换,简单来说就是重新定义世界坐标系,原点期望放在被测物体的中心
(3)render_path_spiral():生成用来渲染的螺旋路径的位姿
(4)函数load_llff_data最后的输出:
pose:[N,3,5],N为图像个数,3x5中,前3为旋转,第4列为t,第5列为[h,w,f]
images: [N,h,w,c]
bds:[2,N]采样深度范围
render_poses:螺旋路径的位姿
i_test:距离最小的id,作为测试
(5)网络构建前的数据预处理
上面得到的i_val作为验证,其余作为训练集;紧接着计算内参K;然后创建log路径,保存训练用的所有参数到args,复制config参数并保存

3. NeRF网络构建create_nerf

(1)位置编码get_embedder
对xyz以及view方向的都进行位置编码
输入:xyz三维或者view
输出:input_ch=63高维的特征或者对应view的27维
实现:对应公式
接下来输入网络深度和每层宽度,且输入宽度不是5d的5,而是位置编码后的通道数63
(2)模型初始化NeRF,实例化
输入:8层网络,每层设置的通道数为256,xyz对应的输入通道数63,view对应的输入27,再多一次输入的层序号skips
输出:feature_linear输出特征256维,不透明度alpha_linear256维,Rgb128为3通道rgb_linear
在这里插入图片描述

上面是coarse网络,接下来refine网络也是类似,唯一的区别就是网络深度和每层的通道数不同(但好多时候其实也是一样的)
(3)模型批量处理数据函数
位置编码:embedded = embed_fn(inputs_flat)
以更小的patch-netchunk去跑网络前向:outputs_flat = batchify(fn, netchunk)(embedded)
将output进行reshape:[1024,64,4](这里面的4表示rgb+alpha)
(4)定义优化器Adam
optimizer = torch.optim.Adam(params=grad_vars, lr=args.lrate, betas=(0.9, 0.999))
加载已有模型参数,传递给优化器;加载已有模型
训练需要的参数:
create_nerf最后的输出:

render_kwargs_train = {
        'network_query_fn' : network_query_fn,#网络
        'perturb' : args.perturb,#扰动
        'N_importance' : args.N_importance,#refine采样点
        'network_fine' : model_fine,#refine网络
        'N_samples' : args.N_samples,
        'network_fn' : model,
        'use_viewdirs' : args.use_viewdirs,
        'white_bkgd' : args.white_bkgd,
        'raw_noise_std' : args.raw_noise_std,
    }
    # NDC only good for LLFF-style forward facing data
    if args.dataset_type != 'llff' or args.no_ndc:
        print('Not ndc!')
        render_kwargs_train['ndc'] = False
        render_kwargs_train['lindisp'] = args.lindisp
    render_kwargs_test = {k : render_kwargs_train[k] for k in render_kwargs_train}
    render_kwargs_test['perturb'] = False
    render_kwargs_test['raw_noise_std'] = 0.
    # start:下一次迭代的起始点,最开始为0,从ckpt里读取
    # render_kwargs_train:训练参数
    # render_kwargs_test:测试参数
    # optimizer:优化器
    # grad_vars:未使用
    return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer

4. 渲染 if args.render_only

如果只是进行渲染,那运行到这里就调用render函数,完成后就return;(在后面的训练中,还是会执行这个函数)

5. 构建raybatch tensor

(1)use_batching为true的话:就会将所有图里的ray都算出来,然后每次随机取N_rand去训练;如果为false,就从一张图里去选;
(2)get_rays_np

def get_rays_np(H, W, K, c2w):
    # 像素的ux全部放到i中,uy全放j中
    i, j = np.meshgrid(np.arange(W, dtype=np.float32), np.arange(H, dtype=np.float32), indexing='xy')
    # 2D点到3D点的计算,也就是从像素坐标到相机归一化坐标的过程:[x,y,z] = [(u-cx)/fx,-(v-cy)/fy,-1]
    # 这里存在不同的就是,在y和z都取了反,原因是呢人防中使用的坐标系是x向右,y向上,z朝我们(正常的右手系)
    dirs = np.stack([(i-K[0][2])/K[0][0], -(j-K[1][2])/K[1][1], -np.ones_like(i)], -1)
    # Rotate ray directions from camera frame to the world frame
    # 这里再用pose把相机坐标系下的光线向量转到世界坐标系
    rays_d = np.sum(dirs[..., np.newaxis, :] * c2w[:3,:3], -1)  # dot product, equals to: [c2w.dot(dir) for dir in dirs]
    # Translate camera frame's origin to the world frame. It is the origin of all rays.
    # 相机原点在世界坐标系中的坐标,也是同一个相机所有ray的起点
    rays_o = np.broadcast_to(c2w[:3,-1], np.shape(rays_d))
    # rays_o:ray的起点;rays_d:ray的方向
    return rays_o, rays_d

然后把对应是属于训练图像的rays取出来为rays_rgb,并打乱顺序,使得随机取去训练的时候更鲁棒。接下来就是训练步骤中的核心,也就是渲染:

6. 渲染render

开始训练,进行迭代:

start = start + 1
    for i in trange(start, N_iters):
        time0 = time.time()
        # Sample random ray batch
        if use_batching:
            # Random over all images,每次从所有图像的ray中取N_rand个ray,每遍历一次都要打乱顺序
            batch = rays_rgb[i_batch:i_batch+N_rand] # [B, 2+1, 3*?]
            batch = torch.transpose(batch, 0, 1)
            # batch_rays:ray的起点和方向;target_s:rgb值,拿它当真值
            batch_rays, target_s = batch[:2], batch[2]
            i_batch += N_rand
            if i_batch >= rays_rgb.shape[0]:
                print("Shuffle data after an epoch!")
                # 每个epoch后都会重新打乱ray的分布
                rand_idx = torch.randperm(rays_rgb.shape[0])
                rays_rgb = rays_rgb[rand_idx]
                i_batch = 0
        else:
            # Random from one image
            # 每次随机抽取一张图像,抽取一个batch的ray

在经过上述步骤拿到ray的信息后,开始调用render


#####  Core optimization loop  #####
        rgb, disp, acc, extras = render(H, W, K, chunk=args.chunk, rays=batch_rays,
                                                verbose=i < 10, retraw=True,
                                                **render_kwargs_train)

def render(H, W, K, chunk=1024*32, rays=None, c2w=None, ndc=True,
                  near=0., far=1.,
                  use_viewdirs=False, c2w_staticcam=None,
                  **kwargs):
    """Render rays
    Args:
      H: int. Height of image in pixels.高度
      W: int. Width of image in pixels.宽度
      focal: float. Focal length of pinhole camera.焦距
      chunk: int. Maximum number of rays to process simultaneously. Used to
        control maximum memory usage. Does not affect final results.并行处理的最大光线数
      rays: array of shape [2, batch_size, 3]. Ray origin and direction for
        each example in batch.每个batch的ray的原点和方向
      c2w: array of shape [3, 4]. Camera-to-world transformation matrix.相机到世界的转换矩阵
      ndc: bool. If True, represent ray origin, direction in NDC coordinates.NDC坐标
      near: float or array of shape [batch_size]. Nearest distance for a ray.光线最近距离
      far: float or array of shape [batch_size]. Farthest distance for a ray.光线最远距离
      use_viewdirs: bool. If True, use viewing direction of a point in space in model.是否使用view方向
      c2w_staticcam: array of shape [3, 4]. If not None, use this transformation matrix for
       camera while using other c2w argument for viewing directions.相机参数
    Returns:
      rgb_map: [batch_size, 3]. Predicted RGB values for rays.预测的rgb图
      disp_map: [batch_size]. Disparity map. Inverse of depth.视差图
      acc_map: [batch_size]. Accumulated opacity (alpha) along a ray.深度图
      extras: dict with everything returned by render_rays().
    """
    # 1. 确定有ray和view处理
    if c2w is not None:
        # special case to render full image
        rays_o, rays_d = get_rays(H, W, K, c2w)
    else:
        # use provided ray batch
        rays_o, rays_d = rays
    if use_viewdirs:
        # provide ray directions as input
        viewdirs = rays_d # 方向赋值
        if c2w_staticcam is not None:
            # special case to visualize effect of viewdirs,用来分析的,例如给一个corner case
            rays_o, rays_d = get_rays(H, W, K, c2w_staticcam)
        viewdirs = viewdirs / torch.norm(viewdirs, dim=-1, keepdim=True)#归一化
        viewdirs = torch.reshape(viewdirs, [-1,3]).float()
    sh = rays_d.shape # [..., 3]
    if ndc:#未使用
        # for forward facing scenes
        rays_o, rays_d = ndc_rays(H, W, K[0][0], 1., rays_o, rays_d)
    # 2. Create ray batch构建rays batch
    rays_o = torch.reshape(rays_o, [-1,3]).float()
    rays_d = torch.reshape(rays_d, [-1,3]).float()
    near, far = near * torch.ones_like(rays_d[...,:1]), far * torch.ones_like(rays_d[...,:1])
    # 这里维度变成[1024x8];8 = 3 + 3 + 1 + 1
    rays = torch.cat([rays_o, rays_d, near, far], -1)
    if use_viewdirs:
        rays = torch.cat([rays, viewdirs], -1)
    # 3. Render and reshape 渲染
    all_ret = batchify_rays(rays, chunk, **kwargs)
    for k in all_ret:
        k_sh = list(sh[:-1]) + list(all_ret[k].shape[1:])
        all_ret[k] = torch.reshape(all_ret[k], k_sh)
    k_extract = ['rgb_map', 'disp_map', 'acc_map']
    ret_list = [all_ret[k] for k in k_extract]
    ret_dict = {k : all_ret[k] for k in all_ret if k not in k_extract}
    return ret_list + [ret_dict]

以上的就是渲染的整体流程,其中为了在更小的batch上渲染,避免超出内存,所以会调用batchify_rays函数;
而为了更进一步细小化的batch上渲染,所以最终每步的渲染是在render_rays中实现:


def render_rays(ray_batch,# 用来view ray的所有数据:ray原点、方向,最大最小距离,单位方向
                network_fn,#NeRF网络,用来预测空间中每个点rgb和透明度
                network_query_fn,# 将查询传递给network_fn的函数
                N_samples,#coarse采样点数
                retraw=False,#如果为真,返回数据无压缩
                lindisp=False,#如果为真,在深度图的逆上面进行线性采样
                perturb=0.,#扰动值
                N_importance=0,#fine网络增加的采样点数
                network_fine=None,#fine网络
                white_bkgd=False,#若为true,则认为是白色背景
                raw_noise_std=0.,#噪声
                verbose=False,#打印debug信息
                pytest=False):
    """Volumetric rendering.
    Args:
      ray_batch: array of shape [batch_size, ...]. All information necessary
        for sampling along a ray, including: ray origin, ray direction, min
        dist, max dist, and unit-magnitude viewing direction.
      network_fn: function. Model for predicting RGB and density at each point
        in space.
      network_query_fn: function used for passing queries to network_fn.
      N_samples: int. Number of different times to sample along each ray.
      retraw: bool. If True, include model's raw, unprocessed predictions.
      lindisp: bool. If True, sample linearly in inverse depth rather than in depth.
      perturb: float, 0 or 1. If non-zero, each ray is sampled at stratified
        random points in time.
      N_importance: int. Number of additional times to sample along each ray.
        These samples are only passed to network_fine.
      network_fine: "fine" network with same spec as network_fn.
      white_bkgd: bool. If True, assume a white background.
      raw_noise_std: ...
      verbose: bool. If True, print more debugging info.
    Returns:
      rgb_map: [num_rays, 3]. Estimated RGB color of a ray. Comes from fine model.
      disp_map: [num_rays]. Disparity map. 1 / depth.
      acc_map: [num_rays]. Accumulated opacity along each ray. Comes from fine model.
      raw: [num_rays, num_samples, 4]. Raw predictions from model.体素渲染之前的数据
      # coarse网络得到的值,上面几个是fine网络得到的值
      rgb0: See rgb_map. Output for coarse model.
      disp0: See disp_map. Output for coarse model.
      acc0: See acc_map. Output for coarse model.
      #标准差
      z_std: [num_rays]. Standard deviation of distances along ray for each
        sample.
    """
    # 0. 接受数据
    N_rays = ray_batch.shape[0]
    rays_o, rays_d = ray_batch[:,0:3], ray_batch[:,3:6] # [N_rays, 3] each
    viewdirs = ray_batch[:,-3:] if ray_batch.shape[-1] > 8 else None
    bounds = torch.reshape(ray_batch[...,6:8], [-1,1,2])
    near, far = bounds[...,0], bounds[...,1] # [-1,1]
    # 0.1 0-1线性采样N_samples个点,取的都是z值
    t_vals = torch.linspace(0., 1., steps=N_samples)
    if not lindisp:
        z_vals = near * (1.-t_vals) + far * (t_vals)
    else:
        z_vals = 1./(1./near * (1.-t_vals) + 1./far * (t_vals))
    z_vals = z_vals.expand([N_rays, N_samples])
    if perturb > 0.:#如果加了扰动
        # get intervals between samples
        mids = .5 * (z_vals[...,1:] + z_vals[...,:-1])
        upper = torch.cat([mids, z_vals[...,-1:]], -1)
        lower = torch.cat([z_vals[...,:1], mids], -1)
        # stratified samples in those intervals
        t_rand = torch.rand(z_vals.shape)
        # Pytest, overwrite u with numpy's fixed random numbers
        if pytest:
            np.random.seed(0)
            t_rand = np.random.rand(*list(z_vals.shape))
            t_rand = torch.Tensor(t_rand)
        z_vals = lower + (upper - lower) * t_rand
    # 0.2 获取每个采样点的3D坐标
    pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples, 3]
#     raw = run_network(pts)
    # 1. coarse网络推断 network_query_fn得到前向结果:
    raw = network_query_fn(pts, viewdirs, network_fn)#[ray,batch,N_samples,4][1024,64,4]其中4=rgb+alpha
    # 2. coarse raw数据转换成rgb alpha等(公式3等)
    # rgb,视差图,权重加和(也就是不透明度值),权重,深度图
    rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)
    if N_importance > 0:
        rgb_map_0, disp_map_0, acc_map_0 = rgb_map, disp_map, acc_map
        z_vals_mid = .5 * (z_vals[...,1:] + z_vals[...,:-1])
        # 3. finew网络需要的点采样
        z_samples = sample_pdf(z_vals_mid, weights[...,1:-1], N_importance, det=(perturb==0.), pytest=pytest)
        z_samples = z_samples.detach()
        z_vals, _ = torch.sort(torch.cat([z_vals, z_samples], -1), -1)
        pts = rays_o[...,None,:] + rays_d[...,None,:] * z_vals[...,:,None] # [N_rays, N_samples + N_importance, 3]
        run_fn = network_fn if network_fine is None else network_fine
#         raw = run_network(pts, fn=run_fn)
        # 4. fine网络推断
        raw = network_query_fn(pts, viewdirs, run_fn)
        # 5. 数据转换
        rgb_map, disp_map, acc_map, weights, depth_map = raw2outputs(raw, z_vals, rays_d, raw_noise_std, white_bkgd, pytest=pytest)
    ret = {'rgb_map' : rgb_map, 'disp_map' : disp_map, 'acc_map' : acc_map}
    if retraw:
        ret['raw'] = raw
    if N_importance > 0:
        ret['rgb0'] = rgb_map_0
        ret['disp0'] = disp_map_0
        ret['acc0'] = acc_map_0
        ret['z_std'] = torch.std(z_samples, dim=-1, unbiased=False)  # [N_rays]
    for k in ret:
        if (torch.isnan(ret[k]).any() or torch.isinf(ret[k]).any()) and DEBUG:
            print(f"! [Numerical Error] {k} contains nan or inf.")
    return ret

其中,在完成上面代码中注释的第2步以后,要开始进行fine网络需要的点采样,此时要遵循分层体积采样,可由上一篇博客中所述,需要评估这些采样点位置的coarse网络,计算每个采样点的权重,并进行归一化处理,得到概率密度函数;最后才能沿着每条射线对权重更大的点进行更精确的采样,具体过程由如下代码实现:

# Hierarchical sampling (section 5.2),逆变换采样
# 根据概率密度分布PDF计算累积分布函数CDF,在[0,1]内,对CDF值用均匀分布进行采样
# 将采样到的CDF值映射回坐标值
def sample_pdf(bins, weights, N_samples, det=False, pytest=False):
    # Get pdf,计算概率密度(公式5)
    weights = weights + 1e-5 # prevent nans,防止做除法时产生nan值,[1024,62]
    pdf = weights / torch.sum(weights, -1, keepdim=True)# 归一化
    # 返回输入元素的总和sum,[1024,62]
    cdf = torch.cumsum(pdf, -1)
    cdf = torch.cat([torch.zeros_like(cdf[...,:1]), cdf], -1)  # (batch, len(bins))
    # Take uniform samples
    if det:#未加扰动,申城线性间隔的均匀样本点集合
        u = torch.linspace(0., 1., steps=N_samples)
        u = u.expand(list(cdf.shape[:-1]) + [N_samples])
    else:# 随机均匀分布的样本点集合
        u = torch.rand(list(cdf.shape[:-1]) + [N_samples])
    # Pytest, overwrite u with numpy's fixed random numbers
    if pytest:
        np.random.seed(0)
        new_shape = list(cdf.shape[:-1]) + [N_samples]
        if det:
            u = np.linspace(0., 1., N_samples)
            u = np.broadcast_to(u, new_shape)
        else:
            u = np.random.rand(*new_shape)
        u = torch.Tensor(u)
    # Invert CDF,逆CDF变换,把tensor变成内存连续分布形式[1024,64]
    u = u.contiguous()
    # 用高维的searchsorted算子去寻找坐标值的位置索引
    inds = torch.searchsorted(cdf, u, right=True)#找到每个样本点在CDF中的位置索引
    # 对于每个样本点,找到其对应的CDF值区间
    below = torch.max(torch.zeros_like(inds-1), inds-1)
    above = torch.min((cdf.shape[-1]-1) * torch.ones_like(inds), inds)
    inds_g = torch.stack([below, above], -1)  # (batch, N_samples, 2)
    # cdf_g = tf.gather(cdf, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
    # bins_g = tf.gather(bins, inds_g, axis=-1, batch_dims=len(inds_g.shape)-2)
    matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
    cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
    bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
    denom = (cdf_g[...,1]-cdf_g[...,0])
    denom = torch.where(denom<1e-5, torch.ones_like(denom), denom)
    t = (u-cdf_g[...,0])/denom
    samples = bins_g[...,0] + t * (bins_g[...,1]-bins_g[...,0])
    # 根据给定的PDF生成的样本点集合
    return samples

7. loss计算并进行反向传播

optimizer.zero_grad()
        # 计算损失函数,只使用rgb值
        img_loss = img2mse(rgb, target_s)
        trans = extras['raw'][...,-1]
        loss = img_loss
        psnr = mse2psnr(img_loss)# 计算psnr
        if 'rgb0' in extras:
            img_loss0 = img2mse(extras['rgb0'], target_s)
            loss = loss + img_loss0
            psnr0 = mse2psnr(img_loss0)
        # 反向传播
        loss.backward()
        optimizer.step()
        # NOTE: IMPORTANT!
        ###   update learning rate调整学习率   ###
        decay_rate = 0.1
        decay_steps = args.lrate_decay * 1000
        new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))
        for param_group in optimizer.param_groups:
            param_group['lr'] = new_lrate
        ################################
        dt = time.time()-time0
        # print(f"Step: {global_step}, Loss: {loss}, Time: {dt}")
        #####           end            #####
        # Rest is logging,保存log、
        if i%args.i_weights==0:
            path = os.path.join(basedir, expname, '{:06d}.tar'.format(i))
            torch.save({
                'global_step': global_step,# 迭代次数
                'network_fn_state_dict': render_kwargs_train['network_fn'].state_dict(),# coarse网络参数
                'network_fine_state_dict': render_kwargs_train['network_fine'].state_dict(),# fine网络参数
                'optimizer_state_dict': optimizer.state_dict(),#优化器参数
            }, path)
            print('Saved checkpoints at', path)
        if i%args.i_video==0 and i > 0:# 视频渲染
            # Turn on testing mode
            with torch.no_grad():
                rgbs, disps = render_path(render_poses, hwf, K, args.chunk, render_kwargs_test)
            print('Done, saving', rgbs.shape, disps.shape)
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值