画像分類の際に判断根拠を把握するGrad-CAMをPyTorch(Google Colab)で簡単に使用する

画像分類の判断根拠を把握したいです。

Grad-CAMを使えばできるぞ

簡単に使う方法ありますか

PyTorchだとpytorch-gradcamがあるぞ

Grad-CAM

下記のように’Cat’と分類した場合にどこに注目したか分かる仕組みになります。

予測値”Tigar Cat”に寄与が高い(勾配が大きい)部分を取得しています。

Grad-CAMは、予測クラス値の寄与の大きいところ(勾配の大きいところ)が分類予測を行う上で、重要という発想の手法です。
勾配に関しては最後の畳み込み層(以下最後のconv層)の予測クラス値に対する勾配が用いられます。

Overview

動作確認

下記のコードはpytorchでgrad-camを簡単に使用できるようにしているライブラリです。

https://github.com/vickyliin/gradcam_plus_plus-pytorch

Google Colabで動作確認をします。使用方法は下記をご覧ください。

まずライブラリをインストールします。

! pip install pytorch-gradcam

ライブラリを導入します。

import os
import PIL
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.models as models
from torchvision import transforms
from torchvision.utils import make_grid, save_image

from gradcam.utils import visualize_cam
from gradcam import GradCAM, GradCAMpp

device = 'cuda' if torch.cuda.is_available() else 'cpu'

解析用に画像を取得します。

! wget https://raw.githubusercontent.com/vickyliin/gradcam_plus_plus-pytorch/master/images/water-bird.JPEG

解析用の画像を確認します。

img_name = 'water-bird.JPEG'

pil_img = PIL.Image.open(img_name)
pil_img

画像の前処理をします。サイズを224×224にリサイズして正規化します。

torch_img = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])(pil_img).to(device)
normed_torch_img = transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])(torch_img)[None]

Grad-GAMの様子を比較するためモデルを5種類取得します。


alexnet = models.alexnet(pretrained=True)
vgg = models.vgg16(pretrained=True)
resnet = models.resnet101(pretrained=True)
densenet = models.densenet161(pretrained=True)
squeezenet = models.squeezenet1_1(pretrained=True)

各モデルの勾配を取得するレイヤーを設定します。

configs = [
    dict(model_type='alexnet', arch=alexnet, layer_name='features_11'),
    dict(model_type='vgg', arch=vgg, layer_name='features_29'),
    dict(model_type='resnet', arch=resnet, layer_name='layer4'),
    dict(model_type='densenet', arch=densenet, layer_name='features_norm5'),
    dict(model_type='squeezenet', arch=squeezenet, layer_name='features_12_expand3x3_activation')
]

勾配と勾配を元の画像に重ね合わせるための関数を設定します。

for config in configs:
    config['arch'].to(device).eval()

cams = [
    [cls.from_config(**config) for cls in (GradCAM, GradCAMpp)]
    for config in configs
]

各モデルのGrad-CAMの様子を可視化して確認してみます。

images = []
for gradcam, gradcam_pp in cams:
    mask, _ = gradcam(normed_torch_img)
    heatmap, result = visualize_cam(mask, torch_img)

    mask_pp, _ = gradcam_pp(normed_torch_img)
    heatmap_pp, result_pp = visualize_cam(mask_pp, torch_img)
    
    images.extend([torch_img.cpu(), heatmap, heatmap_pp, result, result_pp])
    
grid_image = make_grid(images, nrow=5)
transforms.ToPILImage()(grid_image)

下記のモデルを使った結果になります。

  • alexnet:くちばし部分の寄与度が低くなっているように見えます。
  • vgg:比較的、寄与度が高い部分を取得できているようにみえます。
  • resnet:もっとも範囲を広くとっているように見えます。
  • densenet:resnetの次に範囲を広く取っており、寄与度の高い部分の確認が難しくなっています。
  • squeezenet:分類に寄与が高い部分のみを他のモデルに比べて正確に取得しているように見えます。

勾配を取得するレイヤーを浅くしてみます。深いレイヤーほど抽象化された情報であり、浅いレイヤーほど、どの部分を細かく重要視しているかを把握できます。

configs = [
    dict(model_type='alexnet', arch=alexnet, layer_name='features_0'),
    dict(model_type='vgg', arch=vgg, layer_name='features_0'),
    dict(model_type='resnet', arch=resnet, layer_name='layer1'),
    dict(model_type='densenet', arch=densenet, layer_name='features_norm0'),
    dict(model_type='squeezenet', arch=squeezenet, layer_name='features_3_expand3x3_activation')
]

同様に勾配と勾配を元の画像に重ね合わせるための関数を設定します。

for config in configs:
    config['arch'].to(device).eval()

cams = [
    [cls.from_config(**config) for cls in (GradCAM, GradCAMpp)]
    for config in configs
]

同様に可視化を行います。

images = []
for gradcam, gradcam_pp in cams:
    mask, _ = gradcam(normed_torch_img)
    heatmap, result = visualize_cam(mask, torch_img)

    mask_pp, _ = gradcam_pp(normed_torch_img)
    heatmap_pp, result_pp = visualize_cam(mask_pp, torch_img)
    
    images.extend([torch_img.cpu(), heatmap, heatmap_pp, result, result_pp])
    
grid_image = make_grid(images, nrow=5)

transforms.ToPILImage()(grid_image)

深いレイヤーに比べて細かく重要視している部分が把握できます。

Close Bitnami banner
Bitnami