GNOEM: Gauss Net with On-line EM algorithm

正規化ガウス関数ネットワーク(Normalized Gaussian Networks)は正規分布による線形回帰ユニットを組み合わせて非線形回帰をするモデルだ。論文によってはほとんど同じモデルが確率的ニューラルネットワークと呼ばれていることもある。修士のときにこの論文(On-line EM Algorithm for the Normalized Gaussian Network http://www.mitpressjournals.org/doi/abs/10.1162/089976600300015853)アルゴリズムを実装をしたことがある。このプログラムを再び使う機会があったので、ついでにASDFでパッケージ化してgithubに公開してみた。SBCL 1.0.34 (FreeBSD 8.0 Release/amd64)で動作確認している。

http://github.com/masatoi/gnoem

インストール

wiz-util(http://github.com/masatoi/wiz-util)に依存している。適当なディレクトリで

$ git clone git://github.com/masatoi/wiz-util.git
$ git clone git://github.com/masatoi/gnoem.git

して、asdファイルをASDF:*CENTRAL-REGISTRY*に登録されているディレクトリにリンクを張った後、LISP処理系で

(asdf:oos 'asdf:load-op :gnoem)

とすればコンパイルしてくれる。

簡単なサンプル

学習データの準備

まずライブラリをロードして、近似対象の関数を定義する。

;; GNOEMをロード
(asdf:oos 'asdf:load-op :gnoem)

(in-package :gnoem)

;; 近似対象の関数
(defun schaal-function (x1 x2)
  (max (exp (* -10d0 x1 x1))
       (exp (* -50d0 x2 x2))
       (* 1.25d0 (exp (* -5d0 (+ (* x1 x1)
				 (* x2 x2)))))))

これは入力が2次元で出力が1次元の関数。

上から見た図。

ここからランダムに1000個ほどデータをサンプリングする。

;; 入出力データを1000個用意
(defparameter x-data
  (loop repeat 1000 collect
       (make-vector 2 :initial-contents (list (1- (random 2d0)) (1- (random 2d0))))))

(defparameter y-data
  (mapcar (lambda (x_in)
	    (schaal-function (aref x_in 0 0) (aref x_in 1 0)))
	  x-data))
初期設定

データを用意したら、make-initialized-sessionでセッションをつくる。セッションはパラメータや統計量などが含まれる構造体。

;; 新規にセッションをつくる
;; (make-initialized-session ユニット数 入力次元 出力次元)
(defparameter init-session
  (make-initialized-session 300 2 1 :lambda-factor 0.99d0 :alpha 0.1d0))
学習

学習はオンラインEMアルゴリズムによって行う。oem-data-listでさっきサンプリングした入出力データのリストから一つずつ学習していく。

;; リストに格納されているデータから学習
(defparameter learned-session (oem-data-list x-data y-data init-session))

もちろんオンライン学習なので、一組のデータのみから1ステップずつ学習することもできる。

(defparameter learned-session (oem-1step (car x-data) (car y-data) init-session))
予測

学習によって得られたlearned-sessionを元に、未知入力に対する出力を予測する。

;; 予測
(prediction (make-vector 2 :initial-contents (list 0.5d0 0.5d0)) learned-session)
;; => 0.14885842827023427d0

定義域全体で予測をした結果をプロットすると、次の図になる。

データのリストを2巡したものが次の図。

対数尤度が収束したあたりの図が次になる。ここまでくると若干オーバーフィッティングしている。make-initialized-sessionのときに指定したユニット数が多すぎるのかも。

データ一組当たりの対数尤度は次のようになる。