import torch
import torch.nn.functional as F
from PIL import Image
import numpy as np
import os
import torchvision.transforms as transforms
import sys
import math
# 添加模型路径
sys.path.append('.')
# 导入FAR项目中的VAE模型
from models.vae import AutoencoderKL, DiagonalGaussianDistribution
def center_crop_arr(pil_image, image_size):
"""
Center cropping implementation from ADM.
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
"""
# 首先降采样,如果图像最小尺寸大于目标尺寸的两倍
while min(*pil_image.size) >= 2 * image_size:
pil_image = pil_image.resize(
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
)
# 计算缩放比例,将最小边缩放到目标尺寸
scale = image_size / min(*pil_image.size)
pil_image = pil_image.resize(
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
)
# 中心裁剪为目标尺寸
arr = np.array(pil_image)
crop_y = (arr.shape[0] - image_size) // 2
crop_x = (arr.shape[1] - image_size) // 2
return arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size]
def process_image_frequency():
"""
按照FAR项目的方法对图像进行频率处理
"""
# 硬编码参数
image_path = './data/test.png'
output_path = './data/test_freq14.png'
vae_path = 'pretrained_models/vae/kl16.ckpt'
source_freq = 16
target_freq = 14
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 加载VAE模型
print(f"加载VAE模型: {vae_path}")
vae = AutoencoderKL(
embed_dim=16, # 设置为FAR中的vae_embed_dim
ch_mult=(1, 1, 2, 2, 4), # 按照FAR中的设置
ckpt_path=vae_path
).to(device)
vae.eval()
# 加载图像并预处理,使用与FAR相同的裁剪方法
print(f"处理图像: {image_path}")
img = Image.open(image_path).convert('RGB')
# 使用FAR项目中的预处理方式:中心裁剪、随机水平翻转、归一化
image_size = 256 # FAR的默认图像大小
transform = transforms.Compose([
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size)),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # 归一化到[-1, 1]
])
# 应用转换
img_array = transform(img)
img_tensor = img_array.unsqueeze(0).to(device) # [1, 3, 256, 256]
with torch.no_grad():
# 1. 使用VAE编码器将图像编码到潜在空间
posterior = vae.encode(img_tensor)
latent = posterior.sample()
# 标准化潜变量,与FAR项目保持一致
latent = latent.mul_(0.2325) # FAR项目中的标准化因子
# 2. 在潜在空间中进行频率处理
# 潜在空间的图像大小为16x16(假设vae_stride=16,对于256x256的输入)
_, C, H, W = latent.shape
print(f"潜在空间的大小: {H}x{W}x{C}")
# 根据FAR项目的频率处理方式
# 首先用area模式降采样到目标频率
latent_downsampled = F.interpolate(latent, size=(target_freq, target_freq), mode='area')
# 然后用bicubic模式上采样回原始潜在空间大小
latent_processed = F.interpolate(latent_downsampled, size=(H, W), mode='bicubic')
# 逆标准化
latent_processed = latent_processed / 0.2325
# 3. 使用VAE解码器将处理后的潜在表示解码回图像空间
decoded_img = vae.decode(latent_processed)
# 将解码后的图像从[-1,1]转换到[0,1]再转到[0,255]
decoded_img = decoded_img * 0.5 + 0.5
decoded_img = decoded_img.clamp(0, 1)
decoded_img = decoded_img.squeeze(0).permute(1, 2, 0).cpu().numpy()
decoded_img = (decoded_img * 255).astype(np.uint8)
# 创建输出目录(如果不存在)
os.makedirs(os.path.dirname(output_path), exist_ok=True)
# 保存图像
output_img = Image.fromarray(decoded_img)
output_img.save(output_path)
print(f"处理完成! 已保存到: {output_path}")
# 可选:保存原始VAE重建的图像进行对比
original_latent_decoded = vae.decode(latent)
original_latent_decoded = original_latent_decoded * 0.5 + 0.5
original_latent_decoded = original_latent_decoded.clamp(0, 1)
original_latent_decoded = original_latent_decoded.squeeze(0).permute(1, 2, 0).cpu().numpy()
original_latent_decoded = (original_latent_decoded * 255).astype(np.uint8)
original_output_path = os.path.splitext(output_path)[0] + "_original.png"
original_img = Image.fromarray(original_latent_decoded)
original_img.save(original_output_path)
print(f"原始VAE重建图像已保存到: {original_output_path}")
return decoded_img
if __name__ == "__main__":
process_image_frequency()
jiangcaiynag
于 2025-03-25 20:13:16 首次发布