2017/08/18

deeplearn.jsを遠くからそっと眺めてみた

Googleから発表された機械学習のためのJavaScriptライブラリのdeeplearn.jsを眺めてみたので、纏めておきます。
オフィシャルを舐めて、どういうものか、これからどんな感じになっていくか(いってほしいか)をダラっとタレます。


注意
眺めたのはv0.1.0なので、内容が今後大幅に変更になる可能性があります(というかある)。


触った環境
MacBook Pro (Retina, 15-inch, Mid 2015)
    - プロセッサ: 2.5GHz Intel Core i7
    - メモリ: 16GB 1600 MHz DDR3
    - グラフィックス: AMD Radeon R9 M370X 2048 MB
Chrome 60.0.3112.101 (Official Build) (64ビット)
deeplearn.js v0.1.0


Exampleを眺める

何ができるかを眺めるためのExamplesが準備されています。
Model Builderが面白い(というか分かりやすい)のでおすすめです。
デフォルトでみんなだいすきMNISTをCNNで推論(Inference)するプログラムが走っています。

CPUとGPUの性能差を噛みしめる


そのまま(デフォルト値)だと、正答率が低いので、左側の[TRAIN]ボタンをクリックして学習を開始します。
すると20〜40秒ほどで、学習が収束した感じになる。早い!。左から2つ目のカラムの上にCPUとGPUを切り替えるトグルボタンがあるので、それでCPUに切り替えて実行してみると、途中で「もういいや」ってなるぐらい遅いのでGPUのありがたみを噛みしめられますん。
CPUと言ってもJSインタプリタが間に入っているのでCPUはより遅いのですが、TensorFlowのCPUとくらべてもdeeplearn.jsのGPUは速いんじゃないかなぁという体感値です(正確には計測していませんが…)

モデル定義ファイル

左カラムの「Download model」というリンクがワクワクするわけですが、次のようなJSONファイルでした。
[
    {"layerName":"Convolution","fieldSize":5,"stride":1,"zeroPad":2,"outputDepth":8},
    {"layerName":"ReLU"},{"layerName":"Max pool","fieldSize":2,"stride":2,"zeroPad":0},
    {"layerName":"Convolution","fieldSize":5,"stride":1,"zeroPad":2,"outputDepth":16},
    {"layerName":"ReLU"},
    {"layerName":"Max pool","fieldSize":2,"stride":2,"zeroPad":0},
    {"layerName":"Flatten"},
    {"layerName":"Fully connected","hiddenUnits":10}
]

わー、よみやすーい。
どうやらModel Builderのモデルを定義するためだけのファイルっぽいです。学習済みのパラメータなどは含まれないので、学習結果を保存することは今のところできなさそう。
でも必須機能だからすぐ対応するでしょうけど(対応してください)。

性能ベンチマーク

他のExamplesにBenchmarksというのがあり、CPUとGPUの性能比較とかができます。
例えばMatrix Multiplication(matmul)の性能比較とか面白いです。

我が軍(GPU)は圧倒的ではないか!
てかGPUの得意分野なので当たり前だし、CPUはJSインタプリタ挟んでるので不利だし…。といってもこの性能差を見せられるとグッときます個人的に。


TensorFlowのモデル使いたい

TensorFlowで学習したモデル(チェックポイント)を読み込んで、推論だけブラウザで実行というチュートリアルがありますので、既存モデルをdeeplearn.jsで実行することができます。
一部抜粋です。

import {CheckpointLoader, Graph} from 'deeplearnjs';
// manifest.json is in the same dir as index.html.
const reader = new CheckpointReader('.');
reader.getAllVariables().then(vars => {
  // Write your model here.
  const g = new Graph();
  const input = g.placeholder('input', [784]);
  const hidden1W = g.constant(vars['hidden1/weights']);
  const hidden1B = g.constant(vars['hidden1/biases']);
  const hidden1 = g.relu(g.add(g.matmul(input, hidden1W), hidden1B));
  ...
  ...
  const math = new NDArrayMathGPU();
  const sess = new Session(g, math);
  math.scope(() => {
    const result = sess.eval(...);
    console.log(result.getValues());
  });
});
わー、てがきーぃ。
Caffe to TensorFlowみたいに、モデル作成プログラムを自動生成するツールの登場が待たれる…。

なんとなく分かってきたので書いてみる

何ができそうかが見えてきたので、じゃぁどうやって書くのか。
それもチュートリアルがありますが、先程のTensorFlowのモデルをロードするプログラムが全てです。
Graphオブジェクトにplaceholderやらconstant、variable、reluなどで計算グラフを構築し、Sessionのevalでグラフを実行するという感じっぽいです。TensorFlowのLow Level APIに慣れている人なら読みやすいですが、同時にhigh-level APIはいねぇ〜がぁ〜という気持ちが湧き出てきます。

Model Builderがあるので、high-level APIも遠くない未来にお目見えするんじゃないかなぁとか思ったり、、、というかKeras Model的なモジュールがあれば、WebフロントでのDeep Learning活用が捗りそう。

NumPyっぽくも使える

deeplearn.jsはTensorFlowみたくグラフを定義して実行するだけでなく、NumPyみたく逐次実行APIも提供しているのが面白い。gpu.jsよりもテンソル計算に関しては使いやすそうな印象。

const shape = [2, 3];  // 2 rows, 3 columns
const a = Array2D.new(shape, [1.0, 2.0, 3.0, 10.0, 20.0, 30.0]);

ロードマップ

最後ですが、v0.1.0と産まれたてホヤホヤなdeeplearn.jsなので、現状を評価してもあまり意味はなく、Production Readyに向けてどういう方向で進んでいくのかが大事なわけで、そこんところがロードマップに纏められていました。
テキトーに纏めると、こんな感じです↓(翻訳間違ってたらゴメンナサイ)

More devices: WebGL 1.0と2.0をターゲットにしてるけどモバイルやら他のブラウザへ展開していきたいなぁ

Optimizers: SGDしかサポートしとらんけん、RMSPropやらAdamとかAdamaxとか色々サポートせにゃならん

Logical sampling: deeplearn.jsで扱う高次元配列は2Dのテクスチャ空間にマッピングしてから計算されるのでシェーダープログラム(GPU演算させる部分)が書きづらい。論理空間が定義できたら、書きやすくなるよねー。mutmulだけ論理空間でテスト実装してるけど、これを展開していけるようにしないとね…。

Batch as the outer dimension: いまbatch sizeが1しかサポートしてないけど、Logical samplingの機構ができたら何とかできるはず

Automatic TensorFlow to deeplearn.js: TFのGraphDefから自動的にdeeplearn.jsのモデルにポーティングできるようにする計画アルよ。

Dynamic batching: 計算グラフの計算処理を一括化するような計算最適化をNDArrayMathのレイヤーで実装します。←TensorFlow Foldみたいなことをやるっぽい?

Recurrence in training: 再帰をサポートしていきます。←RNNとかLSTMのことかな?


おわりに

deeplearn.jsはDefined and Run型の機械学習ライブラリで、かつ、Googleから出されているものなんでTensorFlow知ってる人にはすぐ馴染めるライブラリになりそうな印象でした。
というかAPIとか読まずに理解できますw

v0.1.0ですので、これから変化するでしょうし、このエントリ自体がオフィシャルをサラッと舐めただけの内容なので薄〜い!
このエントリを信用せずw ちゃんとオフィシャルを読んだり、手元で実装してみた方が良いと思います。
Typescriptがオススメと書いているので、実装しやすいですし、npmとかでサクッとはじめられます。