前期我们使用大量的篇幅介绍了手写数字识别与手写文字识别,当然那里主要使用的是CNN卷积神经网络,利用CNN卷积神经网络来训练文字识别的模型。
这里一旦提到OCR相关的技术,肯定第一个想到的便是CNN卷积神经网络,毕竟CNN卷积神经网络在计算机视觉任务上起到了至关重要的作用。有关CNN卷积神经网络的相关知识与文章,可以参考往期的文章内容。
但是随着transformer模型attention注意力机制进入计算机视觉任务,我们同样可以使用transformer来进行计算机视觉方面的任务,比如对象检测,对象分类,对象分割等,这里毕竟著名的模型VIT,Swin便是成功的把transformer的注意力机制应用到了计算机视觉任务,那么基于transformer模型的OCR识别任务,便是理所当然的了。
TrOCR是transformer OCR的简写,是microsoft发布的一个OCR识别模型,光看这个模型的名字就知道此模型基于transformer模型,其模型架构如下,完全采用了标准的transformer模型。
需要注意的一件事是,在进入编码器之前,图像的大小已调整为 384 384 分辨率。 这是因为 DeIT 模型统一了输入图片的尺寸。
TrOCR 预训练模型
TrOCR 系列中的预训练模型是根据大规模综合生成的数据进行训练的。 该数据集包括数亿张打印文本行的图像。官方存储库释放了预训练阶段的三个模型。
import task,deit,trocr_models,torch,fairseq
from fairseq import utils
from fairseq_cli import generate
from PIL import Image
import torchvision.transforms as transforms
def init(model_path, beam=5):
model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task( [model_path],arg_overrides={"beam": beam, "task": "text_recognition", "data": "", "fp16": False})
device = "cuda" if torch.cuda.is_available() else "cpu"
model[0].to(device)
img_transform = transforms.Compose([transforms.Resize((384, 384), interpolation=3), transforms.ToTensor(), transforms.Normalize(0.5, 0.5) ])
generator = task.build_generator(model, cfg.generation, extra_gen_cls_kwargs={'lm_model': None, 'lm_weight': None} )
bpe = task.build_bpe(cfg.bpe)
return model, cfg, task, generator, bpe, img_transform, device
def preprocess(img_path, img_transform):
im = Image.open(img_path).convert('RGB').resize((384, 384))
im = img_transform(im).unsqueeze(0).to(device).float()
sample = { 'net_input': {"imgs": im},}
return sample
def get_text(cfg, generator, model, sample, bpe):
decoder_output = task.inference_step(generator, model, sample, prefix_tokens=None, constraints=None)
decoder_output = decoder_output[0][0]
hypo_tokens, hypo_str, alignment = utils.post_process_prediction( hypo_tokens=decoder_output["tokens"].int().cpu(), src_str="", alignment=decoder_output["alignment"] align_dict=None,tgt_dict=model[0].decoder.dictionary,remove_bpe=cfg.common_eval.post_process, extra_symbols_to_ignore=generate.get_symbols_to_strip_from_output(generator), )
detok_hypo_str = bpe.decode(hypo_str)
return detok_hypo_str
if __name__ == '__main__':
model_path = 'path/to/model'
jpg_path = "path/to/pic"
beam = 5
model, cfg, task, generator, bpe, img_transform, device = init(model_path, beam)
sample = preprocess(jpg_path, img_transform)
text = get_text(cfg, generator, model, sample, bpe)
print(text)
这里我们需要下载预训练模型,并传递一张需要的图片来进行识别即可。
当然既然是transformer模型,我们就可以使用hugging face的transformers库来实现上面的代码,且代码量就精简了很多。
!pip install transformers
'''
Installing collected packages: tokenizers, safetensors, huggingface-hub, transformers
Successfully installed huggingface-hub-0.16.4 safetensors-0.3.3 tokenizers-0.13.3 transformers-4.32.1
'''
这里我们首先需要安装transform库,并插入如下代码来进行TrOCR的文字识别。
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
import requests
from PIL import Image
processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-handwritten")
model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-handwritten")
image = Image.open('11.jpg').convert("RGB")
pixel_values = processor(image, return_tensors="pt").pixel_values
generated_ids = model.generate(pixel_values)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(generated_text)
首先需要使用TrOCRProcessor来进行图片的预处理,然后使用VisionEncoderDecoderModel建立一个OCR图片识别模型,然后打开一张需要识别的图片,当然图片需要使用processor进行图片的预处理操作,最后使用model函数进行图片的预测,预测完成后,就可以识别完整的文本文件了。
Downloading (…)rocessor_config.json: 100% 228/228 [00:00<00:00, 3.64kB/s]
Could not find image processor class in the image processor config or the model config. Loading based on pattern matching with the model's feature extractor configuration.
Downloading (…)okenizer_config.json: 100% 1.12k/1.12k [00:00<00:00, 23.3kB/s]
Downloading (…)olve/main/vocab.json: 100% 899k/899k [00:00<00:00, 15.8MB/s]
Downloading (…)olve/main/merges.txt: 100% 456k/456k [00:00<00:00, 7.61MB/s]
Downloading (…)cial_tokens_map.json: 100% 772/772 [00:00<00:00, 21.5kB/s]
Downloading (…)lve/main/config.json: 100% 4.17k/4.17k [00:00<00:00, 130kB/s]
Downloading pytorch_model.bin: 100% 1.33G/1.33G [00:13<00:00, 155MB/s]
Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-base-handwritten and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Downloading (…)neration_config.json: 100% 190/190 [00:00<00:00, 10.3kB/s]
/usr/local/lib/python3.10/dist-packages/transformers/generation/utils.py:1254: UserWarning: Using the model-agnostic default `max_length` (=20) to control thegeneration length. We recommend setting `max_new_tokens` to control the maximum length of the generation.
warnings.warn(
代码运行后,会自动下载相关模型,并输出最终的文本。
industry, " Mr. Brown commented icily. " Let us have a
这里我们可以看一下TrOCR的主要配置,可以看到整个数据库有50265个词表,且embedding维度为1024,每个encoder与decoder有12层的结构,且每个注意力层有16个头,虽然TrOCR完全复制了整个transformer模型,但是更改了一些参数。
vocab_size (int, defaults to 50265) — 模型的词表数量
d_model (int, defaults to 1024) — embedding的维度
decoder_layers (int, defaults to 12) — decoder layers的数量,这里默认12层,那么encoder也同样是12层.
decoder_attention_heads (int, defaults to 16) — 多头注意力机制的头数
decoder_ffn_dim (int, defaults to 4096) — feed forward前馈神经网络的维度
activation_function (str defaults to "gelu") — "gelu", "relu", "silu" and "gelu_new" 激励函数
max_position_embeddings (int, defaults to 512) — 最大输入的sequence长度
############ 参考代码
https://github.com/microsoft/unilm/tree/master/trocr
页面更新:2024-03-29
本站资料均由网友自行发布提供,仅用于学习交流。如有版权问题,请与我联系,QQ:4156828
© CopyRight 2008-2024 All Rights Reserved. Powered By bs178.com 闽ICP备11008920号-3
闽公网安备35020302034844号