pytorch入门编程,test7_shufflenet

train.py

import os
import math
import argparse

import torch
import torch.optim as optim
from torch.utils.tensorboard import SummaryWriter
from torchvision import transforms
import torch.optim.lr_scheduler as lr_scheduler

from model import shufflenet_v2_x1_0
from my_dataset import MyDataSet
from utils import read_split_data, train_one_epoch, evaluate


def main(args):
    device=torch.device(args.device if torch.cuda.is_available() else "cpu")
    tb_writer=SummaryWriter()
    if os.path.exists("./weights") is False:
        os.makedirs("./weights")
    train_images_path,train_images_label,val_images_path,val_images_label=read_split_data(args.data_path)
    data_transform={
        "train":transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])]),
        "val":transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
        ])
    }
    train_dataset=MyDataSet(images_path=train_images_path,
                            images_class=train_images_label,
                            transforms=data_transform["train"])
    val_dataset=MyDataSet(images_path=val_images_path,
                          images_class=val_images_label,
                          transforms=data_transform["val"])

    batch_size=args.batch_size
    nw = min([os.cpu_count(), batch_size if batch_size > 1 else 0, 8])
    train_loader=torch.utils.data.Dataloader(train_dataset,
                                             batch_size=batch_size,
                                             shuffle=True,
                                             pin_memory=True,
                                             num_workers=nw,
                                             collate_fn=train_dataset.collate_fn
                                             )

    val_loader=torch.utils.data.Dataloader(val_dataset,
                                           batch_size=batch_size,
                                           shuffle=False,
                                           pin_memory=True,
                                           num_workers=nw,
                                           collate_fn=train_dataset.collate_fn
                                           )

    model=shufflenet_v2_x1_0(num_classes=args.num_classes).to(device)
    if args.weights!="":
        if os.path.exists(args.weights):
            weights_dict=torch.load(args.weights,map_location=device)
            load_weights_dict={k:v for k,v in weights_dict.item()
                               if model.state_dict()[k].numel()==v.numel()}
        else:
            raise FileNotFoundError("not found weights file: {}".format(args.weights))
    if args.freeze_layers:
        for name,para in model.names_parameters():
            if "fc" not in name:
                para.requires_grad_(False)
    pg = [p for p in model.parameters() if p.requires_grad]
    optimizer=optim.SGD(pg,lr=args.lr,momentum=0.9,weight_decay=4E-5)
    lf=lambda x:((1+math.cos(x*math.pi/args.epochs))/2)*(1-args.lrf)+args.lrf
    scheduler=lr_scheduler.LambdaLR(optimizer,lr_lambda=lf)

    for epoch in range(args.epochs):
        mean_loss=train_one_epoch(model=model,
                                  optimizer=optimizer,
                                  data_loader=train_loader,
                                  device=device,
                                  epoch=epoch)
        scheduler.step()

        acc=evaluate(model=model,
                     data_loader=val_loader,
                     device=device)
        tags=["loss","acc","learning_rate"]
        tb_writer.add_scalar(tags[0],mean_loss,epoch)
        tb_writer.add_scalar(tags[1],acc,epoch)
        tb_writer.add_scalar(tags[2],optimizer.param_groups[0]["lr"],epoch)
        torch.save(model.state_dict(),"./weights/model-{}.pth".format(epoch))
if __name__=='__main__':
    parser=argparse.ArgumentParser()
    parser.add_argument('--num_classes',type=int,default=5)
    parser.add_argument('--epochs',type=int,default=30)
    parser.add_argument('--batch_size',type=int,default=16)
    parser.add_argument('--lr',type=float,default=0.01)
    parser.add_argument('--lrf',type=float,default=0.1)
    parser.add_argument('--data_path',type=str,default="/data/flower_photos")
    parser.add_argument('--weights', type=str, default='./shufflenetv2_x1.pth',
                        help='initial weights path')
    parser.add_argument('--freeze-layers', type=bool, default=False)
    parser.add_argument('--device', default='cuda:0', help='device id (i.e. 0 or 0,1 or cpu)')
    opt=parser.parse_args()
    main(opt)

dataset:

from PIL import Image
import torch
from torch.utils.data import Dataset

class MyDataSet(Dataset):
    def __init__(self,images_path:list,images_class:list,transform=None):
        self.images_path=images_path
        self.images_class=images_class
        self.transform=transform
    def __len__(self):
        return len(self.images_path)
    def __getitem__(self,item):
        img=Image.open(self.images_path[item])
        label=self.images_class[item]
        if self.transform is not None:
            img=self.transform(img)
        return img,label
    def collate_fn(batch):
        images,labels=tuple(zip(*batch))
        images=torch.stack(images,dim=0)
        labels=torch.as_tensor(labels)
        return images,labels

内部函数utils.py

import os
import sys
import json
import pickle
import random

import torch
from tqdm import tqdm

import matplotlib.pyplot as plt

def read_split_data(root:str,val_rate:float=0.2):
    random.seed(0)
    assert  os.path.exists(root),"dataset root: {} does not exist.".format(root)
    #遍历文件夹,一个文件夹代表一个类别
    flower_class=[cla for cla in os.listdir(root) if os.path.isdir(os.path.join(root,cla))]
    flower_class.sort()
    class_indices=dict((k,v)for v,k in enumerate(flower_class))
    json_str=json.dumps(dict((val,key) for key,val in class_indices.item()),indent=4)
    with open('class_indices.json','w') as json_file:
        json_file.write(json_str)

    train_images_path=[]
    train_images_label=[]
    val_images_path=[]
    val_images_label=[]
    every_class_num=[]
    supported=[".jpg", ".JPG", ".png", ".PNG"]
    for cla in flower_class:
        cla_path=os.path.join(root,cla)
        images=[os.path.join(root,cla,i) for i in os.listdir(cla_path)
                if os.path.splitext(i)[-1] in supported]
        image_class=class_indices[cla]
        every_class_num.append((len(images)))
        val_path=random.sample(images,k=int(len(images)*val_rate))

        for img_path in images:
            if img_path in val_path:
                val_images_path.append(img_path)
                val_images_label.append(image_class)
            else:
                train_images_path.append(img_path)
                train_images_label.append(image_class)

    plot_image=False
    return train_images_path,train_images_label,val_images_path,val_images_label

def plot_data_loader_image(data_loader):
    batch_size=data_loader.batch_size
    plot_num=min(batch_size,4)
    json_path='./class_indices.json'

    json_file=open(json_path,'r')
    class_indices=json.load(json_file)
    for data in data_loader:
        images,labels=data
        for i in range(plot_num):
            img=images[i].numpy().transpose(1,2,0)
            img=(img*[0.229,0.224,0.225]+[0.485,0.456,0.406])*255
            labels=labels[i].item()
def train_one_epoch(model,optimizer,data_loader,device,epoch):
    model.train()
    loss_fuction=torch.nn.CrossEntropyLoss()
    mean_loss=torch.zero(1).to(device)
    optimizer.zero_grad()
    data_loader=tqdm(data_loader,file=sys.stdout)
    for step,data in enumerate(data_loader):
        images,labels=data
        pred=model(images.to(device))
        loss=loss_fuction(pred,labels.to(device))
        loss.backward()
        mean_loss=(mean_loss*step+loss.detach()/(step+1))

        optimizer.step()
        optimizer.zero_grad()
    return mean_loss.item()

def evaluate(model,data_loader,device):
    model.eval()
    total_num=len(data_loader.dataset)
    sum_num=torch.zeros(1).to(device)
    data_loader=tqdm(data_loader,file=sys.stdout)
    for step,data in enumerate(data_loader):
        images,labels=data
        pred=model(images.to(device))
        pred=torch.max(pred,dim=1)[1]
        sum_num+=torch.eq(pred,labels.to(device)).sum()
    return sum_num.item()/total_num

predict

import os
import json

import torch
from PIL import Image
from torchvision import transforms
import matplotlib.pyplot as plt

from model import shufflenet_v2_x1_0

def main():
    device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    data_transform=transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
    ])
    img_path = "../tulip.jpg"
    assert os.path.exists(img_path), "file: '{}' dose not exist.".format(img_path)
    img = Image.open(img_path)
    img=data_transform(img)
    img=torch.unsqueeze(img,dim=0)
    json_path='./class_indices.json'
    with open(json_path,'r') as f:
        class_indict=json.load(f)
    model=shufflenet_v2_x1_0(num_classes=5).to(device)
    model_weight_path = "./weights/model-29.pth"
    model.load_state_dict(torch.load(model_weight_path,map_location=device))
    model.eval()
    with torch.no_grad():
        output=torch.squeeze(model(img.to(device))).cpu()
        predict=torch.softmax(output,dim=0)
        predict_cla=torch.argmax(predict).numpy()
        

model

from typing import List, Callable

import torch
from torch import Tensor
import torch.nn as nn


def channel_shuffle(x: Tensor, groups: int) -> Tensor:

    batch_size, num_channels, height, width = x.size()
    channels_per_group = num_channels // groups

    # reshape
    # [batch_size, num_channels, height, width] -> [batch_size, groups, channels_per_group, height, width]
    x = x.view(batch_size, groups, channels_per_group, height, width)

    x = torch.transpose(x, 1, 2).contiguous()

    # flatten
    x = x.view(batch_size, -1, height, width)

    return x


class InvertedResidual(nn.Module):
    def __init__(self, input_c: int, output_c: int, stride: int):
        super(InvertedResidual, self).__init__()

        if stride not in [1, 2]:
            raise ValueError("illegal stride value.")
        self.stride = stride

        assert output_c % 2 == 0
        branch_features = output_c // 2
        # 当stride为1时,input_channel应该是branch_features的两倍
        # python中 '<<' 是位运算,可理解为计算×2的快速方法
        assert (self.stride != 1) or (input_c == branch_features << 1)

        if self.stride == 2:
            self.branch1 = nn.Sequential(
                self.depthwise_conv(input_c, input_c, kernel_s=3, stride=self.stride, padding=1),
                nn.BatchNorm2d(input_c),
                nn.Conv2d(input_c, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
                nn.BatchNorm2d(branch_features),
                nn.ReLU(inplace=True)
            )
        else:
            self.branch1 = nn.Sequential()

        self.branch2 = nn.Sequential(
            nn.Conv2d(input_c if self.stride > 1 else branch_features, branch_features, kernel_size=1,
                      stride=1, padding=0, bias=False),
            nn.BatchNorm2d(branch_features),
            nn.ReLU(inplace=True),
            self.depthwise_conv(branch_features, branch_features, kernel_s=3, stride=self.stride, padding=1),
            nn.BatchNorm2d(branch_features),
            nn.Conv2d(branch_features, branch_features, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(branch_features),
            nn.ReLU(inplace=True)
        )

    @staticmethod
    def depthwise_conv(input_c: int,
                       output_c: int,
                       kernel_s: int,
                       stride: int = 1,
                       padding: int = 0,
                       bias: bool = False) -> nn.Conv2d:
        return nn.Conv2d(in_channels=input_c, out_channels=output_c, kernel_size=kernel_s,
                         stride=stride, padding=padding, bias=bias, groups=input_c)

    def forward(self, x: Tensor) -> Tensor:
        if self.stride == 1:
            x1, x2 = x.chunk(2, dim=1)
            out = torch.cat((x1, self.branch2(x2)), dim=1)
        else:
            out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)

        out = channel_shuffle(out, 2)

        return out


class ShuffleNetV2(nn.Module):
    def __init__(self,
                 stages_repeats: List[int],
                 stages_out_channels: List[int],
                 num_classes: int = 1000,
                 inverted_residual: Callable[..., nn.Module] = InvertedResidual):
        super(ShuffleNetV2, self).__init__()

        if len(stages_repeats) != 3:
            raise ValueError("expected stages_repeats as list of 3 positive ints")
        if len(stages_out_channels) != 5:
            raise ValueError("expected stages_out_channels as list of 5 positive ints")
        self._stage_out_channels = stages_out_channels

        # input RGB image
        input_channels = 3
        output_channels = self._stage_out_channels[0]

        self.conv1 = nn.Sequential(
            nn.Conv2d(input_channels, output_channels, kernel_size=3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(output_channels),
            nn.ReLU(inplace=True)
        )
        input_channels = output_channels

        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        # Static annotations for mypy
        self.stage2: nn.Sequential
        self.stage3: nn.Sequential
        self.stage4: nn.Sequential

        stage_names = ["stage{}".format(i) for i in [2, 3, 4]]
        for name, repeats, output_channels in zip(stage_names, stages_repeats,
                                                  self._stage_out_channels[1:]):
            seq = [inverted_residual(input_channels, output_channels, 2)]
            for i in range(repeats - 1):
                seq.append(inverted_residual(output_channels, output_channels, 1))
            setattr(self, name, nn.Sequential(*seq))
            input_channels = output_channels

        output_channels = self._stage_out_channels[-1]
        self.conv5 = nn.Sequential(
            nn.Conv2d(input_channels, output_channels, kernel_size=1, stride=1, padding=0, bias=False),
            nn.BatchNorm2d(output_channels),
            nn.ReLU(inplace=True)
        )

        self.fc = nn.Linear(output_channels, num_classes)

    def _forward_impl(self, x: Tensor) -> Tensor:
        # See note [TorchScript super()]
        x = self.conv1(x)
        x = self.maxpool(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.conv5(x)
        x = x.mean([2, 3])  # global pool
        x = self.fc(x)
        return x

    def forward(self, x: Tensor) -> Tensor:
        return self._forward_impl(x)


def shufflenet_v2_x0_5(num_classes=1000):
    """
    Constructs a ShuffleNetV2 with 0.5x output channels, as described in
    `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
    <https://arxiv.org/abs/1807.11164>`.
    weight: https://download.pytorch.org/models/shufflenetv2_x0.5-f707e7126e.pth

    :param num_classes:
    :return:
    """
    model = ShuffleNetV2(stages_repeats=[4, 8, 4],
                         stages_out_channels=[24, 48, 96, 192, 1024],
                         num_classes=num_classes)

    return model


def shufflenet_v2_x1_0(num_classes=1000):
    """
    Constructs a ShuffleNetV2 with 1.0x output channels, as described in
    `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
    <https://arxiv.org/abs/1807.11164>`.
    weight: https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth

    :param num_classes:
    :return:
    """
    model = ShuffleNetV2(stages_repeats=[4, 8, 4],
                         stages_out_channels=[24, 116, 232, 464, 1024],
                         num_classes=num_classes)

    return model


def shufflenet_v2_x1_5(num_classes=1000):
    """
    Constructs a ShuffleNetV2 with 1.0x output channels, as described in
    `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
    <https://arxiv.org/abs/1807.11164>`.
    weight: https://download.pytorch.org/models/shufflenetv2_x1_5-3c479a10.pth

    :param num_classes:
    :return:
    """
    model = ShuffleNetV2(stages_repeats=[4, 8, 4],
                         stages_out_channels=[24, 176, 352, 704, 1024],
                         num_classes=num_classes)

    return model


def shufflenet_v2_x2_0(num_classes=1000):
    """
    Constructs a ShuffleNetV2 with 1.0x output channels, as described in
    `"ShuffleNet V2: Practical Guidelines for Efficient CNN Architecture Design"
    <https://arxiv.org/abs/1807.11164>`.
    weight: https://download.pytorch.org/models/shufflenetv2_x2_0-8be3c8ee.pth

    :param num_classes:
    :return:
    """
    model = ShuffleNetV2(stages_repeats=[4, 8, 4],
                         stages_out_channels=[24, 244, 488, 976, 2048],
                         num_classes=num_classes)

    return model

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值