【TrOCR】在自己数据集上训练TrOCR


在这里插入图片描述
论文地址: TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models
Huggingface:[ 文档][ 模型]

一、数据集预处理

数据集格式转换与划分:

两个变量记录输入的数据集图片和标签txt文件地址。
标签文件内容格式为:

crop_img/388_crop_2.jpg	通道
crop_img/388_crop_3.jpg	未意
crop_img/389_crop_0.jpg	1-E
crop_img/389_crop_1.jpg	7150
crop_img/391_crop_0.jpg	8400

划分数据集并且保存为一下格式其中用变量记录保存的位置。

data_dir/
   ├── train/
   │   ├── images/
   │   └── labels.json
   └── val/
       ├── images/
       └── labels.json

由于划分数据集要用到sklearn库的train_test_split,需要安装一下。

pip install scikit-learn

数据集格式转换与划分的脚本,可将原始图像和标签文件转换为训练所需的目录结构,并自动划分训练集和验证集。
format_split.py

import os
import json
import shutil
from sklearn.model_selection import train_test_split

# -------------------------- 配置参数(请根据实际情况修改) --------------------------
# 原始数据路径
ORIGINAL_IMG_DIR = r"C:\Users\Virgil\Desktop\TrOCR\TrOCR_Dataset\oringal_dateset\crop_img"
ORIGINAL_LABEL_PATH = r"C:\Users\Virgil\Desktop\TrOCR\TrOCR_Dataset\oringal_dateset\rec_gt2.txt"

# 输出数据路径(最终会生成 train/val 目录)
OUTPUT_DATA_DIR = r"C:\Users\Virgil\Desktop\TrOCR\TrOCR_Dataset\formatted_dataset"

# 划分比例(验证集占比)
VAL_RATIO = 0.2  # 20% 数据作为验证集
# ----------------------------------------------------------------------------------

def main():
    # 1. 读取原始标签文件
    print("开始读取标签文件...")
    image_label_pairs = []
    with open(ORIGINAL_LABEL_PATH, "r", encoding="utf-8") as f:
        for line in f.readlines():
            line = line.strip()
            if not line:
                continue
            # 分割图像路径和标签(按制表符分割)
            img_path, label = line.split("\t")
            # 提取图像文件名(如 "388_crop_2.jpg")
            img_filename = os.path.basename(img_path)
            # 检查图像文件是否存在
            full_img_path = os.path.join(ORIGINAL_IMG_DIR, img_filename)
            if not os.path.exists(full_img_path):
                print(f"警告:图像文件不存在 - {
     
     full_img_path},已跳过")
                continue
            image_label_pairs.append({
   
   
                "file_name": img_filename,  # 图像文件名
                "text": label  # 对应标签
            })
    
    if not image_label_pairs:
        print("错误:未找到有效图像-标签对,请检查输入路径")
        return
    print(f"成功读取 {
     
     len(image_label_pairs)} 个有效图像-标签对")

    # 2. 划分训练集和验证集
    print(f"划分训练集({
     
     1-VAL_RATIO:.0%})和验证集({
     
     VAL_RATIO:.0%})...")
    train_pairs, val_pairs = train_test_split(
        image_label_pairs,
        test_size=VAL_RATIO,
        random_state=42  # 固定随机种子,确保划分结果可复现
    )
    print(f"训练集:{
     
     len(train_pairs)} 个样本,验证集:{
     
     len(val_pairs)} 个样本")

    # 3. 创建输出目录结构
    print("创建输出目录结构...")
    # 训练集路径
    train_img_dir = os.path.join(OUTPUT_DATA_DIR, "train", "images")
    train_label_path = os.path.join(OUTPUT_DATA_DIR, "train", "labels.json")
    # 验证集路径
    val_img_dir = os.path.join(OUTPUT_DATA_DIR, "val", "images")
    val_label_path = os.path.join(OUTPUT_DATA_DIR, "val", "labels.json")

    # 创建目录(递归创建,已存在则忽略)
    os.makedirs(train_img_dir, exist_ok=True)
    os.makedirs(val_img_dir, exist_ok=True)

    # 4. 复制图像文件并保存标签JSON
    # 处理训练集
    print("复制训练集图像并保存标签...")
    for pair in train_pairs:
        src_img = os.path.join(ORIGINAL_IMG_DIR, pair["file_name"])
        dst_img = os.path.join(train_img_dir, pair["file_name"])
        shutil.copy2(src_img, dst_img)  # 保留文件元数据
    # 保存训练集标签JSON
    with open(train_label_path, "w", encoding="utf-8") as f:
        json.dump(train_pairs, f, ensure_ascii=False, indent=2)

    # 处理验证集
    print("复制验证集图像并保存标签...")
    for pair in val_pairs:
        src_img = os.path.join(ORIGINAL_IMG_DIR, pair["file_name"])
        dst_img = os.path.join(val_img_dir, pair["file_name"])
        shutil.copy2(src_img, dst_img)
    # 保存验证集标签JSON
    with open(val_label_path, "w", encoding="utf-8") as f:
        json.dump(val_pairs, f, ensure_ascii=False, indent=2)

    print(f"数据集转换完成!输出路径:{
     
     OUTPUT_DATA_DIR}")
    print(f"训练集图像:{
     
     train_img_dir}")
    print(f"训练集标签:{
     
     train_label_path}")
    print(f"验证集图像:{
     
     val_img_dir}")
    print(f"验证集标签:{
     
     val_label_path}")

if __name__ == "__main__":
    main()

示例输出:

开始读取标签文件...
成功读取 6586 个有效图像-标签对
划分训练集(80%)和验证集(20%)...
训练集:5268 个样本,验证集:1318 个样本
创建输出目录结构...
复制训练集图像并保存标签...
复制验证集图像并保存标签...
数据集转换完成!输出路径:C:\Users\Virgil\Desktop\TrOCR\TrOCR_Dataset\formatted_dataset
训练集图像:C:\Users\Virgil\Desktop\TrOCR\TrOCR_Dataset\formatted_dataset\train\images
训练集标签:C:\Users\Virgil\Desktop\TrOCR\TrOCR_Dataset\formatted_dataset\train\labels.json
验证集图像:C:\Users\Virgil\Desktop\TrOCR\TrOCR_Dataset\formatted_dataset\val\images
验证集标签:C:\Users\Virgil\Desktop\TrOCR\TrOCR_Dataset\formatted_dataset\val\labels.json
(base) root@5de27e9cb8c1:/mnt/Virgil/TrOCR/ChineseDataset# python formet_split.py 
开始读取标签文件...
成功读取 90000 个有效图像-标签对
划分训练集(80%)和验证集(20%)...
训练集:72000 个样本,验证集:18000 个样本
创建输出目录结构...
复制训练集图像并保存标签...
复制验证集图像并保存标签...
数据集转换完成!输出路径:/mnt/Virgil/TrOCR/ChineseDataset/9w_train
训练集图像:/mnt/Virgil/TrOCR/ChineseDataset/9w_train/train/images
训练集标签:/mnt/Virgil/TrOCR/ChineseDataset/9w_train/train/labels.json
验证集图像:/mnt/Virgil/TrOCR/ChineseDataset/9w_train/val/images
验证集标签:/mnt/Virgil/TrOCR/ChineseDataset/9w_train/val/labels.json

数据增强

二、TrOCR推理

huggingface上TrOCR模型介绍

TrOCR预训练权重下载

huggingface上预训练权重trocr-base-printed
在这里插入图片描述

在这里插入图片描述
以下是 trocr-base-printed 目录中各文件的意义:

  1. config.json:存储模型的配置信息,包括模型架构(如编码器和解码器层数、隐藏层维度等)、超参数(如激活函数类型等)。这些信息用于在创建模型实例时正确初始化模型结构。
  2. generation_config.json:与模型生成相关的配置文件,例如设置生成文本时的参数,如最大生成长度、束搜索(beam search)的参数(如束宽 num_beams)、是否使用早停(early_stopping)等,用于控制模型生成文本的过程。
  3. gitattributes:用于定义 Git 仓库中文件的属性,例如可以指定某些文件的合并策略、是否进行文本换行符转换等,主要是在版本控制管理方面起作用。
  4. merges.txt:在基于字节对编码(Byte - Pair Encoding, BPE)的分词器中,该文件记录了合并操作的规则。BPE 是一种将频繁出现的子词单元合并成新的词单元的算法,merges.txt 记录了这些合并的顺序和规则,用于构建分词器的词汇表。
  5. model.safetensors:存储模型的权重参数。safetensors 是一种更安全、更轻量级的张量存储格式,相比传统的 PyTorch 模型权重存储方式,它在加载和保存时更高效,且能更好地防止恶意代码注入。
  6. preprocessor_config.json:保存预处理器(如 TrOCR 中用于处理图像和文本的组件)的配置信息,定义了如何对输入的图像和文本进行预处理,包括图像的归一化参数、文本分词器的相关配置等。
  7. README.md:通常是模型的说明文档,介绍模型的基本信息(如模型架构、用途)、训练数据、如何使用该预训练模型(包括示例代码)、模型性能指标等内容,是了解和使用模型的重要参考文档。
  8. special_tokens_map.json:记录特殊标记(如起始标记 <s>、填充标记 <pad>、结束标记 </s> 等)与它们在词汇表中对应索引的映射关系。这些特殊标记在模型处理文本输入和输出时起到重要作用,如标识句子开头、填充长度等。
  9. tokenizer_config.json:分词器的配置文件,定义了分词器的行为,如使用的分词算法、词汇表大小、是否使用一些特定的分词规则(如是否进行字节对编码等),是分词器正确工作的配置依据。
  10. vocab.json:存储分词器的词汇表,即所有词单元及其对应的索引。模型在处理文本时,会根据这个词汇表将文本转换为对应的数字索引序列,是文本向量化的重要依据。

对于单图TrOCR推理代码

修改image_path变量路径

from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image

# 加载处理器和模型
processor = TrOCRProcessor.from_pretrained("../trocr-base-printed")
model = VisionEncoderDecoderModel.from_pretrained("../trocr-base-printed")

# image_path = "img_test.png"
image_path = r"C:\Users\Virgil\Desktop\TrOCR\trocr\code_1\img_test.png"

try:
    # 打开本地图片
    image = Image.open(image_path).convert("RGB")

    # 预处理图像
    pixel_values = processor(images=image, return_tensors="pt").pixel_values

    # 生成文本
    generated_ids = model.generate(pixel_values)
    generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]

    # 输出识别结果
    print("识别结果:", generated_text)

except FileNotFoundError:
    print(f"未找到图片文件: {
     
     image_path}")
except Image.UnidentifiedImageError as e:
    print(f"无法识别图像: {
     
     e}")
except Exception as e:
    print(f"发生其他错误: {
     
     e}")

对这个图片进行推理
在这里插入图片描述

识别结果: ACKNOWLEDGHENTS

并且输出相关信息:

Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig {
   
   
  "attention_probs_dropout_prob": 0.0,
  "encoder_stride": 16,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.0,
  "hidden_size": 768,
  "image_size": 384,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "model_type": "vit",
  "num_attention_heads": 12,
  "num_channels": 3,
  "num_hidden_layers": 12,
  "patch_size": 16,
  "qkv_bias": false,
  "transformers_version": "4.48.0"
}

Config of the decoder: <class 'transformers.models.trocr.modeling_trocr.TrOCRForCausalLM'> is overwritten by shared decoder config: TrOCRConfig {
   
   
  "activation_dropout": 0.0,
  "activation_function": "gelu",
  "add_cross_attention": true,
  "attention_dropout": 0.0,
  "bos_token_id": 0,
  "classifier_dropout": 0.0,
  "cross_attention_hidden_size": 768,
  "d_model": 1024,
  "decoder_attention_heads": 16,
  "decoder_ffn_dim": 4096,
  "decoder_layerdrop": 0.0,
  "decoder_layers": 12,
  "decoder_start_token_id": 2,
  "dropout": 0.1,
  "eos_token_id": 2,
  "init_std": 0.02,
  "is_decoder": true,
  "layernorm_embedding": true,
  "max_position_embeddings": 512,
  "model_type": "trocr",
  "pad_token_id": 1,

这段输出信息主要包含了模型配置的覆盖情况以及模型权重初始化的相关提示

  1. 编码器配置覆盖信息
    Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig {
         
         
      "attention_probs_dropout_prob": 0.0,
      "encoder_stride": 16,
      "hidden_act": "gelu",
      "hidden_dropout_prob": 0.0,
      "hidden_size": 768,
      "image_size": 384,
      "initializer_range": 0.02,
      "intermediate_size": 3072,
      "layer_norm_eps": 1e-12,
      "model_type": "vit",
      "num_attention_heads": 12,
      "num_channels": 3,
      "num_hidden_layers": 12,
      "patch_size": 16,
      "qkv_bias": false,
      "transformers_version": "4.48.0"
    }
    
  • 含义:这里表明编码器原本的配置(transformers.models.vit.modeling_vit.ViTModel)被一个共享的编码器配置(ViTConfig)覆盖了。覆盖后的编码器配置包含了一系列超参数,例如:
    • attention_probs_dropout_prob:注意力概率的丢弃率,这里设置为 0.0,表示不进行丢弃。
    • hidden_size:隐藏层的维度大小,为 768。
    • num_hidden_layers:隐藏层的数量,为 12。
  1. 解码器配置覆盖信息
    Config of the decoder: <class 'transformers.models.trocr.modeling_trocr.TrOCRForCausalLM'> is overwritten by shared decoder config: TrOCRConfig {
         
         
      "activation_dropout": 0.0,
      "activation_function": "gelu",
      "add_cross_attention": true,
      "attention_dropout": 0.0,
      "bos_token_id": 0,
      "classifier_dropout": 0.0,
      "cross_attention_hidden_size": 768,
      "d_model": 1024,
      "decoder_attention_heads": 16,
      "decoder_ffn_dim": 4096,
      "decoder_layerdrop": 0.0,
      "decoder_layers": 12,
      "decoder_start_token_id": 2,
      "dropout": 0.1,
      "eos_token_id": 2,
      "init_std": 0.02,
      "is_decoder": true,
      "layernorm_embedding": true,
      "max_position_embeddings": 512,
      "model_type": "trocr",
      "pad_token_id": 1,
      "scale_embedding": false,
      "transformers_version": "4.48.0",
      "use_cache": false,
      "use_learned_position_embeddings": true,
      "vocab_size": 50265
    <
### 使用TroCR模型进行ONNX格式的转换 对于将TroCR模型转换为ONNX格式,可以遵循类似的步骤来准备和优化模型以便于后续部署。具体操作涉及使用特定工具或库来进行模型转换。 #### 转换过程概述 为了将PyTorch中的预训练TroCR模型导出到ONNX格式,需先加载该模型并设置输入张量形状。之后通过调用`torch.onnx.export()`函数完成实际转换工作[^1]: ```python import torch from transformers import TrOCRProcessor, VisionEncoderDecoderModel processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten") model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten") dummy_input = torch.randn(1, 3, 224, 224) torch.onnx.export( model, dummy_input, "trocr_base_handwritten.onnx", export_params=True, opset_version=10, do_constant_folding=True, input_names=['input'], output_names=['output'], dynamic_axes={'input': {0: 'batch_size'},'output': {0: 'batch_size'}} ) ``` 此代码片段展示了如何利用Hugging Face Transformers库加载TroCR模型,并将其保存为`.onnx`文件。注意这里创建了一个虚拟输入用于模拟真实场景下的数据流。 #### 部署考虑因素 当涉及到部署经过转换后的ONNX模型时,可以选择多种方式集成至应用程序中。一种常见方法是借助像ONNX Runtime这样的高性能推理引擎,在不同平台上执行已编译好的模型。这不仅限于服务器端应用,还包括移动设备和其他边缘计算环境[^3]。 另外值得注意的是,针对具体的业务需求(如人力资源管理系统),可能还需要额外开发接口和服务层以支持与现有系统的交互以及处理来自用户的请求[^2]。
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值