
论文地址: TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models
Huggingface:[ 文档][ 模型]
一、数据集预处理
数据集格式转换与划分:
两个变量记录输入的数据集图片和标签txt文件地址。
标签文件内容格式为:
crop_img/388_crop_2.jpg 通道
crop_img/388_crop_3.jpg 未意
crop_img/389_crop_0.jpg 1-E
crop_img/389_crop_1.jpg 7150
crop_img/391_crop_0.jpg 8400
划分数据集并且保存为一下格式其中用变量记录保存的位置。
data_dir/
├── train/
│ ├── images/
│ └── labels.json
└── val/
├── images/
└── labels.json
由于划分数据集要用到sklearn库的train_test_split,需要安装一下。
pip install scikit-learn
数据集格式转换与划分的脚本,可将原始图像和标签文件转换为训练所需的目录结构,并自动划分训练集和验证集。
format_split.py
import os
import json
import shutil
from sklearn.model_selection import train_test_split
# -------------------------- 配置参数(请根据实际情况修改) --------------------------
# 原始数据路径
ORIGINAL_IMG_DIR = r"C:\Users\Virgil\Desktop\TrOCR\TrOCR_Dataset\oringal_dateset\crop_img"
ORIGINAL_LABEL_PATH = r"C:\Users\Virgil\Desktop\TrOCR\TrOCR_Dataset\oringal_dateset\rec_gt2.txt"
# 输出数据路径(最终会生成 train/val 目录)
OUTPUT_DATA_DIR = r"C:\Users\Virgil\Desktop\TrOCR\TrOCR_Dataset\formatted_dataset"
# 划分比例(验证集占比)
VAL_RATIO = 0.2 # 20% 数据作为验证集
# ----------------------------------------------------------------------------------
def main():
# 1. 读取原始标签文件
print("开始读取标签文件...")
image_label_pairs = []
with open(ORIGINAL_LABEL_PATH, "r", encoding="utf-8") as f:
for line in f.readlines():
line = line.strip()
if not line:
continue
# 分割图像路径和标签(按制表符分割)
img_path, label = line.split("\t")
# 提取图像文件名(如 "388_crop_2.jpg")
img_filename = os.path.basename(img_path)
# 检查图像文件是否存在
full_img_path = os.path.join(ORIGINAL_IMG_DIR, img_filename)
if not os.path.exists(full_img_path):
print(f"警告:图像文件不存在 - {
full_img_path},已跳过")
continue
image_label_pairs.append({
"file_name": img_filename, # 图像文件名
"text": label # 对应标签
})
if not image_label_pairs:
print("错误:未找到有效图像-标签对,请检查输入路径")
return
print(f"成功读取 {
len(image_label_pairs)} 个有效图像-标签对")
# 2. 划分训练集和验证集
print(f"划分训练集({
1-VAL_RATIO:.0%})和验证集({
VAL_RATIO:.0%})...")
train_pairs, val_pairs = train_test_split(
image_label_pairs,
test_size=VAL_RATIO,
random_state=42 # 固定随机种子,确保划分结果可复现
)
print(f"训练集:{
len(train_pairs)} 个样本,验证集:{
len(val_pairs)} 个样本")
# 3. 创建输出目录结构
print("创建输出目录结构...")
# 训练集路径
train_img_dir = os.path.join(OUTPUT_DATA_DIR, "train", "images")
train_label_path = os.path.join(OUTPUT_DATA_DIR, "train", "labels.json")
# 验证集路径
val_img_dir = os.path.join(OUTPUT_DATA_DIR, "val", "images")
val_label_path = os.path.join(OUTPUT_DATA_DIR, "val", "labels.json")
# 创建目录(递归创建,已存在则忽略)
os.makedirs(train_img_dir, exist_ok=True)
os.makedirs(val_img_dir, exist_ok=True)
# 4. 复制图像文件并保存标签JSON
# 处理训练集
print("复制训练集图像并保存标签...")
for pair in train_pairs:
src_img = os.path.join(ORIGINAL_IMG_DIR, pair["file_name"])
dst_img = os.path.join(train_img_dir, pair["file_name"])
shutil.copy2(src_img, dst_img) # 保留文件元数据
# 保存训练集标签JSON
with open(train_label_path, "w", encoding="utf-8") as f:
json.dump(train_pairs, f, ensure_ascii=False, indent=2)
# 处理验证集
print("复制验证集图像并保存标签...")
for pair in val_pairs:
src_img = os.path.join(ORIGINAL_IMG_DIR, pair["file_name"])
dst_img = os.path.join(val_img_dir, pair["file_name"])
shutil.copy2(src_img, dst_img)
# 保存验证集标签JSON
with open(val_label_path, "w", encoding="utf-8") as f:
json.dump(val_pairs, f, ensure_ascii=False, indent=2)
print(f"数据集转换完成!输出路径:{
OUTPUT_DATA_DIR}")
print(f"训练集图像:{
train_img_dir}")
print(f"训练集标签:{
train_label_path}")
print(f"验证集图像:{
val_img_dir}")
print(f"验证集标签:{
val_label_path}")
if __name__ == "__main__":
main()
示例输出:
开始读取标签文件...
成功读取 6586 个有效图像-标签对
划分训练集(80%)和验证集(20%)...
训练集:5268 个样本,验证集:1318 个样本
创建输出目录结构...
复制训练集图像并保存标签...
复制验证集图像并保存标签...
数据集转换完成!输出路径:C:\Users\Virgil\Desktop\TrOCR\TrOCR_Dataset\formatted_dataset
训练集图像:C:\Users\Virgil\Desktop\TrOCR\TrOCR_Dataset\formatted_dataset\train\images
训练集标签:C:\Users\Virgil\Desktop\TrOCR\TrOCR_Dataset\formatted_dataset\train\labels.json
验证集图像:C:\Users\Virgil\Desktop\TrOCR\TrOCR_Dataset\formatted_dataset\val\images
验证集标签:C:\Users\Virgil\Desktop\TrOCR\TrOCR_Dataset\formatted_dataset\val\labels.json
(base) root@5de27e9cb8c1:/mnt/Virgil/TrOCR/ChineseDataset# python formet_split.py
开始读取标签文件...
成功读取 90000 个有效图像-标签对
划分训练集(80%)和验证集(20%)...
训练集:72000 个样本,验证集:18000 个样本
创建输出目录结构...
复制训练集图像并保存标签...
复制验证集图像并保存标签...
数据集转换完成!输出路径:/mnt/Virgil/TrOCR/ChineseDataset/9w_train
训练集图像:/mnt/Virgil/TrOCR/ChineseDataset/9w_train/train/images
训练集标签:/mnt/Virgil/TrOCR/ChineseDataset/9w_train/train/labels.json
验证集图像:/mnt/Virgil/TrOCR/ChineseDataset/9w_train/val/images
验证集标签:/mnt/Virgil/TrOCR/ChineseDataset/9w_train/val/labels.json
数据增强
二、TrOCR推理
TrOCR预训练权重下载
huggingface上预训练权重trocr-base-printed
以下是 trocr-base-printed
目录中各文件的意义:
config.json
:存储模型的配置信息,包括模型架构(如编码器和解码器层数、隐藏层维度等)、超参数(如激活函数类型等)。这些信息用于在创建模型实例时正确初始化模型结构。generation_config.json
:与模型生成相关的配置文件,例如设置生成文本时的参数,如最大生成长度、束搜索(beam search)的参数(如束宽num_beams
)、是否使用早停(early_stopping
)等,用于控制模型生成文本的过程。gitattributes
:用于定义 Git 仓库中文件的属性,例如可以指定某些文件的合并策略、是否进行文本换行符转换等,主要是在版本控制管理方面起作用。merges.txt
:在基于字节对编码(Byte - Pair Encoding, BPE)的分词器中,该文件记录了合并操作的规则。BPE 是一种将频繁出现的子词单元合并成新的词单元的算法,merges.txt
记录了这些合并的顺序和规则,用于构建分词器的词汇表。model.safetensors
:存储模型的权重参数。safetensors
是一种更安全、更轻量级的张量存储格式,相比传统的 PyTorch 模型权重存储方式,它在加载和保存时更高效,且能更好地防止恶意代码注入。preprocessor_config.json
:保存预处理器(如 TrOCR 中用于处理图像和文本的组件)的配置信息,定义了如何对输入的图像和文本进行预处理,包括图像的归一化参数、文本分词器的相关配置等。README.md
:通常是模型的说明文档,介绍模型的基本信息(如模型架构、用途)、训练数据、如何使用该预训练模型(包括示例代码)、模型性能指标等内容,是了解和使用模型的重要参考文档。special_tokens_map.json
:记录特殊标记(如起始标记<s>
、填充标记<pad>
、结束标记</s>
等)与它们在词汇表中对应索引的映射关系。这些特殊标记在模型处理文本输入和输出时起到重要作用,如标识句子开头、填充长度等。tokenizer_config.json
:分词器的配置文件,定义了分词器的行为,如使用的分词算法、词汇表大小、是否使用一些特定的分词规则(如是否进行字节对编码等),是分词器正确工作的配置依据。vocab.json
:存储分词器的词汇表,即所有词单元及其对应的索引。模型在处理文本时,会根据这个词汇表将文本转换为对应的数字索引序列,是文本向量化的重要依据。
对于单图TrOCR推理代码
修改image_path
变量路径
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from PIL import Image
# 加载处理器和模型
processor = TrOCRProcessor.from_pretrained("../trocr-base-printed")
model = VisionEncoderDecoderModel.from_pretrained("../trocr-base-printed")
# image_path = "img_test.png"
image_path = r"C:\Users\Virgil\Desktop\TrOCR\trocr\code_1\img_test.png"
try:
# 打开本地图片
image = Image.open(image_path).convert("RGB")
# 预处理图像
pixel_values = processor(images=image, return_tensors="pt").pixel_values
# 生成文本
generated_ids = model.generate(pixel_values)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# 输出识别结果
print("识别结果:", generated_text)
except FileNotFoundError:
print(f"未找到图片文件: {
image_path}")
except Image.UnidentifiedImageError as e:
print(f"无法识别图像: {
e}")
except Exception as e:
print(f"发生其他错误: {
e}")
对这个图片进行推理
识别结果: ACKNOWLEDGHENTS
并且输出相关信息:
Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig {
"attention_probs_dropout_prob": 0.0,
"encoder_stride": 16,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.0,
"hidden_size": 768,
"image_size": 384,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"model_type": "vit",
"num_attention_heads": 12,
"num_channels": 3,
"num_hidden_layers": 12,
"patch_size": 16,
"qkv_bias": false,
"transformers_version": "4.48.0"
}
Config of the decoder: <class 'transformers.models.trocr.modeling_trocr.TrOCRForCausalLM'> is overwritten by shared decoder config: TrOCRConfig {
"activation_dropout": 0.0,
"activation_function": "gelu",
"add_cross_attention": true,
"attention_dropout": 0.0,
"bos_token_id": 0,
"classifier_dropout": 0.0,
"cross_attention_hidden_size": 768,
"d_model": 1024,
"decoder_attention_heads": 16,
"decoder_ffn_dim": 4096,
"decoder_layerdrop": 0.0,
"decoder_layers": 12,
"decoder_start_token_id": 2,
"dropout": 0.1,
"eos_token_id": 2,
"init_std": 0.02,
"is_decoder": true,
"layernorm_embedding": true,
"max_position_embeddings": 512,
"model_type": "trocr",
"pad_token_id": 1,
这段输出信息主要包含了模型配置的覆盖情况以及模型权重初始化的相关提示
- 编码器配置覆盖信息
Config of the encoder: <class 'transformers.models.vit.modeling_vit.ViTModel'> is overwritten by shared encoder config: ViTConfig { "attention_probs_dropout_prob": 0.0, "encoder_stride": 16, "hidden_act": "gelu", "hidden_dropout_prob": 0.0, "hidden_size": 768, "image_size": 384, "initializer_range": 0.02, "intermediate_size": 3072, "layer_norm_eps": 1e-12, "model_type": "vit", "num_attention_heads": 12, "num_channels": 3, "num_hidden_layers": 12, "patch_size": 16, "qkv_bias": false, "transformers_version": "4.48.0" }
- 含义:这里表明编码器原本的配置(
transformers.models.vit.modeling_vit.ViTModel
)被一个共享的编码器配置(ViTConfig
)覆盖了。覆盖后的编码器配置包含了一系列超参数,例如:attention_probs_dropout_prob
:注意力概率的丢弃率,这里设置为 0.0,表示不进行丢弃。hidden_size
:隐藏层的维度大小,为 768。num_hidden_layers
:隐藏层的数量,为 12。
- 解码器配置覆盖信息
Config of the decoder: <class 'transformers.models.trocr.modeling_trocr.TrOCRForCausalLM'> is overwritten by shared decoder config: TrOCRConfig { "activation_dropout": 0.0, "activation_function": "gelu", "add_cross_attention": true, "attention_dropout": 0.0, "bos_token_id": 0, "classifier_dropout": 0.0, "cross_attention_hidden_size": 768, "d_model": 1024, "decoder_attention_heads": 16, "decoder_ffn_dim": 4096, "decoder_layerdrop": 0.0, "decoder_layers": 12, "decoder_start_token_id": 2, "dropout": 0.1, "eos_token_id": 2, "init_std": 0.02, "is_decoder": true, "layernorm_embedding": true, "max_position_embeddings": 512, "model_type": "trocr", "pad_token_id": 1, "scale_embedding": false, "transformers_version": "4.48.0", "use_cache": false, "use_learned_position_embeddings": true, "vocab_size": 50265 <