train.py
import os
import math
import argparse
import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torch.optim.lr_scheduler as lr_scheduler
from model import shufflenet_v2_x1_0
from my_dataset import MyDataSet
from utils import read_split_data, train_one_epoch, evaluate
def main(args):
device=torch.device(args.device if torch.cuda.is_available() else "cpu")
tb_writer=SummaryWriter()
if os.path.exists("./weights") is False:
os.makedirs("./weights")
train_images_path,train_images_label,val_images_path,val_images_label=read_split_data(args.data_path)
data_transform={
"train":transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])]),
"val":transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
])
}
train_dataset=MyDataSet(images_path=train_images_path,
images_class=train_images_label,
transforms=data_transform["train"])
val_dataset=MyDataSet(images_path=val_images_path,
images_class=val_images_label,
transforms=data_transform["val"])
batch_size=args.batch_size
nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
train_loader=torch.utils.data.Dataloader(train_dataset,
batch_size=batch_size,
shuffle=True,
pin_memory=True,
num_workers=nw,
collate_fn=train_dataset.collate_fn
)
val_loader=torch.utils.data.Dataloader(val_dataset,
batch_size=batch_size,
shuffle=False,
pin_memory=True,
num_workers=nw,
collate_fn=train_dataset.collate_fn
)
model=shufflenet_v2_x1_0(num_classes=args.num_classes).to(device)
if args.weights!="":
if os.path.exists(args.weights):
weights_dict=torch.load(args.weights,map_location=device)
load_weights_dict={k:v for k,v in weights_dict.item()
if model.state_dict()[k].numel()==v.numel()}
else:
raise FileNotFoundError("not found weights file: {}".format(args.weights))
if args.freeze_layers:
for name,para in model.names_parameters():
if "fc" not in name:
para.requires_grad_(False)
pg = [p for p in model.parameters() if p.requires_grad]
optimizer=optim.SGD(pg,lr=args.lr,momentum=0.9,weight_decay=4E-5)
lf=lambda x:((1+math.cos(x*math.pi/args.epochs))/2)*(1-args.lrf)+args.lrf
scheduler=lr_scheduler.LambdaLR(optimizer,lr_lambda=lf)
for epoch in range(args.epochs):
mean_loss=train_one_epoch(model=model,
optimizer=optimizer,
data_loader=train_loader,
device=device,
epoch=epoch)
scheduler.step()
acc=evaluate(model=model,
data_loader=val_loader,
device=device)
tags=["loss","acc","learning_rate"]
tb_writer.add_scalar(tags[0],mean_loss,epoch)
tb_writer.add_scalar(tags[1],acc,epoch)
tb_writer.add_scalar(tags[2],optimizer.param_groups[0]["lr"],epoch)
torch.save(model.state_dict(),"./weights/model-{}.pth".format(epoch))
if __name__=='__main__':
parser=argparse.ArgumentParser()
parser.add_argument('--num_classes',type=int,default=5)
parser.add_argument('--epochs',type=int,default=30)
parser.add_argument('--batch_size',type=int,default=16)
parser.add_argument('--lr',type=float,default=0.01)
parser.add_argument('--lrf',type=float,default=0.1)
parser.add_argument('--data_path',type=str,default="/data/flower_photos")
parser.add_argument('--weights', type=str, default='./shufflenetv2_x1.pth',
help='initial weights path')
parser.add_argument('--freeze-layers', type=bool, default=False)
parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')
opt=parser.parse_args()
main(opt)
dataset:
from PIL import Image
import torch
from torch.utils.data import Dataset
class MyDataSet(Dataset):
def __init__(self,images_path:list,images_class:list,transform=None):
self.images_path=images_path
self.images_class=images_class
self.transform=transform
def __len__(self):
return len(self.images_path)
def __getitem__(self,item):
img=Image.open(self.images_path[item])
label=self.images_class[item]
if self.transform is not None:
img=self.transform(img)
return img,label
def collate_fn(batch):
images,labels=tuple(zip(*batch))
images=torch.stack(images,dim=0)
labels=torch.as_tensor(labels)
return images,labels
内部函数utils.py
import os
import sys
import json
import pickle
import random
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
def read_split_data(root:str,val_rate:float=0.2):
random.seed(0)
assert os.path.exists(root),"dataset root: {} does not exist.".format(root)
#遍历文件夹,一个文件夹代表一个类别
flower_class=[cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root,cla))]
flower_class.sort()
class_indices=dict((k,v)for v,k in enumerate(flower_class))
json_str=json.dumps(dict((val,key) for key,val in class_indices.item()),indent=4)
with open('class_indices.json','w') as json_file:
json_file.write(json_str)
train_images_path=[]
train_images_label=[]
val_images_path=[]
val_images_label=[]
every_class_num=[]
supported=[".jpg", ".JPG", ".png", ".PNG"]
for cla in flower_class:
cla_path=os.path.join(root,cla)
images=[os.path.join(root,cla,i) for i in os.listdir(cla_path)
if os.path.splitext(i)[-1] in supported]
image_class=class_indices[cla]
every_class_num.append((len(images)))
val_path=random.sample(images,k=int(len(images)*val_rate))
for img_path in images:
if img_path in val_path:
val_images_path.append(img_path)
val_images_label.append(image_class)
else:
train_images_path.append(img_path)
train_images_label.append(image_class)
plot_image=False
return train_images_path,train_images_label,val_images_path,val_images_label
def plot_data_loader_image(data_loader):
batch_size=data_loader.batch_size
plot_num=min(batch_size,4)
json_path='./class_indices.json'
json_file=open(json_path,'r')
class_indices=json.load(json_file)
for data in data_loader:
images,labels=data
for i in range(plot_num):
img=images[i].numpy().transpose(1,2,0)
img=(img*[0.229,0.224,0.225]+[0.485,0.456,0.406])*255
labels=labels[i].item()
def train_one_epoch(model,optimizer,data_loader,device,epoch):
model.train()
loss_fuction=torch.nn.CrossEntropyLoss()
mean_loss=torch.zero(1).to(device)
optimizer.zero_grad()
data_loader=tqdm(data_loader,file=sys.stdout)
for step,data in enumerate(data_loader):
images,labels=data
pred=model(images.to(device))
loss=loss_fuction(pred,labels.to(device))
loss.backward()
mean_loss=(mean_loss*step+loss.detach()/(step+1))
optimizer.step()
optimizer.zero_grad()
return mean_loss.item()
def evaluate(model,data_loader,device):
model.eval()
total_num=len(data_loader.dataset)
sum_num=torch.zeros(1).to(device)
data_loader=tqdm(data_loader,file=sys.stdout)
for step,data in enumerate(data_loader):
images,labels=data
pred=model(images.to(device))
pred=torch.max(pred,dim=1)[1]
sum_num+=torch.eq(pred,labels.to(device)).sum()
return sum_num.item()/total_num
predict
import os
import json
import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt
from model import shufflenet_v2_x1_0
def main():
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data_transform=transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
])
img_path = "../tulip.jpg"
assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
img = Image.open(img_path)
img=data_transform(img)
img=torch.unsqueeze(img,dim=0)
json_path='./class_indices.json'
with open(json_path,'r') as f:
class_indict=json.load(f)
model=shufflenet_v2_x1_0(num_classes=5).to(device)
model_weight_path = "./weights/model-29.pth"
model.load_state_dict(torch.load(model_weight_path,map_location=device))
model.eval()
with torch.no_grad():
output=torch.squeeze(model(img.to(device))).cpu()
predict=torch.softmax(output,dim=0)
predict_cla=torch.argmax(predict).numpy()
model
from typing import List, Callable
import torch
from torch import Tensor
import torch.nn as nn
def channel_shuffle(x: Tensor, groups: int) -> Tensor:
batch_size, num_channels, height, width = x.size()
channels_per_group = num_channels // groups
# reshape
# [batch_size, num_channels, height, width] -> [batch_size, groups, channels_per_group, height, width]
x = x.view(batch_size, groups, channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batch_size, -1, height, width)
return x
class InvertedResidual(nn.Module):
def __init__(self, input_c: int, output_c: int, stride: int):
super(InvertedResidual, self).__init__()
if stride not in [1, 2]:
raise ValueError("illegal stride value.")
self.stride = stride
assert output_c % 2 == 0
branch_features = output_c // 2
# 当stride为1时,input_channel应该是branch_features的两倍
# python中 '<<' 是位运算,可理解为计算×2的快速方法
assert (self.stride != 1) or (input_c == branch_features << 1)
if self.stride == 2:
self.branch1 = nn.Sequential(
self.depthwise_conv(input_c, input_c, kernel_s=3, stride=self.stride, padding=1),
nn.BatchNorm2d(input_c),
nn.Conv2d(input_c, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True)
)
else:
self.branch1 = nn.Sequential()
self.branch2 = nn.Sequential(
nn.Conv2d(input_c if self.stride > 1 else branch_features, branch_features, kernel_size=1,
stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
self.depthwise_conv(branch_features, branch_features, kernel_s=3, stride=self.stride, padding=1),
nn.BatchNorm2d(branch_features),
nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True)
)
@staticmethod
def depthwise_conv(input_c: int,
output_c: int,
kernel_s: int,
stride: int = 1,
padding: int = 0,
bias: bool = False) -> nn.Conv2d:
return nn.Conv2d(in_channels=input_c, out_channels=output_c, kernel_size=kernel_s,
stride=stride, padding=padding, bias=bias, groups=input_c)
def forward(self, x: Tensor) -> Tensor:
if self.stride == 1:
x1, x2 = x.chunk(2, dim=1)
out = torch.cat((x1, self.branch2(x2)), dim=1)
else:
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
out = channel_shuffle(out, 2)
return out
class ShuffleNetV2(nn.Module):
def __init__(self,
stages_repeats: List[int],
stages_out_channels: List[int],
num_classes: int = 1000,
inverted_residual: Callable[..., nn.Module] = InvertedResidual):
super(ShuffleNetV2, self).__init__()
if len(stages_repeats) != 3:
raise ValueError("expected stages_repeats as list of 3 positive ints")
if len(stages_out_channels) != 5:
raise ValueError("expected stages_out_channels as list of 5 positive ints")
self._stage_out_channels = stages_out_channels
# input RGB image
input_channels = 3
output_channels = self._stage_out_channels[0]
self.conv1 = nn.Sequential(
nn.Conv2d(input_channels, output_channels, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True)
)
input_channels = output_channels
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
# Static annotations for mypy
self.stage2: nn.Sequential
self.stage3: nn.Sequential
self.stage4: nn.Sequential
stage_names = ["stage{}".format(i) for i in [2, 3, 4]]
for name, repeats, output_channels in zip(stage_names, stages_repeats,
self._stage_out_channels[1:]):
seq = [inverted_residual(input_channels, output_channels, 2)]
for i in range(repeats - 1):
seq.append(inverted_residual(output_channels, output_channels, 1))
setattr(self, name, nn.Sequential(*seq))
input_channels = output_channels
output_channels = self._stage_out_channels[-1]
self.conv5 = nn.Sequential(
nn.Conv2d(input_channels, output_channels, kernel_size=1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(output_channels),
nn.ReLU(inplace=True)
)
self.fc = nn.Linear(output_channels, num_classes)
def _forward_impl(self, x: Tensor) -> Tensor:
# See note [TorchScript super()]
x = self.conv1(x)
x = self.maxpool(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
x = self.conv5(x)
x = x.mean([2, 3]) # global pool
x = self.fc(x)
return x
def forward(self, x: Tensor) -> Tensor:
return self._forward_impl(x)
def shufflenet_v2_x0_5(num_classes=1000):
"""
Constructs a ShuffleNetV2 with 0.5x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
<https://arxiv.org/abs/1807.11164>`.
weight: https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth
:param num_classes:
:return:
"""
model = ShuffleNetV2(stages_repeats=[4, 8, 4],
stages_out_channels=[24, 48, 96, 192, 1024],
num_classes=num_classes)
return model
def shufflenet_v2_x1_0(num_classes=1000):
"""
Constructs a ShuffleNetV2 with 1.0x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
<https://arxiv.org/abs/1807.11164>`.
weight: https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth
:param num_classes:
:return:
"""
model = ShuffleNetV2(stages_repeats=[4, 8, 4],
stages_out_channels=[24, 116, 232, 464, 1024],
num_classes=num_classes)
return model
def shufflenet_v2_x1_5(num_classes=1000):
"""
Constructs a ShuffleNetV2 with 1.0x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
<https://arxiv.org/abs/1807.11164>`.
weight: https://download.pytorch.org/models/shufflenetv2_x1_5-3c479a10.pth
:param num_classes:
:return:
"""
model = ShuffleNetV2(stages_repeats=[4, 8, 4],
stages_out_channels=[24, 176, 352, 704, 1024],
num_classes=num_classes)
return model
def shufflenet_v2_x2_0(num_classes=1000):
"""
Constructs a ShuffleNetV2 with 1.0x output channels, as described in
`"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
<https://arxiv.org/abs/1807.11164>`.
weight: https://download.pytorch.org/models/shufflenetv2_x2_0-8be3c8ee.pth
:param num_classes:
:return:
"""
model = ShuffleNetV2(stages_repeats=[4, 8, 4],
stages_out_channels=[24, 244, 488, 976, 2048],
num_classes=num_classes)
return model