CLMLのランダムフォレストを試してみる

ランダムフォレストは決定木ベースの分類/回帰モデルで、ニューラルネットSVMなどと同様に非線形モデルなので線形分離不可能な問題にも使える。SVMはデータ数に対して指数的に計算時間がかかる一方、ランダムフォレストはデータ数をnとしてn*log(n)のオーダであり、軽い。また、SVMは基本的に二値分類器なので複数の学習器を組み合わせてマルチクラス分類することが多いが、ランダムフォレストは元からマルチクラスに対応していて追加のコストがかからない。さらに並列化も簡単にできるなど、扱いやすい性質をたくさん備えているので現実世界でよく使われている。

参考にしたランダムフォレストについての記事いろいろ

CLMLにランダムフォレストの実装があったので試してみる。

CLMLはCommon Lisp用の機械学習パッケージ集であり、Quicklispからインストールできる。ただし、あらかじめ処理系のdynamic-space-sizeを2560以上にして起動しておく必要があることに注意する。

先に必要ライブラリを読み込んでおく。

(ql:quickload :clml)
(ql:quickload :cl-online-learning)
(ql:quickload :alexandria)

データの読み込み

まずはデータの読み込みの例。サンプルデータをネットからダウンロードしてきて読み込む。

(defparameter *syobu*
  (clml.hjs.read-data:read-data-from-file 
   (clml.utility.data:fetch "https://mmaul.github.io/clml.data/sample/syobu.csv")
   :type :csv :csv-type-spec '(string integer integer integer integer)))

;; CL-USER> *syobu*
;; #<CLML.HJS.READ-DATA:UNSPECIALIZED-DATASET >
;; DIMENSIONS: 種類 | がく長 | がく幅 | 花びら長 | 花びら幅
;; TYPES:      UNKNOWN | UNKNOWN | UNKNOWN | UNKNOWN | UNKNOWN
;; NUMBER OF DIMENSIONS: 5
;; DATA POINTS: 150 POINTS

;; CL-USER> (aref (clml.hjs.read-data:dataset-points *syobu*) 0)
;; #("Setosa" 51 35 14 2)

ここから分かるように、データセットは特徴量の各次元の名前のリストと、データ点を表すベクタのベクタを持っている。例えば以下のようにして新たにデータセットを作れる。

(defparameter dataset1
  (clml.hjs.read-data:make-unspecialized-dataset
   '("class" "feat1" "feat2")
   (vector (vector "posi" 1 2)
           (vector "nega" -1 -2)
           (vector "posi" 10 20))))

データ点のベクタにラベルの文字列と数値が混在しているのが遅そう。

LIBSVMデータセットのデータを読み込む

今読んでるランダムフォレストのオンライン版の論文に出てくるものと同じデータセットで比較したいので、LIBSVMのデータセット集からmushroomsデータを読み込む。

cl-online-learningのread-data関数で読み込んでCLMLの形式に変換する。予測したい特徴の名前が後で必要になってくるので、クラスラベルに"class"と名前を付けて、残りの特徴名には単に連番を振っておく。

(defparameter *mushrooms-dim* 112)
(defparameter *mushrooms-train* (clol.utils:read-data "/home/wiz/datasets/mushrooms-train" *mushrooms-dim*))
(defparameter *mushrooms-test* (clol.utils:read-data "/home/wiz/datasets/mushrooms-test" *mushrooms-dim*))

(defun clol.datum->clml.datum (datum)
  (let ((label (if (> (car datum) 0) "posi" "nega")))
    (coerce (cons label (coerce (cdr datum) 'list)) 'vector)))

(defun clol.dataset->clml.dataset (dataset)
  (let ((datum-dim (length (cdar dataset))))
    (clml.hjs.read-data:make-unspecialized-dataset
     (cons "class" (mapcar #'(lambda (x) (format nil "~A" x)) (alexandria:iota datum-dim)))
     (map 'vector #'clol.datum->clml.datum dataset))))

(defparameter *mushrooms-train.clml* (clol.dataset->clml.dataset *mushrooms-train*))
(defparameter *mushrooms-test.clml* (clol.dataset->clml.dataset *mushrooms-test*))

ランダムフォレストの学習

lparallelを使って並列化しているので、まずカーネルサイズの設定をしておく。

(setf lparallel:*kernel* (lparallel:make-kernel 4))

上で作った *mushrooms-train.clml* データセットから学習する。予測対象のクラスラベルの特徴名をここで指定する。

(defparameter *forest* (clml.decision-tree.random-forest:make-random-forest *mushrooms-train.clml* "class"))
;; Evaluation took:
;;   81.367 seconds of real time
;;   269.523032 seconds of total run time (265.196590 user, 4.326442 system)
;;   [ Run times consist of 9.140 seconds GC time, and 260.384 seconds non-GC time. ]
;;   331.24% CPU
;;   276,010,481,282 processor cycles
;;   120,282,960,608 bytes consed

4コアCPUで計算しているが、かなり時間がかかってしまっている。

予測は、予測したいデータ点のラベル部分を "?" に置き換えたものを学習済みのモデルとともにpredict-forest関数に与える。例えば、テストセットの最初のデータ点を予測するには、

(aref (clml.hjs.read-data:dataset-points *mushrooms-test.clml*) 0)
;; #("nega" 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 1.0d0 1.0d0 0.0d0 0.0d0 0.0d0 0.0d0
;;   0.0d0 1.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 1.0d0 0.0d0 0.0d0
;;   1.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 1.0d0 1.0d0 0.0d0 0.0d0 1.0d0
;;   1.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0
;;   1.0d0 0.0d0 0.0d0 1.0d0 0.0d0 0.0d0 1.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0
;;   0.0d0 0.0d0 1.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 1.0d0 0.0d0
;;   0.0d0 1.0d0 0.0d0 0.0d0 1.0d0 0.0d0 0.0d0 1.0d0 0.0d0 0.0d0 0.0d0 1.0d0 0.0d0
;;   0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 1.0d0 0.0d0 0.0d0 0.0d0 0.0d0
;;   0.0d0 1.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 1.0d0 0.0d0 0.0d0)

(defparameter *query*
  #("?" 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 1.0d0 1.0d0 0.0d0 0.0d0 0.0d0 0.0d0
    0.0d0 1.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 1.0d0 0.0d0 0.0d0
    1.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 1.0d0 1.0d0 0.0d0 0.0d0 1.0d0
    1.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0
    1.0d0 0.0d0 0.0d0 1.0d0 0.0d0 0.0d0 1.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0
    0.0d0 0.0d0 1.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 1.0d0 0.0d0
    0.0d0 1.0d0 0.0d0 0.0d0 1.0d0 0.0d0 0.0d0 1.0d0 0.0d0 0.0d0 0.0d0 1.0d0 0.0d0
    0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 1.0d0 0.0d0 0.0d0 0.0d0 0.0d0
    0.0d0 1.0d0 0.0d0 0.0d0 0.0d0 0.0d0 0.0d0 1.0d0 0.0d0 0.0d0))

(clml.decision-tree.random-forest:predict-forest *query* *mushrooms-train.clml* *forest*)
; => "nega"

テストセットをまとめて予測するには、forest-validation関数を使う。

(clml.decision-tree.random-forest:forest-validation *mushrooms-test.clml* "class" *forest*)

;; Evaluation took:
;;   12.064 seconds of real time
;;   12.075773 seconds of total run time (12.044405 user, 0.031368 system)
;;   [ Run times consist of 0.337 seconds GC time, and 11.739 seconds non-GC time. ]
;;   100.10% CPU
;;   40,922,800,082 processor cycles
;;   10,874,892,336 bytes consed

;; ((("nega" . "posi") . 100) (("posi" . "posi") . 428) (("nega" . "nega") . 1596))
;; => Error rate: 0.04708098

この返り値は、例えば"nega"のものを"posi"と判定したケースが100件あったということを意味する。残りは正解しているので正答率は95.3%程度となる。論文では99%くらい出るとあるのでちょっと低い。なんで。Hivemailでも99%くらい出る模様。

メタパラメータの指定

メタパラメータとして指定できるものは、

  • 決定木の数 (デフォルトは500)
  • 元のデータセットから重複を許してサンプリングしたものをデータセットとして各決定木を学習するのだが、その際のサンプルサイズ (デフォルトは元のデータセット全体のサイズ)
  • 元の特徴量からサンプリングしたものを各決定木の特徴量として使うのだが、その際のサンプリングする特徴の数 (デフォルトはフルの特徴量を使うが、よく使われるのは元の特徴量のサイズの平方根)

これらを調整してより小さなランダムフォレストを作ると、

(defparameter *small-forest*
  (clml.decision-tree.random-forest:make-random-forest
   *mushrooms-train.clml* "class"
   :tree-number 100
   :data-bag-size (floor (/ (length *mushrooms-train*) 10))
   :feature-bag-size (floor (sqrt *mushrooms-dim*))))

;; Evaluation took:
;;   1.217 seconds of real time
;;   4.279771 seconds of total run time (4.191569 user, 0.088202 system)
;;   [ Run times consist of 0.141 seconds GC time, and 4.139 seconds non-GC time. ]
;;   351.68% CPU
;;   4,125,539,396 processor cycles
;;   2,282,142,096 bytes consed

;; => Error rate: 0.041902073

となってかなり高速になり、精度も若干上がった。

番外編: cl-online-learningでもやってみる

(defparameter arow-learner (clol:make-arow *mushrooms-dim* 0.1d0))
(loop repeat 10 do
  (clol:train arow-learner *mushrooms-train*)
  (clol:test arow-learner *mushrooms-test*))

;; Accuracy: 91.99623%, Correct: 1954, Total: 2124
;; Accuracy: 94.3032%, Correct: 2003, Total: 2124
;; Accuracy: 93.92655%, Correct: 1995, Total: 2124
;; Accuracy: 93.87947%, Correct: 1994, Total: 2124
;; Accuracy: 93.83239%, Correct: 1993, Total: 2124
;; Accuracy: 93.87947%, Correct: 1994, Total: 2124
;; Accuracy: 93.83239%, Correct: 1993, Total: 2124
;; Accuracy: 93.83239%, Correct: 1993, Total: 2124
;; Accuracy: 93.83239%, Correct: 1993, Total: 2124
;; Accuracy: 93.83239%, Correct: 1993, Total: 2124

;; Evaluation took:
;; 0.018 seconds of real time
;; 0.018714 seconds of total run time (0.014723 user, 0.003991 system)
;; 105.56% CPU
;; 64,026,632 processor cycles
;; 786,432 bytes consed

(defparameter scw-learner (clol:make-scw *mushrooms-dim* 0.999d0 0.001d0))
(loop repeat 10 do
  (clol:train scw-learner *mushrooms-train*)
  (clol:test scw-learner *mushrooms-test*))

;; Accuracy: 38.93597%, Correct: 827, Total: 2124
;; Accuracy: 66.00753%, Correct: 1402, Total: 2124
;; Accuracy: 69.96233%, Correct: 1486, Total: 2124
;; Accuracy: 81.40301%, Correct: 1729, Total: 2124
;; Accuracy: 88.46516%, Correct: 1879, Total: 2124
;; Accuracy: 93.36158%, Correct: 1983, Total: 2124
;; Accuracy: 96.65725%, Correct: 2053, Total: 2124
;; Accuracy: 97.834274%, Correct: 2078, Total: 2124
;; Accuracy: 98.44633%, Correct: 2091, Total: 2124
;; Accuracy: 98.06968%, Correct: 2083, Total: 2124

;; Evaluation took:
;;   0.038 seconds of real time
;;   0.037766 seconds of total run time (0.037766 user, 0.000000 system)
;;   100.00% CPU
;;   128,015,604 processor cycles
;;   2,064,384 bytes consed

このデータセットではSCW-Iが比較的良い性能を出した。過学習しがちなので早めに止める必要がある。

まとめ/感想

  • CLMLのランダムフォレストを試してみた
  • データセットが型指定しないunspecialized-datasetであり、遅い
  • あまり最適化されていないのでまだまだ速くなりそう
  • 精度も何故か良くない
  • CLMLの実装を叩き台にしてオンラインランダムフォレストを実装してみるつもり