这里解析的是github里的mask2former简化版,该版本解耦了detectron2<
我真受不了原版和detectron2混在一起的版本,根本看不懂啊喂!!!
>
目录标题
1、init方法
1.1 ADE200KDataset
class ADE200kDataset(BaseDataset):
def __init__(self, odgt, opt, dynamic_batchHW=False, **kwargs):
super(ADE200kDataset, self).__init__(odgt, opt, **kwargs)
self.root_dataset = opt.DATASETS.ROOT_DIR
# down sampling rate of segm labe
self.segm_downsampling_rate = opt.MODEL.SEM_SEG_HEAD.COMMON_STRIDE # 网络输出相对于输入缩小的倍数
self.dynamic_batchHW = dynamic_batchHW # 是否动态调整batchHW, cswin_transformer需要使用固定image size
self.num_querys = opt.MODEL.MASK_FORMER.NUM_OBJECT_QUERIES
# self.visualize = ADEVisualize()
self.aug_pipe = self.get_data_aug_pipe()
1) odgt:dataset/training.odgt — 数据集列表文件
{“fpath_img”:“ADEChallengeData2016/images/training/ADE_train_00000001.jpg”,“fpath_segm”:“ADEChallengeData2016/annotations/training/ADE_train_00000001.png”, “width”: 683, “height”: 512}
2) 其余初始化参数变量
- self.root_dataset – “data/” – 数据集root地址
- segm_downsampling_rate – 分割标签的下采样率
- dynamic_batchHW – 是否动态调整batchHW
- num_querys – MaskFormer模型的查询数量
- aug_pipe – 数据增强pipline
3) aug_pipe
# 随机生成一个数据增强操作序列(可能包含1-2种增强方法)
def get_data_aug_pipe(self):
pipe_aug = []
if random.random() > 0.5: # 50%的概率进行增强
# 增强操作a:旋转、缩放、平移、模糊、翻转、仿射剪切、对比度调整
# 增强概率p:5、 25、 20、 25、 15、 5、 5
aug_list = [pipe_sequential_rotate, pipe_sequential_scale, pipe_sequential_translate, pipe_someof_blur,
pipe_someof_flip, pipe_sometimes_mpshear, pipe_someone_contrast]
index = np.random.choice(a=[0, 1, 2, 3, 4, 5, 6],
p=[0.05, 0.25, 0.20, 0.25, 0.15, 0.05, 0.05])
# 当选择旋转/翻转/剪切时,有50%概率额外叠加一个增强操作
if (index == 0 or index == 4 or index == 5) and random.random() < 0.5:
index2 = np.random.choice(a=[1, 2, 3], p=[0.4, 0.3, 0.3]) # 缩放/平移/模糊
pipe_aug = [aug_list[index], aug_list[index2]]
else:
pipe_aug = [aug_list[index]]
return pipe_aug
1.2 BaseDataset
class BaseDataset(torch.utils.data.Dataset):
def __init__(self, odgt, opt, **kwargs):
# parse options
self.imgSizes = opt.INPUT.CROP.SIZE # imgSizes=[224, 320, 480, 512]
self.imgMaxSize = opt.INPUT.CROP.MAX_SIZE # imgMaxSize = [1024, 576]
# 网格参数
self.padding_constant = 2**5 # resnet 总共下采样5次
# 当 odgt 参数非空时,调用 parse_input_list 方法解析数据列表
if odgt is not None:
self.parse_input_list(odgt, **kwargs)
# 数据标准化
self.pixel_mean = np.array(opt.DATASETS.PIXEL_MEAN)
self.pixel_std = np.array(opt.DATASETS.PIXEL_STD)
1)参数初始化
- imgSizes – [224, 320, 480, 512]
- imgMaxSize – [1024, 576]
- padding_constant – 网格参数 – 2**5(resnet总共下采样5次)
- 数据标准化 – pixel_mean & pixel_std
2) parse_input_list
def parse_input_list(self, odgt, max_sample=-1, start_idx=-1, end_idx=-1):
# 如果odgt是列表,则直接使用现有列表, eg: [{"image": "1.jpg", "mask": "1.png"}, ...]
if isinstance(odgt, list):
self.list_sample = odgt
# 从json文件中加载,得到字典,其中包含四个值--1)img路径 2)mask路径 3)高度 4)宽度
elif isinstance(odgt, str):
self.list_sample = [json.loads(x.rstrip()) for x in open(odgt, 'r')]
# 限制最大样本数
if max_sample > 0:
self.list_sample = self.list_sample[0:max_sample]
if start_idx >= 0 and end_idx >= 0: # 切片范围
self.list_sample = self.list_sample[start_idx:end_idx]
self.num_sample = len(self.list_sample)
assert self.num_sample > 0
print('# samples: {}'.format(self.num_sample))
2、 getitem方法
def __getitem__(self, index):
this_record = self.list_sample[index]
# load image and label 加载图像和标签
image_path = os.path.join(self.root_dataset, this_record['fpath_img'])
segm_path = os.path.join(self.root_dataset, this_record['fpath_segm'])
img = Image.open(image_path).convert('RGB')
segm = Image.open(segm_path).convert('L')
# data augmentation 数据增强
img = np.array(img)
segm = np.array(segm)
for seq in self.aug_pipe:
img, segm = imgaug_mask(img, segm, seq)
output = dict()
output['image'] = img
output['mask'] = segm
return output
3、数据加载器Dataloader
dataset_train = ADE200kDataset(cfg.DATASETS.TRAIN, cfg, dynamic_batchHW=True)
if cfg.ngpus > 1:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train, rank=cfg.local_rank)
else:
train_sampler = None
loader_train = torch.utils.data.DataLoader( # 将dataset包装成可迭代的批量数据流
dataset_train,
batch_size=cfg.TRAIN.BATCH_SIZE, # batch大小
shuffle=False if train_sampler is not None else True,
collate_fn=dataset_train.collate_fn, # 自定义批次合并函数:固定尺寸+img&mask的协同处理
num_workers=cfg.TRAIN.WORKERS, # 数据加载的子进程数
drop_last=True, # 是否丢弃不完整批次
pin_memory=True, # 是否使用锁页内存
sampler=train_sampler) # 自定义采样器
3.1 自定义合并函数–collate_fn
def collate_fn(self, batch):
batch_width, batch_height = self.get_batch_size(batch) # 计算批次尺寸
out = {}
images = []
masks = []
raw_images = []
for item in batch:
img = deepcopy(item['image'])
segm = item['mask']
img = Image.fromarray(img) # 将numpy数组转为PIL图像
segm = Image.fromarray(segm)
img = self.resize_padding(img, (batch_width, batch_height))
img = self.img_transform(img)
segm = self.resize_padding(segm, (batch_width, batch_height), Image.NEAREST)
segm = segm.resize((batch_width // self.segm_downsampling_rate, batch_height // self.segm_downsampling_rate), Image.NEAREST)
images.append(torch.from_numpy(img).float())
masks.append(torch.from_numpy(np.array(segm)).long())
raw_images.append(item['image'])
out['images'] = torch.stack(images)
out['masks'] = torch.stack(masks)
out['raw_img'] = raw_images
return out
输出返回out为字典,有三个变量:
- raw_img:原始输入图像,list列表。
- images:将原始输入图像进行缩放填充到固定大小,然后进行数值归一化、标准化,并将h,w,c更换为维度c,h,w。
- masks:将mask进行缩放和填充。
1)得到批量尺寸大小
动态调整批量图像尺寸,使其不超过定义的最大值
2) resize_padding
3)img_transformer
def img_transform(self, img):
# 0-255 to 0-1,像素值归一化
img = np.float32(np.array(img)) / 255.
img = (img - self.pixel_mean) / self.pixel_std # 像素值标准化
img = img.transpose((2, 0, 1)) # [c, h, w],更换维度
return img
4)mask进行两次尺寸变换
segm = self.resize_padding(segm, (batch_width, batch_height), Image.NEAREST)
segm = segm.resize((batch_width // self.segm_downsampling_rate, batch_height // self.segm_downsampling_rate), Image.NEAREST)
- 第一次变换
将mask和图像进行同样的变换,目的是在空间上进行对齐 - 第二次变换
将mask进行下采样,使其大小和model的输出大小一致