TechFULの中の人

triple-four’s blog

PyTorchで機械学習!

こんにちは!TechFULでアルバイトしているあたかです。 今回はPyTorchを用いて、MNISTというデータセットを用いた手書き文字の分類の実装を解説します!

PyTorchとは?

PyTorchはFacebookによって開発が主導されたPython向けのディープラーニング用のフレームワークです。 自動で微分してくれたりする機能があり、自分で一からコーディングしなくても比較的楽に実装することができます。

MNIST データセット

MNISTとは70,000個の手書きの数字の画像と正解ラベルの組みが用意されたデータセットです。画像は28×28のピクセルで、グレースケール(白黒) で1ピクセルに0以上255以下の整数が割り当てられています。 正解ラベルは書かれた数字が0~9のどれかを示す数字です。以下に例を示します。この画像の一つ一つの数字がラベルとセットで一つのデータです。 f:id:sanbaiefforts:20190315112228p:plain

実験

ニューラルネットワークの構造は 中間層は全結合層の4層です。 訓練データとして60,000個、テストデータとして10,000個のデータを用います。

入力

28×28=784のそれぞれの画素を0~1の実数値に正規化します。

出力

出力層の活性化関数はsoftmax関数です。 それぞれの数字に対しての一致率とみなすことができます。

損失関数

損失関数はクロスエントロピー関数をつかいます。分類問題では二乗誤差よりもクロスエントロピー関数の方が優れています。

PyTorchのコード

実際のコードを解説していきます。

import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from sklearn import datasets, model_selection

まずは、必要なPyTorchとScikit-learnのライブラリを読み込みます。 これで、いろいろな機能が使えるようになりました。

from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1, cache=True)

OpenMLからMNISTのデータセットをとってきます。

mnist.target = mnist.target.astype(np.long)

このデータセットではラベルがstring型で与えられるのでlong型になおしておきます。

mnist_data = mnist.data / 255

学習しやすいように画素値を正規化します。

from matplotlib import pyplot as plt
from matplotlib import cm

plt.imshow(mnist_data[1119].reshape(28, 28), cmap=cm.gray_r)
plt.show()

ここで、学習とは関係ないですがデータの中見をみてみます。 1119番目のデータはこのようになっています。正解ラベルはもちろん0です。

f:id:sanbaiefforts:20190321115952p:plain

mnist_label = mnist.target

わかりやすいようにmnist_labelという名前の変数に正解ラベルをいれときます。

train_size = 60000
test_size = 10000
train_X, test_X, train_Y, test_Y = model_selection.train_test_split(
    mnist_data, mnist_label, train_size=train_size,test_size=test_size)

train_X, test_X, train_Y, test_Yという変数にそれぞれデータを入れます。 Xに画像データ、Yに正解ラベルがはいっています。

train_X = torch.from_numpy(train_X).float()
train_Y = torch.from_numpy(train_Y).long()
test_X = torch.from_numpy(test_X).float()
test_Y = torch.from_numpy(test_Y).long()

PyTorchで扱えるようにnumpy.ndarrayからtensorに型変換します。

train = TensorDataset(train_X, train_Y)
train_loader = DataLoader(train, batch_size=100, shuffle=True)

trainという変数に訓練データと正解ラベルを一緒にいれます。 train_loaderに100個ごとに訓練データと正解ラベルの組みをまとめます。

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(784, 256)
        self.fc2 = nn.Linear(256, 256)
        self.fc3 = nn.Linear(256, 256)
        self.fc4 = nn.Linear(256, 128)
        self.fc5 = nn.Linear(128, 128)
        self.fc6 = nn.Linear(128, 10)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = F.relu(self.fc3(x))
        x = F.relu(self.fc4(x))
        x = F.relu(self.fc5(x))
        x = F.dropout(x, training=self.training)
        x = self.fc6(x)
        return F.log_softmax(x,dim=0)

model = Net()

ネットワークを定義しています。 ドロップアウトというテクニックをつかっています。

# 誤差関数
criterion = nn.CrossEntropyLoss()
# 最適化関数
optimizer = optim.SGD(model.parameters(), lr=0.01)
# ネットワークの学習
for epoch in range(1000):
    total_loss = 0
    # 分割したデータの取り出し
    for train_x, train_y in train_loader:
     #データ型の変換
        train_x, train_y = Variable(train_x), Variable(train_y)
        # 勾配をリセット
        optimizer.zero_grad()
        # 順伝播
        output = model(train_x)
        # 誤差
        loss = criterion(output, train_y)
        # 逆伝播
        loss.backward()
        # 重みの更新
        optimizer.step()
        # 誤差の累積
        total_loss += loss.item()
    # 累積誤差を100回ごとに表示
    if (epoch+1) % 100 == 0:
        print(epoch+1, total_loss)

結果は以下のようになりました。

100 0.007138569951499418
200 0.008241292930844635
300 0.019155481463363433
400 0.023139646113246704
500 0.005373297337906591
600 0.005126366108063829
700 0.018377367902964403
800 0.0029309967260018333
900 0.06596824008159818
1000 0.016231205472653

テストデータをつかって、一致率を計算します。

test_x, test_y = Variable(test_X), Variable(test_Y)
result = torch.max(model(test_x).data, 1)[1]
# 一致率を計算
accuracy = sum(test_y.data.numpy() == result.numpy()) / len(test_y.data.numpy())
# 一致率を表示
accuracy

結果は以下のようになりました。

0.9766

おわりに

ニューラルネットワークの基礎知識があれば、理解するのはそれほど難しくなかったと思います。 ぜひ実際に自分で試してみてください! わからないことなどあれば、そのたびにググってみてください。優秀な人はググりかたが上手だと思います。
StackOverflowなどのサイトで自分と同じような疑問をもっている人がいるかもしれません。英語の勉強にもなるので一石二鳥ですね。
そもそもPyTorchのインストール方法がわからなかった人もいるかもしれません。開発環境構築は初心者にとって大きなハードルだと思います。次回はそのあたりを易しく解説できたらいいなとおもいます。

最後まで読んでいただきありがとうございます。それでは。


運営会社 / サービス

444株式会社

triple-four.com

TechFUL

procon.techful.jp