MMOCR使用指南

原创
07/25 15:35
阅读数 1.5K

MMOCR是通用视觉框架OpenMMLab的光学字符识别器。

安装配置环境

MMOCR github主页:GitHub - open-mmlab/mmocr: OpenMMLab Text Detection, Recognition and Understanding Toolbox

pip install mmcv-full -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.11.0/index.html
pip install mmdet -i https://pypi.tuna.tsinghua.edu.cn/simple
pip install lmdb
pip install shapely
pip install rapidfuzz
pip install lanms
pip install pyclipper
pip install scikit-image
pip install imgaug

验证是否安装成功代码

import torch, torchvision
import mmcv
from mmcv.ops import get_compiling_cuda_version, get_compiler_version
import mmdet
import mmocr
from mmocr.utils.ocr import MMOCR

mmocr = MMOCR(det=None, recog='SAR', device='cpu')
print('mmocr载入成功')

文本检测与文本提取

import torch, torchvision
import mmcv
from mmcv.ops import get_compiling_cuda_version, get_compiler_version
import mmdet
import mmocr
from mmocr.utils.ocr import MMOCR

# mmocr = MMOCR(det=None, recog='SAR', device='cpu')
# print('mmocr载入成功')

if __name__ == '__main__':

    detector = MMOCR(det='TextSnake', recog='SAR', device='cuda')
    result = detector.readtext('demo/demo_densetext_det.jpg', output='output/demo_densetext_det.jpg')

文字分类

import torch, torchvision
import mmcv
from mmcv.ops import get_compiling_cuda_version, get_compiler_version
import mmdet
import mmocr
from mmocr.utils.ocr import MMOCR

# mmocr = MMOCR(det=None, recog='SAR', device='cpu')
# print('mmocr载入成功')

if __name__ == '__main__':

    detector = MMOCR(det='TextSnake', recog='SAR', kie='SDMGR', device='cuda')
    result = detector.readtext('data/wildreceipt/image_files/Image_1/0/0ea337776eb4a57010accaf2814ea7351770819b.jpeg', output='output/0ea337776eb4a57010accaf2814ea7351770819b.jpeg')
    print(result[0]['text'])

中文检测与提取

在mmocr主目录下新建文件夹/data/chineseocr/labels

进入该文件夹执行

wget http://download.openmmlab.com/mmocr/textrecog/sar/dict_printed_chinese_english_digits.txt
wget http://download.openmmlab.com/mmocr/data/font.TTF

下载字体和字库

import torch, torchvision
import mmcv
from mmcv.ops import get_compiling_cuda_version, get_compiler_version
import mmdet
import mmocr
from mmocr.utils.ocr import MMOCR

# mmocr = MMOCR(det=None, recog='SAR', device='cpu')
# print('mmocr载入成功')

if __name__ == '__main__':

    detector = MMOCR(det='TextSnake', recog='SAR_CN', device='cuda')
    result = detector.readtext('demo/demo_densetext_det.jpg', output='output/demo_densetext_det.jpg')

模型训练

Kaggle验证码文本识别

数据集下载地址:CAPTCHA Images | Kaggle

下载完成后,将samples文件夹下的图片放入mmocr主目录下的tests/data/ocr_toy_dataset/imgs目录下,图片样式大致如下

划分训练集和验证集

import pandas as pd
import os

print('imgs文件夹中的文件总数', len(os.listdir('tests/data/ocr_toy_dataset/imgs')))
df = pd.DataFrame()
# 获取所有图像的文件名
df['file_name'] = os.listdir('tests/data/ocr_toy_dataset/imgs')
# 由文件名提取文本内容标签
df['label'] = df['file_name'].apply(lambda x: x.split('.')[0])
# 随机打乱
df = df.sample(frac=1, random_state=666)
# 重排索引
df.reset_index(drop=True, inplace=True)
# 训练集
train_df = df.iloc[:800]
# 测试集
test_df = df.iloc[801:]
# 生成训练集标签
train_df.to_csv('tests/data/ocr_toy_dataset/train_label.txt', sep=' ', index=False, header=None)
# 生成测试集标签
test_df.to_csv('tests/data/ocr_toy_dataset/test_label.txt', sep=' ', index=False, header=None)
print('标签文件生成成功')

此时会在tests/data/ocr_toy_dataset目录下生成两个标签文件train_label.txt和test_label.txt,内容大致如下

2wc38.png 2wc38
y5n6d.png y5n6d
men4f.png men4f
57b27.png 57b27
x3deb.png x3deb
f858x.png f858x
xxw44.png xxw44

下载toy_data.py,放入configs/_base_/recog_datasets目录

wget https://download.openmmlab.com/mmocr/tutorial/toy_data.py

内容如下

dataset_type = 'OCRDataset'

root = 'tests/data/ocr_toy_dataset'
img_prefix = f'{root}/imgs'
train_anno_file1 = f'{root}/train_label.txt'

train1 = dict(
    type=dataset_type,
    img_prefix=img_prefix,
    ann_file=train_anno_file1,
    loader=dict(
        type='HardDiskLoader',
        repeat=10, # 与训练轮次相关
        parser=dict(
            type='LineStrParser',
            keys=['filename', 'text'],
            keys_idx=[0, 1],
            separator=' ')),
    pipeline=None,
    test_mode=False)

test_anno_file1 = f'{root}/test_label.txt'
test = dict(
    type=dataset_type,
    img_prefix=img_prefix,
    ann_file=test_anno_file1,
    loader=dict(
        type='HardDiskLoader',
        repeat=1,
        parser=dict(
            type='LineStrParser',
            keys=['filename', 'text'],
            keys_idx=[0, 1],
            separator=' ')),
    pipeline=None,
    test_mode=True)

train_list = [train1]

test_list = [test]

修改configs/textrecog/sar目录下的sar_r31_parallel_decoder_toy_dataset.py文件内容如下

_base_ = [
    '../../_base_/default_runtime.py', '../../_base_/recog_models/sar.py',
    '../../_base_/schedules/schedule_adam_step_5e.py',
    '../../_base_/recog_pipelines/sar_pipeline.py',
    '../../_base_/recog_datasets/toy_data.py'
]

train_list = {{_base_.train_list}}
test_list = {{_base_.test_list}}

train_pipeline = {{_base_.train_pipeline}}
test_pipeline = {{_base_.test_pipeline}}

data = dict(
    workers_per_gpu=2,
    samples_per_gpu=8,
    train=dict(
        type='UniformConcatDataset',
        datasets=train_list,
        pipeline=train_pipeline),
    val=dict(
        type='UniformConcatDataset',
        datasets=test_list,
        pipeline=test_pipeline),
    test=dict(
        type='UniformConcatDataset',
        datasets=test_list,
        pipeline=test_pipeline))

evaluation = dict(interval=1, metric='acc')

default_runtime.py是配置训练时间和迭代的轮次的,也就是epoch;sar.py是配置算法模型的;schedule_adam_step_5e.py是配置学习率和优化器的;sar_pipeline.py是配置工作流的。

训练代码

from mmcv import Config
from mmdet.apis import set_random_seed
import mmcv
from mmocr.datasets import build_dataset
from mmocr.models import build_detector
from mmocr.apis import train_detector, init_detector, model_inference
import os.path as osp

cfg = Config.fromfile('./configs/textrecog/sar/sar_r31_parallel_decoder_toy_dataset.py')
# 存放输出结果和日志目录
cfg.work_dir = './demo/tutorial_exps'
cfg.optimizer.lr = 0.001 / 8
cfg.lr_config.warmup = None
# 每训练500张图片记录一次日志
cfg.log_config.interval = 500

cfg.seed = 0
set_random_seed(0, deterministic=False)
cfg.gpu_ids = range(1)

print(cfg.pretty_text)
# 建立数据集
datasets = [build_dataset(cfg.data.train)]
# 建立模型
model = build_detector(cfg.model, train_cfg=cfg.get('train_cfg'), test_cfg=cfg.get('test_cfg'))
# 创建新目录,保存训练结果
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
# 开始训练
train_detector(model, datasets, cfg, distributed=False, validate=True)

对测试图片进行推理代码

from mmcv import Config
from mmdet.apis import set_random_seed
import mmcv
from mmocr.datasets import build_dataset
from mmocr.models import build_detector
from mmocr.apis import init_detector, model_inference

cfg = Config.fromfile('./configs/textrecog/sar/sar_r31_parallel_decoder_toy_dataset.py')
# 存放输出结果和日志目录
cfg.work_dir = './demo/tutorial_exps'
cfg.optimizer.lr = 0.001 / 8
cfg.lr_config.warmup = None
# 每训练500张图片记录一次日志
cfg.log_config.interval = 500

cfg.seed = 0
set_random_seed(0, deterministic=False)
cfg.gpu_ids = range(1)

checkpoint = './demo/tutorial_exps/epoch_5.pth'

model = init_detector(cfg, checkpoint, device='cuda')
input_path = 'tests/data/ocr_toy_dataset/imgs/f6ne5.png'
result = model_inference(model, input_path)
out_img = model.show_result(input_path, result, out_file='output/demo_f6ne5.jpg', show=False)

推理结果

不规则文字识别方法之 SAR: Show, Attend and Read

对于不规则(曲形文字、艺术字等)的识别,作者没有采用基于修正(rectification)的策略,而是提出利用基于不规则文字而构造的(tailored)基于二维注意力机制模块(2D attention module)的模型来定位和逐个识别字符的弱监督方法。之所以说是弱监督是由于该模型可以在不用额外的监督信息就可以定位单个字符(即不需要字符级别或像素级别的标注)。

图像送入主干网(backbone),经过31层的ResNet卷积得到的特征图(feature maps),分别送入编解码器以及注意力机制模块,最终输出识别的字符串

ResNet CNN 模块

  1. 共31层的ResNet,对于每个残差模块,如果输入-输出维度不同,则使用1x1卷积做projection shortcut;
  2. 同时使用了2x2最大池化和1x2最大池化(为了保留更多水平轴上的信息,对于'i', 'l'这种偏瘦的字符增益较大)
  3. 最终输出V为 H x W x D (高、宽、通道数)的二维特征图,以用于提取图像的整体特征(holistic feature);
  4. 在保持原输入图像宽高比的基础上,将图像缩放至固定高度(论文中是48)和随之变化的宽度,因此得到的特征图的宽度也是不固定的;
import torch
import torch.nn as nn

__all__ = ['basicblock', 'backbone']
    
class basicblock(nn.Module):
    # 残差模块
    def __init__(self, depth_in, output_dim, kernel_size, stride):
        super(basicblock, self).__init__()
        self.identity = nn.Identity()
        self.conv_res = nn.Conv2d(depth_in, output_dim, kernel_size=1, stride=1)
        self.batchnorm_res = nn.BatchNorm2d(output_dim)
        self.conv1 = nn.Conv2d(depth_in, output_dim, kernel_size=kernel_size, stride=stride, padding=1)
        self.conv2 = nn.Conv2d(output_dim, output_dim, kernel_size=kernel_size, stride=stride, padding=1)
        self.batchnorm1 = nn.BatchNorm2d(output_dim)
        self.batchnorm2 = nn.BatchNorm2d(output_dim)
        self.relu1 = nn.ReLU(inplace=True)
        self.relu2 = nn.ReLU(inplace=True)
        self.depth_in = depth_in
        self.output_dim = output_dim

    def forward(self, x):
        # create shortcut path
        if self.depth_in == self.output_dim:
            residual = self.identity(x)
        else:
            # 如果输入 - 输出维度不同,则使用1x1卷积做projection shortcut
            residual = self.conv_res(x)
            residual = self.batchnorm_res(residual)
        out = self.conv1(x)
        out = self.batchnorm1(out)
        out = self.relu1(out)
        out = self.conv2(out)
        out = self.batchnorm2(out)

        out += residual
        out = self.relu2(out)

        return out

class backbone(nn.Module):
    # 主干网络

    def __init__(self, input_dim):
        super(backbone, self).__init__()
        self.conv1 = nn.Conv2d(input_dim, 64, kernel_size=3, stride=1, padding=1)
        self.batchnorm1 = nn.BatchNorm2d(64)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
        self.batchnorm2 = nn.BatchNorm2d(128)
        self.relu2 = nn.ReLU(inplace=True)
        # 2*2最大池化
        self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
        # Block 1 starts
        self.basicblock1 = basicblock(128, 256, kernel_size=3, stride=1)
        # Block 1 ends
        self.conv3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.batchnorm3 = nn.BatchNorm2d(256)
        self.relu3 = nn.ReLU(inplace=True)
        # 2*2最大池化
        self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        # Block 2 starts
        self.basicblock2 = basicblock(256, 256, kernel_size=3, stride=1)
        self.basicblock3 = basicblock(256, 256, kernel_size=3, stride=1)
        # Block 2 ends
        self.conv4 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
        self.batchnorm4 = nn.BatchNorm2d(256)
        self.relu4 = nn.ReLU(inplace=True)
        # 1*2最大池化
        self.maxpool3 = nn.MaxPool2d(kernel_size=(1, 2), stride=(1, 2))
        # Block 5 starts
        self.basicblock4 = basicblock(256, 512, kernel_size=3, stride=1)
        self.basicblock5 = basicblock(512, 512, kernel_size=3, stride=1)
        self.basicblock6 = basicblock(512, 512, kernel_size=3, stride=1)
        self.basicblock7 = basicblock(512, 512, kernel_size=3, stride=1)
        self.basicblock8 = basicblock(512, 512, kernel_size=3, stride=1)
        # Block 5 ends
        self.conv5 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
        self.batchnorm5 = nn.BatchNorm2d(512)
        self.relu5 = nn.ReLU(inplace=True)
        # Block 3 starts
        self.basicblock9 = basicblock(512, 512, kernel_size=3, stride=1)
        self.basicblock10 = basicblock(512, 512, kernel_size=3, stride=1)
        self.basicblock11 = basicblock(512, 512, kernel_size=3, stride=1)
        # Block 3 ends
        self.conv6 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
        self.batchnorm6 = nn.BatchNorm2d(512)
        self.relu6 = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv1(x)
        x = self.batchnorm1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.batchnorm2(x)
        x = self.relu2(x)
        x = self.maxpool1(x)
        x = self.basicblock1(x)
        x = self.conv3(x)
        x = self.batchnorm3(x)
        x = self.relu3(x)
        x = self.maxpool2(x)
        x = self.basicblock2(x)
        x = self.basicblock3(x)
        x = self.conv4(x)
        x = self.batchnorm4(x)
        x = self.relu4(x)
        x = self.maxpool3(x)
        x = self.basicblock4(x)
        x = self.basicblock5(x)
        x = self.basicblock6(x)
        x = self.basicblock7(x)
        x = self.basicblock8(x)
        x = self.conv5(x)
        x = self.batchnorm5(x)
        x = self.relu5(x)
        x = self.basicblock9(x)
        x = self.basicblock10(x)
        x = self.basicblock11(x)
        x = self.conv6(x)
        x = self.batchnorm6(x)
        x = self.relu6(x)

        return x

# unit test
if __name__ == '__main__':

    batch_size = 32
    Height = 48
    Width = 160
    Channel = 3

    input_images = torch.randn(batch_size, Channel, Height, Width)
    model = backbone(Channel)
    output_features = model(input_images)

    print("Input size is:", input_images.shape)
    print("Output feature map size is:", output_features.shape)

运行结果

Input size is: torch.Size([32, 3, 48, 160])
Output feature map size is: torch.Size([32, 512, 12, 20])

LSTM 编码器-解码器 模块

编码,就是将输入序列转化成一个固定长度的向量;解码,就是将之前生成的固定向量再转化成输出序列。 当前 time step 的 hidden state 是由上一 time step 的state和当前 time step 输入决定的,也就是获得了各个时间段的隐藏层以后,再将隐藏层的信息汇总,生成最后的语义向量C;通常传统的 encoder-decoder 结构将 encoder 最后的隐藏层作为语义向量C,作为 decoder 的输入;

  • 不改变原文字图片(不修正)
  • 编码器encoder:

  1. 2层,每层各512个hidden state的LSTM模型;
  2. 每一个time step编码器的一项输入(图中下方)是CNN得到的二维特征图的第 i 列经过垂直方向最大池化的特征信息
  3. 经过W(特征图的宽)个time step后,第二层LSTM的最后一个hidden state   就是输入图像的一个固定尺寸的特征表示,称为 holistic feature;

有关LSTM的内容请参考Tensorflow深度学习算法整理(二) 中的长短期记忆网络,不过它这里输入LSTM的是特征图的每一列池化后的值,而在Tensorflow那边是一个一个的文字或者字符。

import torch
import torch.nn as nn

__all__ = ['encoder']

class encoder(nn.Module):
    # LSTM编码器

    def __init__(self, H, C, hidden_units=512, layers=2, keep_prob=1.0, device='cpu'):
        super(encoder, self).__init__()
        self.maxpool = nn.MaxPool2d(kernel_size=(H, 1), stride=1)
        self.lstm = nn.LSTM(input_size=C, hidden_size=hidden_units, num_layers=layers, batch_first=True, dropout=keep_prob)
        # 层数
        self.layers = layers
        # 序列数
        self.hidden_units = hidden_units
        self.device = device

    def forward(self, x):
        self.lstm.flatten_parameters()
        # x is feature map in [batch, C, H, W]
        # 初始化两个状态
        h_0 = torch.zeros(self.layers * 1, x.size(0), self.hidden_units).to(self.device)
        c_0 = torch.zeros(self.layers * 1, x.size(0), self.hidden_units).to(self.device)
        # 先池化
        x = self.maxpool(x)  # [batch, C, 1, W]
        x = torch.squeeze(x)  # [batch, C, W]
        if len(x.size()) == 2:  # [C, W]
            x = x.unsqueeze(0)  # [batch, C, W]
        x = x.permute(0, 2, 1)  # [batch, W, C]
        # 将池化后的feature map的每一列输入lstm网络
        _, (h, _) = self.lstm(x, (h_0, c_0))  # h with shape [layers*1, batch, hidden_uints]

        return h[-1]  # shape [batch, hidden_units]

# unit test
if __name__ == '__main__':

    batch_size = 32
    Height = 48
    Width = 160
    Channel = 3
    input_feature = torch.randn(batch_size, Channel, Height, Width)
    print("Input feature size is:", input_feature.shape)

    encoder_model = encoder(Height, Channel, hidden_units=512, layers=2, keep_prob=1.0)
    output_encoder = encoder_model(input_feature)

    print("Output feature of encoder size is:", output_encoder.shape) # (batch, hidden_units)

运行结果

Input feature size is: torch.Size([32, 3, 48, 160])
Output feature of encoder size is: torch.Size([32, 512])
  • 解码器decoder:

  1. 2层,每层各512个hidden state的LSTM模型;
  2. 编码器和解码器之间不共享参数;
  3. (初始化的输入), "START" token,以及前一层 的输出,依次作为当前step的输入,直到被"END" token终止;
  4. 所有的LSTM输入都是经过one-hot向量表示后,再经过一个线性变化  函数;
  5. 训练阶段,解码器LSTM的输入由 ground truth 的字符序列代替;
  6. 每一个step的输出y_i 由当前step的 hidden state 和attention的输出作为\phi() 函数的输入得到:
    1. y_t = \phi(h'_t, c_t) = softmax(W_o[h'_t;c_t]), 其中h'_t 是当前的 hidden state,c_t是attention模块的输出, W_o是一个线性变化,将特征嵌入输出空间的94个类别(10个数字,26*2 个字符,31个标点符号);
class decoder(nn.Module):
    # LSTM解码器

    def __init__(self, output_classes, H, W, D=512, hidden_units=512, seq_len=40, device='cpu'):
        super(decoder, self).__init__()
        '''
        output_classes: 解码后的分类数
        H: 特征图高
        W: 特征图宽
        D: 特征图通道数
        hidden_units: 序列数
        seq_len: 输出的序列化长度
        '''
        self.linear1 = nn.Linear(output_classes, hidden_units)
        self.lstmcell1 = [nn.LSTMCell(hidden_units, hidden_units) for i in range(seq_len + 1)]
        self.lstmcell2 = [nn.LSTMCell(hidden_units, hidden_units) for i in range(seq_len + 1)]
        self.attention = attention(hidden_units, H, W, D)
        self.linear2 = nn.Linear(hidden_units + D, output_classes)
        self.softmax = nn.LogSoftmax(dim=1)
        self.seq_len = seq_len
        self.START_TOKEN = output_classes - 3  # Same as END TOKEN
        self.output_classes = output_classes
        self.hidden_units = hidden_units
        self.device = device

        self.lstmcell1 = torch.nn.ModuleList(self.lstmcell1)
        self.lstmcell2 = torch.nn.ModuleList(self.lstmcell2)

    def forward(self, hw, y, V):
        '''
        hw: 编码后的向量特征 [batch, hidden_units]
        y: 标签的one-hot编码 [batch, seq, output_classes]
        V: 主干网输出的特征图 [batch, D, H, W]
        '''
        outputs = []
        attention_weights = []
        batch_size = hw.shape[0]
        # 初始化一个one-hot输出编码
        y_onehot = torch.zeros(batch_size, self.output_classes).to(self.device)
        for t in range(self.seq_len + 1):
            if t == 0:
                # step为0的时候输入编码器输出的向量
                inputs_y = hw  # size [batch, hidden_units]
                # LSTM layer 1 initialization:
                hx_1 = torch.zeros(batch_size, self.hidden_units).to(self.device)  # initial h0_1
                cx_1 = torch.zeros(batch_size, self.hidden_units).to(self.device)  # initial c0_1
                # LSTM layer 2 initialization:
                hx_2 = torch.zeros(batch_size, self.hidden_units).to(self.device)  # initial h0_2
                cx_2 = torch.zeros(batch_size, self.hidden_units).to(self.device)  # initial c0_2
            elif t == 1:
                y_onehot.zero_()
                y_onehot[:, self.START_TOKEN] = 1.0
                # step为1的时候输入标签的one-hot编码
                inputs_y = y_onehot
                inputs_y = self.linear1(inputs_y)  # [batch, hidden_units]
            else:
                if self.training:
                    # 继续输入标签的one-hot编码
                    inputs_y = y[:, t-2, :]  # [batch, output_classes]
                else:
                    # greedy search for now - beam search to be implemented!
                    index = torch.argmax(outputs[t-1], dim=-1) # [batch]
                    index = index.unsqueeze(1)  # [batch, 1]
                    y_onehot.zero_()
                    inputs_y = y_onehot.scatter_(1, index, 1) # [batch, output_classes]
                # 经过一个全连接转成序列数的向量
                inputs_y = self.linear1(inputs_y)  # [batch, hidden_units_encoder]

            # 将inputs送入LSTM网络
            hx_1, cx_1 = self.lstmcell1[t](inputs_y, (hx_1, cx_1))
            hx_2, cx_2 = self.lstmcell2[t](hx_1, (hx_2, cx_2))
            # 将主干网络输出的特征图以及LSTM网络该step的输出送入注意力机制中
            glimpse, att_weights = self.attention(hx_2, V) # [batch, D], [batch, 1, H, W]
            # 拼接LSTM网络与注意力机制的输出
            combine = torch.cat((hx_2, glimpse), dim=1) # [batch, hidden_units_decoder+D]
            # 拼接后的结果经过一个全连接转成分类数的向量
            out = self.linear2(combine)  # [batch, output_classes]
            # 将该向量转成分类概率
            out = self.softmax(out)  # [batch, output_classes]
            outputs.append(out)
            attention_weights.append(att_weights)

        outputs = outputs[1:]  # [seq_len, batch, output_classes]
        attention_weights = attention_weights[1:]  # [seq_len, batch, 1, H, W]
        outputs = torch.stack(outputs)  # [seq_len, batch, output_classes]
        outputs = outputs.permute(1, 0, 2)  # [batch, seq_len, output_classes]
        attention_weights = torch.stack(attention_weights)  # [seq_len, batch, 1, H, W]
        attention_weights = attention_weights.permute(1, 0, 2, 3, 4)  # [batch, seq_len, 1, H, W]

        return outputs, attention_weights

我们再来看一下注意力机制

  • 2D Attention 模块:

传统的 2D Attention 模块独立处理每一个位置,不能很好的利用二维的空间信息,作者提出针对曲形文字的 tailored 2D Attention 机制:
输入由feature map V和 LSTM解码器的hidden state h' 组成:

  1. h' 经过 1x1 卷积维度从L转换至d,再通过tile操作(堆叠广播 H x W次)形成 H x W x d 的特征图;
  2. feature map 经过 stride=1, padding=1, 3x3卷积后输出位 H x W x d  的特征图;
  3. 上述两个特征图相加后经过 tanh 形成新的特征图(H x W x d), 再经过 1x1 卷积和softmax得到attention weights 这样一个 H x W 的激活图\alpha ;
  4. 最终\alpha 和 原特征图V 加权求和得到最后的 glimpse c (1 x 1 x D);

g_{ij}=tanh(W_vv_{ij} +\sum_{p,q\in N_{ij}}\tilde{W}_{p-i,q-j} \cdot v_{pq}+W_h h'_t)

\alpha_{ij}=softmax(W^T_g \cdot g_{ij})

c_t = \sum_{i,j}\alpha_{ij}\cdot v_{ij}, i = 1,2,...,H; j=1,2,...,W.

其中v_{ij} 是 V(feature map)在(i, j) 处 的特征向量,  N{ij}是该位置周边的8个相邻点; h'_t是 time step t 的LSTM解码器的hidden state;W_v, W_h, \tilde{W}_s 是可学习的线性变换;\alpha_{ij} 是(i, j) 处的注意力权重(attention weight); c_t是当下位置特征的加权和,即开始结构图中的glimpse;
当计算v_{ij}的权重时,引入\sum_{p,q\in N_{ij}}\tilde{W}_{p-i,q-j} \cdot v_{pq} ,从而有效利用了当前位置周围的二维空间信息;

class attention(nn.Module):
    # 注意力机制

    def __init__(self, hidden_units, H, W, D):
        super(attention, self).__init__()
        '''
        hidden_units: 序列数
        H: 特征图的高
        W: 特征图的宽
        D: 特征图的通道数
        '''
        self.conv1 = nn.Conv2d(hidden_units, D, kernel_size=1, stride=1)
        self.conv2 = nn.Conv2d(D, D, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv2d(D, 1, kernel_size=1, stride=1)
        self.dropout = nn.Dropout(0.5)
        self.softmax = nn.Softmax(dim=-1)
        self.H = H
        self.W = W
        self.D = D

    def forward(self, h, feature_map):
        '''
        h: LSTM解码器每一个step的输出 [batch, hidden_units]
        feature_map: 特征图 [batch, channel, H, W]
        '''
        # reshape hidden state [batch, hidden_units] to [batch, hidden_units, 1, 1]
        h = h.unsqueeze(2)
        h = h.unsqueeze(3)
        # 1*1的卷积
        h = self.conv1(h)  # [batch, D, 1, 1]
        # tile操作(堆叠广播H*W次)
        h = h.repeat(1, 1, self.H, self.W)  # tiling to [batch, D, H, W]
        feature_map_origin = feature_map
        # 3*3的卷积
        feature_map = self.conv2(feature_map)  # [batch, D, H, W]
        # 两个特征图相加后经过tanh形成新的特征图,再经过 1*1 卷积
        combine = self.conv3(self.dropout(torch.tanh(feature_map + h)))  # [batch, 1, H, W]
        combine_flat = combine.view(combine.size(0), -1)  # resize to [batch, H*W]
        # 经过softmax得到attention weights的激活图
        attention_weights = self.softmax(combine_flat)  # [batch, H*W]
        attention_weights = attention_weights.view(combine.size())  # [batch, 1, H, W]
        # 该激活图与原特征图V加权求和得到最后的 glimpse
        glimpse = feature_map_origin * attention_weights.repeat(1, self.D, 1, 1)  # [batch, D, H, W]
        glimpse = torch.sum(glimpse, dim=(2, 3))  # [batch, D]

        return glimpse, attention_weights

测试代码

import torch
import torch.nn as nn

__all__ = ['word_embedding', 'attention', 'decoder']

class word_embedding(nn.Module):
    def __init__(self, output_classes, embedding_dim):
        super(word_embedding, self).__init__()
        '''
        output_classes: number of output classes for the one hot encoding of a word
        embedding_dim: embedding dimension for a word
        '''
        self.linear = nn.Linear(output_classes, embedding_dim) # linear transformation

    def forward(self, x):
        x = self.linear(x)

        return x

# unit test
if __name__ == '__main__':

    batch_size = 2
    Height = 48
    Width = 160
    Channel = 512
    output_classes = 94
    embedding_dim = 512
    hidden_units = 512
    layers_decoder = 2
    seq_len = 40
    # 模拟标签的one-hot编码
    one_hot_embedding = torch.randn(batch_size, output_classes)
    one_hot_embedding[one_hot_embedding > 0] = torch.ones(1)
    one_hot_embedding[one_hot_embedding < 0] = torch.zeros(1)
    print("Word embedding size is:", one_hot_embedding.shape)
    # 将该one-hot编码通过全连接层进行通道数变换
    embedding_model = word_embedding(output_classes, embedding_dim)
    embedding_transform = embedding_model(one_hot_embedding)
    print("Embedding transform size is:", embedding_transform.shape)
    # 模拟一个编码器输出的向量
    hw = torch.randn(batch_size, hidden_units)
    # 模拟一个特征图
    feature_map = torch.randn(batch_size, Channel, Height, Width)
    print("Feature map size is:", feature_map.shape)
    # 创建注意力模型对象
    attention_model = attention(hidden_units, Height, Width, Channel)
    # 将编码器输出的向量与特征图送入注意力模型,获取glimpse和attention_weights激活图
    glimpse, attention_weights = attention_model(hw, feature_map)
    print("Glimpse size is:", glimpse.shape)
    print("Attention weight size is:", attention_weights.shape)
    # 模拟标签
    label = torch.randn(batch_size, seq_len, output_classes)
    # 创建一个解码器对象
    decoder_model = decoder(output_classes, Height, Width, Channel, hidden_units, seq_len)
    # 对编码器输出向量,标签,特征图进行解码,获取分类概率
    outputs, attention_weights = decoder_model(hw, label, feature_map)
    print("Output size is:", outputs.shape)
    print("Attention_weights size is:", attention_weights.shape)

运行结果

Word embedding size is: torch.Size([2, 94])
Embedding transform size is: torch.Size([2, 512])
Feature map size is: torch.Size([2, 512, 48, 160])
Glimpse size is: torch.Size([2, 512])
Attention weight size is: torch.Size([2, 1, 48, 160])
Output size is: torch.Size([2, 40, 94])
Attention_weights size is: torch.Size([2, 40, 1, 48, 160])

 

展开阅读全文
加载中

作者的其它热门文章

打赏
0
0 收藏
分享
打赏
0 评论
0 收藏
0
分享
返回顶部
顶部