论文地址:https://ieeexplore.ieee.org/abstract/document/9981070
这篇论文提出了一种名为 HighlightNet 的低光增强器,旨在解决低光环境下无人机(UAV)跟踪面临的挑战,以下是其关键点总结:
一、研究背景与问题
- 低光环境的挑战:在低光环境下,现有最先进的跟踪器性能下降,因为潜在的图像特征难以提取;同时,由于能见度低,人工操作员难以在地面控制站准确初始化无人机跟踪。
- 现有增强器的不足:增强暗区物体可能导致光照良好区域过曝光,破坏潜在特征;难以区分噪声与有效特征;多数针对稳定摄影场景,缺乏应对无人机场景中全局光照快速变化的调整机制;且对无人机跟踪任务的适应性不足。
二、HighlightNet 的设计
- 核心目标:为人工操作员和无人机跟踪器照亮潜在物体,同时兼顾人类感知和无人机在夜间面临的挑战。
- 关键模块
- 逐像素低光增强模块:通过 CNN 获取与输入灰度图像分辨率相同的范围掩码,结合 Transformer 生成的约束 α,利用伽马变换实现逐像素的增强,使增强器聚焦于跟踪对象和无光源区域的增强,且计算高效。
- 基于 Transformer 的参数调整模块:处理下采样的灰度图像,生成约束 α 和截断阈值 β,能根据全局光照信息动态调节参数,以适应无人机高机动性带来的场景快速变化和全局光照变化。
- 软截断模块:通过由截断阈值 β 确定的三次函数构成的截断函数,生成反噪声掩码,过滤过度增强的噪声,其增强减弱强度与灰度值负相关,且计算简单。
- 损失函数:采用非参考损失函数,包括暗区噪声损失、空间一致性损失、曝光控制损失和光照平滑损失,使网络可使用非配对数据训练,避免准备配对数据集的高昂成本。
三、实验与结果
- 实验设置:在 SICE 数据集的 Part1 进行训练,在 Part2 子集和 UAVDark135 基准等上进行测试,对比了多种最先进的低光增强器,并在不同的 SOTA 跟踪器上进行了验证,还进行了消融研究和真实世界测试。
- 主要结果
- 在 SICE 数据集 Part2 子集上,HighlightNet 的 PSNR 和 SSIM 得分最高,在促进人类感知方面有优势。
- 在 UAVDark135 基准上,与其他增强器相比,HighlightNet 使跟踪器的成功率和精度提升更显著,对光照变化、快速运动等典型无人机挑战的应对能力增强。
- 消融研究验证了各模块的有效性,范围掩码、Transformer-based 参数调整和软截断模块均对性能有贡献。
- 真实世界测试中,在 NVIDIA Jetson AGX Xavier 上实现了 32.2 FPS 的实时性能,跟踪准确可靠。
四、结论与意义
- 结论:HighlightNet 通过三个新颖模块,突出了潜在特征,减少了快速光照变化、人工光源、小目标和图像噪声的影响,在在线目标选择和跟踪中均有效,且计算资源消耗少。
- 意义:有助于将无人机跟踪应用扩展到夜间环境。
这段代码主要实现了一个低光照图像增强的深度学习模型的训练和测试过程。下面我们将对各个文件的功能进行详细分析:
1. model.py
该文件定义了低光照图像增强模型的网络结构,主要包含以下几个部分:
TF
类:这是一个自定义的 Transformer 模块,继承自torch.nn.Module
。它包含多头注意力机制(MultiheadAttention
)和前馈神经网络(由两个线性层和一个 ReLU 激活函数组成),用于对输入的特征进行特征提取和转换。discriminator_block
函数:用于构建判别器的下采样层,包含卷积层、LeakyReLU 激活函数和可选的实例归一化层。enhance_net_nopool
类:这是核心的图像增强网络,继承自torch.nn.Module
。它包含多个卷积层和一个 Transformer 模块,用于对输入的低光照图像进行增强处理。在forward
方法中,首先将输入图像拆分为红、绿、蓝三个通道,计算其平均值,然后通过一系列卷积层和 Transformer 模块进行特征提取和转换,最后根据提取的特征对图像进行增强处理,输出增强后的图像、一个权重矩阵A
和一个掩码矩阵t
。
2. Myloss.py
该文件定义了训练过程中使用的各种损失函数,主要包含以下几个部分:
L_spa
类:空间损失函数,用于衡量增强图像和原始图像在空间上的差异。它通过计算图像在不同方向上的梯度差异,并结合一个权重因子来计算损失。L_exp
类:曝光损失函数,用于衡量增强图像的平均亮度与目标亮度之间的差异。它通过对图像进行平均池化,然后计算平均亮度与目标亮度的均方误差来计算损失。L_TV
类:总变差损失函数,用于减少图像中的噪声和纹理,提高图像的平滑度。它通过计算图像在水平和垂直方向上的梯度平方和来计算损失。Sa_Loss
类:色彩饱和度损失函数,用于衡量图像的色彩饱和度。它通过计算图像中每个像素的 RGB 通道与平均 RGB 值的差异,然后计算这些差异的平方和的平方根,并取平均值来计算损失。perception_loss
类:感知损失函数,用于衡量增强图像和原始图像在特征空间上的差异。它使用预训练的 VGG16 网络提取图像的特征,然后计算这些特征之间的差异来计算损失。
3. lowlight_train_onlylow.py
该文件是训练脚本,用于训练低光照图像增强模型,主要包含以下几个部分:
weights_init
函数:用于初始化模型的权重,对于卷积层、批归一化层和线性层,分别使用不同的初始化方法。train
函数:训练模型的主函数,主要步骤包括:- 初始化模型,并将其移动到 GPU 上。
- 加载训练数据集,并创建数据加载器。
- 定义各种损失函数和优化器。
- 开始训练循环,在每个 epoch 中,遍历训练数据,计算损失并进行反向传播和参数更新。
- 定期打印损失值,并保存模型的权重。
- 主程序:解析命令行参数,创建保存模型权重的文件夹,并调用
train
函数开始训练。
4. dataloader.py
该文件定义了数据加载器,用于加载低光照图像数据集,主要包含以下几个部分:
populate_train_list
函数:用于获取指定路径下的所有低光照图像文件列表,并随机打乱顺序。lowlight_loader
类:继承自torch.utils.data.Dataset
,用于加载和处理低光照图像。在__getitem__
方法中,读取图像文件,将其调整为指定大小,将像素值归一化到 [0, 1] 范围内,并转换为 PyTorch 张量。
5. lowlight_test.py
该文件是测试脚本,用于对训练好的模型进行测试,主要包含以下几个部分:
lowlight
函数:对单张低光照图像进行增强处理的函数,主要步骤包括:- 读取图像文件,将其像素值归一化到 [0, 1] 范围内,并转换为 PyTorch 张量。
- 加载训练好的模型,并将其设置为评估模式。
- 对输入图像进行增强处理,记录处理时间。
- 将增强后的图像保存到指定路径。
- 主程序:遍历指定文件夹下的所有图像文件,调用
lowlight
函数对每张图像进行增强处理。
综上所述,这些代码实现了一个完整的低光照图像增强系统,包括模型定义、损失函数定义、数据加载、模型训练和模型测试等功能。
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional
from torch import Tensor
from torch.nn import Module
from torch.nn import MultiheadAttention
from torch.nn import Dropout
#作用:实现一个 Transformer 块,用于处理序列数据。包含多头自注意力机制和前馈神经网络。
# 关键步骤:
# 初始化多头自注意力层、线性层、归一化层和丢弃层。
# 在 forward 方法中,首先进行多头自注意力计算,然后进行残差连接和归一化。
# 接着通过前馈神经网络进行特征变换,再次进行残差连接和归一化。
class TF(Module):
def __init__(self, d_model, nhead, dim_feedforward=128, dropout=0.1, activation="relu"):
super(TF, self).__init__()
self.self_attn = MultiheadAttention(d_model, nhead, dropout=dropout)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm0 = nn.LayerNorm(d_model)
self.norm1 = nn.LayerNorm(d_model)
self.norm2 = nn.LayerNorm(d_model)
self.dropout1 = Dropout(dropout)
self.dropout2 = Dropout(dropout)
self.relu = nn.ReLU(inplace=True)
def __setstate__(self, state):
if 'activation' not in state:
state['activation'] = F.relu
super(TF, self).__setstate__(state)
def forward(self, src: Tensor, src_mask: Optional[Tensor] = None, src_key_padding_mask: Optional[Tensor] = None) -> Tensor:
b,c,s=src.permute(1,2,0).size()
src2 = self.self_attn(src, src, src, attn_mask=src_mask,
key_padding_mask=src_key_padding_mask)[0]
src = src + self.dropout1(src2)
src = self.norm1(src)
src2 = self.linear2(self.dropout(self.relu(self.linear1(src))))
src = src + self.dropout2(src2)
src = self.norm2(src)
return src
# 作用:返回鉴别器块的下采样层。包含卷积层、LeakyReLU 激活函数和可选的实例归一化层。
# 关键步骤:
# 创建一个卷积层,步长为 2,用于下采样。
# 添加 LeakyReLU 激活函数。
# 如果 normalization 为 True,则添加实例归一化层。
def discriminator_block(in_filters, out_filters, normalization=False):
"""Returns downsampling layers of each discriminator block"""
layers = [nn.Conv2d(in_filters, out_filters, 3, stride=2, padding=1)]
layers.append(nn.LeakyReLU(0.2))
if normalization:
layers.append(nn.InstanceNorm2d(out_filters, affine=True))
return layers
# 作用:实现一个图像增强网络,不使用池化层。包含卷积层、Transformer 块和一些辅助操作。
# 关键步骤:
# 初始化卷积层、Transformer 块、池化层和上采样层。
# 在 forward 方法中,首先将输入图像拆分为三个通道,计算亮度值 v。
# 通过一系列卷积层提取特征,计算增强系数 v_r。
# 对亮度值进行下采样,通过卷积层和 Transformer 块进行特征变换,计算亮度和对比度调整参数 level。
# 根据调整参数对亮度值进行幂运算和偏移操作,得到增强后的亮度值 ev。
# 根据增强后的亮度值对三个通道进行调整,得到增强后的图像。
class enhance_net_nopool(nn.Module):
def __init__(self):
super(enhance_net_nopool, self).__init__()
self.relu = nn.ReLU(inplace=True)
number_f = 4
self.e_conv1 = nn.Conv2d(1,number_f,3,1,1,bias=True)
self.e_conv2 = nn.Conv2d(number_f,number_f,3,1,1,bias=True)
self.e_conv3 = nn.Conv2d(number_f,number_f,3,1,1,bias=True)
self.e_conv4 = nn.Conv2d(number_f,number_f,3,1,1,bias=True)
self.e_conv5 = nn.Conv2d(number_f*2,number_f,3,1,1,bias=True)
self.e_conv6 = nn.Conv2d(number_f*2,number_f,3,1,1,bias=True)
self.e_conv7 = nn.Conv2d(number_f*2,1,3,1,1,bias=True)
self.model = nn.Sequential(
nn.Conv2d(1, 16, 3, stride=2, padding=1),
nn.LeakyReLU(0.2),
nn.InstanceNorm2d(16, affine=True),
)
self.TF = TF(16, 8)#transform
self.F = nn.Conv2d(16, 2, 16, padding=0)
self.maxpool = nn.MaxPool2d(2, stride=2, return_indices=False, ceil_mode=False)
self.upsample = nn.UpsamplingBilinear2d(scale_factor=2)
self.adapt = nn.AdaptiveAvgPool2d((32,32))
def forward(self, x):
batch,w,h,b = x.shape
red,green,blue = torch.split(x ,1,dim = 1)
v = (red + green + blue)/3
x1 = self.relu(self.e_conv1(v))
x2 = self.relu(self.e_conv2(x1))
x3 = self.relu(self.e_conv3(x2))
x4 = self.relu(self.e_conv4(x3))
x5 = self.relu(self.e_conv5(torch.cat([x3,x4],1)))
x6 = self.relu(self.e_conv6(torch.cat([x2,x5],1)))
v_r = torch.sigmoid(self.e_conv7(torch.cat([x1,x6],1)))
zero = 0.000001*torch.ones_like(v)
one = 0.999999*torch.ones_like(v)
v0 = torch.where(v>0.999999,one,v)
v0 = torch.where(v<0.000001,zero,v0)
r = v_r
v32 = F.interpolate(v,size=(32,32),mode='nearest')
v1 = self.model(v32)
bb, cc, ww, hh = v1.size()
v2 = v1.view(bb,cc,-1).permute(2, 0, 1)
v3 = self.TF(v2)
v4 = v3.permute(1,2,0)
v5 = v4.view(bb,cc,ww,hh)
level = torch.sigmoid(self.F(v5))
g1 = level[0,0].item()
b1 = level[0,1].item()
g = 0.1*g1+0.2
b = 0.04*b1+0.06
for i in range(batch):
if(i == 0):
r0 = torch.pow (0.1*level[i,0].item()+0.2,torch.unsqueeze(r[i,:,:,:],0))
else:
r1 = torch.pow (0.1*level[i,0].item()+0.2,torch.unsqueeze(r[i,:,:,:],0))
r0 = torch.cat([r0,r1],0)
ev0 = torch.pow(v0,r0)
for i in range(batch):
if(i == 0):
L = 400*torch.pow((0.04*level[i,1].item()+0.06 - torch.unsqueeze(v[i,:,:,:],0)),3)
else:
L0 = 400*torch.pow((0.04*level[i,1].item()+0.06 - torch.unsqueeze(v[i,:,:,:],0)),3)
L = torch.cat([L,L0],0)
L = torch.where(L<0.00001,zero,L)
ev = ev0 - L
v = v + 0.000001
red1 = red/v
green1 = green/v
blue1 = blue/v
red0 = red1*ev
green0 = green1*ev
blue0 = blue1*ev
enhance_image = torch.cat([red0,green0,blue0],1)
zero1 = torch.zeros_like(x)
vvv = torch.cat([v,v,v],1)
t = torch.where(vvv>0.04,zero1,enhance_image)
A = r
return enhance_image,A,t