読者です 読者をやめる 読者になる 読者になる

ytbilly3636’s 研究備忘録

機械学習,Python,ガンダムなど

Caffe2を使ってみたその2「MNISTの学習」

こんにちは.

今回の記事ではCaffe2を使ってLeNetを実装し,MNISTを学習させてみます. 方法については公式のチュートリアルを参照(というよりそのままです)しました.

MNIST - Create a CNN from Scratch | Caffe2

環境

項目 詳細
OS Ubuntu 14.04
GPU GTX970

データセットの用意

Caffe2ではCaffeの時と同様, 学習画像のデータベースを用意しておき, データベースから読み込んでいきます.

下記サイトから4つのファイルをダウンロードします.

http://yann.lucan.com/exdb/mnist/

  • train-images-idx3-ubyte.gz
  • train-labels-idx1-ubyte.gz
  • t10k-images-idx3-ubyte.gz
  • t10k-labels-idx1-ubyte.gz

ダウンロードしたファイルを解凍し,Caffe2のコマンドによってデータベース化していきます.

$ gunzip train-images-idx3-ubyte.gz
$ gunzip train-labels-idx1-ubyte.gz
$ gunzip t10k-images-idx3-ubyte.gz
$ gunzip t10k-labels-idx1-ubyte.gz

# 解凍したファイルはホームフォルダに置いたとする.
$ cd caffe2
$ mkdir -p caffe2/python/tutorials/tutorial_data/mnist/
$ ./build/caffe2/binaries/make_mnist_db --channel_first --db leveldb --image_file ~/train-images-idx3-ubyte --label_file ~/train-labels-idx1-ubyte --output_file caffe2/python/tutorials/tutorial_data/mnist/mnist-train-nchw-lebeldb
$ ./build/caffe2/binaries/make_mnist_db --channel_first --db leveldb --image_file ~/t10k-images-idx3-ubyte --label_file ~/t10k-labels-idx1-ubyte --output_file caffe2/python/tutorials/tutorial_data/mnist/mnist-test-nchw-lebeldb

スクリプトの用意

Caffe2ではCaffeとは異なりprototxtを用意せずとも, Pythonスクリプト1つで学習の記述ができるようです.

# -*- coding: utf-8 -*-

# 読み込むモジュール
import numpy as np, os, shutil
from   matplotlib import pyplot
from   caffe2.python import core, cnn, net_drawer, workspace, visualize


# 初期化
core.GlobalInit(['caffe2', '--caffe2_log_level=0'])
caffe2_root = '~/caffe2'


# パスのオプション
# 適宜書き換えてください!
data_folder = '/home/username/caffe2/caffe2/python/tutorials/tutorial_data/mnist'


# ネットワークの入力
def AddInput(model, batch_size, db, db_type):
    # DBからデータ(画像、ラベル)の読み込み
    data_uint8, label = model.TensorProtosDBInput(
        [], ['data_uint8', 'label'], batch_size=batch_size, db=db, db_type=db_type
    )
    
    # データ型の変換: uint8 -> float
    data = model.Cast(data_uint8, "data", to=core.DataType.FLOAT)
    
    # スケールの変更: [0 255] -> [0 1]
    data = model.Scale(data, data, scale=float(1./256))
    
    # 逆伝播時に微分の計算をしない(入力だから)
    data = model.StopGradient(data, data)
    
    return data, label
    
    
# ネットワークの本体
def AddLeNetModel(model, data):
    # conv -> pool -> conv -> pool -> Relu(fc) -> SoftMax(fc)
    
    # NxCxHxW: 1x1x28x28 -> 20x1x24x24 -> 20x1x12x12
    conv1 = model.Conv(data, 'conv1', 1, 20, 5)
    pool1 = model.MaxPool(conv1, 'pool1', kernel=2, stride=2)
    
    # NxCxHxW: 20x1x12x12 -> 50x1x8x8 -> 50x1x4x4
    conv2 = model.Conv(pool1, 'conv2', 20, 50, 5)
    pool2 = model.MaxPool(conv2, 'pool2', kernel=2, stride=2)
    
    # NxCxHxW: 800x1x1x1 -> 500x1x1x1
    fc3   = model.FC(pool2, 'fc3', 50 * 4 * 4, 500)
    fc3   = model.Relu(fc3, fc3)
    
    # NxCxHxW: 500x1x1x1 -> 10x1x1x1
    pred  = model.FC(fc3, 'pred', 500, 10)
    softmax = model.Softmax(pred, 'softmax')
    
    return softmax
    
    
# ネットワークの正解率
def AddAccuracy(model, softmax, label):
    accuracy = model.Accuracy([softmax, label], 'accuracy')
    return accuracy
    
    
# 学習
def AddTrainingOperators(model, softmax, label):
    # クロスエントロピーの計算
    xent = model.LabelCrossEntropy([softmax, label], 'xent')
    
    # クロスエントロピーの平均損失の計算
    loss = model.AveragedLoss(xent, 'loss')
    
    # 正解率の計算
    AddAccuracy(model, softmax, label)
    
    # 損失関数の勾配を計算
    model.AddGradientOperators([loss])
    
    # 学習率の設定(lr = base_lr * (t ^ gamma))
    ITER = model.Iter('iter')
    LR = model.LearningRate(ITER, 'LR', base_lr=-0.1, policy='step', stepsize=1, gamma=0.999)
    
    # 更新に使う定数
    ONE = model.param_init_net.ConstantFill([], 'ONE', shape=[1], value=1.0)
    
    # 全パラメータにおいて更新
    # param = param + param_grad * LR(このような更新式の定義なので,学習率は負値になっている.)
    for param in model.params:
        param_grad = model.param_to_grad[param]
        model.WeightedSum([param, ONE, param_grad, LR], param)
        
    # 20イテレーション毎にチェックポイントを作成
    model.Checkpoint([ITER] + model.params, [], db='mnist_lenet_checkpoint_%05d.leveldb', db_type='leveldb', every=20)
    
    
# ログなどの出力用
def AddBookkeepingOperators(model):
    model.Print('accuracy', [], to_file=1)
    model.Print('loss', [], to_file=1)
    for param in model.params:
        model.Summarize(param, [], to_file=1)
        model.Summarize(model.param_to_grad[param], [], to_file=1)
        
        
# ここからネットワークの準備

# CNNのモデル型を学習用として用意
train_model = cnn.CNNModelHelper(order='NCHW', name='mnist_train')
# データセットの読み込み
data, label = AddInput(train_model, batch_size=64, db=os.path.join(data_folder, 'mnist-train-nchw-leveldb'), db_type='leveldb')
# ネットワークの設定
softmax = AddLeNetModel(train_model, data)
# 学習の設定
AddTrainingOperators(train_model, softmax, label)
# ログの設定
AddBookkeepingOperators(train_model)

# CNNのモデル型をテスト用として用意
test_model = cnn.CNNModelHelper(order='NCHW', name='mnist_test', init_params=False)
# データセットの読み込み
data, label = AddInput(test_model, batch_size=100, db=os.path.join(data_folder, 'mnist-test-nchw-leveldb'), db_type='leveldb')
# ネットワークの設定
softmax = AddLeNetModel(test_model, data)
# 正解率の設定
AddAccuracy(test_model, softmax, label)

# デプロイを用意(何に使うのだろう)
deploy_model = cnn.CNNModelHelper(order='NCHW', name='mnist_deploy', init_params=False)
AddLeNetModel(deploy_model, 'data')


# ここから学習の処理

# ネットワークの初期化
workspace.RunNetOnce(train_model.param_init_net)
workspace.CreateNet(train_model.net)

# pyplot用
total_iters = 200
accuracy = np.zeros(total_iters)
loss     = np.zeros(total_iters)

# 学習
for i in xrange(total_iters):
    workspace.RunNet(train_model.net.Proto().name)
    
    # グラフ描画
    accuracy[i] = workspace.FetchBlob('accuracy')
    loss[i]     = workspace.FetchBlob('loss')
    pyplot.clf()
    pyplot.plot(accuracy, 'r')
    pyplot.plot(loss, 'b')
    pyplot.legend(('loss', 'accuracy'), loc='upper right')
    pyplot.pause(.01)

    
# ここからテストの処理

# ネットワークの初期化
workspace.RunNetOnce(test_model.param_init_net)
workspace.CreateNet(test_model.net)

# pyplot用
test_accuracy = np.zeros(100)

# テスト
for i in range(100):
    workspace.RunNet(test_model.net.Proto().name)
    test_accuracy[i] = workspace.FetchBlob('accuracy')

# グラフ描画
pyplot.plot(test_accuracy, 'r')
pyplot.title('Acuracy over test batches.')
print('test_accuracy: %f' % test_accuracy.mean())

スクリプトの前半部分で関数を定義して, 後半部分でそれらの関数を使ってるイメージです. スクリプトを「train_mnist.py」という名前で保存して,下記コマンドで実行できます.

$ python train_mnist.py

使ってみて思ったこと

Caffeのprototxtよりもネットワークの記述がスッキリしたように思えます. Caffe2独自の書き方に慣れさえすれば,Pythonユーザーにとってはそれほど難しいものではないように思えます. ネットワークが小さいせいか,Caffe2の新機能のおかげなのかはわかりませんが学習はサクサク動きます.