lightReid训练代码备忘
环境:pytorch1.1 cuda9.0
conda install pytorch==1.1.0 torchvision==0.3.0 cudatoolkit=9.0 -c pytorch
import torch
import lightreid
from lightreid.evaluations import build_evaluator
modelDir='/home/.../light-reid-master'
#build dataset
dataManager=lightreid.data.DataManager(
sources=lightreid.data.build_train_dataset(['market1501']),
target=lightreid.data.build_test_dataset('market1501'),
transforms_train=lightreid.data.build_transforms(img_size=[256,128],transforms_list=['randomflip','padcrop','rea']),
transforms_test=lightreid.data.build_transforms(img_size=[256,128],transforms_list=[]),sampler='pk',p=16,k=4
)
#build model
backbone=lightreid.models.backbones.resnet50(pretrained=True,last_stride_one=True)
pooling=torch.nn.AdaptiveAvgPool2d(1)
head=lightreid.models.BNHead(in_dim=backbone.dim,class_num=dataManager.class_num)
model=lightreid.models.BaseReIDModel(backbone=backbone,pooling=pooling,head=head)
#build loss
criterion=lightreid.losses.Criterion([
{'criterion': lightreid.losses.CrossEntropyLabelSmooth(num_classes=dataManager.class_num),'weight':1.0,'inputs':{'inputs': 'logits', 'targets': 'pids'}},
{'criterion': lightreid.losses.TripletLoss(margin=0.3,metric='euclidean'),'weight':1.0,'inputs':{'emb': 'feats', 'label': 'pids'}}
])
# build optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.00035, weight_decay=5e-4)
lr_scheduler = lightreid.optim.WarmupMultiStepLR(optimizer, milestones=[40, 70], gamma=0.1, warmup_factor=0.01, warmup_epochs=10)
optimizer = lightreid.optim.Optimizer(optimizer=optimizer, lr_scheduler=lr_scheduler, max_epochs=120)
evaluator={
'name':'cmc_map_eval',
'metric':'cosine',
'mode':'inter-camera'
}
evaluator=build_evaluator(**evaluator)
# build optimizer
solver = lightreid.engine.Engine(
results_dir=modelDir, datamanager=dataManager, model=model, criterion=criterion, optimizer=optimizer, use_gpu=True,
light_model=False, light_feat=False, light_search=False,evaluator=evaluator)
# train
solver.train(eval_freq=10)
# test
solver.resume_latest_model()
solver.eval(onebyone=True)
# visualize
solver.visualize()
参考:
https://zhuanlan.zhihu.com/p/348214785 作者的知乎博客地址(包含代码以及论文地址)
https://github.com/wangguanan/light-reid/issues/26 参考其中代码