jiangcaiynag

import torch
import torch.nn.functional as F
from PIL import Image
import numpy as np
import os
import torchvision.transforms as transforms
import sys
import math

# 添加模型路径
sys.path.append('.')

# 导入FAR项目中的VAE模型
from models.vae import AutoencoderKL, DiagonalGaussianDistribution

def center_crop_arr(pil_image, image_size):
    """
    Center cropping implementation from ADM.
    https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
    """
    # 首先降采样,如果图像最小尺寸大于目标尺寸的两倍
    while min(*pil_image.size) >= 2 * image_size:
        pil_image = pil_image.resize(
            tuple(x // 2 for x in pil_image.size), resample=Image.BOX
        )

    # 计算缩放比例,将最小边缩放到目标尺寸
    scale = image_size / min(*pil_image.size)
    pil_image = pil_image.resize(
        tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
    )

    # 中心裁剪为目标尺寸
    arr = np.array(pil_image)
    crop_y = (arr.shape[0] - image_size) // 2
    crop_x = (arr.shape[1] - image_size) // 2
    return arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]

def process_image_frequency():
    """
    按照FAR项目的方法对图像进行频率处理
    """
    # 硬编码参数
    image_path = './data/test.png'
    output_path = './data/test_freq14.png'
    vae_path = 'pretrained_models/vae/kl16.ckpt'
    source_freq = 16
    target_freq = 14
    
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # 加载VAE模型
    print(f"加载VAE模型: {vae_path}")
    vae = AutoencoderKL(
        embed_dim=16,  # 设置为FAR中的vae_embed_dim
        ch_mult=(1, 1, 2, 2, 4),  # 按照FAR中的设置
        ckpt_path=vae_path
    ).to(device)
    vae.eval()
    
    # 加载图像并预处理,使用与FAR相同的裁剪方法
    print(f"处理图像: {image_path}")
    img = Image.open(image_path).convert('RGB')
    
    # 使用FAR项目中的预处理方式:中心裁剪、随机水平翻转、归一化
    image_size = 256  # FAR的默认图像大小
    transform = transforms.Compose([
        transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])  # 归一化到[-1, 1]
    ])
    
    # 应用转换
    img_array = transform(img)
    img_tensor = img_array.unsqueeze(0).to(device)  # [1, 3, 256, 256]
    
    with torch.no_grad():
        # 1. 使用VAE编码器将图像编码到潜在空间
        posterior = vae.encode(img_tensor)
        latent = posterior.sample()
        
        # 标准化潜变量,与FAR项目保持一致
        latent = latent.mul_(0.2325)  # FAR项目中的标准化因子
        
        # 2. 在潜在空间中进行频率处理
        # 潜在空间的图像大小为16x16(假设vae_stride=16,对于256x256的输入)
        _, C, H, W = latent.shape
        print(f"潜在空间的大小: {H}x{W}x{C}")
        
        # 根据FAR项目的频率处理方式
        # 首先用area模式降采样到目标频率
        latent_downsampled = F.interpolate(latent, size=(target_freq, target_freq), mode='area')
        
        # 然后用bicubic模式上采样回原始潜在空间大小
        latent_processed = F.interpolate(latent_downsampled, size=(H, W), mode='bicubic')
        
        # 逆标准化
        latent_processed = latent_processed / 0.2325
        
        # 3. 使用VAE解码器将处理后的潜在表示解码回图像空间
        decoded_img = vae.decode(latent_processed)
    
    # 将解码后的图像从[-1,1]转换到[0,1]再转到[0,255]
    decoded_img = decoded_img * 0.5 + 0.5
    decoded_img = decoded_img.clamp(0, 1)
    decoded_img = decoded_img.squeeze(0).permute(1, 2, 0).cpu().numpy()
    decoded_img = (decoded_img * 255).astype(np.uint8)
    
    # 创建输出目录(如果不存在)
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    
    # 保存图像
    output_img = Image.fromarray(decoded_img)
    output_img.save(output_path)
    
    print(f"处理完成! 已保存到: {output_path}")
    
    # 可选:保存原始VAE重建的图像进行对比
    original_latent_decoded = vae.decode(latent)
    original_latent_decoded = original_latent_decoded * 0.5 + 0.5
    original_latent_decoded = original_latent_decoded.clamp(0, 1)
    original_latent_decoded = original_latent_decoded.squeeze(0).permute(1, 2, 0).cpu().numpy()
    original_latent_decoded = (original_latent_decoded * 255).astype(np.uint8)
    
    original_output_path = os.path.splitext(output_path)[0] + "_original.png"
    original_img = Image.fromarray(original_latent_decoded)
    original_img.save(original_output_path)
    
    print(f"原始VAE重建图像已保存到: {original_output_path}")
    
    return decoded_img

if __name__ == "__main__":
    process_image_frequency() 

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值