Qwen2.5-VL-7B如何从输入到输出-代码解析(图片解析)

出发点

有一段时间没写过代码解析类的文章了,这次重新出发从Qwen2.5-VL-7B出发,从这个代码开始就没有tokenizer了,不得不感叹一下:transformer的代码更新太快了,稍微有一段时间不看了,就看不懂了

模型架构

环境部署:monkeyocr

from transformers import Qwen2_5_VLForConditionalGeneration, AutoTokenizer, AutoProcessor

from qwen_vl_utils import process_vision_info


path = "/usr/downloads/Qwen/Qwen2.5-VL-7B-Instruct/"
# default: Load the  model on the available device(s)
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    path, torch_dtype="auto", device_map="auto"
)

# We recommend enabling flash_attention_2 for better acceleration and memory saving, especially in multi-image and video scenarios.
# model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
#     "Qwen/Qwen2.5-VL-7B-Instruct",
#     torch_dtype=torch.bfloat16,
#     attn_implementation="flash_attention_2",
#     device_map="auto",
# )

# default processer
processor = AutoProcessor.from_pretrained(path)

# The default range for the number of visual tokens per image in the model is 4-16384.
# You can set min_pixels and max_pixels according to your needs, such as a token range of 256-1280, to balance performance and cost.
# min_pixels = 256*28*28
# max_pixels = 1280*28*28
# processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", min_pixels=min_pixels, max_pixels=max_pixels)

messages = [
    {
        "role": "user",
        "content": [
            {
                "type": "image",
                "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
            },
            {"type": "text", "text": "Describe this image."},
        ],
    }
]

# Preparation for inference
text = processor.apply_chat_template(
    messages, tokenize=False, add_generation_prompt=True
)
image_inputs, video_inputs = process_vision_info(messages)
inputs = processor(
    text=[text],
    images=image_inputs,
    videos=video_inputs,
    padding=True,
    return_tensors="pt",
)
inputs = inputs.to("cuda")

# Inference: Generation of the output

generated_ids = model.generate(**inputs, max_new_tokens=128)
generated_ids_trimmed = [
    out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
output_text = processor.batch_decode(
    generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
print(output_text)
# model结构
Qwen2_5_VLForConditionalGeneration(
  (model): Qwen2_5_VLModel(
    (visual): Qwen2_5_VisionTransformerPretrainedModel(
      (patch_embed): Qwen2_5_VisionPatchEmbed(
        (proj): Conv3d(3, 1280, kernel_size=(2, 14, 14), stride=(2, 14, 14), bias=False)
      )
      (rotary_pos_emb): Qwen2_5_VisionRotaryEmbedding()
      (blocks): ModuleList(
        (0-31): 32 x Qwen2_5_VLVisionBlock(
          (norm1): Qwen2RMSNorm((1280,), eps=1e-06)
          (norm2): Qwen2RMSNorm((1280,), eps=1e-06)
          (attn): Qwen2_5_VLVisionSdpaAttention(
            (qkv): Linear(in_features=1280, out_features=3840, bias=True)
            (proj): Linear(in_features=1280, out_features=1280, bias=True)
          )
          (mlp): Qwen2_5_VLMLP(
            (gate_proj): Linear(in_features=1280, out_features=3420, bias=True)
            (up_proj): Linear(in_features=1280, out_features=3420, bias=True)
            (down_proj): Linear(in_features=3420, out_features=1280, bias=True)
            (act_fn): SiLU()
          )
        )
      )
      (merger): Qwen2_5_VLPatchMerger(
        (ln_q): Qwen2RMSNorm((1280,), eps=1e-06)
        (mlp): Sequential(
          (0): Linear(in_features=5120, out_features=5120, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=5120, out_features=3584, bias=True)
        )
      )
    )
    (language_model): Qwen2_5_VLTextModel(
      (embed_tokens): Embedding(152064, 3584)
      (layers): ModuleList(
        (0-27): 28 x Qwen2_5_VLDecoderLayer(
          (self_attn): Qwen2_5_VLSdpaAttention(
            (q_proj): Linear(in_features=3584, out_features=3584, bias=True)
            (k_proj): Linear(in_features=3584, out_features=512, bias=True)
            (v_proj): Linear(in_features=3584, out_features=512, bias=True)
            (o_proj): Linear(in_features=3584, out_features=3584, bias=False)
            (rotary_emb): Qwen2_5_VLRotaryEmbedding()
          )
          (mlp): Qwen2MLP(
            (gate_proj): Linear(in_features=3584, out_features=18944, bias=False)
            (up_proj): Linear(in_features=3584, out_features=18944, bias=False)
            (down_proj): Linear(in_features=18944, out_features=3584, bias=False)
            (act_fn): SiLU()
          )
          (input_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
          (post_attention_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
        )
      )
      (norm): Qwen2RMSNorm((3584,), eps=1e-06)
      (rotary_emb): Qwen2_5_VLRotaryEmbedding()
    )
  )
  (lm_head): Linear(in_features=3584, out_features=152064, bias=False)
)

qwen_vl_utils

from qwen_vl_utils import process_vision_info
messages = [
    {
        "role": "user",
        "content": [
            {
                "type": "image",
                "image": "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen-VL/assets/demo.jpeg",
            },
            {"type": "text", "text": "Describe this image."},
        ],
    }
]
image_inputs, video_inputs = process_vision_info(messages)
# image_inputs [<PIL.Image.Image image mode=RGB size=2044x1372 at 0x7F294B1BC450>]
# 2044/28=73
# 1372/28=49
# 原图片大小是2048*1365

这里主要看一下process_vision_info这个函数
vision_process.py
这份代码主要是在提取messages中的image和video,并将image和video进行resize(28*28的倍数),并指定有一个最小尺寸和最大尺寸

一些定义和包引用

from __future__ import annotations

import base64
import math
import warnings
from io import BytesIO

import requests
import torch
from PIL import Image
from torchvision import io, transforms
from torchvision.transforms import InterpolationMode


IMAGE_FACTOR = 28# 宽和高最后得是28的倍数
MIN_PIXELS = 4 * 28 * 28# 最小像素数,这是宽*高,等宽高的情况下,宽和高都是2*28
MAX_PIXELS = 16384 * 28 * 28# 最大像素数,这是宽*高,等宽高的情况下,宽和高都是128*28
MAX_RATIO = 200# 最大宽和高的比例,比如宽=1,高是200

VIDEO_MIN_PIXELS = 128 * 28 * 28# 视频的最小像素数,这是宽*高
VIDEO_MAX_PIXELS = 768 * 28 * 28# 视频的最大像素数,这是宽*高
# VIDEO_TOTAL_PIXELS = 24576 * 28 * 28
VIDEO_TOTAL_PIXELS = 128000 * 28 * 28 * 0.9# 视频的总像素数,28*28作为一个token,即可以有128000*0.9=115200个token
"""
这里看起来115200个token数好多啊,简单算一下
如果video的单张图片对应的像素数到了最大,也就是768 * 28 * 28
那单张图片占了768个token
所以最多可以从视频里拿出115200/768=150张图片
代码里的实现逻辑是先找到取多少帧,然后再算每张图片有多少像素点
"""
FRAME_FACTOR = 2# 每秒视频里可以取到几张图片
FPS = 2.0
FPS_MIN_FRAMES = 4# 最小取4帧
FPS_MAX_FRAMES = 768# 最大取768帧

取整

def round_by_factor(number: int, factor: int) -> int:
    """Returns the closest integer to 'number' that is divisible by 'factor'."""
    # 返回离number最近的可整除factor的整数
    # 四舍五入
    # round_by_factor(3000, 28)=2996
    # round_by_factor(6, 28)=0
    return round(number / factor) * factor


def ceil_by_factor(number: int, factor: int) -> int:
    """Returns the smallest integer greater than or equal to 'number' that is divisible by 'factor'."""
    # 返回大于等于number且能被factor整除的最小值
    # 向上取
    # ceil_by_factor(5,28)=28
    # ceil_by_factor(29,28)=56
    return math.ceil(number / factor) * factor


def floor_by_factor(number: int, factor: int) -> int:
    """Returns the largest integer less than or equal to 'number' that is divisible by 'factor'."""
    # 返回小于等于number且能被factor整除的最大值
    # 向下取
    # floor_by_factor(5,28)=0
    # floor_by_factor(29,28)=28
    return math.floor(number / factor) * factor

如果对round、ceil、floor很熟,很简单的三个函数,找到最符合条件的(四舍五入、向下、向上)能够整除factor的number,这是每张图片进行resize前的最后一步,为了后续token化做准备

smart_resize

def smart_resize(
    height: int, width: int, factor: int = IMAGE_FACTOR, min_pixels: int = MIN_PIXELS, max_pixels: int = MAX_PIXELS
) -> tuple[int, int]:
    """
    Rescales the image so that the following conditions are met:
    1. Both dimensions (height and width) are divisible by 'factor'.
    2. The total number of pixels is within the range ['min_pixels', 'max_pixels'].
    3. The aspect ratio of the image is maintained as closely as possible.
    """
    # smart_resize(6,6) = (56, 56) 56*56==4*28*28
    # smart_resize(30006,30006) = (3584, 3584) 最大尺寸是16384 * 28 * 28
    # smart_resize(30006,300006) = (1120, 11312) 保持长宽比的情况下达到最大尺寸
    # smart_resize(3000, 2000) = (2996, 1988) 正常尺寸,能够整除28
    # 长宽比要小于MAX_RATIO,即200
    if max(height, width) / min(height, width) > MAX_RATIO:
        raise ValueError(
            f"absolute aspect ratio must be smaller than {MAX_RATIO}, got {max(height, width) / min(height, width)}"
        )
    h_bar = max(factor, round_by_factor(height, factor))# 最小为factor
    w_bar = max(factor, round_by_factor(width, factor))# 最小为factor
    if h_bar * w_bar > max_pixels:# 16384 * 28 * 28
        beta = math.sqrt((height * width) / max_pixels)
        h_bar = floor_by_factor(height / beta, factor)
        w_bar = floor_by_factor(width / beta, factor)
    elif h_bar * w_bar < min_pixels:# 
        beta = math.sqrt(min_pixels / (height * width))
        h_bar = ceil_by_factor(height * beta, factor)
        w_bar = ceil_by_factor(width * beta, factor)
    return h_bar, w_bar

这里是对height和width进行处理,找到最符合条件的能被28整除的h_bar, w_bar
最后得到的两个值符合以下几个条件:

  • 宽高比小于200(强制条件)
  • h_bar, w_bar能被28整除
  • 总的像素值和比min_pixels大,比max_pixels大,且调整的时候,尽量保持原始图片的宽高比

fetch_image

def fetch_image(ele: dict[str, str | Image.Image], size_factor: int = IMAGE_FACTOR) -> Image.Image:
    if "image" in ele:
        image = ele["image"]
    else:
        image = ele["image_url"]
    image_obj = None
    if isinstance(image, Image.Image):
        image_obj = image
    elif image.startswith("http://") or image.startswith("https://"):
        image_obj = Image.open(requests.get(image, stream=True).raw)
    elif image.startswith("file://"):
        image_obj = Image.open(image[7:])
    elif image.startswith("data:image"):
        data = image.split(";", 1)[1]
        if data.startswith("base64,"):
            data = base64.b64decode(data[7:])
            image_obj = Image.open(BytesIO(data))
    else:
        image_obj = Image.open(image)
    if image_obj is None:
        raise ValueError(f"Unrecognized image input, support local path, http url, base64 and PIL.Image, got {image}")
    image = image_obj.convert("RGB")
    ## resize
    if "resized_height" in ele and "resized_width" in ele:
        resized_height, resized_width = smart_resize(
            ele["resized_height"],
            ele["resized_width"],
            factor=size_factor,
        )
    else:
        width, height = image.size
        min_pixels = ele.get("min_pixels", MIN_PIXELS)
        max_pixels = ele.get("max_pixels", MAX_PIXELS)
        resized_height, resized_width = smart_resize(
            height,
            width,
            factor=size_factor,
            min_pixels=min_pixels,
            max_pixels=max_pixels,
        )
    image = image.resize((resized_width, resized_height))

    return image

这里的代码是在通过路径(base64、http等方式)拿到图片
然后调用smart_resize拿到新的宽高,进行resize

fetch_video

到了最长的代码了,这里是通过路径(文件)拿到视频后读取视频,拿到视频、声频和信息(帧数)
视频有5334帧,每秒有29.97302764666217帧,所以有5334/29.97=177.977977977978秒
每秒取2(FPS)帧,即5334/29.97*2=355,将这个值按照round的方式到2的倍数356,所以得到了nframes
把nframes和 max_frames、min_frames 进行对比,确保帧数在这个范围内

所以视频是先拿到了帧数,从video中读取这些帧,然后再找到每张图片的总像素数

如果是图片,则是先将每张图片进行resize,这里没有VIDEO_MIN_PIXELS的限制,可能是觉得没人会上传100多长图片吧。。

def fetch_video(ele: dict, size_factor: int = FRAME_FACTOR) -> torch.Tensor | list[Image.Image]:
    if isinstance(ele["video"], str):# 传进来的是一个video
        # TODO: support http url

        video = ele["video"]
        if video.startswith("file://"):
            video = video[7:]

        video, audio, info = io.read_video(
            video,
            start_pts=ele.get("video_start", 0.0),
            end_pts=ele.get("video_end", None),
            pts_unit="sec",
            output_format="TCHW",
        )
        # info = {'video_fps': 29.97302764666217}
        # video.shape = torch.Size([354, 3, 720, 1280])
        # video.size(0) = 5334
        # video_inputs[0].shape = torch.Size([354, 3, 532, 952])

        assert not ("fps" in ele and "nframes" in ele), "Only accept either `fps` or `nframes`"
        if "nframes" in ele:
            nframes = round_by_factor(ele["nframes"], size_factor)# 2
        else:
            fps = ele.get("fps", FPS)# 2
            nframes = video.size(0) / info["video_fps"] * fps# 355
            nframes = round_by_factor(nframes, size_factor)# 356
            if "min_frames" in ele:
                min_frames = ele["min_frames"]
                if nframes < min_frames:
                    nframes = ceil_by_factor(min_frames, size_factor)
            else:
                min_frames = FPS_MIN_FRAMES# 4
                if nframes < min_frames:
                    warnings.warn(f"nframes is less than DEFAULT_MIN_FRAMES {min_frames}, set to {nframes}.")
                    nframes = ceil_by_factor(min_frames, size_factor)
            if "max_frames" in ele:
                max_frames = ele["max_frames"]
                if nframes > max_frames:
                    nframes = floor_by_factor(max_frames, size_factor)
            else:
                max_frames = FPS_MAX_FRAMES# 768
                if nframes > max_frames:
                    warnings.warn(f"nframes is greater than DEFAULT_MAX_FRAMES {max_frames}, set to {nframes}.")
                    nframes = floor_by_factor(max_frames, size_factor)

        if not (size_factor <= nframes and nframes <= video.size(0)):
            raise ValueError(f"nframes should in interval [{size_factor}, {video.size(0)}], but got {nframes}.")

        idx = torch.linspace(0, video.size(0) - 1, nframes).round().long()
        # 总的长度是nframes,间隔是(video.size(0) - 1)/nframes
        height, width = video.shape[2:]# 720 1280
        video = video[idx]# 取出该拿的帧

        min_pixels = ele.get("min_pixels", VIDEO_MIN_PIXELS)# 128 * 28 * 28
        total_pixels = ele.get("total_pixels", VIDEO_TOTAL_PIXELS)# 128000 * 28 * 28 * 0.9
        # VIDEO_MAX_PIXELS = 768 * 28 * 28
        max_pixels = max(min(VIDEO_MAX_PIXELS, total_pixels / nframes * size_factor), min_pixels * 1.05)# 
        # 能够在保证帧数的情况下,每帧最大的像素数
        max_pixels = ele.get("max_pixels", max_pixels)
        if "resized_height" in ele and "resized_width" in ele:
            resized_height, resized_width = smart_resize(
                ele["resized_height"],
                ele["resized_width"],
                factor=IMAGE_FACTOR,
            )
        else:
            resized_height, resized_width = smart_resize(
                height,
                width,
                factor=IMAGE_FACTOR,
                min_pixels=min_pixels,
                max_pixels=max_pixels,
            )
        video = transforms.functional.resize(
            video,
            [resized_height, resized_width],
            interpolation=InterpolationMode.BICUBIC,
            antialias=True,
        ).float()
        return video
    else:# 传进来的是一堆图片
        assert isinstance(ele["video"], (list, tuple))
        process_info = ele.copy()
        process_info.pop("type", None)
        process_info.pop("video", None)
        images = [fetch_image({"image": video_element, **process_info}) for video_element in ele["video"]]
        nframes = ceil_by_factor(len(images), size_factor)
        if len(images) < nframes:
            images.extend([images[-1]] * (nframes - len(images)))
        return images


总结

这篇主要是查看了qwen2.5-VL的模型使用方法、模型结构、已经对图片(视频)的处理流程,每个图像块28*28,比之前的14/*14要大,而且视频中每秒可以取到2帧,和之前默认取128(64)的方法也不一样。取图片块的方法和之前看的MiNicpm也不一样,下一篇看输入的得到啦,加油加油

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值