物体検出でよく使用されるモデルSSDのモデルがPytorchで簡単に使えるようになったぞ
どのくらい簡単なんですか
サンプルコードは50行未満。これはデータの処理や可視化も含めてこの行数だ
目次
SSD(Single Shot MultiBox Detector)とは
物体検出に主に使用されるモデル
SSD: Single Shot MultiBox Detector の論文で提唱された手法です。
細かい部分は他の方がブログで説明されていますので省きますが、主要な部分は複数のスケールの特徴量を使用して物体検出性能を上げる手法になります。
この複雑なモデルを簡単に使えるような内容を公開してくれています。
https://pytorch.org/hub/nvidia_deeplearningexamples_ssd/
PyTorchでSSD(物体検出モデル)が簡単に使えるようになっている
— Yurui⛅DeepLearning (@DeepYurui) August 14, 2020
– 学習済みモデルが用意
– サンプルコード内に後処理の関数が用意されているため、後処理の実装が不要
– サンプルコード内に前処理も実装https://t.co/WmB9VF0p6m
Google Colaboratoryで動作確認
では早速動作確認をしていくぞ
Google Colaboratoryの設定方法は下記に書いています。
モデルを取得
まず学習済みモデルを取得します。
import torch
precision = 'fp32'
ssd_model = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_ssd', model_math=precision)
CPUだと動作せずに下記のようなエラーが発生するのでRuntimeはGPUに設定して実行する必要があります。
RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.
前処理や後処理を使用できる共通関数を取得します。
utils = torch.hub.load('NVIDIA/DeepLearningExamples:torchhub', 'nvidia_ssd_processing_utils')
GPUでモデルを動作するようにします。batchnormとdropoutレイヤーをevalモードにするため、model.evalを使用します。
ssd_model.to('cuda')
ssd_model.eval()
データの前処理
画像データを取得するためのリンクのリストを作成します。
uris = [
'http://images.cocodataset.org/val2017/000000397133.jpg',
'http://images.cocodataset.org/val2017/000000037777.jpg',
'http://images.cocodataset.org/val2017/000000252219.jpg'
]
画像データを前処理してPyTorchで処理しやすいTensor形式にします。
inputs = [utils.prepare_input(uri) for uri in uris]
tensor = utils.prepare_tensor(inputs, precision == 'fp16')
物体検出処理
torch.no_grad()でautogradエンジンをオフにしてメモリの使用量を減らし、計算速度を向上させてから推論処理をします。下記のフォーラムで内容について言及されています。
https://discuss.pytorch.org/t/model-eval-vs-with-torch-no-grad/19615
with torch.no_grad():
detections_batch = ssd_model(tensor)
推論実行時間は9.16msでした。ただしGPUはT4と推論が高速にできるGPUだったので速度は参考程度にしておいた方が良いです。
Fri Aug 14 04:20:12 2020
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.57 Driver Version: 418.67 CUDA Version: 10.1 |
|-------------------------------+----------------------+----------------------+
| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|===============================+======================+======================|
| 0 Tesla T4 Off | 00000000:00:04.0 Off | 0 |
| N/A 51C P0 29W / 70W | 1091MiB / 15079MiB | 0% Default |
| | | ERR! |
+-------------------------------+----------------------+----------------------+
+-----------------------------------------------------------------------------+
| Processes: |
| GPU GI CI PID Type Process name GPU Memory |
| ID ID Usage |
|=============================================================================|
| No running processes found |
+-----------------------------------------------------------------------------+
データの後処理
SSDで処理した結果の後処理をします。大量の検出結果が発生するので一定のしきい値(下記の例では0.4)以上の値のみ採用します。
results_per_input = utils.decode_results(detections_batch)
best_results_per_input = [utils.pick_best(results, 0.40) for results in results_per_input]
クラス情報は数値情報なのでラベル情報に変換するための処理をします。
classes_to_labels = utils.get_coco_object_dictionary()
結果の確認
物体検出した結果を確認するための処理をします。
バッチ分のデータがあるのでまずそのデータを取得します。
各データの物体検出結果を取得して、画像のサイズにリスケールして表示します。
from matplotlib import pyplot as plt
import matplotlib.patches as patches
for image_idx in range(len(best_results_per_input)):
fig, ax = plt.subplots(1)
# Show original, denormalized image...
image = inputs[image_idx] / 2 + 0.5
ax.imshow(image)
# ...with detections
bboxes, classes, confidences = best_results_per_input[image_idx]
for idx in range(len(bboxes)):
left, bot, right, top = bboxes[idx]
x, y, w, h = [val * 300 for val in [left, bot, right - left, top - bot]]
rect = patches.Rectangle((x, y), w, h, linewidth=1, edgecolor='r', facecolor='none')
ax.add_patch(rect)
ax.text(x, y, "{} {:.0f}%".format(classes_to_labels[classes[idx] - 1], confidences[idx]*100), bbox=dict(facecolor='white', alpha=0.5))
plt.show()
下記のように表示されます。
少ない工数で実現できるので余った時間を他のことに使おう!!