PyTorch を用いた手書き数字の分類
[履歴] [最終更新] (2020/07/28 19:40:10)
ここは
趣味のプログラミングを楽しむための情報共有サービス。記事の一部は有料設定にして公開できます。 詳しくはこちらをクリック📝
最近の投稿
注目の記事

概要

こちらのページで基本的な使い方を把握した PyTorch を用いて、手書き数字の分類を行ってみます。サポートベクターマシンを用いた場合は HOG などの特徴量を考える必要がありましたが、ディープラーニングでは十分な質の良いデータがあればその必要がありません。

MNIST データの読み込み

手書き数字のデータとして、MNIST データをダウンロードして利用することにします。Matplotlib で描画する例は以下のようになります。

Uploaded Image

# -*- coding: utf-8 -*-
import gzip
import pickle
import matplotlib.pyplot as plt
import torch

def Main():

    # pickle 形式で保存されています。
    with gzip.open('mnist.pkl.gz', 'rb') as f:
        ((xTrain, yTrain), (xValid, yValid), _) = pickle.load(f, encoding='latin-1')

    # 28x28 ピクセルの画像データが 50000 枚分あります。
    print(xTrain.shape)  #=> (50000, 784)

    # 描画してみます。
    print(yTrain[0])  #=> 5
    plt.imshow(xTrain[0].reshape(28, 28), cmap='gray')
    plt.show()

    # pytorch で利用するためには torch.tensor に変換します。
    print(type(xTrain[0]))  #=> <class 'numpy.ndarray'>

    xTrain, yTrain, xValid, yValid = map(
        torch.tensor, (xTrain, yTrain, xValid, yValid)
    )
    print(type(xTrain[0]))  #=> <class 'torch.Tensor'>

if __name__ == '__main__':
    Main()

ニューラルネットワークの定義

手書き数字の描かれた画像を分類するニューラルネットワークとして、ディープラーニングでよく利用される「畳み込みニューラルネットワーク (CNN; Convolutional Neural Network)」を用いてみます。こちらのページに記載したとおり、torch.nn.Module を継承したクラスを利用してネットワークを定義できます。

MNIST データの分類を考えたときには、以下のようなネットワーク定義となります。ただしこれは CNN の一つの例であり、一般形ではありません。

#!/usr/bin/python
# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
import torch.nn.functional as F

def Main():

    # MNIST データは RGB ではなくグレースケールです。
    inChannels = 1

    # 0 から 9 までの数字への分類を考えます。
    outFeatures = 10

    # MNIST データは 28x28 の画像です。
    inputSize = 28

    # ネットワークを定義します。
    cnn = CNN(inChannels, outFeatures, inputSize)

    # 一つのミニバッチに含まれるデータの個数
    bs = 1

    # 乱数で MNIST と同じサイズのデータを用意してみます。
    x = torch.randn(bs, inChannels, inputSize, inputSize)
    yPred = cnn(x)
    print(cnn)
    print(yPred.shape)


class CNN(nn.Module):

    def __init__(self, inChannels, outFeatures3, inputSize):
        super(CNN, self).__init__()

        # 隠れ層の次元数など
        outChannels = 6
        kernelSize = 3
        outChannels2 = 16
        outFeatures = 120
        outFeatures2 = 84
        poolingStride = 2

        sz = inputSize - kernelSize + 1
        sz = sz // poolingStride
        sz = sz - kernelSize + 1
        sz = sz // poolingStride

        self.__poolingStride = poolingStride
        self.__conv1 = nn.Conv2d(inChannels, outChannels, kernelSize)
        self.__conv2 = nn.Conv2d(outChannels, outChannels2, kernelSize)
        self.__fc1 = nn.Linear(outChannels2 * sz * sz, outFeatures)
        self.__fc2 = nn.Linear(outFeatures, outFeatures2)
        self.__fc3 = nn.Linear(outFeatures2, outFeatures3)

    def forward(self, x):

        # 1 x 1 x 28 x 28

        x = self.__conv1(x)  #=> 1 x 6 x 26 x 26
        x = F.relu(x)  #=> 1 x 6 x 26 x 26
        x = F.max_pool2d(x, self.__poolingStride)  #=> 1 x 6 x 13 x 13

        x = self.__conv2(x)  #=> 1 x 16 x 11 x 11
        x = F.relu(x)  #=> 1 x 16 x 11 x 11
        x = F.max_pool2d(x, self.__poolingStride)  #=> 1 x 16 x 5 x 5

        # note: 第一引数を -1 とすることで、第二引数の値から形状を推定させることができます。
        x = x.reshape(-1, self.__GetNumFlatFeatures(x))  #=> 1 x 400

        x = F.relu(self.__fc1(x))  #=> 1 x 120
        x = F.relu(self.__fc2(x))  #=> 1 x 84
        x = self.__fc3(x)  #=> 1 x 10
        return x

    def __GetNumFlatFeatures(self, x):
        size = x.size()[1:]  # ミニバッチの個数の次元を除く、すべての次元
        numFeatures = 1
        for sz in size:
            numFeatures *= sz
        return numFeatures

if __name__ == '__main__':
    Main()

Conv2d(inChannels, outChannels, kernelSize)

inChannels 方向には動かさず、画像の平面内で畳み込みを行います。この畳み込みを独立に outChannels 個のフィルタで行い、結果を一つのテンソルとしてまとめます。kernelSizeOpenCV での畳み込み処理におけるカーネルと同じ概念です。

Linear(inFeatures, outFeatures)

こちらのページでも利用した線形変換です。重みとバイアスをパラメータとして持ちます。

max_pool2d(x, stride)

stride x stride において最大となる値をフィルタします。

出力例

CNN(
  (_CNN__conv1): Conv2d(1, 6, kernel_size=(3, 3), stride=(1, 1))
  (_CNN__conv2): Conv2d(6, 16, kernel_size=(3, 3), stride=(1, 1))
  (_CNN__fc1): Linear(in_features=400, out_features=120, bias=True)
  (_CNN__fc2): Linear(in_features=120, out_features=84, bias=True)
  (_CNN__fc3): Linear(in_features=84, out_features=10, bias=True)
)
torch.Size([1, 10])

分類問題で利用する損失関数について

こちらのページで回帰問題を扱う際に利用した平均二乗誤差は、分類問題ではそのまま利用できません。そのため、ここでは手書き数字の分類を行うために、以下の式で表される、交差エントロピー誤差という損失関数を用いてニューラルネットワークを学習します。

$$L = -\frac{1}{N} \sum_{j=1}^{N=64} \sum_{i=1}^{10} t_{i,j} \log( p_{i,j} ) $$

MNIST データには学習用のデータが 50000 枚あります。ディープラーニングでパラメータの学習のためにループを回す際に、利用可能なすべての学習用のデータを分割して、小さなバッチデータ毎にループを回す手法があります。本ページではミニバッチのサイズ $N$ を 64 として学習することにします。

ある一つの画像データをニューラルネットワークに入力として与えると、入力画像が 0-9 の数字である確率が、長さ 10 のベクトルとして出力されます。実際には $N$ 個のデータを一度に入力するため、このベクトルが $N$ 個出力されます。

交差エントロピー誤差では、10 の長さのベクトルのうち、例えば入力画像が 0 という数字であった場合は、最初の要素だけを取り出して対数を取ります。10 個の確率から一つのデータを取り出せるように $t_{i,j}$ は 0 または 1 の値を取ります。$N$ 個のデータについて同様の処理を行い、平均を計算したものが誤差となります。

例えば $\log(1)$ は 0 となるため、正しい分類ができている場合の誤差は 0 となります。

ソフトマックス関数について

交差エントロピー誤差を計算するためには、ニューラルネットワークの出力を確率として扱えるように変換する必要があります。PyTorch では交差エントロピー誤差を計算する関数 nn.CrossEntropyLoss の内部で、ソフトマックス関数 nn.Softmax を利用して出力を確率として扱えるように変換しています。

nn.CrossEntropyLoss の利用例

以下では nn.CrossEntropyLoss で計算した誤差と、定義に基いて手動計算した誤差が一致することを確認しています。

import torch
import torch.nn as nn

lossFn = nn.CrossEntropyLoss()

N = 64
y = torch.empty(N, dtype=torch.long).random_(10)

yPred = torch.randn(N, 10)

loss = lossFn(yPred, y)

loss2 = 0.0
for j in range(N):
    loss2 += -torch.log(torch.exp(yPred[j][y[j]]) / sum(torch.exp(yPred[j])))
loss2 /= N

print(loss)  #=> 2.8512
print(loss2)  #=> 2.8512

ニューラルネットワークの学習

上述の CNN と交差エントロピー誤差を用いて MNIST データの分類を試してみます。

# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader

import gzip
import pickle
import matplotlib.pyplot as plt

def Main():

    # MNIST データ
    xTrain, yTrain, xValid, yValid = GetMnistData()

    # CNN モデル
    inChannels = 1
    outFeatures = 10
    inputSize = 28
    model = CNN(inChannels, outFeatures, inputSize)

    # 交差エントロピー誤差
    lossFn = F.cross_entropy

    # 学習率、学習の反復回数
    learningRate = 0.001
    iters = 10

    # 最適化関数
    optimizer = torch.optim.Adam(model.parameters(), lr=learningRate)

    # ミニバッチのサイズ
    bs = 64
    trainDs = TensorDataset(xTrain, yTrain)
    trainDl = DataLoader(trainDs, batch_size=bs, shuffle=True)

    # 全体のループ
    for t in range(iters):

        # ミニバッチ毎のループ
        for x, y in trainDl:
            yPred = model(x)
            loss = lossFn(yPred, y)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # 誤差の出力
        print(t, loss.item())

    # 学習済みモデルの検証
    CheckOutput(model, xTrain, yTrain)
    CheckOutput(model, xValid, yValid)

def CheckOutput(model, x, y):
    yPred = map(lambda xx: xx.max(0).indices.item(), model(x))
    wrong = 0
    for xx, yy, yyPred in zip(x, y, yPred):
        if yy.item() == yyPred:
            continue
        # print('{} != {}'.format(yy.item(), yyPred))
        # plt.imshow(xx.reshape(28, 28), cmap='gray')
        # plt.show()
        wrong += 1
    print('Accuracy: {}'.format(100 - wrong / len(x) * 100))

def GetMnistData():
    with gzip.open('mnist.pkl.gz', 'rb') as f:
        (xTrain, yTrain), (xValid, yValid), _ = pickle.load(f, encoding='latin-1')
    xTrain = list(map(lambda x: x.reshape(1, 28, 28), xTrain))
    xValid = list(map(lambda x: x.reshape(1, 28, 28), xValid))
    return map(torch.tensor, (xTrain, yTrain, xValid, yValid))

class CNN(nn.Module):

    def __init__(self, inChannels, outFeatures3, inputSize):
        super(CNN, self).__init__()
        outChannels = 6
        kernelSize = 3
        outChannels2 = 16
        outFeatures = 120
        outFeatures2 = 84
        poolingStride = 2
        sz = inputSize - kernelSize + 1
        sz = sz // poolingStride
        sz = sz - kernelSize + 1
        sz = sz // poolingStride
        self.__poolingStride = poolingStride
        self.__conv1 = nn.Conv2d(inChannels, outChannels, kernelSize)
        self.__conv2 = nn.Conv2d(outChannels, outChannels2, kernelSize)
        self.__fc1 = nn.Linear(outChannels2 * sz * sz, outFeatures)
        self.__fc2 = nn.Linear(outFeatures, outFeatures2)
        self.__fc3 = nn.Linear(outFeatures2, outFeatures3)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.__conv1(x)), self.__poolingStride)
        x = F.max_pool2d(F.relu(self.__conv2(x)), self.__poolingStride)
        x = x.reshape(-1, self.__GetNumFlatFeatures(x))
        x = F.relu(self.__fc1(x))
        x = F.relu(self.__fc2(x))
        x = self.__fc3(x)
        return x

    def __GetNumFlatFeatures(self, x):
        size = x.size()[1:]
        numFeatures = 1
        for sz in size:
            numFeatures *= sz
        return numFeatures

if __name__ == '__main__':
    Main()

実行例

0 0.10681144148111343
1 0.06327979266643524
2 0.040145404636859894
3 0.015086745843291283
4 0.005156606901437044
5 0.0018728474387899041
6 0.000745357247069478
7 0.0005938038229942322
8 5.65591617487371e-05
9 0.0003749439201783389
Accuracy: 99.456
Accuracy: 98.58

訓練用のデータで 99.456%、未知のデータで 98.58% となりました。分類に失敗したデータの例としては以下のようなものがあります。

2 と認識 (正しくは 3)

Uploaded Image

8 と認識 (正しくは 3)

Uploaded Image

6 と認識 (正しくは 5)

Uploaded Image

関連ページ