Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
behrica committed Feb 2, 2025
1 parent 79a3d1a commit 984c9dc
Show file tree
Hide file tree
Showing 5 changed files with 217 additions and 204 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
unreleased
* use malli to describe options
* use malli to describe model options and validate them in ml/train
* renamed :other-metrices -> :other-metrics

0.12
Expand Down
117 changes: 59 additions & 58 deletions src/scicloj/metamorph/ml.clj
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,15 @@

(exporter/export-symbols scicloj.metamorph.ml.ensemble ensemble-pipe)

(def train-predict-cache
(def train-predict-cache
"Controls , if train/predict invocations are cached.
if 'use-cache' is true, the get-fn and set-fn functions ar called accorddngly"
(atom {:use-cache false
:get-fn (fn [key] nil)
:set-fn (fn [key value] nil)}))


(defn get-categorical-maps [ds]
(defn- get-categorical-maps [ds]
(->> (ds/column-names ds)
(map #(list % (-> ds (get %) meta :categorical-map)))
(remove #(nil? (second %)))
Expand Down Expand Up @@ -141,11 +141,11 @@

eval-pipeline-result-train (eval-pipeline pipeline-fn fitted-ctx metric-fn train-ds (:other-metrics tune-options))
eval-pipeline-result-test (if (-> fitted-ctx :model ::unsupervised?)
{:other-metrics []
:timing 0
:ctx fitted-ctx
:metric 0}
(eval-pipeline pipeline-fn fitted-ctx metric-fn test-ds (:other-metrics tune-options)))]
{:other-metrics []
:timing 0
:ctx fitted-ctx
:metric 0}
(eval-pipeline pipeline-fn fitted-ctx metric-fn test-ds (:other-metrics tune-options)))]



Expand Down Expand Up @@ -198,8 +198,8 @@
(defn- evaluate-one-pipeline [pipeline-decl-or-fn train-test-split-seq metric-fn loss-or-accuracy tune-options]

(let [pipeline-fn (if (fn? pipeline-decl-or-fn)
pipeline-decl-or-fn
(mm/->pipeline pipeline-decl-or-fn))
pipeline-decl-or-fn
(mm/->pipeline pipeline-decl-or-fn))
pipeline-decl (when (sequential? pipeline-decl-or-fn)
pipeline-decl-or-fn)

Expand Down Expand Up @@ -334,8 +334,8 @@
[:ppmap-grain-size {:optional true} int?]
[:evaluation-handler-fn {:optional true} fn?]
[:other-metrics {:optional true} [:sequential [:map
[:name keyword?]
[:metric-fn fn?]]]]
[:name keyword?]
[:metric-fn fn?]]]]
[:attach-fn-sources {:optional true} [:map [:ns any?]
[:pipe-fns-clj-file string?]]]]]
::evaluation-result
Expand All @@ -348,9 +348,9 @@

[:train-transform [:map {:closed true}
[:other-metrics [:sequential [:map {:closed true}
[:name keyword?]
[:metric-fn fn?]
[:metric float?]]]]
[:name keyword?]
[:metric-fn fn?]
[:metric float?]]]]
[:timing int?]
[:metric float?]
[:probability-distribution [:maybe [:fn dataset?]]]
Expand All @@ -360,9 +360,9 @@
[:ctx map?]]]
[:test-transform [:map {:closed true}
[:other-metrics [:sequential [:map {:closed true}
[:name keyword?]
[:metric-fn fn?]
[:metric float?]]]]
[:name keyword?]
[:metric-fn fn?]
[:metric float?]]]]
[:timing int?]
[:metric float?]
[:probability-distribution [:maybe [:fn dataset?]]]
Expand Down Expand Up @@ -409,7 +409,7 @@
:return-best-crossvalidation-only true
:evaluation-handler-fn eval-handler/default-result-dissoc-in-fn
:ppmap-grain-size 10}


options)
map-fn
Expand Down Expand Up @@ -496,9 +496,9 @@
:as opts}]

(println "Register model: " model-kwd)

(malli/model-options->full-schema opts) ;; throws on invalid malli schema for options


(swap! model-definitions* assoc model-kwd {:train-fn train-fn
:predict-fn predict-fn
Expand Down Expand Up @@ -577,17 +577,14 @@
(let [model-options (options->model-def options)
_ (when (some? (:options model-options))
(validate-options model-options options))

combined-hash (when (:use-cache @train-predict-cache)
(str (hash dataset) "___" (hash options))
)


combined-hash (when (:use-cache @train-predict-cache)
(str (hash dataset) "___" (hash options)))

cached (when combined-hash ((:get-fn @train-predict-cache) combined-hash))]

(if cached
(do
(println :cache-hit-train! combined-hash)
cached)
cached
(let [{:keys [train-fn unsupervised?]} model-options
feature-ds (cf/feature dataset)
_ (errors/when-not-error (> (ds/row-count feature-ds) 0)
Expand Down Expand Up @@ -621,14 +618,12 @@
(when-not (== 0 (count cat-maps))
{:target-categorical-maps cat-maps}))]
(when combined-hash
(println :cache-miss-train! combined-hash)
((:set-fn @train-predict-cache) combined-hash model))

model))
))



model))))




(defn thaw-model
"Thaw a model. Model's returned from train may be 'frozen' meaning a 'thaw'
Expand All @@ -654,28 +649,36 @@
(thaw-fn (:model-data model)))))


(defn- warn-inconsitent-maps [model pred-ds]
(defn- warn-inconsistent-maps [model pred-ds]
;; TODO revise
;;https://github.com/scicloj/metamorph.ml/issues/35


(let [target-cat-maps-from-train (-> model :target-categorical-maps)
target-cat-maps-from-predict (-> pred-ds get-categorical-maps)
simple-predicted-values (-> pred-ds cf/prediction (get (first (keys target-cat-maps-from-predict))) seq)
inverse-map (-> target-cat-maps-from-predict vals first :lookup-table set/map-invert)]
(when (not (= target-cat-maps-from-predict target-cat-maps-from-train)))
;; (println
;; (format
;; "target categorical maps do not match between train an predict. \n train: %s \n predict: %s "
;; target-cat-maps-from-train target-cat-maps-from-predict))
(when (not (= target-cat-maps-from-predict target-cat-maps-from-train))

;; (throw (Exception.
;; (format
;; "target categorical maps do not match between train an predict. \n train: %s \n predict: %s "
;; target-cat-maps-from-train target-cat-maps-from-predict)))

)

(when (not (every? some?
(map inverse-map
(distinct simple-predicted-values)))))))
;; (println
;; (format
;; "Some predicted values are not in catetegorical map. -> Invalid predict fn.
;; values: %s
(distinct simple-predicted-values))))
;; (throw (Exception.
;; (format
;; "Some predicted values are not in categorical map. -> Invalid predict fn.
;; predicted values: %s
;; categorical map: %s "
;; (vec (distinct simple-predicted-values))
;; (-> target-cat-maps-from-predict vals first :lookup-table)))
;; (vec (distinct simple-predicted-values))
;; (-> target-cat-maps-from-predict vals first :lookup-table))))

)))



Expand Down Expand Up @@ -708,38 +711,36 @@
string in train -> string in predict
categorical map in train -> equivalent categorical map in predict
ml/train passes the needed information of the rain target column to the model implementaion to do this.
ml/train passes the needed information of the train target column to the model implementaion to do this.
"
{:malli/schema [:=> [:cat [:fn dataset?]
[:map [:options map?]
[:feature-columns sequential?]
[:target-columns sequential?]]]


[map?]]}
[dataset {:keys [feature-columns options train-input-hash]

[dataset {:keys [feature-columns options train-input-hash]
:as model}]
(let [predict-hash (when (:use-cache @train-predict-cache) (str train-input-hash "--" (hash dataset)))
cached (when predict-hash ((:get-fn @train-predict-cache) predict-hash))

pred-ds
(if cached
(do
(println :cache-hit-predict! predict-hash)
cached)
cached

(let [{:keys [predict-fn] :as model-def} (options->model-def options)
feature-ds (ds/select-columns dataset feature-columns)
thawed-model (thaw-model model model-def)
pred-ds (predict-fn feature-ds
thawed-model
model)]
(warn-inconsitent-maps model pred-ds)
(warn-inconsistent-maps model pred-ds)

(when predict-hash
(println :cache-miss-predict! predict-hash)
( (:set-fn @train-predict-cache) predict-hash pred-ds))

((:set-fn @train-predict-cache) predict-hash pred-ds))

pred-ds))]
pred-ds))

Expand Down
11 changes: 8 additions & 3 deletions src/scicloj/metamorph/ml/classification.clj
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,20 @@

(defn- get-majority-class [target-ds]
(let [target-column-name (first
(ds-mod/inference-target-column-names target-ds))]
(ds-mod/inference-target-column-names target-ds))
target-column (get target-ds target-column-name)
freqs (frequencies target-column)
]
;; (println :target-column target-column)
;; (println :target-column--meta (meta target-column))
;; (println :target-column--freq freqs)
(->>
(-> target-ds (get target-column-name) frequencies)
freqs
(sort-by second)
reverse
first
first)))


(ml/define-model! :metamorph.ml/dummy-classifier
(fn [feature-ds target-ds options]
(let [target-column-name (first
Expand Down
24 changes: 19 additions & 5 deletions test/scicloj/metamorph/classification_test.clj
Original file line number Diff line number Diff line change
Expand Up @@ -107,19 +107,33 @@
(is (= [:a ] (-> (ds/->dataset {:x [0]}) (ml/predict model) :y)))))


(deftest dummy-categorical
(deftest dummy-categorical-int

(let [ds (->
{:x [0] :y [:a]}
{:x [3.0] :y [:a]}
(ds/->dataset)
(ds/categorical->number [:y])
(ds/categorical->number [:y] [:a] :int)
(ds-mod/set-inference-target :y))
model (ml/train ds {:model-type :metamorph.ml/dummy-classifier
:dummy-strategy :random-class})
prediction (ml/predict (ds/->dataset {:x [0]}) model)]

(is (= [:a] (-> prediction (ds-cat/reverse-map-categorical-xforms) :y seq)))))


(deftest dummy-categorical-float32--failing

(let [ds (->
{:x [3.0] :y [:a]}
(ds/->dataset)
(ds/categorical->number [:y] [:a] :float32)
(ds-mod/set-inference-target :y))
model (ml/train ds {:model-type :metamorph.ml/dummy-classifier
:dummy-strategy :random-class})]
:dummy-strategy :random-class})
prediction (ml/predict (ds/->dataset {:x [0]}) model)]

(is (= [:a] (-> prediction (ds-cat/reverse-map-categorical-xforms) :y seq)))))

(is (= [:a ] (-> (ds/->dataset {:x [0]}) (ml/predict model) (ds-cat/reverse-map-categorical-xforms) :y)))))

(deftest dummy-pipeline-eval
(let [pipe-fn (mm/pipeline
Expand Down
Loading

0 comments on commit 984c9dc

Please sign in to comment.