file-type

UNet StyleGAN2的Pytorch实现:增强图像质量的新方法

4星 · 超过85%的资源 | 下载需积分: 50 | 543KB | 更新于2024-12-16 | 93 浏览量 | 7 评论 | 7 下载量 举报 1 收藏
download 立即下载
项目的主要目的是在生成对抗网络(GAN)领域中提高生成图像的质量。在该项目中,UNet结构的鉴别器被用来替代标准的StyleGAN2鉴别器,以期望获得更好的性能。 UNet是一种广泛应用于医学图像分割的卷积神经网络架构,它具有一个对称的U形结构,有利于图像特征的快速传播和细节的捕捉。而StyleGAN2则是基于生成对抗网络(GAN)的一个变体,它采用了一种新颖的风格化生成方式,通过将潜在空间映射到图像空间,产生了高质量和高分辨率的图像。 项目提供了一个简化的安装和使用流程,用户只需要通过pip安装unet-stylegan2包,然后使用unet_stylegan2命令行工具即可进行数据的训练和生成工作。具体的命令格式是$ unet_stylegan2 --data ./path/to/data,其中--data参数后跟的是用户希望用于训练的数据路径。 该项目的更新记录显示,其生成结果表现优秀,作者计划进一步研究将unet-stylegan2与其他技术结合的可能性,并编写完整的使用说明。 由于该项目的标签中提到了deep-learning、artificial-intelligence、generative-adversarial-network和u-net等标签,可以推断该项目是在深度学习和人工智能的背景下,特别针对生成对抗网络(GAN)领域中的图像生成问题。u-net标签表明该项目与医学图像处理技术中的UNet模型有着直接的关联。 从技术实现角度来看,该项目实现了一个深度学习模型的训练和推理过程,涉及到多个深度学习和Pytorch框架的关键技术点,如自定义层、损失函数、优化器等。此外,该实现还可能涉及到对StyleGAN2原有架构的修改和UNet鉴别器的集成,这些都要求开发者对深度学习框架和算法有较为深入的理解和应用经验。 综上所述,unet-stylegan2是一个结合了UNet结构优势与StyleGAN2生成能力的深度学习项目,它为生成高质量图像提供了一个新的思路和工具。该框架的易用性和潜在的高性能,使其成为研究生成对抗网络和图像生成问题的有力工具。"

相关推荐

filetype

#---Part1--- SSAnetwork.py # -*- coding: utf-8 -*- """ Created on Sat Aug 27 13:43:52 2022 @author: User """ import torch.nn as nn def weights_init(mod): """ Custom weights initialization called on netG, netD and netE :param m: :return: """ classname = mod.__class__.__name__ if classname.find('Conv') != -1: mod.weight.data.normal_(0.0, 0.02) elif classname.find('BatchNorm') != -1: mod.weight.data.normal_(1.0, 0.02) mod.bias.data.fill_(0) # ==================================================Alister add 2022-10-14====================== import torch.nn as nn import torch.nn.functional as F import torch from network.attension import ChannelAttention, SpatialAttention from network.SelfAttension import Self_Attn import torchvision.models as models def weights_init_normal(m): classname = m.__class__.__name__ if classname.find("Conv") != -1: torch.nn.init.normal_(m.weight.data, 0.0, 0.02) elif classname.find("BatchNorm2d") != -1: torch.nn.init.normal_(m.weight.data, 1.0, 0.02) torch.nn.init.constant_(m.bias.data, 0.0) ############################## # U-NET ############################## # torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros', device=None, dtype=None) ''' class UNetDown(nn.Module): def __init__(self, in_size, out_size, normalize=True, dropout=0.0): super(UNetDown, self).__init__() self.leakyrelu = nn.LeakyReLU(0.2) self.conv = nn.Conv2d(in_size, out_size, 4, 2, 1) self.ca = ChannelAttention(out_size) self.sa = SpatialAttention() self.norm = nn.InstanceNorm2d(out_size) self.normalize = normalize self.dropout = dropout self.drop = nn.Dropout() #layers.append(nn.Conv2d(in_size, out_size, 4, 2, 1)) #if normalize: #layers.append(nn.InstanceNorm2d(out_size)) #if dropout: #layers.append(nn.Dropout(dropout)) #self.model = nn.Sequential(*layers) def forward(self, x): x = self.leakyrelu(x) xc = self.conv(x) xa = self.ca(xc)*xc xa = self.sa(xa)*xa x = xc + xa if self.normalize: x = self.norm(x) if self.dropout: x = self.drop(x) return x ''' class UNetDown(nn.Module): def __init__(self, in_size, out_size, normalize=True, dropout=0.0): super(UNetDown, self).__init__() layers = [nn.LeakyReLU(0.2)] layers.append(nn.Conv2d(in_size, out_size, 4, 2, 1)) if normalize: layers.append(nn.InstanceNorm2d(out_size)) if dropout: layers.append(nn.Dropout(dropout)) self.model = nn.Sequential(*layers) def forward(self, x): return self.model(x) class UNetUp(nn.Module): def __init__(self, in_size, out_size): super(UNetUp, self).__init__() self.up = nn.Sequential( nn.ConvTranspose2d(in_size, out_size, 4, 2, 1, bias=False), nn.BatchNorm2d(out_size), nn.ReLU(True) ) # 修改卷积层的输入通道数,考虑拼接后的通道数 self.conv = nn.Sequential( nn.Conv2d(out_size + out_size, out_size, 3, 1, 1, bias=False), # 输入是out_size*2 nn.BatchNorm2d(out_size), nn.ReLU(True) ) def forward(self, x, skip_input): x = self.up(x) # 确保skip_input的尺寸与x匹配 if skip_input.size() != x.size(): skip_input = F.interpolate(skip_input, size=x.size()[2:], mode='bilinear', align_corners=True) x = torch.cat((x, skip_input), 1) x = self.conv(x) return x class ResNet50FeatureExtractor(nn.Module): def __init__(self, pretrained=True): super(ResNet50FeatureExtractor, self).__init__() resnet = models.resnet50(pretrained=pretrained) self.conv1 = nn.Sequential( resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool ) self.layer1 = resnet.layer1 # 256 channels self.layer2 = resnet.layer2 # 512 channels self.layer3 = resnet.layer3 # 1024 channels self.layer4 = resnet.layer4 # 2048 channels def forward(self, x): x1 = self.conv1(x) # 1/4 x2 = self.layer1(x1) # 1/4 x3 = self.layer2(x2) # 1/8 x4 = self.layer3(x3) # 1/16 x5 = self.layer4(x4) # 1/32 return [x1, x2, x3, x4, x5] class SSAEncoder(nn.Module): """ UNET ENCODER NETWORK """ def __init__(self, isize=32, nz=100, nc=3, ndf=64, ngpu=1, n_extra_layers=0, add_final_conv=True): super(SSAEncoder, self).__init__() self.isize = isize # 使用ResNet50作为特征提取器 self.resnet = ResNet50FeatureExtractor(pretrained=True) # 特征转换层 self.conv1 = nn.Conv2d(64, 64, 1) # 转换conv1输出 self.conv2 = nn.Conv2d(256, 128, 1) # 转换layer1输出 self.conv3 = nn.Conv2d(512, 256, 1) # 转换layer2输出 self.conv4 = nn.Conv2d(1024, 512, 1) # 转换layer3输出 self.conv5 = nn.Conv2d(2048, 512, 1) # 转换layer4输出 if add_final_conv: self.final_conv = nn.Conv2d(512, nz, 4, 1, 0, bias=False) # 注意力模块 self.ca1 = ChannelAttention(64) self.sa1 = SpatialAttention() self.ca2 = ChannelAttention(128) self.sa2 = SpatialAttention() self.ca3 = ChannelAttention(256) self.sa3 = SpatialAttention() self.ca4 = ChannelAttention(512) self.sa4 = SpatialAttention() self.ca5 = ChannelAttention(512) self.sa5 = SpatialAttention() def forward(self, x): # 获取ResNet50的多尺度特征 features = self.resnet(x) # 转换特征维度并应用注意力机制 d1 = self.conv1(features[0]) d1 = self.ca1(d1) * d1 d1 = self.sa1(d1) * d1 d2 = self.conv2(features[1]) d2 = self.ca2(d2) * d2 d2 = self.sa2(d2) * d2 d3 = self.conv3(features[2]) d3 = self.ca3(d3) * d3 d3 = self.sa3(d3) * d3 d4 = self.conv4(features[3]) d4 = self.ca4(d4) * d4 d4 = self.sa4(d4) * d4 d5 = self.conv5(features[4]) d5 = self.ca5(d5) * d5 d5 = self.sa5(d5) * d5 if hasattr(self, 'final_conv'): output = self.final_conv(d5) else: output = d5 return output class SSADecoder(nn.Module): """ UNET DECODER NETWORK """ def __init__(self, isize, nc, nz): super(SSADecoder, self).__init__() # 初始转置卷积层 self.init_conv = nn.Sequential( nn.ConvTranspose2d(nz, 512, 4, 1, 0, bias=False), nn.BatchNorm2d(512), nn.ReLU(True) ) # 上采样层 - 调整通道数以匹配特征图大小 self.up1 = UNetUp(512, 512) # 4x4 -> 8x8 self.up2 = UNetUp(512, 256) # 8x8 -> 16x16 self.up3 = UNetUp(256, 128) # 16x16 -> 32x32 self.up4 = UNetUp(128, 64) # 32x32 -> 64x64 # 特征转换层 - 确保输出通道数与上采样层的输入通道数匹配 self.feat_conv1 = nn.Conv2d(2048, 512, 1) # ResNet50 layer4 -> 512 self.feat_conv2 = nn.Conv2d(1024, 256, 1) # ResNet50 layer3 -> 256 self.feat_conv3 = nn.Conv2d(512, 128, 1) # ResNet50 layer2 -> 128 self.feat_conv4 = nn.Conv2d(256, 64, 1) # ResNet50 layer1 -> 64 # 注意力模块 self.attn1 = Self_Attn(512, 1, 'relu') # 512通道 self.attn2 = Self_Attn(256, 1, 'relu') # 256通道 self.attn3 = Self_Attn(128, 1, 'relu') # 128通道 self.attn4 = Self_Attn(64, 1, 'relu') # 64通道 # 最终输出层 self.final = nn.Sequential( nn.Upsample(scale_factor=2), # 64x64 -> 128x128 nn.Conv2d(64, nc, 3, 1, 1), nn.Tanh() ) # 注意力输出层 self.attn_out = nn.Sequential( nn.Upsample(scale_factor=2), # 64x64 -> 128x128 nn.Conv2d(64, 1, 1), nn.Sigmoid() ) def forward(self, x, features): # 初始特征 x = self.init_conv(x) # 4x4 # 特征转换 feat1 = self.feat_conv1(features[4]) # layer4 -> 512 feat2 = self.feat_conv2(features[3]) # layer3 -> 256 feat3 = self.feat_conv3(features[2]) # layer2 -> 128 feat4 = self.feat_conv4(features[1]) # layer1 -> 64 # 上采样路径 x = self.up1(x, feat1) # 4x4 -> 8x8 x, _ = self.attn1(x) x = self.up2(x, feat2) # 8x8 -> 16x16 x, _ = self.attn2(x) x = self.up3(x, feat3) # 16x16 -> 32x32 x, _ = self.attn3(x) x = self.up4(x, feat4) # 32x32 -> 64x64 x, _ = self.attn4(x) # 生成图像和注意力图 gen_imag = self.final(x) # 64x64 -> 128x128 gen_attn = self.attn_out(x) # 64x64 -> 128x128 return gen_imag, gen_attn class SSANetG(nn.Module): """ GENERATOR NETWORK """ def __init__(self, isize=32, nc=3, nz=100, ngf=64, ndf=64, ngpu=1, extralayers=0): super(SSANetG, self).__init__() self.isize = isize self.nc = nc self.nz = nz self.ngf = ngf self.ndf = ndf self.ngpu = ngpu self.extralayers = extralayers self.encoder1 = SSAEncoder(self.isize, self.nz, self.nc, self.ngf, self.ngpu, self.extralayers) self.decoder = SSADecoder(self.isize, self.nc, self.nz) self.encoder2 = SSAEncoder(self.isize, self.nz, self.nc, self.ngf, self.ngpu, self.extralayers) def forward(self, x): # 第一次编码 features = self.encoder1.resnet(x) latent_i = self.encoder1(x) # 解码 gen_imag, gen_attn = self.decoder(latent_i, features) # 第二次编码 features2 = self.encoder2.resnet(gen_imag) latent_o = self.encoder2(gen_imag) return gen_imag, latent_i, latent_o, gen_attn class SSANetD(nn.Module): """ DISCRIMINATOR NETWORK """ def __init__(self, isize=32, nc=3, attn_c=3): super(SSANetD, self).__init__() self.resnet = ResNet50FeatureExtractor(pretrained=True) self.fusion_conv = nn.Conv2d(nc + 3, 3, 1) # 始终拼接3通道attention self.attn_adjust = nn.Conv2d(1, 3, 1) # 用于1通道attention转3通道 self.conv1 = nn.Conv2d(64, 64, 1) self.conv2 = nn.Conv2d(256, 128, 1) self.conv3 = nn.Conv2d(512, 256, 1) self.conv4 = nn.Conv2d(1024, 512, 1) self.conv5 = nn.Conv2d(2048, 512, 1) self.ca1 = ChannelAttention(64) self.sa1 = SpatialAttention() self.ca2 = ChannelAttention(128) self.sa2 = SpatialAttention() self.ca3 = ChannelAttention(256) self.sa3 = SpatialAttention() self.ca4 = ChannelAttention(512) self.sa4 = SpatialAttention() self.ca5 = ChannelAttention(512) self.sa5 = SpatialAttention() self.classifier = nn.Sequential( nn.AdaptiveAvgPool2d(1), nn.Conv2d(512, 1, 1), nn.Sigmoid() ) def forward(self, x, attention): # 自动适配attention通道数 if attention.shape[1] == 1: attention = self.attn_adjust(attention) elif attention.shape[1] != 3: raise ValueError(f"attention通道数为{attention.shape[1]},仅支持1或3通道") fusion = torch.cat((x, attention), 1) if fusion.shape[1] != self.fusion_conv.in_channels: raise ValueError(f"fusion_conv输入通道数为{self.fusion_conv.in_channels},但实际输入为{fusion.shape[1]},请检查attention通道数!") fusion = self.fusion_conv(fusion) features = self.resnet(fusion) d1 = self.conv1(features[0]) d1 = self.ca1(d1) * d1 d1 = self.sa1(d1) * d1 d2 = self.conv2(features[1]) d2 = self.ca2(d2) * d2 d2 = self.sa2(d2) * d2 d3 = self.conv3(features[2]) d3 = self.ca3(d3) * d3 d3 = self.sa3(d3) * d3 d4 = self.conv4(features[3]) d4 = self.ca4(d4) * d4 d4 = self.sa4(d4) * d4 d5 = self.conv5(features[4]) d5 = self.ca5(d5) * d5 d5 = self.sa5(d5) * d5 features = d5 classifier = self.classifier(features) classifier = classifier.view(-1, 1).squeeze(1) return classifier, features

filetype

接着解读第二段:def train_JSCC_with_DnCNN(config, CHDDIM_config): from DnCNN.models import DnCNN import torch.nn as nn encoder = network.JSCC_encoder(config, config.C).cuda() decoder = network.JSCC_decoder(config, config.C).cuda() encoder_path = config.encoder_path decoder_path = config.decoder_path pass_channel = channel.Channel(config) trainLoader, _ = get_loader(config) encoder.eval() DnCNN=DnCNN(config.C).cuda() # encoder = torch.nn.DataParallel(encoder, device_ids=CHDDIM_config.device_ids) # decoder = torch.nn.DataParallel(decoder, device_ids=CHDDIM_config.device_ids) # CHDDIM = torch.nn.DataParallel(CHDDIM, device_ids=CHDDIM_config.device_ids) # # encoder = encoder.cuda(device=CHDDIM_config.device_ids[0]) # decoder = decoder.cuda(device=CHDDIM_config.device_ids[0]) # CHDDIM = CHDDIM.cuda(device=CHDDIM_config.device_ids[0]) encoder.load_state_dict(torch.load(encoder_path)) decoder.load_state_dict(torch.load(decoder_path)) ckpt = torch.load(CHDDIM_config.save_path) DnCNN.load_state_dict(ckpt) DnCNN.eval() # optimizer_encoder = torch.optim.AdamW( # encoder.parameters(), lr=CHDDIM_config.lr, weight_decay=1e-4) optimizer_decoder = torch.optim.Adam( decoder.parameters(), lr=CHDDIM_config.lr) # start training if config.dataset == "CIFAR10": CalcuSSIM = MS_SSIM(window_size=3, data_range=1., levels=4, channel=3).cuda() else: CalcuSSIM = MS_SSIM(data_range=1., levels=4, channel=3).cuda() for e in range(config.retrain_epoch): with tqdm(trainLoader, dynamic_ncols=True) as tqdmtrainLoader: for i, (images, labels) in enumerate(tqdmtrainLoader): # train snr = config.SNRs - CHDDIM_config.large_snr x_0 = images.cuda() feature, _ = encoder(x_0) y = feature y, pwr, h = pass_channel.forward(y, snr) # normalize sigma_square = 1.0 / (2 * 10 ** (config.SNRs / 10)) if config.channel_type == "awgn": y_awgn = torch.cat((torch.real(y), torch.imag(y)), dim=2) #mse1 = torch.nn.MSELoss()(y_awgn * math.sqrt(2), y * math.sqrt(2) / torch.sqrt(pwr)) receive=y_awgn elif config.channel_type == 'rayleigh': y_mmse = y * torch.conj(h) / (torch.abs(h) ** 2 + sigma_square * 2) y_mmse = torch.cat((torch.real(y_mmse), torch.imag(y_mmse)), dim=2) #mse1 = torch.nn.MSELoss()(y_mmse * math.sqrt(2), y * math.sqrt(2) / torch.sqrt(pwr)) receive=y_mmse else: raise ValueError feature_hat = receive-DnCNN(receive) feature_hat = feature_hat * torch.sqrt(pwr) x_0_hat = decoder(feature_hat) # mse1=torch.nn.MSEloss()() if config.loss_function == "MSE": loss = torch.nn.MSELoss()(x_0, x_0_hat) elif config.loss_function == "MSSSIM": loss = CalcuSSIM(x_0, x_0_hat.clamp(0., 1.)).mean() else: raise ValueError optimizer_decoder.zero_grad() loss.backward() optimizer_decoder.step() # optimizer_encoder.step() if config.loss_function == "MSE": mse = torch.nn.MSELoss()(x_0 * 255., x_0_hat.clamp(0., 1.) * 255) psnr = 10 * math.log10(255. * 255. / mse.item()) matric = psnr elif config.loss_function == "MSSSIM": msssim = 1 - CalcuSSIM(x_0, x_0_hat.clamp(0., 1.)).mean().item() matric = msssim tqdmtrainLoader.set_postfix(ordered_dict={ "dataset": config.dataset, "state": "train_decoder" + config.loss_function, "noise_schedule":CHDDIM_config.noise_schedule, "channel": config.channel_type, "CBR:": feature.numel() / 2 / x_0.numel(), "SNR": snr, "matric": matric, "T_max":CHDDIM_config.t_max }) if (e + 1) % config.retrain_save_model_freq == 0: torch.save(decoder.state_dict(), config.re_decoder_path) def eval_JSCC_with_DnCNN(config, CHDDIM_config): from DnCNN.models import DnCNN import torch.nn as nn encoder = network.JSCC_encoder(config, config.C).cuda() decoder = network.JSCC_decoder(config, config.C).cuda() encoder_path = config.encoder_path decoder_path = config.re_decoder_path pass_channel = channel.Channel(config) encoder.eval() decoder.eval() _, testLoader = get_loader(config) DnCNN=DnCNN(config.C).cuda() # encoder = torch.nn.DataParallel(encoder, device_ids=CHDDIM_config.device_ids) # decoder = torch.nn.DataParallel(decoder, device_ids=CHDDIM_config.device_ids) # CHDDIM = torch.nn.DataParallel(CHDDIM, device_ids=CHDDIM_config.device_ids) # # encoder = encoder.cuda(device=CHDDIM_config.device_ids[0]) # decoder = decoder.cuda(device=CHDDIM_config.device_ids[0]) # CHDDIM = CHDDIM.cuda(device=CHDDIM_config.device_ids[0]) encoder.load_state_dict(torch.load(encoder_path)) ckpt = torch.load(CHDDIM_config.save_path) DnCNN.load_state_dict(ckpt) DnCNN.eval() decoder.load_state_dict(torch.load(decoder_path)) if config.dataset == "CIFAR10": CalcuSSIM = MS_SSIM(window_size=3, data_range=1., levels=4, channel=3).cuda() else: CalcuSSIM = MS_SSIM(data_range=1., levels=4, channel=3).cuda() # start training snr_in = config.SNRs - CHDDIM_config.large_snr matric_aver = 0 mse1_aver = 0 mse2_aver = 0 # sigma_eps_aver=torch.zeros() with tqdm(testLoader, dynamic_ncols=True) as tqdmtestLoader: for i, (images, labels) in enumerate(tqdmtestLoader): # train x_0 = images.cuda() feature, _ = encoder(x_0) y = feature y_0 = y y, pwr, h = pass_channel.forward(y, snr_in) # normalize sigma_square = 1.0 / (2 * 10 ** (config.SNRs / 10)) if config.channel_type == "awgn": y_awgn = torch.cat((torch.real(y), torch.imag(y)), dim=2) mse1 = torch.nn.MSELoss()(y_awgn * math.sqrt(2), y_0 * math.sqrt(2) / torch.sqrt(pwr)) receive=y_awgn elif config.channel_type == 'rayleigh': y_mmse = y * torch.conj(h) / (torch.abs(h) ** 2 + sigma_square * 2) y_mmse = torch.cat((torch.real(y_mmse), torch.imag(y_mmse)), dim=2) mse1 = torch.nn.MSELoss()(y_mmse * math.sqrt(2), y_0 * math.sqrt(2) / torch.sqrt(pwr)) receive=y_mmse else: raise ValueError feature_hat = receive-DnCNN(receive) mse2 = torch.nn.MSELoss()(feature_hat * math.sqrt(2), y_0 * math.sqrt(2) / torch.sqrt(pwr)) feature_hat = feature_hat * torch.sqrt(pwr) x_0_hat = decoder(feature_hat) # optimizer1.step() # optimizer2.step() if config.loss_function == "MSE": mse = torch.nn.MSELoss()(x_0 * 255., x_0_hat.clamp(0., 1.) * 255) psnr = 10 * math.log10(255. * 255. / mse.item()) matric = psnr #save_image(x_0_hat,"/home/wutong/semdif_revise/DIV2K_JSCCCDDM_rayleigh_PSNR_10dB/{}.png".format(i)) elif config.loss_function == "MSSSIM": msssim = 1 - CalcuSSIM(x_0, x_0_hat.clamp(0., 1.)).mean().item() matric = msssim #save_image(x_0_hat,"/home/wutong/semdif_revise/DIV2K_JSCCCDDM_rayleigh_MSSSIM_10dB/{}.png".format(i)) mse1_aver += mse1.item() mse2_aver += mse2.item() matric_aver += matric CBR = feature.numel() / 2 / x_0.numel() tqdmtestLoader.set_postfix(ordered_dict={ "dataset": config.dataset, "re_weight":str(CHDDIM_config.re_weight), "state": 'eval JSCC with CDDM' + config.loss_function, "channel": config.channel_type, "noise_schedule":CHDDIM_config.noise_schedule, "CBR": CBR, "SNR": snr_in, "matric ": matric, "MSE_channel": mse1.item(), "MSE_channel+CDDM": mse2.item(), "T_max":CHDDIM_config.t_max }) mse1_aver = (mse1_aver / (i + 1)) mse2_aver = (mse2_aver / (i + 1)) matric_aver = (matric_aver / (i + 1)) if config.loss_function == "MSE": name = 'PSNR' elif config.loss_function == "MSSSIM": name = "MSSSIM" else: raise ValueError #print("matric:{}",matric_aver) myclient = pymongo.MongoClient(config.database_address) mydb = myclient[config.dataset] if 'SNRs' in config.encoder_path: mycol = mydb[name + '_' + config.channel_type + '_SNRs_' + 'JSCC+CDDM' + '_CBR_' + str(CBR)] mydic = {'SNR': snr_in, name: matric_aver} mycol.insert_one(mydic) print('writing successfully', mydic) mycol = mydb['MSE' + name + '_' + config.channel_type + '_SNRs_' + 'JSCC' + '_CBR_' + str(CBR)] mydic = {'SNR': snr_in, 'MSE': mse1_aver} mycol.insert_one(mydic) print('writing successfully', mydic) mycol = mydb["MSE" + name + '_' + config.channel_type + '_SNRs_' + 'JSCC+CDDM' + '_CBR_' + str(CBR)] mydic = {'SNR': snr_in, 'MSE': mse2_aver} mycol.insert_one(mydic) print('writing successfully', mydic) elif 'CBRs' in config.encoder_path: mycol = mydb[name + '_' + config.channel_type + '_CBRs_' + 'JSCC+CDDM' + '_SNR_' + str(snr_in)] mydic = {'CBR': CBR, name: matric_aver} mycol.insert_one(mydic) print('writing successfully', mydic) mycol = mydb['MSE' + name + '_' + config.channel_type + '_CBRs_' + 'JSCC' + '_SNR_' + str(snr_in)] mydic = {'CBR': CBR, 'MSE': mse1_aver} mycol.insert_one(mydic) print('writing successfully', mydic) mycol = mydb["MSE" + name + '_' + config.channel_type + '_CBRs_' + 'JSCC+CDDM' + '_SNR_' + str(snr_in)] mydic = {'CBR': CBR, 'MSE': mse2_aver} mycol.insert_one(mydic) print('writing successfully', mydic) else: raise ValueError def train_GAN(config,CHDDIM_config): from WGANVGG.networks import WGAN_VGG, WGAN_VGG_generator train_losses = [] encoder = network.JSCC_encoder(config, config.C).cuda() encoder_path = config.encoder_path pass_channel = channel.Channel(config) encoder.eval() GAN_config=copy.deepcopy(config) GAN_config.batch_size=config.CDDM_batch trainLoader, _ = get_loader(GAN_config) WGAN_VGG=WGAN_VGG(input_size=16,in_channels=config.C).cuda() WGAN_VGG_generator=WGAN_VGG_generator() criterion_perceptual = torch.nn.L1Loss() optimizer_g = torch.optim.Adam(WGAN_VGG.generator.parameters(), CHDDIM_config.lr) optimizer_d = torch.optim.Adam(WGAN_VGG.discriminator.parameters(), CHDDIM_config.lr) encoder.load_state_dict(torch.load(encoder_path)) for e in range(CHDDIM_config.epoch): with tqdm(trainLoader, dynamic_ncols=True) as tqdmDataLoader: for images, labels in tqdmDataLoader: snr = config.SNRs - CHDDIM_config.large_snr x_0 = images.cuda() feature, _ = encoder(x_0) y = feature y, pwr, h = pass_channel.forward(y, snr) # normalize sigma_square = 1.0 / (2 * 10 ** (config.SNRs / 10)) if config.channel_type == "awgn": y_awgn = torch.cat((torch.real(y), torch.imag(y)), dim=2) #mse1 = torch.nn.MSELoss()(y_awgn * math.sqrt(2), y * math.sqrt(2) / torch.sqrt(pwr)) receive=y_awgn*torch.sqrt(pwr) elif config.channel_type == 'rayleigh': y_mmse = y * torch.conj(h) / (torch.abs(h) ** 2 + sigma_square * 2) y_mmse = torch.cat((torch.real(y_mmse), torch.imag(y_mmse)), dim=2) #mse1 = torch.nn.MSELoss()(y_mmse * math.sqrt(2), y * math.sqrt(2) / torch.sqrt(pwr)) receive=y_mmse*torch.sqrt(pwr) else: raise ValueError for index_2 in range(GAN_config.n_d_train): optimizer_d.zero_grad() #WGAN_VGG.discriminator.zero_grad() #mse1 = torch.nn.MSELoss()(receive / torch.sqrt(pwr) * math.sqrt(2), feature * math.sqrt(2) / torch.sqrt(pwr)) #print(mse1.item()) d_loss, gp_loss = WGAN_VGG.d_loss(receive, feature, gp=True, return_gp=True) d_loss.backward(retain_graph=True) optimizer_d.step() optimizer_g.zero_grad() g_loss,p_loss= WGAN_VGG.g_loss(receive, feature, perceptual=True, return_p=True) #print(pwr) g_loss.backward() optimizer_g.step() tqdmDataLoader.set_postfix(ordered_dict={ "epoch": e, "state": 'train_GAN', "channel type":config.channel_type, "g loss: ": g_loss.item()-p_loss.item(), "p loss: ": p_loss.item(), "d loss: ": d_loss.item(), "d-gp loss: ":d_loss.item()-gp_loss.item(), "gp loss: ":gp_loss.item(), "input shape: ": x_0.shape, "CBR": feature.numel() / 2 / x_0.numel(), }) if (e + 1) % CHDDIM_config.save_model_freq == 0: torch.save(WGAN_VGG.state_dict(), CHDDIM_config.save_path) def eval_JSCC_with_GAN(config, CHDDIM_config): from WGANVGG.networks import WGAN_VGG, WGAN_VGG_generator encoder = network.JSCC_encoder(config, config.C).cuda() decoder = network.JSCC_decoder(config, config.C).cuda() encoder_path = config.encoder_path decoder_path = config.re_decoder_path pass_channel = channel.Channel(config) encoder.eval() decoder.eval() _, testLoader = get_loader(config) WGAN_VGG=WGAN_VGG(input_size=16,in_channels=config.C).cuda() # encoder = torch.nn.DataParallel(encoder, device_ids=CHDDIM_config.device_ids) # decoder = torch.nn.DataParallel(decoder, device_ids=CHDDIM_config.device_ids) # CHDDIM = torch.nn.DataParallel(CHDDIM, device_ids=CHDDIM_config.device_ids) # # encoder = encoder.cuda(device=CHDDIM_config.device_ids[0]) # decoder = decoder.cuda(device=CHDDIM_config.device_ids[0]) # CHDDIM = CHDDIM.cuda(device=CHDDIM_config.device_ids[0]) encoder.load_state_dict(torch.load(encoder_path)) ckpt = torch.load(CHDDIM_config.save_path) WGAN_VGG.load_state_dict(ckpt) WGAN_VGG.eval() decoder.load_state_dict(torch.load(decoder_path)) if config.dataset == "CIFAR10": CalcuSSIM = MS_SSIM(window_size=3, data_range=1., levels=4, channel=3).cuda() else: CalcuSSIM = MS_SSIM(data_range=1., levels=4, channel=3).cuda() # start training snr_in = config.SNRs - CHDDIM_config.large_snr matric_aver = 0 mse1_aver = 0 mse2_aver = 0 # sigma_eps_aver=torch.zeros() with tqdm(testLoader, dynamic_ncols=True) as tqdmtestLoader: for i, (images, labels) in enumerate(tqdmtestLoader): # train x_0 = images.cuda() feature, _ = encoder(x_0) y = feature y_0 = y y, pwr, h = pass_channel.forward(y, snr_in) # normalize sigma_square = 1.0 / (2 * 10 ** (snr_in / 10)) if config.channel_type == "awgn": y_awgn = torch.cat((torch.real(y), torch.imag(y)), dim=2) mse1 = torch.nn.MSELoss()(y_awgn * math.sqrt(2), y_0 * math.sqrt(2) / torch.sqrt(pwr)) receive=y_awgn*torch.sqrt(pwr) elif config.channel_type == 'rayleigh': y_mmse = y * torch.conj(h) / (torch.abs(h) ** 2 + sigma_square * 2) y_mmse = torch.cat((torch.real(y_mmse), torch.imag(y_mmse)), dim=2) mse1 = torch.nn.MSELoss()(y_mmse * math.sqrt(2), y_0 * math.sqrt(2) / torch.sqrt(pwr)) receive=y_mmse*torch.sqrt(pwr) else: raise ValueError feature_hat=WGAN_VGG.generator(receive) mse2 = torch.nn.MSELoss()(feature_hat * math.sqrt(2)/ torch.sqrt(pwr), y_0 * math.sqrt(2) / torch.sqrt(pwr)) # feature_hat = feature_hat * torch.sqrt(pwr) x_0_hat = decoder(feature_hat) # optimizer1.step() # optimizer2.step() if config.loss_function == "MSE": mse = torch.nn.MSELoss()(x_0 * 255., x_0_hat.clamp(0., 1.) * 255) psnr = 10 * math.log10(255. * 255. / mse.item()) matric = psnr #save_image(x_0_hat,"/home/wutong/semdif_revise/DIV2K_JSCCCDDM_rayleigh_PSNR_10dB/{}.png".format(i)) elif config.loss_function == "MSSSIM": msssim = 1 - CalcuSSIM(x_0, x_0_hat.clamp(0., 1.)).mean().item() matric = msssim #save_image(x_0_hat,"/home/wutong/semdif_revise/DIV2K_JSCCCDDM_rayleigh_MSSSIM_10dB/{}.png".format(i)) mse1_aver += mse1.item() mse2_aver += mse2.item() matric_aver += matric CBR = feature.numel() / 2 / x_0.numel() tqdmtestLoader.set_postfix(ordered_dict={ "dataset": config.dataset, "state": 'eval JSCC with GAN' + config.loss_function, "channel": config.channel_type, "CBR": CBR, "SNR": snr_in, "matric ": matric, "MSE_channel": mse1.item(), "MSE_channel+GAN": mse2.item(), }) mse1_aver = (mse1_aver / (i + 1)) mse2_aver = (mse2_aver / (i + 1)) matric_aver = (matric_aver / (i + 1)) if config.loss_function == "MSE": name = 'PSNR' elif config.loss_function == "MSSSIM": name = "MSSSIM" else: raise ValueError #print("matric:{}",matric_aver) myclient = pymongo.MongoClient(config.database_address) mydb = myclient[config.dataset] if 'SNRs' in config.encoder_path: mycol = mydb[name + '_' + config.channel_type + '_SNRs_' + 'JSCC+CDDM' + '_CBR_' + str(CBR)] mydic = {'SNR': snr_in, name: matric_aver} mycol.insert_one(mydic) print('writing successfully', mydic) mycol = mydb['MSE' + name + '_' + config.channel_type + '_SNRs_' + 'JSCC' + '_CBR_' + str(CBR)] mydic = {'SNR': snr_in, 'MSE': mse1_aver} mycol.insert_one(mydic) print('writing successfully', mydic) mycol = mydb["MSE" + name + '_' + config.channel_type + '_SNRs_' + 'JSCC+CDDM' + '_CBR_' + str(CBR)] mydic = {'SNR': snr_in, 'MSE': mse2_aver} mycol.insert_one(mydic) print('writing successfully', mydic) elif 'CBRs' in config.encoder_path: mycol = mydb[name + '_' + config.channel_type + '_CBRs_' + 'JSCC+CDDM' + '_SNR_' + str(snr_in)] mydic = {'CBR': CBR, name: matric_aver} mycol.insert_one(mydic) print('writing successfully', mydic) mycol = mydb['MSE' + name + '_' + config.channel_type + '_CBRs_' + 'JSCC' + '_SNR_' + str(snr_in)] mydic = {'CBR': CBR, 'MSE': mse1_aver} mycol.insert_one(mydic) print('writing successfully', mydic) mycol = mydb["MSE" + name + '_' + config.channel_type + '_CBRs_' + 'JSCC+CDDM' + '_SNR_' + str(snr_in)] mydic = {'CBR': CBR, 'MSE': mse2_aver} mycol.insert_one(mydic) print('writing successfully', mydic) else: raise ValueError def train_JSCC_with_GAN(config, CHDDIM_config): from WGANVGG.networks import WGAN_VGG, WGAN_VGG_generator encoder = network.JSCC_encoder(config, config.C).cuda() decoder = network.JSCC_decoder(config, config.C).cuda() encoder_path = config.encoder_path decoder_path = config.decoder_path pass_channel = channel.Channel(config) trainLoader, _ = get_loader(config) encoder.eval() WGAN_VGG=WGAN_VGG(input_size=16,in_channels=config.C).cuda() # encoder = torch.nn.DataParallel(encoder, device_ids=CHDDIM_config.device_ids) # decoder = torch.nn.DataParallel(decoder, device_ids=CHDDIM_config.device_ids) # CHDDIM = torch.nn.DataParallel(CHDDIM, device_ids=CHDDIM_config.device_ids) # # encoder = encoder.cuda(device=CHDDIM_config.device_ids[0]) # decoder = decoder.cuda(device=CHDDIM_config.device_ids[0]) # CHDDIM = CHDDIM.cuda(device=CHDDIM_config.device_ids[0]) encoder.load_state_dict(torch.load(encoder_path)) decoder.load_state_dict(torch.load(decoder_path)) ckpt = torch.load(CHDDIM_config.save_path) WGAN_VGG.load_state_dict(ckpt) WGAN_VGG.eval() # optimizer_encoder = torch.optim.AdamW( # encoder.parameters(), lr=CHDDIM_config.lr, weight_decay=1e-4) optimizer_decoder = torch.optim.Adam( decoder.parameters(), lr=CHDDIM_config.lr) # start training if config.dataset == "CIFAR10": CalcuSSIM = MS_SSIM(window_size=3, data_range=1., levels=4, channel=3).cuda() else: CalcuSSIM = MS_SSIM(data_range=1., levels=4, channel=3).cuda() for e in range(config.retrain_epoch): with tqdm(trainLoader, dynamic_ncols=True) as tqdmtrainLoader: for i, (images, labels) in enumerate(tqdmtrainLoader): # train snr = config.SNRs - CHDDIM_config.large_snr x_0 = images.cuda() feature, _ = encoder(x_0) y = feature y, pwr, h = pass_channel.forward(y, snr) # normalize sigma_square = 1.0 / (2 * 10 ** (config.SNRs / 10)) if config.channel_type == "awgn": y_awgn = torch.cat((torch.real(y), torch.imag(y)), dim=2) #mse1 = torch.nn.MSELoss()(y_awgn * math.sqrt(2), y * math.sqrt(2) / torch.sqrt(pwr)) receive=y_awgn* torch.sqrt(pwr) elif config.channel_type == 'rayleigh': y_mmse = y * torch.conj(h) / (torch.abs(h) ** 2 + sigma_square * 2) y_mmse = torch.cat((torch.real(y_mmse), torch.imag(y_mmse)), dim=2) #mse1 = torch.nn.MSELoss()(y_mmse * math.sqrt(2), y * math.sqrt(2) / torch.sqrt(pwr)) receive=y_mmse* torch.sqrt(pwr) else: raise ValueError feature_hat = WGAN_VGG.generator(receive) #feature_hat = feature_hat * torch.sqrt(pwr) x_0_hat = decoder(feature_hat) # mse1=torch.nn.MSEloss()() if config.loss_function == "MSE": loss = torch.nn.MSELoss()(x_0, x_0_hat) elif config.loss_function == "MSSSIM": loss = CalcuSSIM(x_0, x_0_hat.clamp(0., 1.)).mean() else: raise ValueError optimizer_decoder.zero_grad() loss.backward() optimizer_decoder.step() # optimizer_encoder.step() if config.loss_function == "MSE": mse = torch.nn.MSELoss()(x_0 * 255., x_0_hat.clamp(0., 1.) * 255) psnr = 10 * math.log10(255. * 255. / mse.item()) matric = psnr elif config.loss_function == "MSSSIM": msssim = 1 - CalcuSSIM(x_0, x_0_hat.clamp(0., 1.)).mean().item() matric = msssim tqdmtrainLoader.set_postfix(ordered_dict={ "dataset": config.dataset, "state": "train_decoder", "noise_schedule":"GAN", "channel": config.channel_type, "CBR:": feature.numel() / 2 / x_0.numel(), "SNR": snr, "matric": matric, }) if (e + 1) % config.retrain_save_model_freq == 0: torch.save(decoder.state_dict(), config.re_decoder_path) class netCDDM(nn.Module): def __init__(self,config,CHDDIM_config): super().__init__() self.CDDM=UNet(T=CHDDIM_config.T, ch=CHDDIM_config.channel, ch_mult=CHDDIM_config.channel_mult, attn=CHDDIM_config.attn, num_res_blocks=CHDDIM_config.num_res_blocks, dropout=CHDDIM_config.dropout, input_channel=CHDDIM_config.C).cuda() def forward(self,input): h = torch.sqrt(torch.normal(mean=0.0, std=1, size=np.shape(input)) ** 2 + torch.normal(mean=0.0, std=1, size=np.shape(input)) ** 2) / np.sqrt(2) h = h.cuda() t = input.new_ones([input.shape[0], ], dtype=torch.long) * 100 t=t.cuda() x=self.CDDM(input,t,h) return x class netJSCC(nn.Module): def __init__(self,config,CHDDIM_config): super().__init__() self.encoder = network.JSCC_encoder(config, config.C).cuda() self.decoder = network.JSCC_decoder(config, config.C).cuda() def forward(self,input): x,_=self.encoder(input) y=self.decoder(x) return y def test_mem_and_comp(config,CHDDIM_config): from thop import profile from thop import clever_format network=netJSCC(config,CHDDIM_config) input=torch.randn(1,3,256,256).cuda() macs,params=profile(network,inputs=(input,)) macs, params = clever_format([macs, params], "%.3f") print(macs,params)

filetype
资源评论
用户头像
本本纲目
2025.07.26
库的更新提升了性能,使用简单,对于希望深入研究GANs的开发者来说,这是一个值得关注的资源。
用户头像
臭人鹏
2025.07.09
采用UNet作为鉴别器的StyleGAN2,提供了一种新的生成对抗网络优化思路,值得深度学习领域的研究和探索。
用户头像
不美的阿美
2025.06.29
通过UNet改善鉴别器结构的StyleGAN2,使得生成图像质量得到提升,代码库安装简便,为研究者提供新方向。
用户头像
郑瑜伊
2025.06.18
将UNet应用于StyleGAN2的鉴别器部分,为AI图像生成领域带来新的研究方向,简洁的安装和使用流程降低了门槛。
用户头像
高工-老罗
2025.05.25
UNet结合StyleGAN2的实现带来新视角,代码更新后效果显著。结合多种技术的潜力巨大,值得期待完整的教程指南。
用户头像
鲸阮
2025.05.07
此Pytorch实现简化了StyleGAN2的应用,通过UNet Discriminator提升生成质量,有望成为AI领域的热门工具。
用户头像
叫我叔叔就行
2025.04.20
项目展示了将UNet与StyleGAN2结合的创新尝试,代码易用性增强,有望为AI生成领域带来新的突破。
iwbunny
  • 粉丝: 41
上传资源 快速赚钱