中兴捧月RAW图像去噪训练代码

主要使用pytorch进行训练
将图片裁剪为四个
只需要修改训练数据路径,数据里面只保留dng文件

import torch
from  torch.utils.data import Dataset
import torch.utils.data as Data
import numpy as np
import os
import rawpy
from unetTorch import Unet
import torch.nn.functional as F
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

#加载数据集
class facadeData1(Dataset):
    def __init__(self,path='./data/train',data_file='noisy',label_file='ground'):
        self.path=path
        self.data_file=data_file
        self.label_file=label_file
        self.file_list=os.listdir(path+'/'+data_file)
    def __len__(self):
        return len(self.file_list)

    def read_image(self,input_path):
        raw = rawpy.imread(input_path)
        raw_data = raw.raw_image_visible
        height = raw_data.shape[0]
        width = raw_data.shape[1]

        raw_data_expand = np.expand_dims(raw_data, axis=2)
        raw_data_expand_c = np.concatenate((raw_data_expand[0:height:2, 0:width:2, :],
                                            raw_data_expand[0:height:2, 1:width:2, :],
                                            raw_data_expand[1:height:2, 0:width:2, :],
                                            raw_data_expand[1:height:2, 1:width:2, :]), axis=2)
        return raw_data_expand_c, height, width

    def normalization(self,input_data, black_level=1024, white_level=16383):
        output_data = (input_data.astype(float) - black_level) / (white_level - black_level)
        return output_data
    def __getitem__(self, index):
        image_name=self.file_list[index]
        label_name=image_name.replace('noise','gt')
        image_path=self.path+'/'+self.data_file+'/'+image_name
        label_path=self.path+'/'+self.label_file+'/'+label_name



        image,height,width=self.read_image(image_path)
        label,height2,width2=self.read_image(label_path)

        #标准化
        image_norm=self.normalization(image)
        label_norm=self.normalization(label)
        #扩展维度
        image=torch.from_numpy(np.transpose(image_norm.reshape(-1, height//2, width//2, 4), (0, 3, 1, 2))).float()
        image=image[:,:,0:868,0:1156]
        label=torch.from_numpy(np.transpose(label_norm.reshape(-1, height // 2, width//2, 4), (0, 3, 1, 2))).float()
        label=label[:, :, 0:868, 0:1156]
        return image,label
class facadeData2(Dataset):
    def __init__(self,path='./data/train',data_file='noisy',label_file='ground'):
        self.path=path
        self.data_file=data_file
        self.label_file=label_file
        self.file_list=os.listdir(path+'/'+data_file)
    def __len__(self):
        return len(self.file_list)

    def read_image(self,input_path):
        raw = rawpy.imread(input_path)
        raw_data = raw.raw_image_visible
        height = raw_data.shape[0]
        width = raw_data.shape[1]

        raw_data_expand = np.expand_dims(raw_data, axis=2)
        raw_data_expand_c = np.concatenate((raw_data_expand[0:height:2, 0:width:2, :],
                                            raw_data_expand[0:height:2, 1:width:2, :],
                                            raw_data_expand[1:height:2, 0:width:2, :],
                                            raw_data_expand[1:height:2, 1:width:2, :]), axis=2)
        return raw_data_expand_c, height, width

    def normalization(self,input_data, black_level=1024, white_level=16383):
        output_data = (input_data.astype(float) - black_level) / (white_level - black_level)
        return output_data
    def __getitem__(self, index):
        image_name=self.file_list[index]
        label_name=image_name.replace('noise','gt')
        image_path=self.path+'/'+self.data_file+'/'+image_name
        label_path=self.path+'/'+self.label_file+'/'+label_name



        image,height,width=self.read_image(image_path)
        label,height2,width2=self.read_image(label_path)

        #标准化
        image_norm=self.normalization(image)
        label_norm=self.normalization(label)
        #扩展维度
        image=torch.from_numpy(np.transpose(image_norm.reshape(-1, height//2, width//2, 4), (0, 3, 1, 2))).float()
        image=image[:,:,868:1736,0:1156]
        label=torch.from_numpy(np.transpose(label_norm.reshape(-1, height // 2, width//2, 4), (0, 3, 1, 2))).float()
        label=label[:, :, 868:1736, 0:1156]
        return image,label
class facadeData3(Dataset):
    def __init__(self,path='./data/train',data_file='noisy',label_file='ground'):
        self.path=path
        self.data_file=data_file
        self.label_file=label_file
        self.file_list=os.listdir(path+'/'+data_file)
    def __len__(self):
        return len(self.file_list)

    def read_image(self,input_path):
        raw = rawpy.imread(input_path)
        raw_data = raw.raw_image_visible
        height = raw_data.shape[0]
        width = raw_data.shape[1]

        raw_data_expand = np.expand_dims(raw_data, axis=2)
        raw_data_expand_c = np.concatenate((raw_data_expand[0:height:2, 0:width:2, :],
                                            raw_data_expand[0:height:2, 1:width:2, :],
                                            raw_data_expand[1:height:2, 0:width:2, :],
                                            raw_data_expand[1:height:2, 1:width:2, :]), axis=2)
        return raw_data_expand_c, height, width

    def normalization(self,input_data, black_level=1024, white_level=16383):
        output_data = (input_data.astype(float) - black_level) / (white_level - black_level)
        return output_data
    def __getitem__(self, index):
        image_name=self.file_list[index]
        label_name=image_name.replace('noise','gt')
        image_path=self.path+'/'+self.data_file+'/'+image_name
        label_path=self.path+'/'+self.label_file+'/'+label_name



        image,height,width=self.read_image(image_path)
        label,height2,width2=self.read_image(label_path)

        #标准化
        image_norm=self.normalization(image)
        label_norm=self.normalization(label)
        #扩展维度
        image=torch.from_numpy(np.transpose(image_norm.reshape(-1, height//2, width//2, 4), (0, 3, 1, 2))).float()
        image=image[:,:,0:868,1156:2312]
        label=torch.from_numpy(np.transpose(label_norm.reshape(-1, height // 2, width//2, 4), (0, 3, 1, 2))).float()
        label=label[:, :, 0:868, 1156:2312]
        return image,label
class facadeData4(Dataset):
    def __init__(self,path='./data/train',data_file='noisy',label_file='ground'):
        self.path=path
        self.data_file=data_file
        self.label_file=label_file
        self.file_list=os.listdir(path+'/'+data_file)
    def __len__(self):
        return len(self.file_list)

    def read_image(self,input_path):
        raw = rawpy.imread(input_path)
        raw_data = raw.raw_image_visible
        height = raw_data.shape[0]
        width = raw_data.shape[1]

        raw_data_expand = np.expand_dims(raw_data, axis=2)
        raw_data_expand_c = np.concatenate((raw_data_expand[0:height:2, 0:width:2, :],
                                            raw_data_expand[0:height:2, 1:width:2, :],
                                            raw_data_expand[1:height:2, 0:width:2, :],
                                            raw_data_expand[1:height:2, 1:width:2, :]), axis=2)
        return raw_data_expand_c, height, width

    def normalization(self,input_data, black_level=1024, white_level=16383):
        output_data = (input_data.astype(float) - black_level) / (white_level - black_level)
        return output_data
    def __getitem__(self, index):
        image_name=self.file_list[index]
        label_name=image_name.replace('noise','gt')
        image_path=self.path+'/'+self.data_file+'/'+image_name
        label_path=self.path+'/'+self.label_file+'/'+label_name



        image,height,width=self.read_image(image_path)
        label,height2,width2=self.read_image(label_path)

        #标准化
        image_norm=self.normalization(image)
        label_norm=self.normalization(label)
        #扩展维度
        image=torch.from_numpy(np.transpose(image_norm.reshape(-1, height//2, width//2, 4), (0, 3, 1, 2))).float()
        image=image[:,:,868:1736,1156:2312]
        label=torch.from_numpy(np.transpose(label_norm.reshape(-1, height // 2, width//2, 4), (0, 3, 1, 2))).float()
        label=label[:, :, 868:1736, 1156:2312]
        return image,label
class Trainer(object):
    def __init__(self,net,lr=1e-4,batch_size=10,num_epoch=1000,train_data=None):
        self.net=net
        self.net=self.net.to(device)
        self.batch_size=batch_size
        self.lr=lr
        self.num_epoch=num_epoch
        self.train_data=train_data
        self.data_loader=Data.DataLoader(dataset=train_data,batch_size=batch_size,shuffle=True)
        self.loss=torch.nn.MSELoss(reduce=True, size_average=True)
        self.loss = self.loss.to(device)
        
    def train(self):
        optim=torch.optim.Adam(self.net.parameters(),lr=self.lr)
        #self.net.load_state_dict(torch.load('./models2/model30.pth'))
        for epoch in range(self.num_epoch):
            for i,(bx,by) in enumerate(self.data_loader):
                bx=bx.squeeze(dim = 1)
                bx =bx.to(device)
                by=by.squeeze(dim = 1)
                by = by.to(device)
                pre=self.net(bx)
                loss =self.loss(pre, by)  # 求损失  # patch 123#bug
                optim.zero_grad()
                loss.backward()
                optim.step()
                print('i',i,'epoch',epoch,'loss',loss)
            if epoch%10==0:
                torch.save(self.net.state_dict(), f'./models3/model{epoch}.pth')

        return None

if __name__ == '__main__':
    train_data1=facadeData1()
    train_data2=facadeData2()
    train_data3=facadeData3()
    train_data4=facadeData4()
    train_data=torch.utils.data.ConcatDataset([train_data1]+[train_data2]+[train_data3]+[train_data3])

    # # print(train_data)
    # image,label=data.__getitem__(100)
    # print(train_data)
    net=Unet()
    trainer=Trainer(net=net,train_data=train_data)
    trainer.train()
评论 11
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值