Caffe2を使ってみたその3「自作データセットの学習」
こんにちは.
前回の記事ではチュートリアルに則ってMNISTデータセットを学習させました. 今回は自分で用意した画像のデータセットの作成し,CNNに学習させる方法についてまとめます.
データセットの作成
以下のチュートリアルを参考にしました.
チュートリアル曰く,DBの中身は文字列のようです. どうやらTensorProtosというやつでDBを作って,TensorProtosDBInputでデータを読み込めば良いようです.
To a DB, it treats the keys and values as strings, but you probably want structured contents. One way to do this is to use a TensorProtos protocol buffer: it essentially wraps Tensors, aka multi-dimensional arrays, together with the tensor data type and shape information. Then, one can use the TensorProtosDBInput operator to load the data into an SGD training fashion.
以下にDB作成のコードをのせておきます. このコードではCIFAR10のデータベースを使っているのですが,chainerのdatasetsから呼び出しています. とにかくNumPyで読み込めれば良いので,OpenCVのimreadで画像を読んできても良いと思います.
NumPyをCaffe2Tensorという型に変換してTensorProtosに入れていきます. なお,int型は対応していないと言われたので,画像のラベルをわざわざfloat型に変換しています.
データセットを読み込んで学習
前回のMNISTを学習したときのコードを少しだけ改変したものです. AddInputの部分が少し変わっています.
結果
なんかいまいちですが,ネットワークの構造や学習回数を見直せばもう少し良くなると思います. そのうちやり直します.