PyTorchでモデルの枝刈りを行い軽量なモデルを作成する

学習後のモデルはできるだけ軽量なものを使用してメモリの消費を抑えたり、推論速度を向上したいです

それなら枝刈りが効くぞ

なんですか。それ

モデルの不要なパラメータを除く処理だ。これによって精度低下を抑えて、軽量なモデルを作成できるぞ

参考情報

下記の記事を参考にしました。

https://pytorch.org/tutorials/intermediate/pruning_tutorial.html

Google Colabで動作確認をしました。設定方法は下記記事に記述しています。

モデルの作成

下記のコードでLeNetのモデルを作成します。

import torch
from torch import nn
import torch.nn.utils.prune as prune
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class LeNet(nn.Module):
    def __init__(self):
        super(LeNet, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square conv kernel
        self.conv1 = nn.Conv2d(1, 6, 3)
        self.conv2 = nn.Conv2d(6, 16, 3)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)  # 5x5 image dimension
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, int(x.nelement() / x.shape[0]))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = LeNet().to(device=device)

モデルの重みとバイアスを確認します。

module = model.conv1
print(list(module.named_parameters()))

下記のようにパラメータを確認できます。

]
module = model.conv1
print(list(module.named_parameters()))
[('weight', Parameter containing:
tensor([[[[ 0.3062,  0.0112,  0.1236],
          [ 0.2438, -0.3327, -0.0695],
          [ 0.2251,  0.3114, -0.1265]]],


        [[[-0.3025, -0.2474, -0.1669],
          [ 0.0679,  0.2739, -0.3260],
          [ 0.1577,  0.0939, -0.0196]]],


        [[[-0.1976,  0.1710, -0.3251],
          [-0.3188,  0.0758, -0.1598],
          [-0.3105,  0.1251,  0.2656]]],


        [[[-0.2357, -0.0986, -0.3083],
          [ 0.0482,  0.0831, -0.1188],
          [ 0.1228,  0.3241,  0.1533]]],


        [[[ 0.2616,  0.1920, -0.2226],
          [ 0.2361, -0.2428,  0.2704],
          [-0.1442, -0.1851,  0.2758]]],


        [[[ 0.1435,  0.2721, -0.0039],
          [ 0.2585, -0.1883, -0.0184],
          [ 0.2203,  0.1479,  0.0132]]]], device='cuda:0', requires_grad=True)), ('bias', Parameter containing:
tensor([-0.2827, -0.0311, -0.2399, -0.3249,  0.1326,  0.3170], device='cuda:0',
       requires_grad=True))]

モデルの個別レイヤーの枝刈り

まずは下記のコードで`conv1`の重みをランダムに30%の重みを無効にしてみます。ランダムに枝刈りする方法以外は後述します。

prune.random_unstructured(module, name="weight", amount=0.3)

下記のコードでどの重みが有効で無効にされるかを確認できます。1が有効な重み、0が無効な重みになります。

print(list(module.named_buffers()))
[('weight_mask', tensor([[[[0., 1., 1.],
          [1., 1., 1.],
          [1., 0., 1.]]],


        [[[1., 0., 1.],
          [0., 1., 1.],
          [0., 0., 1.]]],


        [[[1., 1., 0.],
          [1., 0., 1.],
          [0., 1., 0.]]],


        [[[1., 1., 0.],
          [1., 1., 1.],
          [1., 0., 1.]]],


        [[[1., 0., 1.],
          [1., 0., 1.],
          [1., 1., 1.]]],


        [[[1., 0., 1.],
          [1., 1., 1.],
          [1., 0., 1.]]]], device='cuda:0'))]

実際にフィルターされた重みが下記になります。上記で0に設定された重みが0でフィルターされていることが確認できます。

print(module.weight)
tensor([[[[ 0.0000,  0.0112,  0.1236],
          [ 0.2438, -0.3327, -0.0695],
          [ 0.2251,  0.0000, -0.1265]]],


        [[[-0.3025, -0.0000, -0.1669],
          [ 0.0000,  0.2739, -0.3260],
          [ 0.0000,  0.0000, -0.0196]]],


        [[[-0.1976,  0.1710, -0.0000],
          [-0.3188,  0.0000, -0.1598],
          [-0.0000,  0.1251,  0.0000]]],


        [[[-0.2357, -0.0986, -0.0000],
          [ 0.0482,  0.0831, -0.1188],
          [ 0.1228,  0.0000,  0.1533]]],


        [[[ 0.2616,  0.0000, -0.2226],
          [ 0.2361, -0.0000,  0.2704],
          [-0.1442, -0.1851,  0.2758]]],


        [[[ 0.1435,  0.0000, -0.0039],
          [ 0.2585, -0.1883, -0.0184],
          [ 0.2203,  0.0000,  0.0132]]]], device='cuda:0',
       grad_fn=<MulBackward0>)

実際にこの枝刈りは`__forward_pre_hooks`を使用して実行されます。

print(module._forward_pre_hooks)
OrderedDict([(0, <torch.nn.utils.prune.RandomUnstructured object at 0x7f937d842710>), (1, <torch.nn.utils.prune.L1Unstructured object at 0x7f937d842828>)])

`conv1`のバイアスにたいしてL1ノルムをかけた際に最も小さい値3つを枝刈りする方法は下記になります。

prune.l1_unstructured(module, name="bias", amount=3)

先程は重みをランダムに枝刈りしていましたが、確実性が薄いのでL2ノルムをかけて重みを除去する方法を記述します。

下記のコードはL2ノルムをかけた際の下位50%を枝刈りする手法です。

prune.ln_structured(module, name="weight", amount=0.5, n=2, dim=0)

# As we can verify, this will zero out all the connections corresponding to 
# 50% (3 out of 6) of the channels, while preserving the action of the 
# previous mask.
print(module.weight)

ここまでは畳み込みレイヤーの枝刈りでしたが畳み込み以外の特定のレイヤーを枝刈りしたいケースがでてきます。このようなケースは下記のようなコードを記述します。畳み込みと全結合のレイヤーが枝刈りされていることが分かります。

new_model = LeNet()
for name, module in new_model.named_modules():
    # prune 20% of connections in all 2D-conv layers 
    if isinstance(module, torch.nn.Conv2d):
        prune.l1_unstructured(module, name='weight', amount=0.2)
    # prune 40% of connections in all linear layers 
    elif isinstance(module, torch.nn.Linear):
        prune.l1_unstructured(module, name='weight', amount=0.4)

print(dict(new_model.named_buffers()).keys())  # to verify that all masks exist
dict_keys(['conv1.weight_mask', 'conv2.weight_mask', 'fc1.weight_mask', 'fc2.weight_mask', 'fc3.weight_mask'])

モデル全体での枝刈り

個別レイヤーの枝刈り処理だったが、ここからはモデル全体での枝刈りの手法の紹介をするぞ

下記のコードではL1ノルムがモデル全体で最も小さい20%を0にして、枝刈り処理を行っています。

model = LeNet()

parameters_to_prune = (
    (model.conv1, 'weight'),
    (model.conv2, 'weight'),
    (model.fc1, 'weight'),
    (model.fc2, 'weight'),
    (model.fc3, 'weight'),
)

prune.global_unstructured(
    parameters_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=0.2,
)

20%の要素が枝刈りされているか確認します。

print(
    "Sparsity in conv1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv1.weight == 0))
        / float(model.conv1.weight.nelement())
    )
)
print(
    "Sparsity in conv2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.conv2.weight == 0))
        / float(model.conv2.weight.nelement())
    )
)
print(
    "Sparsity in fc1.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc1.weight == 0))
        / float(model.fc1.weight.nelement())
    )
)
print(
    "Sparsity in fc2.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc2.weight == 0))
        / float(model.fc2.weight.nelement())
    )
)
print(
    "Sparsity in fc3.weight: {:.2f}%".format(
        100. * float(torch.sum(model.fc3.weight == 0))
        / float(model.fc3.weight.nelement())
    )
)
print(
    "Global sparsity: {:.2f}%".format(
        100. * float(
            torch.sum(model.conv1.weight == 0)
            + torch.sum(model.conv2.weight == 0)
            + torch.sum(model.fc1.weight == 0)
            + torch.sum(model.fc2.weight == 0)
            + torch.sum(model.fc3.weight == 0)
        )
        / float(
            model.conv1.weight.nelement()
            + model.conv2.weight.nelement()
            + model.fc1.weight.nelement()
            + model.fc2.weight.nelement()
            + model.fc3.weight.nelement()
        )
    )
)
Sparsity in conv1.weight: 3.70%
Sparsity in conv2.weight: 8.10%
Sparsity in fc1.weight: 22.03%
Sparsity in fc2.weight: 12.21%
Sparsity in fc3.weight: 10.71%
Global sparsity: 20.00%

オリジナルの枝刈りの実装

自分自身で枝刈りの手法の実装は可能です。下記のように枝刈り処理を行うクラスの作成を行います。

class FooBarPruningMethod(prune.BasePruningMethod):
    """Prune every other entry in a tensor
    """
    PRUNING_TYPE = 'unstructured'

    def compute_mask(self, t, default_mask):
        mask = default_mask.clone()
        mask.view(-1)[::2] = 0 
        return mask

先程作成したクラスを適用するための関数を作成します。

def foobar_unstructured(module, name):
    FooBarPruningMethod.apply(module, name)
    return module

実際にモデルに枝刈りを適用してみます。

model = LeNet()
foobar_unstructured(model.fc3, name='bias')

print(model.fc3.bias_mask)

ここまででL1ノルムを使った個別レイヤーの枝刈りからモデル全体の枝刈り、オリジナルの枝刈りの実装まで言及したぞ

これで軽量なモデルを簡単に試せそうですね!

Close Bitnami banner
Bitnami