Skip to content

Commit

Permalink
added :taregt-datatypes
Browse files Browse the repository at this point in the history
fixes #26
  • Loading branch information
behrica committed Nov 24, 2024
1 parent d5d9055 commit 5592274
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 10 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@

unreleased
* added :target-datatypes in train result and clarified expected 'shape' of prediction

0.10.3
* re-added data.json dependency

Expand Down
33 changes: 30 additions & 3 deletions src/scicloj/metamorph/ml.clj
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,17 @@
* `:options` - the options passed in.
* `:id` - new randomly generated UUID.
* `:feature-columns` - vector of column names.
* `:target-columns` - vector of column names."
* `:target-columns` - vector of column names.
* `:target-datatypes` - map of target columns names -> target columns type
* `:target-categorical-maps` - the categorical maps of the target columns, if present
A well behaving model implementaion should use
:target-column
:target-datatypes
:target-categorical-maps
to construct its prediction dataset so that its matches with the train data target column.
"
{:malli/schema [:=> [:cat [:fn dataset?] map?]
[map?]]}
[dataset options]
Expand All @@ -605,14 +615,23 @@

model-data (train-fn feature-ds target-ds options)
;; _ (errors/when-not-error (:model-as-bytes model-data) "train-fn need to return a map with key :model-as-bytes")

targets-datatypes
(zipmap
(keys target-ds)
(->>
(vals target-ds)
(map meta)
(map :datatype)))
cat-maps (ds-mod/dataset->categorical-xforms target-ds)]

(merge
{:model-data model-data
:options options
:id (UUID/randomUUID)
:feature-columns (vec (ds/column-names feature-ds))
:target-columns (vec (ds/column-names target-ds))}
:target-columns (vec (ds/column-names target-ds))
:target-datatypes targets-datatypes}
(when-not (== 0 (count cat-maps))
{:target-categorical-maps cat-maps}))))

Expand Down Expand Up @@ -671,7 +690,15 @@
* For regression, a single column dataset is returned with the column named after the
target
* For classification, a dataset is returned with a float64 column for each target
value and values that describe the probability distribution."
value and values that describe the probability distribution.
Each implementing model should construct its prediction in a shape expressed by
:target-column
:target-datatypes
:target-categorical-maps
it is receiving.
"
{:malli/schema [:=> [:cat [:fn dataset?]
[:map [:options map?]
[:feature-columns sequential?]
Expand Down
9 changes: 2 additions & 7 deletions test/scicloj/metamorph/classification_test.clj
Original file line number Diff line number Diff line change
Expand Up @@ -47,24 +47,19 @@
model (ml/train ds {:model-type :metamorph.ml/dummy-classifier
:dummy-strategy :fixed-class
:fixed-class 0})



prediction (ml/predict ds model)]


(is (= {:species :int16} (:target-datatypes model)))
(is (= (:species prediction) (repeat 150 0)))))



(deftest dummy-classification-majority []
(let [ds (toydata/breast-cancer-ds)
model (ml/train ds {:model-type :metamorph.ml/dummy-classifier
:dummy-strategy :majority-class})


prediction (ml/predict ds model)]

(is (= (:class prediction) (repeat 569 0)))))


Expand Down

0 comments on commit 5592274

Please sign in to comment.