MGLで回帰:多次元出力

シンプルな例で多次元出力ができるかテスト

(ql:quickload :mgl-user)
(ql:quickload :clgplot)

(in-package :mgl-user)

;;; 入力1次元、出力2次元のデータ
(defparameter sin-cos-data
  (let* ((data-size 10000)
         (data (make-array data-size)))
    (loop for i from 0 to (1- data-size) do
      (let* ((x (coerce (* pi (1- (random 2.0))) 'single-float))
             (y1 (cos x))
             (y2 (sin x)))
        (setf (aref data i)
              (make-regression-datum :id i
                                     :target (make-mat 2 :initial-contents (list y1 y2))
                                     :array  (make-mat 1 :initial-contents (list x))))))
    data))

;; モデル定義
(defparameter fnn
  (build-fnn (:class 'regression-fnn :max-n-stripes 100)
    ;; Input Layer
    (inputs (->input :size 1))
    (f1-activations (->activation inputs :name 'f1 :size 256))
    (f1 (->relu f1-activations))
    (f2-activations (->activation f1 :name 'f2 :size 256))
    (f2 (->relu f2-activations))
    (prediction-activations (->activation f2 :name 'prediction :size 2))
    ;; Output Lump: ->squared-difference
    (prediction (->loss (->squared-difference (activations-output prediction-activations)
                                              (->input :name 'targets :size 2))
                        :name 'prediction))))

;; 訓練
(train-regression-fnn-process-with-monitor fnn sin-cos-data :n-epochs 100)

;; 予測をプロット
(let* ((x-list (wiz:seq (- (- pi) 1) (+ pi 1) :by 0.1))
       (result
        (loop for x in x-list collect
          (let ((result-mat (predict-regression-datum
                             fnn
                             (make-regression-datum :id 0
                                                    :target (make-mat 2 :initial-element 1.0)
                                                    :array (make-mat 1 :initial-element x)))))
            (list (mref result-mat 0)
                  (mref result-mat 1))))))
  (clgp:plot-lists (list (mapcar #'car result)
                         (mapcar #'cadr result)
                         (mapcar #'sin x-list)
                         (mapcar #'cos x-list))
                   :x-lists (list x-list x-list x-list x-list)
                   :title-list '("prediction-1dim" "prediction-2dim" "sin(x)" "cos(x)")))



普通に学習できてる。
[-π,π]のデータで学習させているのでその範囲からはみ出すとズレる。