Caffe2を使ってみたその2「MNISTの学習」
こんにちは.
今回の記事ではCaffe2を使ってLeNetを実装し,MNISTを学習させてみます. 方法については公式のチュートリアルを参照(というよりそのままです)しました.
MNIST - Create a CNN from Scratch | Caffe2
環境
項目 | 詳細 |
---|---|
OS | Ubuntu 14.04 |
GPU | GTX970 |
データセットの用意
Caffe2ではCaffeの時と同様, 学習画像のデータベースを用意しておき, データベースから読み込んでいきます.
下記サイトから4つのファイルをダウンロードします.
http://yann.lucan.com/exdb/mnist/
- train-images-idx3-ubyte.gz
- train-labels-idx1-ubyte.gz
- t10k-images-idx3-ubyte.gz
- t10k-labels-idx1-ubyte.gz
ダウンロードしたファイルを解凍し,Caffe2のコマンドによってデータベース化していきます.
$ gunzip train-images-idx3-ubyte.gz $ gunzip train-labels-idx1-ubyte.gz $ gunzip t10k-images-idx3-ubyte.gz $ gunzip t10k-labels-idx1-ubyte.gz # 解凍したファイルはホームフォルダに置いたとする. $ cd caffe2 $ mkdir -p caffe2/python/tutorials/tutorial_data/mnist/ $ ./build/caffe2/binaries/make_mnist_db --channel_first --db leveldb --image_file ~/train-images-idx3-ubyte --label_file ~/train-labels-idx1-ubyte --output_file caffe2/python/tutorials/tutorial_data/mnist/mnist-train-nchw-lebeldb $ ./build/caffe2/binaries/make_mnist_db --channel_first --db leveldb --image_file ~/t10k-images-idx3-ubyte --label_file ~/t10k-labels-idx1-ubyte --output_file caffe2/python/tutorials/tutorial_data/mnist/mnist-test-nchw-lebeldb
スクリプトの用意
Caffe2ではCaffeとは異なりprototxtを用意せずとも, Pythonのスクリプト1つで学習の記述ができるようです.
# -*- coding: utf-8 -*- # 読み込むモジュール import numpy as np, os, shutil from matplotlib import pyplot from caffe2.python import core, cnn, net_drawer, workspace, visualize # 初期化 core.GlobalInit(['caffe2', '--caffe2_log_level=0']) caffe2_root = '~/caffe2' # パスのオプション # 適宜書き換えてください! data_folder = '/home/username/caffe2/caffe2/python/tutorials/tutorial_data/mnist' # ネットワークの入力 def AddInput(model, batch_size, db, db_type): # DBからデータ(画像、ラベル)の読み込み data_uint8, label = model.TensorProtosDBInput( [], ['data_uint8', 'label'], batch_size=batch_size, db=db, db_type=db_type ) # データ型の変換: uint8 -> float data = model.Cast(data_uint8, "data", to=core.DataType.FLOAT) # スケールの変更: [0 255] -> [0 1] data = model.Scale(data, data, scale=float(1./256)) # 逆伝播時に微分の計算をしない(入力だから) data = model.StopGradient(data, data) return data, label # ネットワークの本体 def AddLeNetModel(model, data): # conv -> pool -> conv -> pool -> Relu(fc) -> SoftMax(fc) # NxCxHxW: 1x1x28x28 -> 20x1x24x24 -> 20x1x12x12 conv1 = model.Conv(data, 'conv1', 1, 20, 5) pool1 = model.MaxPool(conv1, 'pool1', kernel=2, stride=2) # NxCxHxW: 20x1x12x12 -> 50x1x8x8 -> 50x1x4x4 conv2 = model.Conv(pool1, 'conv2', 20, 50, 5) pool2 = model.MaxPool(conv2, 'pool2', kernel=2, stride=2) # NxCxHxW: 800x1x1x1 -> 500x1x1x1 fc3 = model.FC(pool2, 'fc3', 50 * 4 * 4, 500) fc3 = model.Relu(fc3, fc3) # NxCxHxW: 500x1x1x1 -> 10x1x1x1 pred = model.FC(fc3, 'pred', 500, 10) softmax = model.Softmax(pred, 'softmax') return softmax # ネットワークの正解率 def AddAccuracy(model, softmax, label): accuracy = model.Accuracy([softmax, label], 'accuracy') return accuracy # 学習 def AddTrainingOperators(model, softmax, label): # クロスエントロピーの計算 xent = model.LabelCrossEntropy([softmax, label], 'xent') # クロスエントロピーの平均損失の計算 loss = model.AveragedLoss(xent, 'loss') # 正解率の計算 AddAccuracy(model, softmax, label) # 損失関数の勾配を計算 model.AddGradientOperators([loss]) # 学習率の設定(lr = base_lr * (t ^ gamma)) ITER = model.Iter('iter') LR = model.LearningRate(ITER, 'LR', base_lr=-0.1, policy='step', stepsize=1, gamma=0.999) # 更新に使う定数 ONE = model.param_init_net.ConstantFill([], 'ONE', shape=[1], value=1.0) # 全パラメータにおいて更新 # param = param + param_grad * LR(このような更新式の定義なので,学習率は負値になっている.) for param in model.params: param_grad = model.param_to_grad[param] model.WeightedSum([param, ONE, param_grad, LR], param) # 20イテレーション毎にチェックポイントを作成 model.Checkpoint([ITER] + model.params, [], db='mnist_lenet_checkpoint_%05d.leveldb', db_type='leveldb', every=20) # ログなどの出力用 def AddBookkeepingOperators(model): model.Print('accuracy', [], to_file=1) model.Print('loss', [], to_file=1) for param in model.params: model.Summarize(param, [], to_file=1) model.Summarize(model.param_to_grad[param], [], to_file=1) # ここからネットワークの準備 # CNNのモデル型を学習用として用意 train_model = cnn.CNNModelHelper(order='NCHW', name='mnist_train') # データセットの読み込み data, label = AddInput(train_model, batch_size=64, db=os.path.join(data_folder, 'mnist-train-nchw-leveldb'), db_type='leveldb') # ネットワークの設定 softmax = AddLeNetModel(train_model, data) # 学習の設定 AddTrainingOperators(train_model, softmax, label) # ログの設定 AddBookkeepingOperators(train_model) # CNNのモデル型をテスト用として用意 test_model = cnn.CNNModelHelper(order='NCHW', name='mnist_test', init_params=False) # データセットの読み込み data, label = AddInput(test_model, batch_size=100, db=os.path.join(data_folder, 'mnist-test-nchw-leveldb'), db_type='leveldb') # ネットワークの設定 softmax = AddLeNetModel(test_model, data) # 正解率の設定 AddAccuracy(test_model, softmax, label) # デプロイを用意(何に使うのだろう) deploy_model = cnn.CNNModelHelper(order='NCHW', name='mnist_deploy', init_params=False) AddLeNetModel(deploy_model, 'data') # ここから学習の処理 # ネットワークの初期化 workspace.RunNetOnce(train_model.param_init_net) workspace.CreateNet(train_model.net) # pyplot用 total_iters = 200 accuracy = np.zeros(total_iters) loss = np.zeros(total_iters) # 学習 for i in xrange(total_iters): workspace.RunNet(train_model.net.Proto().name) # グラフ描画 accuracy[i] = workspace.FetchBlob('accuracy') loss[i] = workspace.FetchBlob('loss') pyplot.clf() pyplot.plot(accuracy, 'r') pyplot.plot(loss, 'b') pyplot.legend(('loss', 'accuracy'), loc='upper right') pyplot.pause(.01) # ここからテストの処理 # ネットワークの初期化 workspace.RunNetOnce(test_model.param_init_net) workspace.CreateNet(test_model.net) # pyplot用 test_accuracy = np.zeros(100) # テスト for i in range(100): workspace.RunNet(test_model.net.Proto().name) test_accuracy[i] = workspace.FetchBlob('accuracy') # グラフ描画 pyplot.plot(test_accuracy, 'r') pyplot.title('Acuracy over test batches.') print('test_accuracy: %f' % test_accuracy.mean())
スクリプトの前半部分で関数を定義して, 後半部分でそれらの関数を使ってるイメージです. スクリプトを「train_mnist.py」という名前で保存して,下記コマンドで実行できます.
$ python train_mnist.py
使ってみて思ったこと
Caffeのprototxtよりもネットワークの記述がスッキリしたように思えます. Caffe2独自の書き方に慣れさえすれば,Pythonユーザーにとってはそれほど難しいものではないように思えます. ネットワークが小さいせいか,Caffe2の新機能のおかげなのかはわかりませんが学習はサクサク動きます.