diff --git a/CHANGELOG.md b/CHANGELOG.md index c7a5f2b..f7db275 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ + +unreleased +* added :target-datatypes in train result and clarified expected 'shape' of prediction + 0.10.3 * re-added data.json dependency diff --git a/src/scicloj/metamorph/ml.clj b/src/scicloj/metamorph/ml.clj index 2051146..731413e 100644 --- a/src/scicloj/metamorph/ml.clj +++ b/src/scicloj/metamorph/ml.clj @@ -588,7 +588,17 @@ * `:options` - the options passed in. * `:id` - new randomly generated UUID. * `:feature-columns` - vector of column names. - * `:target-columns` - vector of column names." + * `:target-columns` - vector of column names. + * `:target-datatypes` - map of target columns names -> target columns type + * `:target-categorical-maps` - the categorical maps of the target columns, if present + + A well behaving model implementaion should use + :target-column + :target-datatypes + :target-categorical-maps + + to construct its prediction dataset so that its matches with the train data target column. + " {:malli/schema [:=> [:cat [:fn dataset?] map?] [map?]]} [dataset options] @@ -605,6 +615,14 @@ model-data (train-fn feature-ds target-ds options) ;; _ (errors/when-not-error (:model-as-bytes model-data) "train-fn need to return a map with key :model-as-bytes") + + targets-datatypes + (zipmap + (keys target-ds) + (->> + (vals target-ds) + (map meta) + (map :datatype))) cat-maps (ds-mod/dataset->categorical-xforms target-ds)] (merge @@ -612,7 +630,8 @@ :options options :id (UUID/randomUUID) :feature-columns (vec (ds/column-names feature-ds)) - :target-columns (vec (ds/column-names target-ds))} + :target-columns (vec (ds/column-names target-ds)) + :target-datatypes targets-datatypes} (when-not (== 0 (count cat-maps)) {:target-categorical-maps cat-maps})))) @@ -671,7 +690,15 @@ * For regression, a single column dataset is returned with the column named after the target * For classification, a dataset is returned with a float64 column for each target - value and values that describe the probability distribution." + value and values that describe the probability distribution. + + Each implementing model should construct its prediction in a shape expressed by + :target-column + :target-datatypes + :target-categorical-maps + + it is receiving. + " {:malli/schema [:=> [:cat [:fn dataset?] [:map [:options map?] [:feature-columns sequential?] diff --git a/test/scicloj/metamorph/classification_test.clj b/test/scicloj/metamorph/classification_test.clj index 3e8a333..ec54e09 100644 --- a/test/scicloj/metamorph/classification_test.clj +++ b/test/scicloj/metamorph/classification_test.clj @@ -47,16 +47,11 @@ model (ml/train ds {:model-type :metamorph.ml/dummy-classifier :dummy-strategy :fixed-class :fixed-class 0}) - - - prediction (ml/predict ds model)] - + (is (= {:species :int16} (:target-datatypes model))) (is (= (:species prediction) (repeat 150 0))))) - - (deftest dummy-classification-majority [] (let [ds (toydata/breast-cancer-ds) model (ml/train ds {:model-type :metamorph.ml/dummy-classifier @@ -64,7 +59,7 @@ prediction (ml/predict ds model)] - + (is (= (:class prediction) (repeat 569 0)))))