物体検出にHuggingFaceに用意してあるTransformerと畳込みを組み合わせたモデルを適用してみる

Transformerが画像にも適用できると聞いたのですが

HuggingFaceで学習済みモデルが提供されているので簡単に試せるぞ

下記リンクに物体検出にTransformerを適用したコードの紹介されています。

https://github.com/NielsRogge/Transformers-Tutorials/blob/master/DETR/DETR_minimal_example_(with_DetrFeatureExtractor).ipynb

DETRについて

画像をCNNに入力して、特徴量にしたあとで、TransformerのEncoder-Decoderで物体検出、クラス分類を行っています。

Set of box predictionsで各Bounding Boxを推論しています。

detr.PNG

モデルの全体アーキテクチャは下記です。

ここで気になるのはObject Queriesですが、これは乱数を用いているようです。

詳細な説明はこちらのリンクをご参照ください。

Transformerがベースとなっているので、Transformerについて理解したい場合は下記のリンクをご覧ください。

動作環境

Google Colabで動作確認しました。

動作確認

動作に必要なライブラリを導入します。

!pip install -q git+https://github.com/huggingface/transformers.git
!pip install -q timm

推論に使用する画像を取得します。

from PIL import Image
import requests

url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
im = Image.open(requests.get(url, stream=True).raw)
im

猫の画像を使用します。

下記でDETRモデルに入力するために画像データの前処理を行います。

from transformers import DetrFeatureExtractor

feature_extractor = DetrFeatureExtractor.from_pretrained("facebook/detr-resnet-50")

encoding = feature_extractor(im, return_tensors="pt")
encoding.keys()

DETRモデルを取得します。

from transformers import DetrForObjectDetection

model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50")

前処理されたデータをモデルに入力して推論処理をします。

outputs = model(**encoding)

物体検出の結果を確認するための関数を作成します。

import matplotlib.pyplot as plt

# colors for visualization
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
          [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]

def plot_results(pil_img, prob, boxes):
    plt.figure(figsize=(16,10))
    plt.imshow(pil_img)
    ax = plt.gca()
    colors = COLORS * 100
    for p, (xmin, ymin, xmax, ymax), c in zip(prob, boxes.tolist(), colors):
        ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                   fill=False, color=c, linewidth=3))
        cl = p.argmax()
        text = f'{model.config.id2label[cl.item()]}: {p[cl]:0.2f}'
        ax.text(xmin, ymin, text, fontsize=15,
                bbox=dict(facecolor='yellow', alpha=0.5))
    plt.axis('off')
    plt.show()

一定の確信度以上のデータを取得するために各クラスの出力値を確率値にしています。

probas = outputs.logits.softmax(-1)[0, :, :-1]

確率値が0.9以上のものを取得しています。

keep = probas.max(-1).values > 0.9

残りのコードで後処理して、画像データのサイズにリスケールしています。

import torch

# keep only predictions of queries with 0.9+ confidence (excluding no-object class)
probas = outputs.logits.softmax(-1)[0, :, :-1]
keep = probas.max(-1).values > 0.9

# rescale bounding boxes
target_sizes = torch.tensor(im.size[::-1]).unsqueeze(0)
postprocessed_outputs = feature_extractor.post_process(outputs, target_sizes)
bboxes_scaled = postprocessed_outputs[0]['boxes'][keep]

下記のコードで物体検出の結果を確認します。

plot_results(im, probas[keep], bboxes_scaled)

TransformerはAttentionベースのモデルのため、モデルがどの部分に注目しているか確認することができます。

conv_featureは特徴量空間のサイズを取得するだけに使用しています。

画像とObject Queryのcross_attentionの値を取得して8ヘッドのAttentionの平均値を取得します。

# use lists to store the outputs via up-values
conv_features = []

hooks = [
    model.model.backbone.conv_encoder.register_forward_hook(
        lambda self, input, output: conv_features.append(output)
    ),
]

# propagate through the model
outputs = model(**encoding, output_attentions=True)

for hook in hooks:
    hook.remove()

# don't need the list anymore
conv_features = conv_features[0]
# get cross-attention weights of last decoder layer - which is of shape (batch_size, num_heads, num_queries, width*height)
dec_attn_weights = outputs.cross_attentions[-1]
# average them over the 8 heads and detach from graph
dec_attn_weights = torch.mean(dec_attn_weights, dim=1).detach()

画像データとAttentionの情報を可視化します。

# get the feature map shape
h, w = conv_features[-1][0].shape[-2:]

fig, axs = plt.subplots(ncols=len(bboxes_scaled), nrows=2, figsize=(22, 7))
colors = COLORS * 100
for idx, ax_i, (xmin, ymin, xmax, ymax) in zip(keep.nonzero(), axs.T, bboxes_scaled):
    ax = ax_i[0]
    ax.imshow(dec_attn_weights[0, idx].view(h, w))
    ax.axis('off')
    ax.set_title(f'query id: {idx.item()}')
    ax = ax_i[1]
    ax.imshow(im)
    ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                               fill=False, color='blue', linewidth=3))
    ax.axis('off')
    ax.set_title(model.config.id2label[probas[idx].argmax().item()])
fig.tight_layout()

思ったより簡単にTransformerを画像に適用できました。

論文の内容は難しいがHuggingFaceのエコシステムによって簡単に試せるようになっているぞ

Close Bitnami banner
Bitnami