diff --git a/CHANGELOG.md b/CHANGELOG.md index c64d5d8..2656559 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,5 @@ unreleased -* use malli to describe options +* use malli to describe model options and validate them in ml/train * renamed :other-metrices -> :other-metrics 0.12 diff --git a/src/scicloj/metamorph/ml.clj b/src/scicloj/metamorph/ml.clj index d85cbca..3a2907d 100644 --- a/src/scicloj/metamorph/ml.clj +++ b/src/scicloj/metamorph/ml.clj @@ -25,7 +25,7 @@ (exporter/export-symbols scicloj.metamorph.ml.ensemble ensemble-pipe) -(def train-predict-cache +(def train-predict-cache "Controls , if train/predict invocations are cached. if 'use-cache' is true, the get-fn and set-fn functions ar called accorddngly" (atom {:use-cache false @@ -33,7 +33,7 @@ :set-fn (fn [key value] nil)})) -(defn get-categorical-maps [ds] +(defn- get-categorical-maps [ds] (->> (ds/column-names ds) (map #(list % (-> ds (get %) meta :categorical-map))) (remove #(nil? (second %))) @@ -141,11 +141,11 @@ eval-pipeline-result-train (eval-pipeline pipeline-fn fitted-ctx metric-fn train-ds (:other-metrics tune-options)) eval-pipeline-result-test (if (-> fitted-ctx :model ::unsupervised?) - {:other-metrics [] - :timing 0 - :ctx fitted-ctx - :metric 0} - (eval-pipeline pipeline-fn fitted-ctx metric-fn test-ds (:other-metrics tune-options)))] + {:other-metrics [] + :timing 0 + :ctx fitted-ctx + :metric 0} + (eval-pipeline pipeline-fn fitted-ctx metric-fn test-ds (:other-metrics tune-options)))] @@ -198,8 +198,8 @@ (defn- evaluate-one-pipeline [pipeline-decl-or-fn train-test-split-seq metric-fn loss-or-accuracy tune-options] (let [pipeline-fn (if (fn? pipeline-decl-or-fn) - pipeline-decl-or-fn - (mm/->pipeline pipeline-decl-or-fn)) + pipeline-decl-or-fn + (mm/->pipeline pipeline-decl-or-fn)) pipeline-decl (when (sequential? pipeline-decl-or-fn) pipeline-decl-or-fn) @@ -334,8 +334,8 @@ [:ppmap-grain-size {:optional true} int?] [:evaluation-handler-fn {:optional true} fn?] [:other-metrics {:optional true} [:sequential [:map - [:name keyword?] - [:metric-fn fn?]]]] + [:name keyword?] + [:metric-fn fn?]]]] [:attach-fn-sources {:optional true} [:map [:ns any?] [:pipe-fns-clj-file string?]]]]] ::evaluation-result @@ -348,9 +348,9 @@ [:train-transform [:map {:closed true} [:other-metrics [:sequential [:map {:closed true} - [:name keyword?] - [:metric-fn fn?] - [:metric float?]]]] + [:name keyword?] + [:metric-fn fn?] + [:metric float?]]]] [:timing int?] [:metric float?] [:probability-distribution [:maybe [:fn dataset?]]] @@ -360,9 +360,9 @@ [:ctx map?]]] [:test-transform [:map {:closed true} [:other-metrics [:sequential [:map {:closed true} - [:name keyword?] - [:metric-fn fn?] - [:metric float?]]]] + [:name keyword?] + [:metric-fn fn?] + [:metric float?]]]] [:timing int?] [:metric float?] [:probability-distribution [:maybe [:fn dataset?]]] @@ -409,7 +409,7 @@ :return-best-crossvalidation-only true :evaluation-handler-fn eval-handler/default-result-dissoc-in-fn :ppmap-grain-size 10} - + options) map-fn @@ -496,9 +496,9 @@ :as opts}] (println "Register model: " model-kwd) - + (malli/model-options->full-schema opts) ;; throws on invalid malli schema for options - + (swap! model-definitions* assoc model-kwd {:train-fn train-fn :predict-fn predict-fn @@ -577,17 +577,14 @@ (let [model-options (options->model-def options) _ (when (some? (:options model-options)) (validate-options model-options options)) - - combined-hash (when (:use-cache @train-predict-cache) - (str (hash dataset) "___" (hash options)) - ) - + + combined-hash (when (:use-cache @train-predict-cache) + (str (hash dataset) "___" (hash options))) + cached (when combined-hash ((:get-fn @train-predict-cache) combined-hash))] (if cached - (do - (println :cache-hit-train! combined-hash) - cached) + cached (let [{:keys [train-fn unsupervised?]} model-options feature-ds (cf/feature dataset) _ (errors/when-not-error (> (ds/row-count feature-ds) 0) @@ -621,14 +618,12 @@ (when-not (== 0 (count cat-maps)) {:target-categorical-maps cat-maps}))] (when combined-hash - (println :cache-miss-train! combined-hash) ((:set-fn @train-predict-cache) combined-hash model)) - model)) - )) - - - + model)))) + + + (defn thaw-model "Thaw a model. Model's returned from train may be 'frozen' meaning a 'thaw' @@ -654,28 +649,36 @@ (thaw-fn (:model-data model))))) -(defn- warn-inconsitent-maps [model pred-ds] +(defn- warn-inconsistent-maps [model pred-ds] + ;; TODO revise + ;;https://github.com/scicloj/metamorph.ml/issues/35 + (let [target-cat-maps-from-train (-> model :target-categorical-maps) target-cat-maps-from-predict (-> pred-ds get-categorical-maps) simple-predicted-values (-> pred-ds cf/prediction (get (first (keys target-cat-maps-from-predict))) seq) inverse-map (-> target-cat-maps-from-predict vals first :lookup-table set/map-invert)] - (when (not (= target-cat-maps-from-predict target-cat-maps-from-train))) - ;; (println - ;; (format - ;; "target categorical maps do not match between train an predict. \n train: %s \n predict: %s " - ;; target-cat-maps-from-train target-cat-maps-from-predict)) + (when (not (= target-cat-maps-from-predict target-cat-maps-from-train)) + + ;; (throw (Exception. + ;; (format + ;; "target categorical maps do not match between train an predict. \n train: %s \n predict: %s " + ;; target-cat-maps-from-train target-cat-maps-from-predict))) + + ) (when (not (every? some? (map inverse-map - (distinct simple-predicted-values))))))) - ;; (println - ;; (format - ;; "Some predicted values are not in catetegorical map. -> Invalid predict fn. - ;; values: %s + (distinct simple-predicted-values)))) + ;; (throw (Exception. + ;; (format + ;; "Some predicted values are not in categorical map. -> Invalid predict fn. + ;; predicted values: %s ;; categorical map: %s " - ;; (vec (distinct simple-predicted-values)) - ;; (-> target-cat-maps-from-predict vals first :lookup-table))) + ;; (vec (distinct simple-predicted-values)) + ;; (-> target-cat-maps-from-predict vals first :lookup-table)))) + + ))) @@ -708,7 +711,7 @@ string in train -> string in predict categorical map in train -> equivalent categorical map in predict - ml/train passes the needed information of the rain target column to the model implementaion to do this. + ml/train passes the needed information of the train target column to the model implementaion to do this. " {:malli/schema [:=> [:cat [:fn dataset?] @@ -716,30 +719,28 @@ [:feature-columns sequential?] [:target-columns sequential?]]] - [map?]]} - [dataset {:keys [feature-columns options train-input-hash] + + [dataset {:keys [feature-columns options train-input-hash] :as model}] (let [predict-hash (when (:use-cache @train-predict-cache) (str train-input-hash "--" (hash dataset))) cached (when predict-hash ((:get-fn @train-predict-cache) predict-hash)) pred-ds (if cached - (do - (println :cache-hit-predict! predict-hash) - cached) + cached + (let [{:keys [predict-fn] :as model-def} (options->model-def options) feature-ds (ds/select-columns dataset feature-columns) thawed-model (thaw-model model model-def) pred-ds (predict-fn feature-ds thawed-model model)] - (warn-inconsitent-maps model pred-ds) - + (warn-inconsistent-maps model pred-ds) + (when predict-hash - (println :cache-miss-predict! predict-hash) - ( (:set-fn @train-predict-cache) predict-hash pred-ds)) - + ((:set-fn @train-predict-cache) predict-hash pred-ds)) + pred-ds))] pred-ds)) diff --git a/src/scicloj/metamorph/ml/classification.clj b/src/scicloj/metamorph/ml/classification.clj index 0c7b168..d009298 100644 --- a/src/scicloj/metamorph/ml/classification.clj +++ b/src/scicloj/metamorph/ml/classification.clj @@ -74,15 +74,20 @@ (defn- get-majority-class [target-ds] (let [target-column-name (first - (ds-mod/inference-target-column-names target-ds))] + (ds-mod/inference-target-column-names target-ds)) + target-column (get target-ds target-column-name) + freqs (frequencies target-column) + ] + ;; (println :target-column target-column) + ;; (println :target-column--meta (meta target-column)) + ;; (println :target-column--freq freqs) (->> - (-> target-ds (get target-column-name) frequencies) + freqs (sort-by second) reverse first first))) - (ml/define-model! :metamorph.ml/dummy-classifier (fn [feature-ds target-ds options] (let [target-column-name (first diff --git a/test/scicloj/metamorph/classification_test.clj b/test/scicloj/metamorph/classification_test.clj index 8fe1515..58fa84f 100644 --- a/test/scicloj/metamorph/classification_test.clj +++ b/test/scicloj/metamorph/classification_test.clj @@ -107,19 +107,33 @@ (is (= [:a ] (-> (ds/->dataset {:x [0]}) (ml/predict model) :y))))) -(deftest dummy-categorical +(deftest dummy-categorical-int (let [ds (-> - {:x [0] :y [:a]} + {:x [3.0] :y [:a]} (ds/->dataset) - (ds/categorical->number [:y]) + (ds/categorical->number [:y] [:a] :int) (ds-mod/set-inference-target :y)) + model (ml/train ds {:model-type :metamorph.ml/dummy-classifier + :dummy-strategy :random-class}) + prediction (ml/predict (ds/->dataset {:x [0]}) model)] + + (is (= [:a] (-> prediction (ds-cat/reverse-map-categorical-xforms) :y seq))))) + + +(deftest dummy-categorical-float32--failing + (let [ds (-> + {:x [3.0] :y [:a]} + (ds/->dataset) + (ds/categorical->number [:y] [:a] :float32) + (ds-mod/set-inference-target :y)) model (ml/train ds {:model-type :metamorph.ml/dummy-classifier - :dummy-strategy :random-class})] + :dummy-strategy :random-class}) + prediction (ml/predict (ds/->dataset {:x [0]}) model)] + (is (= [:a] (-> prediction (ds-cat/reverse-map-categorical-xforms) :y seq))))) - (is (= [:a ] (-> (ds/->dataset {:x [0]}) (ml/predict model) (ds-cat/reverse-map-categorical-xforms) :y))))) (deftest dummy-pipeline-eval (let [pipe-fn (mm/pipeline diff --git a/test/scicloj/metamorph/ml_test.clj b/test/scicloj/metamorph/ml_test.clj index 784c503..8f5ec01 100644 --- a/test/scicloj/metamorph/ml_test.clj +++ b/test/scicloj/metamorph/ml_test.clj @@ -15,29 +15,32 @@ [tech.v3.dataset.categorical :as ds-cat] [tech.v3.dataset.metamorph :as ds-mm] [tech.v3.dataset.modelling :as ds-mod] - [scicloj.metamorph.ml.tools :refer [keys-in]] - ) + [scicloj.metamorph.ml.tools :refer [keys-in]]) (:import (clojure.lang ExceptionInfo) (java.util UUID))) (def iris (tc/dataset "https://raw.githubusercontent.com/techascent/tech.ml/master/test/data/iris.csv" {:key-fn keyword})) + +(def iris-train + (-> iris + (ds-mod/set-inference-target :species))) + (def iris-target-values (-> iris :species distinct sort)) (defn do-define-model [] (ml/define-model! :test-model (fn train - [feature-ds label-ds options] + [_ _ _] {:model-data {:model-as-bytes [1 2 3] :smile-df-used [:blub]}}) (fn predict - [feature-ds thawed-model {:keys [target-columns + [feature-ds _ {:keys [target-columns target-categorical-maps]}] - (let [ - predic-col (ds/new-column :species (repeat (tc/row-count feature-ds) 1) + (let [predic-col (ds/new-column :species (repeat (tc/row-count feature-ds) 1) {:categorical-map (get target-categorical-maps (first target-columns)) :column-type :prediction}) predict-ds (ds/new-dataset [predic-col])] @@ -79,7 +82,7 @@ :metamorph/mode :transform})) (:metamorph/data))] ;; (ds-mod/column-values->categorical :species) - + (is (= (repeat 10 "versicolor") (-> predictions ds-cat/reverse-map-categorical-xforms :species seq))) @@ -103,20 +106,18 @@ (deftest evaluate-pipelines-simplest-witch-cache (do-define-model) (let [cache-map (atom {})] - - - + + + (reset! ml/train-predict-cache {:use-cache true - :get-fn (fn [key] (get @cache-map key)) - :set-fn (fn [key value] (swap! cache-map assoc key value))}) + :get-fn (fn [key] (get @cache-map key)) + :set-fn (fn [key value] (swap! cache-map assoc key value))}) (validate-simple-pipeline) (validate-simple-pipeline) (reset! ml/train-predict-cache {:use-cache false - :get-fn nil - :set-fn nil})) - - ) + :get-fn nil + :set-fn nil}))) @@ -125,9 +126,7 @@ (deftest test-explain (do-define-model) - (let [ - - pipe-fn + (let [pipe-fn (morph/pipeline (ds-mm/set-inference-target :species) (ds-mm/categorical->number (fn [ds] (cf/intersection (cf/categorical ds) (cf/target ds))) iris-target-values :int) @@ -152,9 +151,7 @@ (deftest test-data-removed (do-define-model) - (let [ - - pipe-fn + (let [pipe-fn (morph/pipeline (ds-mm/set-inference-target :species) (ds-mm/categorical->number (fn [ds] (cf/intersection (cf/categorical ds) (cf/target ds))) iris-target-values :int) @@ -180,13 +177,11 @@ - + (deftest evaluate-pipelines-several-cross (do-define-model) - (let [ - - pipe-fn + (let [pipe-fn (morph/pipeline (ds-mm/set-inference-target :species) (ds-mm/categorical->number cf/categorical iris-target-values) @@ -203,12 +198,12 @@ (ml/evaluate-pipelines pipe-fn-seq train-split-seq loss/classification-loss :loss {:return-best-crossvalidation-only false :return-best-pipeline-only false}) - + evaluations-3 (ml/evaluate-pipelines pipe-fn-seq train-split-seq loss/classification-loss :loss {:return-best-pipeline-only false})] - + (is (= 5 (count (first evaluations-1)))) @@ -222,7 +217,7 @@ (is (= 1 (count (first evaluations-3)))) (is (= 2 (count evaluations-3))))) - + @@ -230,13 +225,13 @@ (deftest evaluate-pipelines-without-model (is (thrown? Exception - (let [ ;; the data + (let [;; the data pipe-fn (morph/pipeline (ds-mm/set-inference-target :species) (ds-mm/categorical->number cf/categorical)) - + train-split-seq (tc/split->seq iris :holdout) pipe-fn-seq [pipe-fn]] @@ -263,8 +258,7 @@ :loss {:evaluation-handler-fn ;(fn [result] (def result result) result) - eval/metrics-keep-fn - })] + eval/metrics-keep-fn})] (is (= [[:train-transform] [:train-transform :metric] @@ -323,11 +317,11 @@ (deftest qualify-pipelines-test (is (= (repeat 3 :scicloj.metamorph.ml-test/do-xxx) - (qualify-pipelines [ ;; 'do-xxx + (qualify-pipelines [;; 'do-xxx ::do-xxx 'scicloj.metamorph.ml-test/do-xxx :scicloj.metamorph.ml-test/do-xxx] - (find-ns 'scicloj.metamorph.ml-test))))) + (find-ns 'scicloj.metamorph.ml-test))))) (defn fit-pipe-in-new-ns [file ds] @@ -336,7 +330,7 @@ _ (intern new-ns 'ds ds) _ (.addAlias new-ns 'morph (the-ns 'scicloj.metamorph.core)) _ (.addAlias new-ns 'nippy (the-ns 'taoensso.nippy)) - species-freqs (binding [*ns* new-ns] + species-freqs (binding [*ns* new-ns] (eval '(def thawed-result (nippy/thaw-from-file file))) @@ -363,14 +357,14 @@ [[:tech.v3.dataset.metamorph/set-inference-target [:species]] [:tech.v3.dataset.metamorph/categorical->number [:species] iris-target-values] [:tech.v3.dataset.metamorph/update-column :species :clojure.core/identity] - {:metamorph/id :model}[:scicloj.metamorph.ml/model {:model-type :test-model}]] + {:metamorph/id :model} [:scicloj.metamorph.ml/model {:model-type :test-model}]] files (atom []) nippy-handler (eval/example-nippy-handler files "/tmp" identity) - - eval-result (ml/evaluate-pipelines + + _ (ml/evaluate-pipelines [base-pipe-declr] (tc/split->seq iris) loss/classification-accuracy @@ -383,11 +377,10 @@ (deftest dissoc--all-fn (do-define-model) - (let [ - base-pipe-declrss + (let [base-pipe-declrss [[:tech.v3.dataset.metamorph/set-inference-target [:species]] [:tech.v3.dataset.metamorph/categorical->number [:species] iris-target-values] - {:metamorph/id :model}[:scicloj.metamorph.ml/model {:model-type :test-model}]] + {:metamorph/id :model} [:scicloj.metamorph.ml/model {:model-type :test-model}]] evaluation-result (ml/evaluate-pipelines @@ -397,8 +390,7 @@ :accuracy {:evaluation-handler-fn eval/result-dissoc-in-seq--all-fn})] - ;(def evaluation-result evaluation-result) - (is (= + (is (= [[:train-transform] [:train-transform :metric] [:train-transform :min] @@ -410,20 +402,19 @@ [:test-transform :max] [:test-transform :mean] [:split-uid]] - + (->> - (flatten evaluation-result) - (apply merge) - keys-in - vec))) - + (flatten evaluation-result) + (apply merge) + keys-in + vec))) + (is (pos? (-> evaluation-result first first :train-transform :metric))))) (deftest remove-all (do-define-model) - (let [ - base-pipe-declrss + (let [base-pipe-declrss [[:tech.v3.dataset.metamorph/set-inference-target [:species]] [:tech.v3.dataset.metamorph/categorical->number [:species] iris-target-values] {:metamorph/id :model} [:scicloj.metamorph.ml/model {:model-type :test-model}]] @@ -434,11 +425,11 @@ (tc/split->seq iris) loss/classification-accuracy :accuracy - {:evaluation-handler-fn (fn [result] + {:evaluation-handler-fn (fn [_] {:train-transform {:metric 1} :test-transform {:metric 1}})})] - + (is (pos? (-> evaluation-result first first :train-transform :metric))))) @@ -449,7 +440,7 @@ (let [base-pipe-declrss [[:tech.v3.dataset.metamorph/set-inference-target [:species]] [:tech.v3.dataset.metamorph/categorical->number [:species] iris-target-values] - {:metamorph/id :model}[:scicloj.metamorph.ml/model {:model-type :test-model}]] + {:metamorph/id :model} [:scicloj.metamorph.ml/model {:model-type :test-model}]] evaluation-result (ml/evaluate-pipelines @@ -457,10 +448,9 @@ (tc/split->seq iris) loss/classification-accuracy :accuracy - { - :other-metrics [{:name :acc-2 :metric-fn loss/classification-accuracy} - {:name :fscore :metric-fn (fn [truth prediction] 0)} - {:name :acc :metric-fn scicloj.metamorph.ml.metrics/accuracy}]})] + {:other-metrics [{:name :acc-2 :metric-fn loss/classification-accuracy} + {:name :fscore :metric-fn (fn [_ _] 0)} + {:name :acc :metric-fn scicloj.metamorph.ml.metrics/accuracy}]})] (is (pos? (-> evaluation-result first first :train-transform :other-metrics first :metric))) (is (zero? (-> evaluation-result first first :train-transform :other-metrics second :metric))) @@ -468,38 +458,36 @@ (deftest validate-schema - (do-define-model) - (let [ - - create-base-pipe-decl - (fn [node-size] - [[:tech.v3.dataset.metamorph/set-inference-target [:species]] - [:tech.v3.dataset.metamorph/categorical->number [:species] iris-target-values] - {:metamorph/id :model}[:scicloj.metamorph.ml/model {:model-type :test-model - :node-size node-size}]]) + (do-define-model) + (let [create-base-pipe-decl + (fn [node-size] + [[:tech.v3.dataset.metamorph/set-inference-target [:species]] + [:tech.v3.dataset.metamorph/categorical->number [:species] iris-target-values] + {:metamorph/id :model} [:scicloj.metamorph.ml/model {:model-type :test-model + :node-size node-size}]]) - pipes (map create-base-pipe-decl [1 5 10 20 50 100]) + pipes (map create-base-pipe-decl [1 5 10 20 50 100]) - split (tc/split->seq iris :holdout) + split (tc/split->seq iris :holdout) - result-schema (-> #'ml/evaluate-pipelines meta :malli/schema second :registry :scicloj.metamorph.ml/evaluation-result) + result-schema (-> #'ml/evaluate-pipelines meta :malli/schema second :registry :scicloj.metamorph.ml/evaluation-result) - evaluation-result - (ml/evaluate-pipelines - pipes split - loss/classification-accuracy - :accuracy - {:result-dissoc-in-seq [] - :return-best-crossvalidation-only false - :return-best-pipeline-only false - :attach-fn-sources {:ns (find-ns 'clojure.core) - :pipe-fns-clj-file "test/scicloj/metamorph/ml_test.clj"}})] + evaluation-result + (ml/evaluate-pipelines + pipes split + loss/classification-accuracy + :accuracy + {:result-dissoc-in-seq [] + :return-best-crossvalidation-only false + :return-best-pipeline-only false + :attach-fn-sources {:ns (find-ns 'clojure.core) + :pipe-fns-clj-file "test/scicloj/metamorph/ml_test.clj"}})] - (is (true? - (m/validate - result-schema - evaluation-result))))) + (is (true? + (m/validate + result-schema + evaluation-result))))) (deftest call-without-ds @@ -511,13 +499,10 @@ (ml/define-model! :test-model-float-predictions (fn train - [feature-ds label-ds options]) + [_ _ _]) (fn predict - [feature-ds thawed-model {:keys [target-columns - target-categorical-maps - top-k - options]}] + [feature-ds _ _] (ds/new-dataset [(ds/new-column :species (repeat (tc/row-count feature-ds) 1.0) @@ -527,12 +512,9 @@ (ml/define-model! :test-model-string-predictions (fn train - [feature-ds label-ds options]) + [_ _ _]) - (fn predict [feature-ds thawed-model {:keys [target-columns - target-categorical-maps - top-k - options]}] + (fn predict [feature-ds _ _] (ds/new-dataset [(ds/new-column :species (repeat (tc/row-count feature-ds) "pred") @@ -542,19 +524,19 @@ (deftest test-preditc-float (let [model (-> - (ds/->dataset {:x [0 1 ] :target ["x" "y"]}) + (ds/->dataset {:x [0 1] :target ["x" "y"]}) (ds-mod/set-inference-target :target) (ml/train {:model-type :test-model-float-predictions}))] (is (= [1.0] - (-> (ml/predict (ds/->dataset {:x [0]}) model) :species))))) + (-> (ml/predict (ds/->dataset {:x [0]}) model) :species))))) (deftest test-predict-striong (let [model (-> - (ds/->dataset {:x [0 1 ] :target ["x" "y"]}) + (ds/->dataset {:x [0 1] :target ["x" "y"]}) (ds-mod/set-inference-target :target) (ml/train {:model-type :test-model-string-predictions}))] @@ -568,8 +550,7 @@ (ds/new-dataset [trueth-col]) :species metric-fn - {}) - ) + {})) (defn is-accuracy [predict-col trueth-col metric-fn expected-acc] @@ -580,10 +561,9 @@ (defn- score-categorical [predict-col-seq predict-a-b-table trueth-col-seq trueth-a-b-table - metric-fn - ] - (do-score - (ds/new-column :species predict-col-seq + metric-fn] + (do-score + (ds/new-column :species predict-col-seq (when predict-a-b-table {:categorical-map {:lookup-table predict-a-b-table @@ -592,24 +572,20 @@ (when trueth-a-b-table {:categorical-map {:lookup-table trueth-a-b-table - :src-column :species}}) - ) - metric-fn -)) - -(defn is-mapped-columns-accuracy [ - predict-col-seq predict-a-b-table - trueth-col-seq trueth-a-b-table - metric-fn - expected-accuracy] - + :src-column :species}})) + metric-fn)) + +(defn is-mapped-columns-accuracy [predict-col-seq predict-a-b-table + trueth-col-seq trueth-a-b-table + metric-fn + expected-accuracy] + (is (= {:metric expected-accuracy, :other-metrics-result []} - + (score-categorical predict-col-seq predict-a-b-table trueth-col-seq trueth-a-b-table - metric-fn - )))) + metric-fn)))) (deftest test-score @@ -653,7 +629,7 @@ 1.0) - + (is-mapped-columns-accuracy [0 1] {:a 0 :b 1} [1 0] {:a 0 :b 1} loss/classification-accuracy @@ -663,10 +639,7 @@ (is (thrown? Exception (score-categorical [0.0 1.0] {:a 0.0 :b 1.0} [0 1] {:a 0 :b 1} - loss/classification-accuracy - )) - - ) + loss/classification-accuracy))) (is (thrown? Exception (score-categorical [0 1] {:a 0.0 :b 1.0} [0 1] {:a 0.0 :b 1.0} @@ -675,20 +648,40 @@ (deftest score-other-metrics - (is (= - {:metric 0.6666666666666667, - :other-metrics-result + (is (= + {:metric 0.6666666666666667, + :other-metrics-result [{:name :m-1, :metric-fn loss/classification-accuracy :metric 0.6666666666666667} {:name :m-2, :metric-fn loss/classification-loss :metric 0.33333333333333326}]} - (ml/score - (ds/->dataset {:x [:a :a :a]}) - (ds/->dataset {:x [:a :b :a]}) - :x - loss/classification-accuracy - [ - {:name :m-1 - :metric-fn loss/classification-accuracy} - {:name :m-2 - :metric-fn loss/classification-loss}])))) - \ No newline at end of file + (ml/score + (ds/->dataset {:x [:a :a :a]}) + (ds/->dataset {:x [:a :b :a]}) + :x + loss/classification-accuracy + [{:name :m-1 + :metric-fn loss/classification-accuracy} + {:name :m-2 + :metric-fn loss/classification-loss}])))) + +(deftest define-model-schema + (ml/define-model! :test-model--options + (fn train + [_ _ _]) + (fn predict + [_ _]) + {:options [:map {:closed true}]}) + + + (try + (ml/train iris-train {:model-type :test-model--options + :a 1}) + (throw (Exception.)) + (catch ExceptionInfo e + (is (= + {:a ["disallowed key"]} + (ex-data e))))) + (is (nil? + (-> + (ml/train iris-train {:model-type :test-model--options}) + :model-data))))