TrOCR-基于transformer模型的OCR手写文字识别

前期我们使用大量的篇幅介绍了手写数字识别与手写文字识别,当然那里主要使用的是CNN卷积神经网络,利用CNN卷积神经网络来训练文字识别的模型。

这里一旦提到OCR相关的技术,肯定第一个想到的便是CNN卷积神经网络,毕竟CNN卷积神经网络在计算机视觉任务上起到了至关重要的作用。有关CNN卷积神经网络的相关知识与文章,可以参考往期的文章内容。

但是随着transformer模型attention注意力机制进入计算机视觉任务,我们同样可以使用transformer来进行计算机视觉方面的任务,比如对象检测,对象分类,对象分割等,这里毕竟著名的模型VIT,Swin便是成功的把transformer的注意力机制应用到了计算机视觉任务,那么基于transformer模型的OCR识别任务,便是理所当然的了。

TrOCR是transformer OCR的简写,是microsoft发布的一个OCR识别模型,光看这个模型的名字就知道此模型基于transformer模型,其模型架构如下,完全采用了标准的transformer模型。

需要注意的一件事是,在进入编码器之前,图像的大小已调整为 384 384 分辨率。 这是因为 DeIT 模型统一了输入图片的尺寸。

TrOCR 预训练模型

TrOCR 系列中的预训练模型是根据大规模综合生成的数据进行训练的。 该数据集包括数亿张打印文本行的图像。官方存储库释放了预训练阶段的三个模型。

  1. TrOCR 微调模型预训练阶段结束后,模型在 IAM 手写文本图像和 SROIE 打印收据数据集上进行了微调。IAM 手写数据集包含手写文本的图像。 微调该数据集使模型比其他模型更好地识别手写文本。同样,SROIE 数据集由数千个图像样本组成。 在此数据集上微调的模型在识别印刷文本方面表现非常好。像预训练阶段模型一样,IAM 手写模型和SROIE 打印数据集模型也分别包含三个维度的模型:

  2. 使用TrOCR 来进行图片文字识别,我们可以直接使用GitHub开源代码来实现

import task,deit,trocr_models,torch,fairseq
from fairseq import utils
from fairseq_cli import generate
from PIL import Image
import torchvision.transforms as transforms
def init(model_path, beam=5):
    model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task( [model_path],arg_overrides={"beam": beam, "task": "text_recognition", "data": "", "fp16": False})
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model[0].to(device)
    img_transform = transforms.Compose([transforms.Resize((384, 384), interpolation=3), transforms.ToTensor(), transforms.Normalize(0.5, 0.5) ])
    generator = task.build_generator(model, cfg.generation, extra_gen_cls_kwargs={'lm_model': None, 'lm_weight': None} )
    bpe = task.build_bpe(cfg.bpe)
    return model, cfg, task, generator, bpe, img_transform, device
def preprocess(img_path, img_transform):
    im = Image.open(img_path).convert('RGB').resize((384, 384))
    im = img_transform(im).unsqueeze(0).to(device).float()
    sample = {  'net_input': {"imgs": im},}
    return sample
def get_text(cfg, generator, model, sample, bpe):
    decoder_output = task.inference_step(generator, model, sample, prefix_tokens=None, constraints=None)
    decoder_output = decoder_output[0][0]       
    hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=decoder_output["tokens"].int().cpu(), src_str="", alignment=decoder_output["alignment"]   align_dict=None,tgt_dict=model[0].decoder.dictionary,remove_bpe=cfg.common_eval.post_process, extra_symbols_to_ignore=generate.get_symbols_to_strip_from_output(generator), )
    detok_hypo_str = bpe.decode(hypo_str)
    return detok_hypo_str
if __name__ == '__main__':
    model_path = 'path/to/model'
    jpg_path = "path/to/pic"
    beam = 5
    model, cfg, task, generator, bpe, img_transform, device = init(model_path, beam)
    sample = preprocess(jpg_path, img_transform)
    text = get_text(cfg, generator, model, sample, bpe)
    print(text)

这里我们需要下载预训练模型,并传递一张需要的图片来进行识别即可。

当然既然是transformer模型,我们就可以使用hugging face的transformers库来实现上面的代码,且代码量就精简了很多。

!pip install transformers
'''
Installing collected packages: tokenizers, safetensors, huggingface-hub, transformers
Successfully installed huggingface-hub-0.16.4 safetensors-0.3.3 tokenizers-0.13.3 transformers-4.32.1
'''

这里我们首先需要安装transform库,并插入如下代码来进行TrOCR的文字识别。

from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import requests
from PIL import Image
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
image = Image.open('11.jpg').convert("RGB")

pixel_values = processor(image, return_tensors="pt").pixel_values
generated_ids = model.generate(pixel_values)

generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(generated_text)

首先需要使用TrOCRProcessor来进行图片的预处理,然后使用VisionEncoderDecoderModel建立一个OCR图片识别模型,然后打开一张需要识别的图片,当然图片需要使用processor进行图片的预处理操作,最后使用model函数进行图片的预测,预测完成后,就可以识别完整的文本文件了。

Downloading (…)rocessor_config.json: 100% 228/228 [00:00<00:00, 3.64kB/s]
Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.
Downloading (…)okenizer_config.json: 100% 1.12k/1.12k [00:00<00:00, 23.3kB/s]
Downloading (…)olve/main/vocab.json: 100% 899k/899k [00:00<00:00, 15.8MB/s]
Downloading (…)olve/main/merges.txt: 100% 456k/456k [00:00<00:00, 7.61MB/s]
Downloading (…)cial_tokens_map.json: 100% 772/772 [00:00<00:00, 21.5kB/s]
Downloading (…)lve/main/config.json: 100% 4.17k/4.17k [00:00<00:00, 130kB/s]
Downloading pytorch_model.bin: 100% 1.33G/1.33G [00:13<00:00, 155MB/s]
Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-base-handwritten and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Downloading (…)neration_config.json: 100% 190/190 [00:00<00:00, 10.3kB/s]
/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1254: UserWarning: Using the model-agnostic default `max_length` (=20) to control thegeneration length. We recommend setting `max_new_tokens` to control the maximum length of the generation.
  warnings.warn(

代码运行后,会自动下载相关模型,并输出最终的文本。

industry, " Mr. Brown commented icily. " Let us have a

这里我们可以看一下TrOCR的主要配置,可以看到整个数据库有50265个词表,且embedding维度为1024,每个encoder与decoder有12层的结构,且每个注意力层有16个头,虽然TrOCR完全复制了整个transformer模型,但是更改了一些参数。

vocab_size (int, defaults to 50265) — 模型的词表数量
d_model (int, defaults to 1024) — embedding的维度
decoder_layers (int, defaults to 12) — decoder layers的数量,这里默认12层,那么encoder也同样是12层.
decoder_attention_heads (int, defaults to 16) — 多头注意力机制的头数
decoder_ffn_dim (int, defaults to 4096) — feed forward前馈神经网络的维度
activation_function (str defaults to "gelu") —  "gelu", "relu", "silu" and "gelu_new" 激励函数
max_position_embeddings (int, defaults to 512) — 最大输入的sequence长度
############ 参考代码
https://github.com/microsoft/unilm/tree/master/trocr

展开阅读全文

页面更新:2024-03-29

标签:卷积   模型   神经网络   维度   编码器   注意力   图像   文本   文字   数据   图片

1 2 3 4 5

上滑加载更多 ↓
推荐阅读:
友情链接:
更多:

本站资料均由网友自行发布提供,仅用于学习交流。如有版权问题,请与我联系,QQ:4156828  

© CopyRight 2008-2024 All Rights Reserved. Powered By bs178.com 闽ICP备11008920号-3
闽公网安备35020302034844号

Top