Common Lispによる線形分類器ライブラリcl-online-learningを書いた

去年、オンライン機械学習本(クマ本)を読んで線形分類器を実装する記事を書いたり、それらのアルゴリズムをまとめてcl-online-learningというライブラリを作ってLispmeetupで紹介したりした。

その後放置していたのだが、最近になってもはや使わないようなアルゴリズムは削除したり、疎ベクトルへの対応や、学習器のCLOSオブジェクトを単なる構造体にするなどの大きな変更をした。このあたりで一度ちゃんと紹介記事を書いておこうかと思う。

cl-online-learningの特徴は、

  • アルゴリズム: パーセプトロン、AROW、SCW-I (おすすめはAROW)
  • 二値分類、多値分類に対応 (one-vs-one、one-vs-rest)
  • データが密ベクトル、疎ベクトルのどちらの場合にも対応
  • Common LispC/C++のライブラリ(AROW++)を上回る速度

インストール

local-projectsディレクトリにソースを展開する。

cd ~/quicklisp/local-projects/
git clone https://github.com/masatoi/cl-online-learning.git

あるいは、Roswellがインストールされているなら単に

ros install masatoi/cl-online-learning

データの読み込み

1つのデータはラベル(+1/-1)と入力ベクトルのペア(cons)で、データのシーケンスがデータセットとなる。 libsvm datasetsの形式のファイルからデータセットを作るには、read-data関数が使える。とりあえずデータはlibsvm datasetsの二値分類データからa1aを使うことにする。

(defpackage :clol-user
  (:use :cl :cl-online-learning :cl-online-learning.utils :cl-online-learning.vector))

(in-package :clol-user)

(defparameter a1a-dim 123)
(defparameter a1a-train (read-data "/path/to/a1a"   a1a-dim))
(defparameter a1a-test  (read-data "/path/to/a1a.t" a1a-dim))

(car a1a-train)
;; (-1.0d0
;;  . #(0.0d0 0.0d0 1.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 1.0d0 0.0d0
;;      0.0d0 1.0d0 0.0d0 0.0d0 0.0d0 0.0d0 1.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0
;;      0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0
;;      0.0d0 0.0d0 1.0d0 0.0d0 0.0d0 1.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0
;;      0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 1.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0
;;      0.0d0 0.0d0 0.0d0 1.0d0 0.0d0 0.0d0 1.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0
;;      1.0d0 0.0d0 1.0d0 1.0d0 0.0d0 0.0d0 0.0d0 1.0d0 0.0d0 0.0d0 1.0d0 0.0d0
;;      0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0
;;      0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0
;;      0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0
;;      0.0d0 0.0d0 0.0d0))

モデル定義

学習器のモデルは単なる構造体で、make-系関数で生成できる。その際いずれもデータの次元数を必要とする。その他にAROWは1個、SCWは2個のメタパラメータを指定する必要がある。パーセプトロン、AROW、SCWのモデルをまとめて定義すると、

(defparameter perceptron-learner (make-perceptron a1a-dim))
(defparameter arow-learner (make-arow a1a-dim 10d0))        ; gamma > 0
(defparameter scw-learner  (make-scw  a1a-dim 0.9d0 0.1d0)) ; 0 < eta < 1 , C > 0

訓練

データ1個を学習するには各学習器のupdate関数を使う。AROWならarow-update関数。これにデータの入力ベクトルとラベルを与えることで、arow-learnerが破壊的に更新される。

(arow-update arow-learner (cdar a1a-train) (caar a1a-train))
;; #S(AROW
;;  :INPUT-DIMENSION 123
;;  :WEIGHT #(0.0d0 0.0d0 -0.04d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0) ...
;;  :BIAS -0.04d0
;;  :GAMMA 10.0d0
;;  :SIGMA #(1.0d0 1.0d0 0.96d0 1.0d0 1.0d0 1.0d0 1.0d0 1.0d0 1.0d0 1.0d0) ...
;;  :SIGMA0 0.96d0)

これをデータセット全体に対して行うのがtrain関数である。

(train arow-learner a1a-train)

予測

こうして学習したモデルを使って、ある入力ベクトルに対して予測を立てるには各学習器のpredict関数を使う。AROWならarow-predict関数。

(arow-predict arow-learner (cdar a1a-test))
;; 1.0d0

正解の値(caar a1a-test)が-1.0d0なのでここは外してしまっている。

これをテストデータ全体に対して行ない、正答率を返すのがtest関数である。

(test arow-learner a1a-test)
;; Accuracy: 84.44244%, Correct: 26140, Total: 30956

となって84%弱の精度が出ていることが分かる。

マルチクラス分類

データの読み込み (MNIST)

マルチクラス分類ではデータのラベルが+1/-1ではなく、0以上の整数になる。例えばlibsvm datasetsからMNISTのデータを落としてきて読み込んでみる。読み込みはread-data関数にmulticlass-pキーワードオプションをつけて呼び出す。

(defparameter mnist-dim 780)
(defparameter mnist-train (read-data "/home/wiz/tmp/mnist.scale" mnist-dim :multiclass-p t))
(defparameter mnist-test  (read-data "/home/wiz/tmp/mnist.scale.t" mnist-dim :multiclass-p t))
;; このデータセットはラベルが1からではなく0から始まるので1足しておく
(dolist (datum mnist-train) (incf (car datum)))
(dolist (datum mnist-test)  (incf (car datum)))

(car mnist-train)
;; (5 . #(0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 ...))
モデル定義

マルチクラス分類は二値分類器の組み合わせで実現する。組み合せ方には色々あるが、cl-online-learningではone-vs-oneとone-vs-restを用意している。一般にone-vs-oneの方が精度が高いが、クラス数の二乗に比例する二値分類器が必要になる。一方のone-vs-restはクラス数に比例する。

例えばone-vs-oneで、二値分類器としてAROWを用いる場合の定義はこうなる。

(defparameter mnist-arow
  (make-one-vs-one mnist-dim    ; データの次元
                   10           ; クラス数
                   'arow 10d0)) ; 二値分類器の型とそのパラメータ

この構造体に対しても二値分類のときと同じくone-vs-one-update、one-vs-one-predict関数でデータを一つずつ処理できるし、train、test関数でデータセットをまとめて処理できる。

訓練、予測

データセットを8周訓練する時間を計測し、テストを行うコードは以下のようになる。

(time (loop repeat 8 do (train mnist-arow mnist-train)))
(test mnist-arow mnist-test)
;; Evaluation took:
;;   3.946 seconds of real time
;;   3.956962 seconds of total run time (3.956962 user, 0.000000 system)
;;   100.28% CPU
;;   13,384,797,419 processor cycles
;;   337,643,712 bytes consed

;; Accuracy: 94.6%, Correct: 9460, Total: 10000
liblinearの場合

高速な線形分類器とされるliblinearで同じデータを学習してみる。

wiz@prime:~/tmp$ time liblinear-train -q mnist.scale mnist.model
real    2m26.804s
user    2m26.668s
sys     0m0.312s

wiz@prime:~/tmp$ liblinear-predict mnist.scale.t mnist.model mnist.out
Accuracy = 91.69% (9169/10000)

こちらはデータの読み込みなども含めた時間なのでフェアな比較ではないが、大まかにいってcl-online-learningの方が大幅に速いといえる。また精度もcl-online-learning(AROW + one-vs-one)の方が良い。ちなみにliblinearのマルチクラス分類はone-vs-restを使っているらしい。

疎なデータの分類

a1aのデータを見ると気付くのは、ほとんどの要素が0の疎(スパース)なデータであるということだ。例えば「単語が文書に出現する回数」のような特徴量は高次元かつスパースになる。これをそのまま扱うと空間計算量も時間計算量も膨れ上がってしまうので、このようなデータではデータの次元数の長さのベクタを用意するのではなく、非零値のインデックスと値のペアだけを保持しておけばいい。 cl-online-learning.vectorパッケージに定義されているsparse-vector構造体がそれで、インデックスのベクタと値のベクタをスロットに持つ。

(make-sparse-vector
 (make-array 3 :element-type 'fixnum :initial-contents '(3 5 10))
 (make-array 3 :element-type 'double-float :initial-contents '(10d0 20d0 30d0)))

;; #S(CL-ONLINE-LEARNING.VECTOR::SPARSE-VECTOR
;;    :LENGTH 3
;;    :INDEX-VECTOR #(3 5 10)
;;    :VALUE-VECTOR #(10.0d0 20.0d0 30.0d0))

疎ベクトルの形でデータセットを読み込むにはread-data関数にsparse-pキーワードオプションをつけて呼び出す。試しに、1355191次元という超高次元のデータセットnews20.binaryを読み込んでみる。1つのデータはラベルとsparse-vector構造体のペアになっていることが分かる。

(defparameter news20.binary-dim 1355191)
(defparameter news20.binary (read-data "/home/wiz/datasets/news20.binary" news20.binary-dim :sparse-p t))

(car news20.binary)
;; (-1.0d0
;;  . #S(CL-ONLINE-LEARNING.VECTOR::SPARSE-VECTOR
;;       :LENGTH 3645
;;       :INDEX-VECTOR #(0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
;;                       ...
;;                       3636 3637 3638 3639 3640 3641 3642 3643 3644)
;;       :VALUE-VECTOR #(0.01656300015747547d0 0.01656300015747547d0
;;                       ...
;;                       0.01656300015747547d0)))

データ中の非零値の数をヒストグラムにしてみるとこうなる。

(ql:quickload :clgplot)
(clgp:plot-histogram (mapcar (lambda (d) (clol.vector::sparse-vector-length (cdr d)))
                             news20.binary) 200 :x-range '(0 3000))


1355191次元といってもほとんどのデータが2000次元以下なので疎なデータであることが分かる。

これを学習するためには、二値分類器としてarowの代わりにsparse-arowを使う。同様にパーセプトロンやSCWにもスパース版がある。

(defparameter news20.binary.arow (make-sparse-arow news20.binary-dim 10d0))
(time (loop repeat 20 do (train news20.binary.arow news20.binary)))
(test news20.binary.arow news20.binary)
;; Evaluation took:
;;   1.588 seconds of real time
;;   1.588995 seconds of total run time (1.582495 user, 0.006500 system)
;;   [ Run times consist of 0.006 seconds GC time, and 1.583 seconds non-GC time. ]
;;   100.06% CPU
;;   5,386,830,659 processor cycles
;;   59,931,648 bytes consed

;; Accuracy: 99.74495%, Correct: 19945, Total: 19996
AROW++の場合

同じことをC++によるAROW実装のAROW++を使ってやってみる。

wiz@prime:~/datasets$ arow_learn -i 20 news20.binary news20.binary.model.arow 
Number of features: 1355191
Number of examples: 19996
Number of updates:  37643
Done!
Time: 9.0135 sec.

wiz@prime:~/datasets$ arow_test news20.binary news20.binary.model.arow 
Accuracy 99.915% (19979/19996)
(Answer, Predict): (t,p):9986 (t,n):9993 (f,p):4 (f,n):13
Done!
Time: 2.2762 sec.
liblinearの場合
wiz@prime:~/datasets$ time liblinear-train -q news20.binary news20.binary.model

real    0m2.800s
user    0m2.772s
sys     0m0.265s
wiz@prime:~/datasets$ liblinear-predict news20.binary news20.binary.model news20.binary.out
Accuracy = 99.875% (19971/19996)

なおAROW++もliblinearもベクトルの内部表現は疎ベクトルでやっている模様。

疎なデータの分類(マルチクラス)

マルチクラス分類の場合でも同じことができるので、MNISTでやってみる。MNISTも画像データではあるが、六割程度は0なので疎なデータといえる。
この場合は疎ベクトルかつマルチクラス分類なので、read-data関数にsparse-pとmulticlass-pの両方のオプションをつけてデータを読み込む。密なデータの時と同様にmake-one-vs-oneを使うが、その引数にsparse-arowを指定するところが異なる。あとは大体一緒。

(defparameter mnist-train.sp (read-data "/home/wiz/tmp/mnist.scale" mnist-dim :sparse-p t :multiclass-p t))
(defparameter mnist-test.sp  (read-data "/home/wiz/tmp/mnist.scale.t" mnist-dim :sparse-p t :multiclass-p t))
;; このデータセットはラベルが1からではなく0から始まるので1足しておく
(dolist (datum mnist-train.sp) (incf (car datum)))
(dolist (datum mnist-test.sp)  (incf (car datum)))

(defparameter mnist-arow.sp (make-one-vs-one mnist-dim 10 'sparse-arow 10d0))
(time (loop repeat 8 do (train mnist-arow.sp mnist-train.sp)))

;; Evaluation took:
;;   1.347 seconds of real time
;;   1.348425 seconds of total run time (1.325365 user, 0.023060 system)
;;   [ Run times consist of 0.012 seconds GC time, and 1.337 seconds non-GC time. ]
;;   100.07% CPU
;;   4,570,387,768 processor cycles
;;   337,618,400 bytes consed

となって、約3倍の高速化となっている。

まとめ

  • Common Lispで線形分類器を書いた
    • 学習器をCLOSオブジェクトではなく単なる構造体にしたり、ベクトル演算の実行時の型チェックを外す、学習器の構造体を破壊的に更新して一時的なデータ構造を作らないなどのチューニングにより訓練部分はかなり速い。
    • AROWとSCW-Iは共分散行列の対角成分だけを使う近似をしている
  • コマンドラインベースのliblinearなどとは立ち位置が違うかも
  • Common Lispにはmecab互換の形態素解析エンジンcl-igoもあるので、文書分類などに応用できるかも