簡単にOCR処理ができるEasyOCRをGoogle Colabで試してみた

OCRやってみたいです。

EasyOCRを使えば簡単に使用できるぞ

日本語対応しているOCRのオープンソースがあったので試してみました。

https://github.com/JaidedAI/EasyOCR

動作環境

環境はGoogle Colabを使用しました。

動作検証

内部のコードはPyTorchベースで実装されているのでGPUで動作するため、RUNTIMEはGPUで動作させます。

Runtime->Change runtime type で下記の画面になるのでHardware acceleratorをGPUに設定します。

まず必要なライブラリをインストールします。

!pip install easyocr --no-deps # Colab already has all dependencies
! pip install python-bidi

検証用の画像をダウンロードします。

# load example images
!npx degit JaidedAI/EasyOCR/examples -f

動作検証する画像を確認します。

# show an image
import PIL
from PIL import ImageDraw
jp_img = "japanese.jpg"
im = PIL.Image.open(jp_img)
im

OCR処理を動作させます。

# Doing OCR. Get bounding boxes.
import easyocr
reader = easyocr.Reader(['ja','en'])
bounds = reader.readtext(jp_img)
bounds

下記のように一部ミスをしていますが、ほとんど正しく認識できています。

[([[71, 49], [489, 49], [489, 159], [71, 159]], 'ポ<捨て禁止!', 0.6339441537857056),
 ([[95, 149], [461, 149], [461, 235], [95, 235]],
  'NOLITTER',
  0.3249408006668091),
 ([[80, 232], [475, 232], [475, 288], [80, 288]],
  '清潔できれいな港区を',
  0.9784270524978638),
 ([[109, 289], [437, 289], [437, 333], [109, 333]],
  '港 区 MINATO CITY',
  0.18789245188236237)]

認識した部分を確認してみます。

# Draw bounding boxes
def draw_boxes(image, bounds, color='yellow', width=2):
    draw = ImageDraw.Draw(image)
    for bound in bounds:
        p0, p1, p2, p3 = bound[0]
        draw.line([*p0, *p1, *p2, *p3, *p0], fill=color, width=width)
    return image

draw_boxes(im, bounds)

文字の部分を抽出していることが確認できます。

内部コード検証

簡単に使用できましたが内部がブラックボックスなので、中身のコードを確認してみます。

OCR処理をしている部分を見てみます。文字の場所を取得するself.detect処理をした後に`self.recognize`で文字認識をしていることが分かります。

https://github.com/JaidedAI/EasyOCR/blob/8fca29a16a0562cde617768f64316cb9ca5c3446/easyocr/easyocr.py#L347

    def readtext(self, image, decoder = 'greedy', beamWidth= 5, batch_size = 1,\
                 workers = 0, allowlist = None, blocklist = None, detail = 1,\
                 rotation_info = None, paragraph = False, min_size = 20,\
                 contrast_ths = 0.1,adjust_contrast = 0.5, filter_ths = 0.003,\
                 text_threshold = 0.7, low_text = 0.4, link_threshold = 0.4,\
                 canvas_size = 2560, mag_ratio = 1.,\
                 slope_ths = 0.1, ycenter_ths = 0.5, height_ths = 0.5,\
                 width_ths = 0.5, add_margin = 0.1):
        '''
        Parameters:
        image: file path or numpy-array or a byte stream object
        '''
        img, img_cv_grey = reformat_input(image)

        horizontal_list, free_list = self.detect(img, min_size, text_threshold,\
                                                 low_text, link_threshold,\
                                                 canvas_size, mag_ratio,\
                                                 slope_ths, ycenter_ths,\
                                                 height_ths,width_ths,\
                                                 add_margin, False)

        result = self.recognize(img_cv_grey, horizontal_list, free_list,\
                                decoder, beamWidth, batch_size,\
                                workers, allowlist, blocklist, detail, rotation_info,\
                                paragraph, contrast_ths, adjust_contrast,\
                                filter_ths, False)

        return result

コードを追っていくとdetect処理はOpenCVのconnectedComponentsWithStatsを使用しているようです。画像中のオブジェクトのサイズや重心の情報を合わせて返す関数になっています。

https://github.com/JaidedAI/EasyOCR/blob/8fca29a16a0562cde617768f64316cb9ca5c3446/easyocr/craft_utils.py#L20

アルゴリズムは下記リンクの内容によると`Kesheng Wu, Ekow Otoo, and Kenji Suzuki. Optimizing two-pass connected-component labeling algorithms. Pattern Analysis and Applications, 12(2):117–135, Jun 2009.`を使用しているようです。

https://docs.opencv.org/3.4/d3/dc0/group__imgproc__shape.html#ga107a78bf7cd25dec05fb4dfc5c9e765f

def getDetBoxes_core(textmap, linkmap, text_threshold, link_threshold, low_text, estimate_num_chars=False):
    # prepare data
    linkmap = linkmap.copy()
    textmap = textmap.copy()
    img_h, img_w = textmap.shape

    """ labeling method """
    ret, text_score = cv2.threshold(textmap, low_text, 1, 0)
    ret, link_score = cv2.threshold(linkmap, link_threshold, 1, 0)

    text_score_comb = np.clip(text_score + link_score, 0, 1)
    nLabels, labels, stats, centroids = cv2.connectedComponentsWithStats(text_score_comb.astype(np.uint8), connectivity=4)

    det = []
    mapper = []
    for k in range(1,nLabels):
        # size filtering
        size = stats[k, cv2.CC_STAT_AREA]
        if size < 10: continue

        # thresholding
        if np.max(textmap[labels==k]) < text_threshold: continue

        # make segmentation map
        segmap = np.zeros(textmap.shape, dtype=np.uint8)
        segmap[labels==k] = 255
        if estimate_num_chars:
            _, character_locs = cv2.threshold((textmap - linkmap) * segmap /255., text_threshold, 1, 0)
            _, n_chars = label(character_locs)
            mapper.append(n_chars)
        else:
            mapper.append(k)
        segmap[np.logical_and(link_score==1, text_score==0)] = 0   # remove link area
        x, y = stats[k, cv2.CC_STAT_LEFT], stats[k, cv2.CC_STAT_TOP]
        w, h = stats[k, cv2.CC_STAT_WIDTH], stats[k, cv2.CC_STAT_HEIGHT]
        niter = int(math.sqrt(size * min(w, h) / (w * h)) * 2)
        sx, ex, sy, ey = x - niter, x + w + niter + 1, y - niter, y + h + niter + 1
        # boundary check
        if sx < 0 : sx = 0
        if sy < 0 : sy = 0
        if ex >= img_w: ex = img_w
        if ey >= img_h: ey = img_h
        kernel = cv2.getStructuringElement(cv2.MORPH_RECT,(1 + niter, 1 + niter))
        segmap[sy:ey, sx:ex] = cv2.dilate(segmap[sy:ey, sx:ex], kernel)

        # make box
        np_contours = np.roll(np.array(np.where(segmap!=0)),1,axis=0).transpose().reshape(-1,2)
        rectangle = cv2.minAreaRect(np_contours)
        box = cv2.boxPoints(rectangle)

        # align diamond-shape
        w, h = np.linalg.norm(box[0] - box[1]), np.linalg.norm(box[1] - box[2])
        box_ratio = max(w, h) / (min(w, h) + 1e-5)
        if abs(1 - box_ratio) <= 0.1:
            l, r = min(np_contours[:,0]), max(np_contours[:,0])
            t, b = min(np_contours[:,1]), max(np_contours[:,1])
            box = np.array([[l, t], [r, t], [r, b], [l, b]], dtype=np.float32)

        # make clock-wise order
        startidx = box.sum(axis=1).argmin()
        box = np.roll(box, 4-startidx, 0)
        box = np.array(box)

        det.append(box)

    return det, labels, mapper

文字検出部分を見てみます。コードを追っていくと下記の部分で文字認識の結果を取得しています。`self.recognizer`で文字認識できるモデルを設定しています。

        result = get_text(self.character, imgH, int(max_width), self.recognizer, self.converter, image_list,\
                      ignore_char, decoder, beamWidth, batch_size, contrast_ths, adjust_contrast, filter_ths,\
                      workers, self.device)

下記のコードで`self.recognizer`を取得しています。

self.recognizer, self.converter = get_recognizer(recog_network, network_params,\
                                                         self.character, separator_list,\
                                                         dict_list, model_path, device = self.device)

下記のコードで文字認識のモデルを取得します。

def get_recognizer(recog_network, network_params, character,\
                   separator_list, dict_list, model_path,\
                   device = 'cpu'):

    converter = CTCLabelConverter(character, separator_list, dict_list)
    num_class = len(converter.character)

    if recog_network == 'standard':
        model_pkg = importlib.import_module("easyocr.model.model")
    elif recog_network == 'lite':
        model_pkg = importlib.import_module("easyocr.model.vgg_model")
    else:
        model_pkg = importlib.import_module(recog_network)
    model = model_pkg.Model(num_class=num_class, **network_params)

    if device == 'cpu':
        state_dict = torch.load(model_path, map_location=device)
        new_state_dict = OrderedDict()
        for key, value in state_dict.items():
            new_key = key[7:]
            new_state_dict[new_key] = value
        model.load_state_dict(new_state_dict)
    else:
        model = torch.nn.DataParallel(model).to(device)
        model.load_state_dict(torch.load(model_path, map_location=device))

    return model, converter

使用しているモデルの一部を見てみます。ResNetで画像を抽出してBidirectionalLSTMで文字認識していることが確認できます。

現状は自然言語処理においてはTransformerの方が性能が高いのでデータを持っている方は`BidirectionalLSTM`を置き換えて見ても良いと思います。

import torch.nn as nn
from .modules import ResNet_FeatureExtractor, BidirectionalLSTM

class Model(nn.Module):

    def __init__(self, input_channel, output_channel, hidden_size, num_class):
        super(Model, self).__init__()
        """ FeatureExtraction """
        self.FeatureExtraction = ResNet_FeatureExtractor(input_channel, output_channel)
        self.FeatureExtraction_output = output_channel  # int(imgH/16-1) * 512
        self.AdaptiveAvgPool = nn.AdaptiveAvgPool2d((None, 1))  # Transform final (imgH/16-1) -> 1

        """ Sequence modeling"""
        self.SequenceModeling = nn.Sequential(
            BidirectionalLSTM(self.FeatureExtraction_output, hidden_size, hidden_size),
            BidirectionalLSTM(hidden_size, hidden_size, hidden_size))
        self.SequenceModeling_output = hidden_size

        """ Prediction """
        self.Prediction = nn.Linear(self.SequenceModeling_output, num_class)


    def forward(self, input, text):
        """ Feature extraction stage """
        visual_feature = self.FeatureExtraction(input)
        visual_feature = self.AdaptiveAvgPool(visual_feature.permute(0, 3, 1, 2))  # [b, c, h, w] -> [b, w, c, h]
        visual_feature = visual_feature.squeeze(3)

        """ Sequence modeling stage """
        contextual_feature = self.SequenceModeling(visual_feature)

        """ Prediction stage """
        prediction = self.Prediction(contextual_feature.contiguous())

        return prediction

EasyOCRは名前通り簡単にOCR処理ができたぞ

中のコードもOpenCVとPyTorchをベースにしているので理解しやすかったですね。

Close Bitnami banner
Bitnami