画像検索に使用できるMetric LearningをGoogle Colabで試してみた

画像の検索エンジンを作ってみたいです。

Metric Learningが有効だぞ

なんか難しそうですね

思ったより簡単なのでここから詳しく紹介するぞ

Metric Learning

Metric Learningを使用すると顔認証、画像検索、異常検知などに適用できます。識別では対応できない未知の顔、画像、データに対応できる点がメリットです。

下記は顔認証の例になります。顔の投影角度が変わっても距離を小さくできていることが分かります。

Schroff, Florian, Dmitry Kalenichenko, and James Philbin. “Facenet: A unified embedding for face recognition and clustering.” Proceedings of the IEEE conference on computer vision and pattern recognition. 2015.

下記は画像検索の例になります。

Boiarov, Andrei, and Eduard Tyantov. “Large scale landmark recognition via deep metric learning.” Proceedings of the 28th ACM International Conference on Information and Knowledge Management. 2019.

下記は異常検知の例になります。正常データのみで学習しています。正常データのみで学習することで異常なデータが入力されれば距離が遠くなり異常検知できる仕組みです。

Bergmann, Paul, et al. “MVTec AD–A Comprehensive Real-World Dataset for Unsupervised Anomaly Detection.” Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition. 2019.

距離とは

下記のように同一画像、もしくは似ている画像だと距離が近くなり、似ていない画像だと距離が遠くなる指標になります。距離に関してはユークリッド距離やマンハッタン距離などがあります。

https://ja.wikipedia.org/wiki/%E8%B7%9D%E9%9B%A2

画像の場合、比較したい要素以外が含まれているため、画像をそのまま使うのは難しいです。

例えば上記の画像は馬の比較をしたいわけですが背景に馬以外の要素があります。

この要素があると比較が難しくなります。

そこで埋め込み空間というテクニックを使用します。

埋め込み空間

下記のようになんらかの写像関数を使用して馬の画像の次元数を減らします。

この写像関数にDeep Learningを使用することで効果的に馬の画像の特徴量を抽出できます。

SIAMESE NETWORK

距離を近くしたい画像とそれ以外の画像を用意します。例えば馬の画像と車の画像を用意した場合は馬同士、車同士の画像に対して下記のようにロスを計算すると画像間の距離が近くなるように学習してくれます。

Yはラベルを表しています。

次に異なる画像同士の比較を行うケースです。先程との違いは異なるクラスのため、距離を大きくしたので符号を逆にするための処理と最低限離す必要がある距離を設定しています。

使用するデータ

cifar10のデータ・セットを使用しました。

https://www.cs.toronto.edu/~kriz/cifar.html

動作環境

Google Colabを使用しました。Google Colabの使用方法は下記に記述しています。

コード

下記のコードを参考にしました。

https://github.com/fangpin/siamese-pytorch

Cifar10のデータ・セットを取得します。学習用と評価用のデータを取得します。

import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

dataset_size = trainset.data.shape[0]
dataset_size

ラベルごとにデータを分けて保存するための処理を記述します。同一ラベルと異なるラベルで学習するための準備です。

from collections import defaultdict

train_data_set = defaultdict(list)

for index, label in enumerate(trainset.targets):
  train_data_set[label].append(Image.fromarray(trainset.data[index]))

test_data_set = defaultdict(list)

for index, label in enumerate(testset.targets):
  test_data_set[label].append(Image.fromarray(testset.data[index]))

Cifar10の学習データ取得するためのクラスです。indexが奇数の場合は同一ラベルのデータ、偶数の場合は異なるラベルのデータを取得するようになっています。

from torch.utils.data import Dataset, DataLoader
import os
from numpy.random import choice as npc
import numpy as np
import time
import random
import torchvision.datasets as dset


class Cifar10Train(Dataset):

    def __init__(self, datas, transform=None):
        super(Cifar10Train, self).__init__()
        np.random.seed(0)
        self.datas = datas
        self.num_classes = len(datas)
        self.transform = transform

    def __len__(self):
        size = len(self.datas) * len(self.datas[0])
        return size

    def __getitem__(self, index):
        # image1 = random.choice(self.dataset.imgs)
        label = None
        img1 = None
        img2 = None
        # get image from same class
        if index % 2 == 1:
            label = 1.0
            idx1 = random.randint(0, self.num_classes - 1)
            image1 = random.choice(self.datas[idx1])
            image2 = random.choice(self.datas[idx1])
        # get image from different class
        else:
            label = 0.0
            idx1 = random.randint(0, self.num_classes - 1)
            idx2 = random.randint(0, self.num_classes - 1)
            while idx1 == idx2:
                idx2 = random.randint(0, self.num_classes - 1)
            image1 = random.choice(self.datas[idx1])
            image2 = random.choice(self.datas[idx2])

        if self.transform:
            image1 = self.transform(image1)
            image2 = self.transform(image2)
        return image1, image2, torch.from_numpy(np.array([label], dtype=np.float32))

Cifar10の評価データ取得するためのクラスです。同一ラベルの取得タイミングをwayという変数で決定しています。それ以外の場合は異なるラベルを取得するようになっています。

class Cifar10Test(Dataset):

    def __init__(self, datas, transform=None, times=200, way=20):
        np.random.seed(1)
        super(Cifar10Test, self).__init__()
        self.transform = transform
        self.times = times
        self.way = way
        self.img1 = None
        self.c1 = None
        self.datas = datas 
        self.num_classes = len(datas)

    def __len__(self):
        return self.times * self.way

    def __getitem__(self, index):
        idx = index % self.way
        label = None
        # generate image pair from same class
        if idx == 0:
            self.c1 = random.randint(0, self.num_classes - 1)
            self.img1 = random.choice(self.datas[self.c1])
            img2 = random.choice(self.datas[self.c1])
        # generate image pair from different class
        else:
            c2 = random.randint(0, self.num_classes - 1)
            while self.c1 == c2:
                c2 = random.randint(0, self.num_classes - 1)
            img2 = random.choice(self.datas[c2])

        if self.transform:
            img1 = self.transform(self.img1)
            img2 = self.transform(img2)
        return img1, img2

SIAMESE NETWORKを記述します。

特徴的な部分はforward関数です。入力が2つになっています。これは画像間の距離を計算できるようにするためです。

`out1 = self.forward_one(x1)`で画像1の特徴量をDeep Learningで取得します。

`out2 = self.forward_one(x2)`で画像2の特徴量をDeep Learningで取得します。

`torch.abs(out1 – out2)`で距離を計算しています。

import torch.nn as nn
import torch.nn.functional as F


class Siamese(nn.Module):

    def __init__(self):
        super(Siamese, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(3, 64, 10),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 7),
            nn.ReLU(), 
            nn.MaxPool2d(2),
            nn.Conv2d(128, 128, 2),
            nn.ReLU(), 
            nn.Conv2d(128, 256, 1),
            nn.ReLU(), 
        )
        self.liner = nn.Sequential(nn.Linear(256, 4096), nn.Sigmoid())
        self.out = nn.Linear(4096, 1)

    def forward_one(self, x):
        x = self.conv(x)
        x = x.view(x.size()[0], -1)
        x = self.liner(x)
        return x

    def forward(self, x1, x2):
        out1 = self.forward_one(x1)
        out2 = self.forward_one(x2)
        dis = torch.abs(out1 - out2)
        out = self.out(dis)
        return out

データを取得するデータローダーを設定します。

PyTorchのデータローダーについてはこちらで説明されています。

学習データはデータの変換処理としてRandomにAffine変換を行っています。

Affine変換に関してはこちらの記事が分かりやすいです。

data_transforms = transforms.Compose([
    transforms.RandomAffine(15),
    transforms.ToTensor()
])


trainSet = Cifar10Train(train_data_set, transform=data_transforms)

trainLoader = DataLoader(trainSet, batch_size=128, shuffle=False, num_workers=4)

testSet = Cifar10Test(test_data_set, transform=transforms.ToTensor(), times = 200, way = 5)
testLoader = DataLoader(testSet, batch_size=5, shuffle=False, num_workers=4)

loss_fn = torch.nn.BCEWithLogitsLoss(size_average=True)
net = Siamese()

SIAMESE NETWORKをGPUで学習するための変換設定を行っています。

学習のための`optimizer`を設定しています。

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device == 'cuda':
    net.cuda()

net.train()

optimizer = torch.optim.Adam(net.parameters(),lr = 0.00006 )
optimizer.zero_grad()

SIAMESE NETWORKを学習します。

from torch.autograd import Variable
import matplotlib.pyplot as plt
import time
import numpy as np
import sys
from collections import deque
import os

train_loss = []
loss_val = 0
time_start = time.time()
queue = deque(maxlen=20)
max_iter = 50000

for batch_id, (img1, img2, label) in enumerate(trainLoader, 1):
    if batch_id > max_iter:
        break
    if device == 'cuda':
        img1, img2, label = Variable(img1.cuda()), Variable(img2.cuda()), Variable(label.cuda())
    else:
        img1, img2, label = Variable(img1), Variable(img2), Variable(label)
    optimizer.zero_grad()
    output = net.forward(img1, img2)
    loss = loss_fn(output, label)
    loss_val += loss.item()
    loss.backward()
    optimizer.step()
    if batch_id % 10 == 0 :
        print('[%d]\tloss:\t%.5f\ttime lapsed:\t%.2f s'%(batch_id, loss_val/10, time.time() - time_start))
        loss_val = 0
        time_start = time.time()
    if batch_id % 100 == 0:
        torch.save(net.state_dict(),  './model-inter-' + str(batch_id+1) + ".pt")
    if batch_id % 100 == 0:
        right, error = 0, 0
        for _, (test1, test2) in enumerate(testLoader, 1):
            if device == 'cuda':
                test1, test2 = test1.cuda(), test2.cuda()
            test1, test2 = Variable(test1), Variable(test2)
            output = net.forward(test1, test2).data.cpu().numpy()
            pred = np.argmax(output)
            if pred == 0:
                right += 1
            else: error += 1
        print('*'*70)
        print('[%d]\tTest set\tcorrect:\t%d\terror:\t%d\tprecision:\t%f'%(batch_id, right, error, right*1.0/(right+error)))
        print('*'*70)
        queue.append(right*1.0/(right+error))
    train_loss.append(loss_val)
#  learning_rate = learning_rate * 0.95

下記のように結果を確認できます。LossとPrecisionを計算しています。

[210]	loss:	0.61768	time lapsed:	10.32 s
[220]	loss:	0.61715	time lapsed:	7.05 s
[230]	loss:	0.62227	time lapsed:	7.06 s
[240]	loss:	0.59738	time lapsed:	7.12 s
[250]	loss:	0.61492	time lapsed:	7.14 s
[260]	loss:	0.63777	time lapsed:	7.18 s
[270]	loss:	0.62249	time lapsed:	7.08 s
[280]	loss:	0.62806	time lapsed:	7.10 s
**********************************************************************
[300]	Test set	correct:	67	error:	133	precision:	0.335000
**********************************************************************
[310]	loss:	0.67323	time lapsed:	10.26 s
[320]	loss:	0.66368	time lapsed:	7.16 s
[330]	loss:	0.66023	time lapsed:	7.04 s
[340]	loss:	0.66601	time lapsed:	7.05 s
[350]	loss:	0.65852	time lapsed:	7.03 s
[360]	loss:	0.65936	time lapsed:	7.07 s
[370]	loss:	0.64518	time lapsed:	7.08 s
[380]	loss:	0.65438	time lapsed:	7.04 s
[390]	loss:	0.65029	time lapsed:	6.95 s

実際に検索処理をしてみます。

検索処理は下記のブログを参考にさせてもらいました。

検索処理のためのデータ・セットクラスを作成します。

class SerachDataSet(Dataset):

    def __init__(self, datas, transform=None):
        super(SerachDataSet, self).__init__()
        np.random.seed(0)
        self.datas = datas
        self.num_classes = len(datas)
        self.transform = transform

    def __len__(self):
        size = len(self.datas) * len(self.datas[0])
        return size

    def __getitem__(self, index):
        # image1 = random.choice(self.dataset.imgs)
        label = None
        img1 = None
        idx1 = random.randint(0, self.num_classes - 1)
        image1 = random.choice(self.datas[idx1])

        if self.transform:
            image1 = self.transform(image1)
        return image1, idx1

testSet = SerachDataSet(test_data_set, transform=transforms.ToTensor())
testLoader = DataLoader(testSet, batch_size=1, shuffle=False, num_workers=4)

先程学習したSIAMESE NETWORKで評価データのすべての画像特徴量を取得します。

import numpy as np
import torch.nn as nn

net = net.eval()


# - test images -
test_output = []
test_output_name = []
test_image = []

for batch_idx, (x, label) in enumerate(testLoader):
    feature = net.forward_one(x)
    test_image.append(x)
    test_output.append(np.array(feature.cpu().tolist()))
    test_output_name.append(label)

効率よく検索できるライブラリ`nmslib`を導入します。

! pip install nmslib

作成した特徴量を用いてIndexを作成します。

import nmslib

index = nmslib.init(method='hnsw', space='cosinesimil')
index.addDataPointBatch(test_output)
index.createIndex({'post': 2}, print_progress=True)

検索対象とする画像を表示します。

import matplotlib.pyplot as plt

def show_images(images, figsize=(20,10), columns = 5):
  plt.figure(figsize=figsize)
  for i, image in enumerate(images):
      image = image.transpose(0, 1).transpose(1, 2)
      plt.subplot(len(images) / columns + 1, columns, i + 1)
      plt.imshow(image)

show_images(test_image[0])

検索対象の画像をクエリとして画像の検索処理をしてTop19のデータを取得します。左上の橋がTopの結果になります。同一画像が一番似ている画像として検索できました。他の画像としてはカラーが似ている画像が取得できています。

ids, distances = index.knnQuery(test_output[0], k=19)
search_images = []
for id, distance in zip(ids, distances):
    search_images.append(test_image[id][0])

show_images(search_images)

Metric Learningで使用するモデルの学習から検索まで行ったぞ

通常の識別モデルと異なって、2つの画像を入力しているところが特徴的でしたね

Close Bitnami banner
Bitnami