From 25bba9227f655b468fdd30223cc0739d1b013adb Mon Sep 17 00:00:00 2001 From: HyunjunA <39776968+HyunjunA@users.noreply.github.com> Date: Tue, 12 Mar 2024 16:44:07 -0700 Subject: [PATCH] Ml backend handling extremely imbalanced and small dataset (#646) * working on sidemenu * replace Aliro with AliroEd. Update footer text "Developed by the Center for AI Research and Education (CAIRE) in the Department of Computational Biomedicine at Cedars-Sinai Medical Center in Los Angeles, California, USA." * fixed typo. updated the broken download link. * Temporary fix to handle NaN values * Addressed issues with an extremely imbalanced and small dataset by removing NaN values from the metrics. (This is a temporary fix.) * remove unnecessary comments --- .../src/components/Results/index-twoside.jsx | 416 ++--- lab/webapp/src/components/Results/index.jsx | 2 +- .../ResultsV2/components/Score/index.jsx | 2 + lab/webapp/src/components/ResultsV2/index.jsx | 1645 +++++++++-------- machine/learn/skl_utils.py | 118 +- 5 files changed, 1152 insertions(+), 1031 deletions(-) diff --git a/lab/webapp/src/components/Results/index-twoside.jsx b/lab/webapp/src/components/Results/index-twoside.jsx index 6c31c04be..99ab33729 100644 --- a/lab/webapp/src/components/Results/index-twoside.jsx +++ b/lab/webapp/src/components/Results/index-twoside.jsx @@ -28,34 +28,34 @@ along with this program. If not, see . (Autogenerated header, do not modify) */ -import React, { Component } from 'react'; -import { connect } from 'react-redux'; -import * as actions from 'data/experiments/selected/actions'; -import SceneHeader from '../SceneHeader'; -import FetchError from '../FetchError'; -import AlgorithmDetails from './components/AlgorithmDetails'; -import RunDetails from './components/RunDetails' -import MSEMAEDetails from './components/MSEMAEDetails';; -import ConfusionMatrix from './components/ConfusionMatrix'; -import ConfusionMatrixJSON from './components/ConfusionMatrixJSON'; -import ROCCurve from './components/ROCCurve'; -import ShapSummaryCurve from './components/ShapSummaryCurve'; -import ImportanceScore from './components/ImportanceScore'; -import ImportanceScoreJSON from './components/ImportanceScoreJSON'; -import LearningCurve from './components/LearningCurve'; -import LearningCurveJSON from './components/LearningCurveJSON'; -import TestChart from './components/TestChart'; -import PCA from './components/PCA'; -import PCAJSON from './components/PCAJSON'; -import TSNE from './components/TSNE'; -import TSNEJSON from './components/TSNEJSON'; -import RegFigure from './components/RegFigure'; -import Score from './components/Score'; -import NoScore from './components/NoScore'; -import { Header, Grid, Loader, Dropdown, Menu} from 'semantic-ui-react'; -import { formatDataset } from 'utils/formatter'; -import ClassRate from './components/ClassRate'; -import ChatGPT from '../ChatGPT'; +import React, { Component } from "react"; +import { connect } from "react-redux"; +import * as actions from "data/experiments/selected/actions"; +import SceneHeader from "../SceneHeader"; +import FetchError from "../FetchError"; +import AlgorithmDetails from "./components/AlgorithmDetails"; +import RunDetails from "./components/RunDetails"; +import MSEMAEDetails from "./components/MSEMAEDetails"; +import ConfusionMatrix from "./components/ConfusionMatrix"; +import ConfusionMatrixJSON from "./components/ConfusionMatrixJSON"; +import ROCCurve from "./components/ROCCurve"; +import ShapSummaryCurve from "./components/ShapSummaryCurve"; +import ImportanceScore from "./components/ImportanceScore"; +import ImportanceScoreJSON from "./components/ImportanceScoreJSON"; +import LearningCurve from "./components/LearningCurve"; +import LearningCurveJSON from "./components/LearningCurveJSON"; +import TestChart from "./components/TestChart"; +import PCA from "./components/PCA"; +import PCAJSON from "./components/PCAJSON"; +import TSNE from "./components/TSNE"; +import TSNEJSON from "./components/TSNEJSON"; +import RegFigure from "./components/RegFigure"; +import Score from "./components/Score"; +import NoScore from "./components/NoScore"; +import { Header, Grid, Loader, Dropdown, Menu } from "semantic-ui-react"; +import { formatDataset } from "utils/formatter"; +import ClassRate from "./components/ClassRate"; +import ChatGPT from "../ChatGPT"; class Results extends Component { constructor(props) { @@ -72,11 +72,11 @@ class Results extends Component { } /** - * Basic helped method to create array containing [key,val] entries where - * key - name of given score - * value - actual score - * passed to Score component which uses javascript library C3 to create graphic - */ + * Basic helped method to create array containing [key,val] entries where + * key - name of given score + * value - actual score + * passed to Score component which uses javascript library C3 to create graphic + */ // async getData(filename){ // const res = await fetch(filename); @@ -90,22 +90,23 @@ class Results extends Component { let expScores = experiment.data.scores; // console.log("experiment.data") - console.log("experiment.data",experiment.data) + console.log("experiment-777.data", experiment.data); // console.log(experiment.data['class_1'][0]) // console.log(experiment.data['class_-1'][0]) - if(typeof(expScores) === 'object'){ - keyList.forEach(scoreKey => { - if(expScores[scoreKey] && typeof expScores[scoreKey].toFixed === 'function'){ + if (typeof expScores === "object") { + keyList.forEach((scoreKey) => { + if ( + expScores[scoreKey] && + typeof expScores[scoreKey].toFixed === "function" + ) { let tempLabel; - scoreKey.includes('train') - ? tempLabel = 'Train (' + expScores[scoreKey].toFixed(2) + ')' - : tempLabel = 'Test (' + expScores[scoreKey].toFixed(2) + ')'; - testList.push( - [tempLabel, expScores[scoreKey]] - ); + scoreKey.includes("train") + ? (tempLabel = "Train (" + expScores[scoreKey].toFixed(2) + ")") + : (tempLabel = "Test (" + expScores[scoreKey].toFixed(2) + ")"); + testList.push([tempLabel, expScores[scoreKey]]); } - }); + }); } return testList; @@ -114,19 +115,15 @@ class Results extends Component { render() { const { experiment, fetchExperiment } = this.props; - if(experiment.isFetching || !experiment.data) { + if (experiment.isFetching || !experiment.data) { return ( ); } - if(experiment.error === 'Failed to fetch') { - return ( - - ); - } else if(experiment.error) { + if (experiment.error === "Failed to fetch") { + return ; + } else if (experiment.error) { return ( { fetch(`/api/v1/experiments/${id}/model`) - .then(response => { - if(response.status >= 400) { + .then((response) => { + if (response.status >= 400) { throw new Error(`${response.status}: ${response.statusText}`); } return response.json(); }) - .then(json => { + .then((json) => { window.location = `/api/v1/files/${json._id}`; }); }; const downloadScript = (id) => { fetch(`/api/v1/experiments/${id}/script`) - .then(response => { - if(response.status >= 400) { + .then((response) => { + if (response.status >= 400) { throw new Error(`${response.status}: ${response.statusText}`); } return response.json(); }) - .then(json => { + .then((json) => { window.location = `/api/v1/files/${json._id}`; }); }; // console.log(experiment.data.prediction_type) // --- get lists of scores --- - if(experiment.data.prediction_type == "classification") { // classification + if (experiment.data.prediction_type == "classification") { + // classification - console.log("experiment.data", experiment.data) + console.log("experiment.data", experiment.data); // console.log("X_pca", experiment.data.X_pca) // console.log("y_pca", experiment.data.y_pca) - let confusionMatrix, rocCurve, importanceScore, learningCurve, pca, pca_json, tsne, tsne_json, shap_explainer, shap_num_samples; - + let confusionMatrix, + rocCurve, + importanceScore, + learningCurve, + pca, + pca_json, + tsne, + tsne_json, + shap_explainer, + shap_num_samples; + let shapSummaryCurveDict = {}; - - experiment.data.experiment_files.forEach(file => { + experiment.data.experiment_files.forEach((file) => { const filename = file.filename; - console.log('filename',filename); - if(filename.includes('confusion_matrix')) { + console.log("filename", filename); + if (filename.includes("confusion_matrix")) { confusionMatrix = file; - } else if(filename.includes('roc_curve')) { + } else if (filename.includes("roc_curve")) { rocCurve = file; - } else if(filename.includes('imp_score')) { + } else if (filename.includes("imp_score")) { importanceScore = file; - } else if(filename.includes('learning_curve')) { + } else if (filename.includes("learning_curve")) { learningCurve = file; - - - } else if(filename.includes('pca') && filename.includes('png')) { + } else if (filename.includes("pca") && filename.includes("png")) { pca = file; - console.log("pca", pca) - } else if (filename.includes('pca-json')) { - console.log("pca_json") + console.log("pca", pca); + } else if (filename.includes("pca-json")) { + console.log("pca_json"); pca_json = file; - console.log("pca_json: ", pca_json) - } - - else if(filename.includes('tsne') && filename.includes('png')) { + console.log("pca_json: ", pca_json); + } else if (filename.includes("tsne") && filename.includes("png")) { tsne = file; - console.log("tsne", tsne) - - } - else if (filename.includes('tsne-json')) { - console.log("tsne_json") + console.log("tsne", tsne); + } else if (filename.includes("tsne-json")) { + console.log("tsne_json"); tsne_json = file; - console.log("tsne_json: ", tsne_json) - } - - else if(filename.includes('shap_summary_curve')) { - console.log("shap_summary_curve") - let class_name = filename.split('_').slice(-2,-1); + console.log("tsne_json: ", tsne_json); + } else if (filename.includes("shap_summary_curve")) { + console.log("shap_summary_curve"); + let class_name = filename.split("_").slice(-2, -1); shapSummaryCurveDict[class_name] = file; - shap_explainer=experiment.data.shap_explainer; - shap_num_samples=experiment.data.shap_num_samples; - } - else if (filename.includes('shap_summary_json')) { - console.log("shap_summary_json") + shap_explainer = experiment.data.shap_explainer; + shap_num_samples = experiment.data.shap_num_samples; + } else if (filename.includes("shap_summary_json")) { + console.log("shap_summary_json"); // shap_json = file; // console.log("shap_json: ", shap_json) } - }); // balanced accuracy - let balancedAccKeys = ['train_balanced_accuracy_score', 'balanced_accuracy_score']; + let balancedAccKeys = [ + "train_balanced_accuracy_score", + "balanced_accuracy_score", + ]; // precision scores - let precisionKeys = ['train_precision_score', 'precision_score'] + let precisionKeys = ["train_precision_score", "precision_score"]; // AUC - let aucKeys = ['train_roc_auc_score', 'roc_auc_score']; + let aucKeys = ["train_roc_auc_score", "roc_auc_score"]; // f1 score - let f1Keys = ['train_f1_score', 'f1_score']; + let f1Keys = ["train_f1_score", "f1_score"]; // recall - let recallKeys = ['train_recall_score', 'recall_score']; + let recallKeys = ["train_recall_score", "recall_score"]; let balancedAccList = this.getGaugeArray(balancedAccKeys); let precisionList = this.getGaugeArray(precisionKeys); @@ -240,67 +239,62 @@ class Results extends Component { let class_percentage = []; // let pca_data = []; - - - - experiment.data.class_names.forEach(eachclass => { - - console.log('eachclass.toString()', eachclass.toString()) + experiment.data.class_names.forEach((eachclass) => { + console.log("eachclass.toString()", eachclass.toString()); // if type of experiment.data['class_' + eachclass.toString()] === 'object' - if ((typeof experiment.data['class_' + eachclass.toString()]) === 'object') - { - class_percentage.push( - [eachclass.toString(), experiment.data['class_' + eachclass.toString()][0]] - ); - console.log("experiment.data['class_1']", experiment.data['class_1']) + if ( + typeof experiment.data["class_" + eachclass.toString()] === "object" + ) { + class_percentage.push([ + eachclass.toString(), + experiment.data["class_" + eachclass.toString()][0], + ]); + console.log("experiment.data['class_1']", experiment.data["class_1"]); + } else { + class_percentage.push([ + eachclass.toString(), + experiment.data["class_" + eachclass.toString()], + ]); + console.log("experiment.data['class_1']", experiment.data["class_1"]); } - else - { - class_percentage.push( - [eachclass.toString(), experiment.data['class_' + eachclass.toString()]] - ); - console.log("experiment.data['class_1']", experiment.data['class_1']) - } - - - }); - - - return ( - -
- + - downloadModel(experiment.data._id)} - />, - downloadScript(experiment.data._id)} - /> + downloadModel(experiment.data._id)} + /> + , + downloadScript(experiment.data._id)} + /> @@ -308,8 +302,6 @@ class Results extends Component { - - @@ -332,7 +324,8 @@ class Results extends Component { type="classification" /> {/* */} - - - - - {/* */} - {/* */} {/* This TestChart is for interactive and responsive confusion matrix */} - - + {/* https://en.wikipedia.org/wiki/Confusion_matrix */} - - + {/* GPT Space */} - + - @@ -454,74 +445,75 @@ class Results extends Component { */} - {/* GPT Space */} - - - - +
- - - ); - } else if(experiment.data.prediction_type == "regression") { // regression + } else if (experiment.data.prediction_type == "regression") { + // regression let importanceScore, reg_cv_pred, reg_cv_resi, reg_cv_qq; - experiment.data.experiment_files.forEach(file => { + experiment.data.experiment_files.forEach((file) => { const filename = file.filename; - if(filename.includes('imp_score')) { + if (filename.includes("imp_score")) { importanceScore = file; - } else if(filename.includes('reg_cv_pred')) { + } else if (filename.includes("reg_cv_pred")) { reg_cv_pred = file; - } else if(filename.includes('reg_cv_resi')) { + } else if (filename.includes("reg_cv_resi")) { reg_cv_resi = file; - } else if(filename.includes('reg_cv_qq')) { + } else if (filename.includes("reg_cv_qq")) { reg_cv_qq = file; } - }); // r2 - let R2Keys = ['train_r2_score', 'r2_score']; + let R2Keys = ["train_r2_score", "r2_score"]; // r - let RKeys = ['train_pearsonr_score', 'pearsonr_score']; + let RKeys = ["train_pearsonr_score", "pearsonr_score"]; // r2 - let VAFKeys = ['train_explained_variance_score', 'explained_variance_score']; + let VAFKeys = [ + "train_explained_variance_score", + "explained_variance_score", + ]; let R2List = this.getGaugeArray(R2Keys); let RList = this.getGaugeArray(RKeys); let VAFList = this.getGaugeArray(VAFKeys); - return (
- + - downloadModel(experiment.data._id)} - />, - downloadScript(experiment.data._id)} - /> + downloadModel(experiment.data._id)} + /> + , + downloadScript(experiment.data._id)} + /> @@ -531,8 +523,6 @@ class Results extends Component { - - - {/* */} {/* */} {/* */} - - - - - - - - - + - {/* GPT Space */} - + @@ -626,11 +610,9 @@ class Results extends Component { {/* */} - - {/* GPT Space */} - - + {/* GPT Space */} +
); } @@ -638,7 +620,7 @@ class Results extends Component { } const mapStateToProps = (state) => ({ - experiment: state.experiments.selected + experiment: state.experiments.selected, }); export { Results }; diff --git a/lab/webapp/src/components/Results/index.jsx b/lab/webapp/src/components/Results/index.jsx index f0025df4e..bfbe9c03b 100644 --- a/lab/webapp/src/components/Results/index.jsx +++ b/lab/webapp/src/components/Results/index.jsx @@ -90,7 +90,7 @@ class Results extends Component { let expScores = experiment.data.scores; // console.log("experiment.data") - console.log("experiment.data", experiment.data); + console.log("experiment-999.data", experiment.data); // console.log(experiment.data['class_1'][0]) // console.log(experiment.data['class_-1'][0]) diff --git a/lab/webapp/src/components/ResultsV2/components/Score/index.jsx b/lab/webapp/src/components/ResultsV2/components/Score/index.jsx index 5bb1edcc5..8c1678bb3 100644 --- a/lab/webapp/src/components/ResultsV2/components/Score/index.jsx +++ b/lab/webapp/src/components/ResultsV2/components/Score/index.jsx @@ -136,7 +136,9 @@ function Score({ if (typeof scoreValue !== "number" && !scoreValueList.length) { return ; } else { + console.log("777-scoreValueList", scoreValueList); let fold = scoreValueList[0][1] / scoreValueList[1][1]; + console.log("777-fold", fold); var icons = foldcheck(fold); let headericon = ( . (Autogenerated header, do not modify) */ -import React, {Component} from 'react'; -import {connect} from 'react-redux'; -import * as actions from 'data/experiments/selected/actions'; -import SceneHeader from '../SceneHeader'; -import FetchError from '../FetchError'; -import AlgorithmDetails from './components/AlgorithmDetails'; -import RunDetails from './components/RunDetails' -import MSEMAEDetails from './components/MSEMAEDetails';; -import ConfusionMatrix from './components/ConfusionMatrix'; -import ConfusionMatrixJSON from './components/ConfusionMatrixJSON'; -import ROCCurve from './components/ROCCurve'; -import ShapSummaryCurve from './components/ShapSummaryCurve'; -import ImportanceScore from './components/ImportanceScore'; -import ImportanceScoreJSON from './components/ImportanceScoreJSON'; -import LearningCurve from './components/LearningCurve'; -import LearningCurveJSON from './components/LearningCurveJSON'; -import TestChart from './components/TestChart'; -import PCA from './components/PCA'; +import React, { Component } from "react"; +import { connect } from "react-redux"; +import * as actions from "data/experiments/selected/actions"; +import SceneHeader from "../SceneHeader"; +import FetchError from "../FetchError"; +import AlgorithmDetails from "./components/AlgorithmDetails"; +import RunDetails from "./components/RunDetails"; +import MSEMAEDetails from "./components/MSEMAEDetails"; +import ConfusionMatrix from "./components/ConfusionMatrix"; +import ConfusionMatrixJSON from "./components/ConfusionMatrixJSON"; +import ROCCurve from "./components/ROCCurve"; +import ShapSummaryCurve from "./components/ShapSummaryCurve"; +import ImportanceScore from "./components/ImportanceScore"; +import ImportanceScoreJSON from "./components/ImportanceScoreJSON"; +import LearningCurve from "./components/LearningCurve"; +import LearningCurveJSON from "./components/LearningCurveJSON"; +import TestChart from "./components/TestChart"; +import PCA from "./components/PCA"; // import PCAJSON from './components/PCAJSON'; -import GenPLOT from './components/GenPLOT'; +import GenPLOT from "./components/GenPLOT"; // import PCAJSONV from './components/PCAJSONV'; // import TSNE from './components/TSNE'; // import TSNEJSON from './components/TSNEJSON'; -import RegFigure from './components/RegFigure'; -import Score from './components/Score'; +import RegFigure from "./components/RegFigure"; +import Score from "./components/Score"; // import NoScore from './components/NoScore'; -import {Header, Grid, Loader, Dropdown, Menu} from 'semantic-ui-react'; -import {formatDataset} from 'utils/formatter'; -import ClassRate from './components/ClassRate'; -import ChatGPT from '../ChatGPT'; +import { Header, Grid, Loader, Dropdown, Menu } from "semantic-ui-react"; +import { formatDataset } from "utils/formatter"; +import ClassRate from "./components/ClassRate"; +import ChatGPT from "../ChatGPT"; function moveSlidermakeBlack(e) { + let block = document.getElementsByClassName("chartsbaseleft")[0]; + let slider = document.getElementsByClassName("slider")[0]; + let chatbox = document.getElementsByClassName("chatbaseright")[0]; + + if (block && slider) { + // console.log("block and slider exist"); + + slider.onmousedown = function dragMouseDown(e) { + // get width of window + let windowWidth = window.innerWidth; + // console.log("windowWidth", windowWidth); + let dragX = e.clientX; + // console.log("e.clientX", e.clientX); + document.onmousemove = function onMouseMove(e) { + // 0.2 --0.3 --0.4 --0.8 -- + + console.log("block.offsetWidth", block.offsetWidth); + + // shift the result block to the right, and make the chatbox invisible + if (block.offsetWidth > 0.8 * windowWidth) { + console.log("range-bigger than 0.8"); + block.style.width = windowWidth + "px"; + dragX = e.clientX; + chatbox.style.visibility = "hidden"; + } - - - let block = document.getElementsByClassName("chartsbaseleft")[0]; - let slider = document.getElementsByClassName("slider")[0]; - let chatbox = document.getElementsByClassName("chatbaseright")[0]; - - - if (block && slider) { - - // console.log("block and slider exist"); - - slider.onmousedown = function dragMouseDown(e) { - // get width of window - let windowWidth = window.innerWidth; - // console.log("windowWidth", windowWidth); - let dragX = e.clientX; - // console.log("e.clientX", e.clientX); - document.onmousemove = function onMouseMove(e) { - - - // 0.2 --0.3 --0.4 --0.8 -- - - console.log("block.offsetWidth", block.offsetWidth) - - // shift the result block to the right, and make the chatbox invisible - if (block.offsetWidth > 0.8 * windowWidth) { - console.log("range-bigger than 0.8") - block.style.width = windowWidth + "px"; - dragX = e.clientX; - chatbox.style.visibility = "hidden"; - } - - // shift the chatbox to the left, and make the result block invisible - else if (block.offsetWidth < 0.2 * windowWidth) { - console.log("range-smaller than 0.2") - block.style.width = 0 + "px"; - block.style.visibility = "hidden"; - slider.style.visibility = "hidden"; - dragX = e.clientX; - - - } - - // else - // { - // console.log("range-bigger than or equal to 0.5 and smaller than or equal to 0.8") - // block.style.visibility = "block"; - // slider.style.visibility = "block"; - - // // origin - // block.style.width = block.offsetWidth + e.clientX - dragX + "px"; - // dragX = e.clientX; - - - - // } - - - - else if (block.offsetWidth >= 0.4 * windowWidth && block.offsetWidth <= 0.8 * windowWidth) - { - console.log("range-bigger than or equal to 0.4 and smaller than or equal to 0.8") - block.style.visibility = "block"; - slider.style.visibility = "block"; - - // origin - block.style.width = block.offsetWidth + e.clientX - dragX + "px"; - dragX = e.clientX; - - // if e.target.parentElement.childNodes[0].childNodes[1] is not undefined - - if (e.target.parentElement.childNodes[0].childNodes[1].className === "ui stackable two column grid") { - - e.target.parentElement.childNodes[0].childNodes[1].className = "ui stackable three column grid" - } - - else if (e.target.parentElement.childNodes[0].childNodes[1].className === "ui stackable one column grid") { - - e.target.parentElement.childNodes[0].childNodes[1].className = "ui stackable three column grid" - } - } - - - - else if (block.offsetWidth >= 0.3 * windowWidth && block.offsetWidth < 0.4 * windowWidth) - { - console.log("range-bigger than or equal to 0.3 and smaller than 0.4") - block.style.visibility = "block"; - slider.style.visibility = "block"; - - // origin - block.style.width = block.offsetWidth + e.clientX - dragX + "px"; - dragX = e.clientX; - - - if (e.target.parentElement.childNodes[0].childNodes[1].className === "ui stackable three column grid") { - - e.target.parentElement.childNodes[0].childNodes[1].className = "ui stackable two column grid" - } - - else if (e.target.parentElement.childNodes[0].childNodes[1].className === "ui stackable one column grid") { - - e.target.parentElement.childNodes[0].childNodes[1].className = "ui stackable two column grid" - } - } - - else if (block.offsetWidth >= 0.2 * windowWidth && block.offsetWidth < 0.3 * windowWidth) - { - console.log("range-bigger than or equal to 0.2 and smaller than 0.3") - block.style.visibility = "block"; - slider.style.visibility = "block"; - - // origin - block.style.width = block.offsetWidth + e.clientX - dragX + "px"; - dragX = e.clientX; - - if (e.target.parentElement.childNodes[0].childNodes[1].className === "ui stackable three column grid") { - - e.target.parentElement.childNodes[0].childNodes[1].className = "ui stackable one column grid" - } - - else if (e.target.parentElement.childNodes[0].childNodes[1].className === "ui stackable two column grid") { - - e.target.parentElement.childNodes[0].childNodes[1].className = "ui stackable one column grid" - } - } - - - } - // remove mouse-move listener on mouse-up - document.onmouseup = () => document.onmousemove = document.onmouseup = null; + // shift the chatbox to the left, and make the result block invisible + else if (block.offsetWidth < 0.2 * windowWidth) { + console.log("range-smaller than 0.2"); + block.style.width = 0 + "px"; + block.style.visibility = "hidden"; + slider.style.visibility = "hidden"; + dragX = e.clientX; } - } + // else + // { + // console.log("range-bigger than or equal to 0.5 and smaller than or equal to 0.8") + // block.style.visibility = "block"; + // slider.style.visibility = "block"; + + // // origin + // block.style.width = block.offsetWidth + e.clientX - dragX + "px"; + // dragX = e.clientX; + + // } + else if ( + block.offsetWidth >= 0.4 * windowWidth && + block.offsetWidth <= 0.8 * windowWidth + ) { + console.log( + "range-bigger than or equal to 0.4 and smaller than or equal to 0.8" + ); + block.style.visibility = "block"; + slider.style.visibility = "block"; + + // origin + block.style.width = block.offsetWidth + e.clientX - dragX + "px"; + dragX = e.clientX; + + // if e.target.parentElement.childNodes[0].childNodes[1] is not undefined + + if ( + e.target.parentElement.childNodes[0].childNodes[1].className === + "ui stackable two column grid" + ) { + e.target.parentElement.childNodes[0].childNodes[1].className = + "ui stackable three column grid"; + } else if ( + e.target.parentElement.childNodes[0].childNodes[1].className === + "ui stackable one column grid" + ) { + e.target.parentElement.childNodes[0].childNodes[1].className = + "ui stackable three column grid"; + } + } else if ( + block.offsetWidth >= 0.3 * windowWidth && + block.offsetWidth < 0.4 * windowWidth + ) { + console.log("range-bigger than or equal to 0.3 and smaller than 0.4"); + block.style.visibility = "block"; + slider.style.visibility = "block"; + + // origin + block.style.width = block.offsetWidth + e.clientX - dragX + "px"; + dragX = e.clientX; + + if ( + e.target.parentElement.childNodes[0].childNodes[1].className === + "ui stackable three column grid" + ) { + e.target.parentElement.childNodes[0].childNodes[1].className = + "ui stackable two column grid"; + } else if ( + e.target.parentElement.childNodes[0].childNodes[1].className === + "ui stackable one column grid" + ) { + e.target.parentElement.childNodes[0].childNodes[1].className = + "ui stackable two column grid"; + } + } else if ( + block.offsetWidth >= 0.2 * windowWidth && + block.offsetWidth < 0.3 * windowWidth + ) { + console.log("range-bigger than or equal to 0.2 and smaller than 0.3"); + block.style.visibility = "block"; + slider.style.visibility = "block"; + + // origin + block.style.width = block.offsetWidth + e.clientX - dragX + "px"; + dragX = e.clientX; + + if ( + e.target.parentElement.childNodes[0].childNodes[1].className === + "ui stackable three column grid" + ) { + e.target.parentElement.childNodes[0].childNodes[1].className = + "ui stackable one column grid"; + } else if ( + e.target.parentElement.childNodes[0].childNodes[1].className === + "ui stackable two column grid" + ) { + e.target.parentElement.childNodes[0].childNodes[1].className = + "ui stackable one column grid"; + } + } + }; + // remove mouse-move listener on mouse-up + document.onmouseup = () => + (document.onmousemove = document.onmouseup = null); + }; + } } function makeOriginColor(e) { - let slider = document.getElementsByClassName("slider")[0]; + let slider = document.getElementsByClassName("slider")[0]; - // make slider color black - slider.style.backgroundColor = "#1B1C1D;" + // make slider color black + slider.style.backgroundColor = "#1B1C1D;"; } class Results extends Component { - constructor(props) { - super(props); - this.getGaugeArray = this - .getGaugeArray - .bind(this); + constructor(props) { + super(props); + this.getGaugeArray = this.getGaugeArray.bind(this); + } + + componentDidMount() { + this.props.fetchExperiment(this.props.params.id); + } + + componentWillUnmount() { + this.props.clearExperiment(); + } + + /** + * Basic helped method to create array containing [key,val] entries where + * key - name of given score + * value - actual score + * passed to Score component which uses javascript library C3 to create graphic + */ + + getGaugeArray(keyList) { + const { experiment } = this.props; + let testList = []; + let expScores = experiment.data.scores; + + if (typeof expScores === "object") { + keyList.forEach((scoreKey) => { + console.log("scoreKey", scoreKey); + // in case of 0 or false, it should satisfy the condition + + if ( + expScores.hasOwnProperty(scoreKey) && + expScores[scoreKey] !== undefined && + expScores[scoreKey] !== null && + typeof expScores[scoreKey].toFixed === "function" + ) { + let tempLabel = scoreKey.includes("train") + ? `Train (${expScores[scoreKey].toFixed(2)})` + : `Test (${expScores[scoreKey].toFixed(2)})`; + console.log("555-scoreKey", scoreKey); + console.log("555-expScores[scoreKey]", expScores[scoreKey]); + testList.push([tempLabel, expScores[scoreKey]]); + } + }); } - componentDidMount() { - this - .props - .fetchExperiment(this.props.params.id); - } + console.log("testList", testList); - componentWillUnmount() { - this - .props - .clearExperiment(); - } + return testList; + } - /** - * Basic helped method to create array containing [key,val] entries where - * key - name of given score - * value - actual score - * passed to Score component which uses javascript library C3 to create graphic - */ - - getGaugeArray(keyList) { - const {experiment} = this.props; - let testList = []; - let expScores = experiment.data.scores; - - // console.log("experiment.data") - console.log("experiment.data", experiment.data) - // console.log(experiment.data['class_1'][0]) - // console.log(experiment.data['class_-1'][0]) - - if (typeof(expScores) === 'object') { - keyList.forEach(scoreKey => { - if (expScores[scoreKey] && typeof expScores[scoreKey].toFixed === 'function') { - let tempLabel; - scoreKey.includes('train') - ? tempLabel = 'Train (' + expScores[scoreKey].toFixed(2) + ')' - : tempLabel = 'Test (' + expScores[scoreKey].toFixed(2) + ')'; - testList.push([ - tempLabel, expScores[scoreKey] - ]); - } - }); - } + render() { + const { experiment, fetchExperiment } = this.props; - return testList; + if (experiment.isFetching || !experiment.data) { + return ( + + ); } - render() { - const {experiment, fetchExperiment} = this.props; + if (experiment.error === "Failed to fetch") { + return ; + } else if (experiment.error) { + return ( + fetchExperiment()} + /> + ); + } - if (experiment.isFetching || !experiment.data) { - return ( - - ); + const downloadModel = (id) => { + // console.log("downloadModel_id",id) + fetch(`/api/v1/experiments/${id}/model`) + .then((response) => { + if (response.status >= 400) { + throw new Error(`${response.status}: ${response.statusText}`); + } + return response.json(); + }) + .then((json) => { + console.log("json", json); + window.location = `/api/v1/files/${json._id}`; + }); + }; + + const downloadScript = (id) => { + fetch(`/api/v1/experiments/${id}/script`) + .then((response) => { + if (response.status >= 400) { + throw new Error(`${response.status}: ${response.statusText}`); + } + return response.json(); + }) + .then((json) => { + window.location = `/api/v1/files/${json._id}`; + }); + }; + + // console.log(experiment.data.prediction_type) --- get lists of scores --- + if (experiment.data.prediction_type == "classification") { + // classification + + console.log("experiment.data", experiment.data); + // console.log("X_pca", experiment.data.X_pca) console.log("y_pca", + // experiment.data.y_pca) + + let confusionMatrix, + rocCurve, + importanceScore, + learningCurve, + pca, + pca_json, + tsne, + tsne_json, + shap_explainer, + shap_num_samples; + + let shapSummaryCurveDict = {}; + + experiment.data.experiment_files.forEach(async (file) => { + const filename = file.filename; + console.log("filename-test", filename); + if (filename.includes("confusion_matrix")) { + confusionMatrix = file; + } else if (filename.includes("roc_curve")) { + rocCurve = file; + // save to local storage localStorage.setItem('rocCurve', rocCurve); + } else if (filename.includes("imp_score")) { + importanceScore = file; + } else if (filename.includes("learning_curve")) { + learningCurve = file; + } else if (filename.includes("pca") && filename.includes("png")) { + pca = file; + console.log("pca", pca); + } else if (filename.includes("pca-json")) { + console.log("pca_json"); + pca_json = file; + } else if (filename.includes("tsne") && filename.includes("png")) { + tsne = file; + console.log("tsne", tsne); + } else if (filename.includes("tsne-json")) { + console.log("tsne_json"); + tsne_json = file; + console.log("tsne_json: ", tsne_json); + } else if (filename.includes("shap_summary_curve")) { + console.log("shap_summary_curve"); + let class_name = filename.split("_").slice(-2, -1); + shapSummaryCurveDict[class_name] = file; + shap_explainer = experiment.data.shap_explainer; + shap_num_samples = experiment.data.shap_num_samples; + + // save to local storage localStorage.setItem( 'shapSummaryCurveDict', + // JSON.stringify(shapSummaryCurveDict) ); + // localStorage.setItem('shap_explainer', shap_explainer); + // localStorage.setItem('shap_num_samples', shap_num_samples); + } else if (filename.includes("shap_summary_json")) { + console.log("shap_summary_json"); + // shap_json = file; console.log("shap_json: ", shap_json) } - - if (experiment.error === 'Failed to fetch') { - return (); - } else if (experiment.error) { - return ( - fetchExperiment()}/> - ); + }); + // balanced accuracy + let balancedAccKeys = [ + "train_balanced_accuracy_score", + "balanced_accuracy_score", + ]; + // precision scores + let precisionKeys = ["train_precision_score", "precision_score"]; + // AUC + let aucKeys = ["train_roc_auc_score", "roc_auc_score"]; + // f1 score + let f1Keys = ["train_f1_score", "f1_score"]; + // recall + let recallKeys = ["train_recall_score", "recall_score"]; + + let balancedAccList = this.getGaugeArray(balancedAccKeys); + let precisionList = this.getGaugeArray(precisionKeys); + let aucList = this.getGaugeArray(aucKeys); + let recallList = this.getGaugeArray(recallKeys); + let f1List = this.getGaugeArray(f1Keys); + let class_percentage = []; + // let pca_data = []; + + experiment.data.class_names.forEach((eachclass) => { + console.log("eachclass.toString()", eachclass.toString()); + // if type of experiment.data['class_' + eachclass.toString()] === 'object' + if ( + typeof experiment.data["class_" + eachclass.toString()] === "object" + ) { + class_percentage.push([ + eachclass.toString(), + experiment.data["class_" + eachclass.toString()][0], + ]); + console.log("experiment.data['class_1']", experiment.data["class_1"]); + } else { + class_percentage.push([ + eachclass.toString(), + experiment.data["class_" + eachclass.toString()], + ]); + console.log("experiment.data['class_1']", experiment.data["class_1"]); } - - const downloadModel = (id) => { - // console.log("downloadModel_id",id) - fetch(`/api/v1/experiments/${id}/model`) - .then(response => { - if (response.status >= 400) { - throw new Error(`${response.status}: ${response.statusText}`); - } - return response.json(); - }) - .then(json => { - console.log("json",json) - window.location = `/api/v1/files/${json._id}`; - }); - }; - - const downloadScript = (id) => { - fetch(`/api/v1/experiments/${id}/script`) - .then(response => { - if (response.status >= 400) { - throw new Error(`${response.status}: ${response.statusText}`); - } - return response.json(); - }) - .then(json => { - window.location = `/api/v1/files/${json._id}`; - }); - }; - - // console.log(experiment.data.prediction_type) --- get lists of scores --- - if (experiment.data.prediction_type == "classification") { // classification - - console.log("experiment.data", experiment.data) - // console.log("X_pca", experiment.data.X_pca) console.log("y_pca", - // experiment.data.y_pca) - - let confusionMatrix, - rocCurve, - importanceScore, - learningCurve, - pca, - pca_json, - tsne, - tsne_json, - shap_explainer, - shap_num_samples; - - let shapSummaryCurveDict = {}; - - experiment - .data - .experiment_files - .forEach(async file => { - const filename = file.filename; - console.log('filename-test', filename); - if (filename.includes('confusion_matrix')) { - confusionMatrix = file; - } else if (filename.includes('roc_curve')) { - rocCurve = file; - // save to local storage localStorage.setItem('rocCurve', rocCurve); - } else if (filename.includes('imp_score')) { - importanceScore = file; - } else if (filename.includes('learning_curve')) { - learningCurve = file; - } else if (filename.includes('pca') && filename.includes('png')) { - pca = file; - console.log("pca", pca) - } else if (filename.includes('pca-json')) { - console.log("pca_json") - pca_json = file; - } else if (filename.includes('tsne') && filename.includes('png')) { - tsne = file; - console.log("tsne", tsne) - } else if (filename.includes('tsne-json')) { - console.log("tsne_json") - tsne_json = file; - console.log("tsne_json: ", tsne_json) - } - - - else if (filename.includes('shap_summary_curve')) { - console.log("shap_summary_curve") - let class_name = filename - .split('_') - .slice(-2, -1); - shapSummaryCurveDict[class_name] = file; - shap_explainer = experiment.data.shap_explainer; - shap_num_samples = experiment.data.shap_num_samples; - - // save to local storage localStorage.setItem( 'shapSummaryCurveDict', - // JSON.stringify(shapSummaryCurveDict) ); - // localStorage.setItem('shap_explainer', shap_explainer); - // localStorage.setItem('shap_num_samples', shap_num_samples); - - } else if (filename.includes('shap_summary_json')) { - console.log("shap_summary_json") - // shap_json = file; console.log("shap_json: ", shap_json) - } - - }); - // balanced accuracy - let balancedAccKeys = ['train_balanced_accuracy_score', 'balanced_accuracy_score']; - // precision scores - let precisionKeys = ['train_precision_score', 'precision_score'] - // AUC - let aucKeys = ['train_roc_auc_score', 'roc_auc_score']; - // f1 score - let f1Keys = ['train_f1_score', 'f1_score']; - // recall - let recallKeys = ['train_recall_score', 'recall_score']; - - let balancedAccList = this.getGaugeArray(balancedAccKeys); - let precisionList = this.getGaugeArray(precisionKeys); - let aucList = this.getGaugeArray(aucKeys); - let recallList = this.getGaugeArray(recallKeys); - let f1List = this.getGaugeArray(f1Keys); - let class_percentage = []; - // let pca_data = []; - - experiment - .data - .class_names - .forEach(eachclass => { - - console.log('eachclass.toString()', eachclass.toString()) - // if type of experiment.data['class_' + eachclass.toString()] === 'object' - if ((typeof experiment.data['class_' + eachclass.toString()]) === 'object') { - class_percentage.push([ - eachclass.toString(), - experiment.data['class_' + eachclass.toString()][0] - ]); - console.log("experiment.data['class_1']", experiment.data['class_1']) - } else { - class_percentage.push([ - eachclass.toString(), - experiment.data['class_' + eachclass.toString()] - ]); - console.log("experiment.data['class_1']", experiment.data['class_1']) - } - - }); - - // console.log('balancedAccList', balancedAccList) save to local storage - // localStorage.setItem('balancedAccList', JSON.stringify(balancedAccList)); - // console.log('precisionList', precisionList) save to local storage - // localStorage.setItem('precisionList', JSON.stringify(precisionList)); save - // to local storage console.log('aucList', aucList) - // localStorage.setItem('aucList', JSON.stringify(aucList)); save to local - // storage console.log('recallList', recallList) - // localStorage.setItem('recallList', JSON.stringify(recallList)); save to - // local storage console.log('f1List', f1List) localStorage.setItem('f1List', - // JSON.stringify(f1List)); save to local storage - // console.log('class_percentage', class_percentage) - // localStorage.setItem('class_percentage', JSON.stringify(class_percentage)); - - return ( -
-
- - - - - - - - - - downloadModel(experiment.data._id)}/>, - downloadScript(experiment.data._id)}/> - - - - - - - - - - - - {/* */} - {/* */} - - {/* */} - - - - - - {/* +
+ + + + + + + + + + downloadModel(experiment.data._id)} + /> + , + downloadScript(experiment.data._id)} + /> + + + + + + + + + + + + {" "} + {/* */} + {" "} + {/* */} + + {/* */} + + + + {/* */ - } - - {/* */} - {/* This TestChart is for interactive and responsive confusion matrix */} - - - - {/* */} - - - - - - - - - - - - -
-
- {/* onChange={moveSlidermakeBlack} */} - ""

- ""

- ""

- ""

- ""

- ""

-
-
- -
-
- - //
- // - // - // - // downloadModel(experiment.data._id)}/>, downloadScript(experiment.data._id)}/> - // - // - // {/* */} {/* */} {/* */} - // {/* */} - // {/* */} {/* */ } - // {/* */ } {/* - // */} {/* - // This TestChart is for interactive and responsive confusion matrix */} - // - // - // - // {/* - // https://en.wikipedia.org/wiki/Confusion_matrix - // */} - // {/* GPT Space */} {/* - // */ } {/* GPT Space */}
- ); - } else if (experiment.data.prediction_type == "regression") { // regression - let importanceScore, - reg_cv_pred, - reg_cv_resi, - reg_cv_qq, - reg_cvp_png, - reg_cvp_json, - reg_cvr_png, - reg_cvr_json, - reg_qqnr_png, - reg_qqnr_json; - - experiment - .data - .experiment_files - .forEach(file => { - const filename = file.filename; - console.log("filename-regression", filename) - if (filename.includes('imp_score')) { - importanceScore = file; - } else if (filename.includes('reg_cv_pred')) { - reg_cv_pred = file; - } else if (filename.includes('reg_cv_resi')) { - reg_cv_resi = file; - } else if (filename.includes('reg_cv_qq')) { - reg_cv_qq = file; - } else if (filename.includes('reg_cv_pred') && filename.includes('png') ) { - reg_cvp_png = file; - console.log("reg_cvp_png", reg_cvp_png) - } else if (filename.includes('reg_cv_resi') && filename.includes('png')) { - reg_cvr_png = file; - console.log("reg_cvr_png", reg_cvr_png) - } else if (filename.includes('reg_cv_qq') && filename.includes('png')) { - reg_qqnr_png = file; - }else if (filename.includes('reg_cvp') && filename.includes('json') ) { - reg_cvp_json = file; - console.log("reg_cvp_json", reg_cvp_json) - } else if (filename.includes('reg_cvr') && filename.includes('json')) { - reg_cvr_json = file; - console.log("reg_cvr_json", reg_cvr_json) - } else if (filename.includes('reg_qqnr') && filename.includes('json')) { - reg_qqnr_json = file; - console.log("reg_qqnr_json", reg_qqnr_json) - } - - }); - // r2 - let R2Keys = ['train_r2_score', 'r2_score']; - // r - let RKeys = ['train_pearsonr_score', 'pearsonr_score']; - // r2 - let VAFKeys = ['train_explained_variance_score', 'explained_variance_score']; - - let R2List = this.getGaugeArray(R2Keys); - let RList = this.getGaugeArray(RKeys); - let VAFList = this.getGaugeArray(VAFKeys); - - return ( - -
-
- - - - - - - - - - downloadModel(experiment.data._id)}/>, - downloadScript(experiment.data._id)}/> - - - - - - - - - - - - {/* */} - - - - - {/* */} - {/* */} - {/* */} - - {/* { + /> */} + {" "} + {/* */} + {/* This TestChart is for interactive and responsive confusion matrix */} + + + + {/* */} + + + + + + + + + + + +
+
+ {/* onChange={moveSlidermakeBlack} */} + ""

+ ""

+ ""

+ ""

+ ""

+ ""

+
+
+ +
+
+ + //
+ // + // + // + // downloadModel(experiment.data._id)}/>, downloadScript(experiment.data._id)}/> + // + // + // {/* */} {/* */} {/* */} + // {/* */} + // {/* */} {/* */ } + // {/* */ } {/* + // */} {/* + // This TestChart is for interactive and responsive confusion matrix */} + // + // + // + // {/* + // https://en.wikipedia.org/wiki/Confusion_matrix + // */} + // {/* GPT Space */} {/* + // */ } {/* GPT Space */}
+ ); + } else if (experiment.data.prediction_type == "regression") { + // regression + let importanceScore, + reg_cv_pred, + reg_cv_resi, + reg_cv_qq, + reg_cvp_png, + reg_cvp_json, + reg_cvr_png, + reg_cvr_json, + reg_qqnr_png, + reg_qqnr_json; + + experiment.data.experiment_files.forEach((file) => { + const filename = file.filename; + console.log("filename-regression", filename); + if (filename.includes("imp_score")) { + importanceScore = file; + } else if (filename.includes("reg_cv_pred")) { + reg_cv_pred = file; + } else if (filename.includes("reg_cv_resi")) { + reg_cv_resi = file; + } else if (filename.includes("reg_cv_qq")) { + reg_cv_qq = file; + } else if ( + filename.includes("reg_cv_pred") && + filename.includes("png") + ) { + reg_cvp_png = file; + console.log("reg_cvp_png", reg_cvp_png); + } else if ( + filename.includes("reg_cv_resi") && + filename.includes("png") + ) { + reg_cvr_png = file; + console.log("reg_cvr_png", reg_cvr_png); + } else if (filename.includes("reg_cv_qq") && filename.includes("png")) { + reg_qqnr_png = file; + } else if (filename.includes("reg_cvp") && filename.includes("json")) { + reg_cvp_json = file; + console.log("reg_cvp_json", reg_cvp_json); + } else if (filename.includes("reg_cvr") && filename.includes("json")) { + reg_cvr_json = file; + console.log("reg_cvr_json", reg_cvr_json); + } else if (filename.includes("reg_qqnr") && filename.includes("json")) { + reg_qqnr_json = file; + console.log("reg_qqnr_json", reg_qqnr_json); + } + }); + // r2 + let R2Keys = ["train_r2_score", "r2_score"]; + // r + let RKeys = ["train_pearsonr_score", "pearsonr_score"]; + // r2 + let VAFKeys = [ + "train_explained_variance_score", + "explained_variance_score", + ]; + + let R2List = this.getGaugeArray(R2Keys); + let RList = this.getGaugeArray(RKeys); + let VAFList = this.getGaugeArray(VAFKeys); + + return ( +
+
+ + + + + + + + + + downloadModel(experiment.data._id)} + /> + , + downloadScript(experiment.data._id)} + /> + + + + + + + + + + + {" "} + {/* */} + + + + {/* */} + {/* */} + {/* */} + + {/* { experiment.data.CVP_2d === undefined ?
: } */} - - - {/* + + {/* */ - } - - {/* { + data={experiment.data}/> */} + + {/* { experiment.data.CVR_2d === undefined ?
: } */} - - - {/* { + + + {/* { experiment.data.QQNR_2d === undefined ?
: } */} - - - - - - - - - - - - -
-
- {/* onChange={moveSlidermakeBlack} */} - ""

- ""

- ""

- ""

- ""

- ""

-
-
- -
-
- - ); - } + + + + + + + + + + +
+
+ {/* onChange={moveSlidermakeBlack} */} + ""

+ ""

+ ""

+ ""

+ ""

+ ""

+
+
+ +
+ + ); } + } } -const mapStateToProps = (state) => ({experiment: state.experiments.selected}); - +const mapStateToProps = (state) => ({ experiment: state.experiments.selected }); -export { - Results -}; +export { Results }; export default connect(mapStateToProps, actions)(Results); diff --git a/machine/learn/skl_utils.py b/machine/learn/skl_utils.py index 8654ff96e..49afe5b80 100644 --- a/machine/learn/skl_utils.py +++ b/machine/learn/skl_utils.py @@ -38,7 +38,7 @@ from sklearn.pipeline import make_pipeline, Pipeline from sklearn.compose import ColumnTransformer from sklearn.preprocessing import OneHotEncoder, OrdinalEncoder, LabelEncoder -from sklearn.model_selection import GridSearchCV, cross_validate, StratifiedKFold, KFold +from sklearn.model_selection import GridSearchCV, cross_validate, StratifiedKFold, RepeatedStratifiedKFold, KFold from sklearn.metrics import SCORERS, roc_curve, auc, make_scorer, confusion_matrix import itertools import json @@ -175,6 +175,32 @@ def get_column_names_from_ColumnTransformer(column_transformer, feature_names): new_feature_names += feature_columns return new_feature_names +# decision rule for choosing number of folds based on the class distribution in the given dataset +def decision_rule_fold_cv_based_on_classes(each_class): + """ + Adjusts the number of cross-validation folds based on the class distribution. + + Parameters + ---------- + each_class : dict + A dictionary where keys are the classes and the values are the number of samples per class. + + Returns + ------- + cv : int + The suitable number of cross-validation folds ensuring that each fold can include instances of each class. + """ + # Find the minimum class count to ensure every fold can contain at least one instance of every class. + min_class_count = min(each_class.values()) + + # The maximum number of folds is determined by the smallest class to ensure representation in each fold. + # However, we cannot have more folds than the minimum class count. + n_folds = min(10, min_class_count) # Starting with a default max of 10 folds + + # Ensure at least 2 folds for meaningful cross-validation. + n_folds = max(n_folds, 2) + + return n_folds def generate_results(model, input_data, tmpdir, _id, target_name='class', @@ -246,6 +272,10 @@ def generate_results(model, input_data, feature_names = np.array( [x for x in input_data.columns.values if x != target_name]) num_classes = input_data[target_name].unique().shape[0] + + # calculate number of each class + each_class = input_data[target_name].value_counts() + features = input_data.drop(target_name, axis=1).values target = input_data[target_name].values @@ -380,16 +410,17 @@ def generate_results(model, input_data, target, cv, return_times=True) model.fit(features, target) - # # plot learning curve - # plot_learning_curve(tmpdir,_id, model,features,target,cv,return_times=True) + + + - # computing cross-validated metrics cv_scores = cross_validate( estimator=model, X=features, y=target, scoring=scoring, - cv=cv, + # cv = stratified_cv, + cv = cv, return_train_score=True, return_estimator=True ) @@ -398,26 +429,73 @@ def generate_results(model, input_data, train_scores = cv_scores['train_' + s] test_scores = cv_scores['test_' + s] + print("train_scores", train_scores) + print("test_scores", test_scores) + + # if abs(train_scores.mean()) is np.nan OR abs(test_scores.mean()) is np.nan + if np.isnan(abs(train_scores.mean())) or np.isnan(abs(test_scores.mean())): + print("777-NaN") + print("train_scores", train_scores) + print("test_scores", test_scores) + # remove _macro score_name = s.replace('_macro', '') # make balanced_accuracy as default score if score_name in ["balanced_accuracy", "neg_mean_squared_error"]: scores['train_score'] = abs(train_scores.mean()) scores['test_score'] = abs(test_scores.mean()) + + # Temporary fix to handle NaN values + if np.isnan(scores['train_score']): + scores['train_score'] = np.nanmean(train_scores) + if np.isnan(scores['test_score']): + scores['test_score'] = np.nanmean(test_scores) # for api will fix later if score_name == "balanced_accuracy": scores['accuracy_score'] = test_scores.mean() + # Temporary fix to handle NaN values + if np.nanmean(test_scores)!=np.nan: + scores['accuracy_score'] = np.nanmean(test_scores) + else: + scores['accuracy_score'] = 0 # for experiment tables if score_name == "balanced_accuracy" or score_name == "r2": scores['exp_table_score'] = test_scores.mean() + # Temporary fix to handle NaN values + if np.nanmean(test_scores)!=np.nan: + scores['exp_table_score'] = np.nanmean(test_scores) + else: + scores['exp_table_score'] = 0 if score_name in ["neg_mean_squared_error", "neg_mean_absolute_error"]: scores['train_{}_score'.format(score_name)] = abs( train_scores.mean()) + # Temporary fix to handle NaN values + if np.nanmean(train_scores)!=np.nan: + scores['train_{}_score'.format(score_name)] = np.nanmean( + train_scores) + else: + scores['train_{}_score'.format(score_name)] = 0 scores['{}_score'.format(score_name)] = abs(test_scores.mean()) + # Temporary fix to handle NaN values + if np.nanmean(test_scores)!=np.nan: + scores['{}_score'.format(score_name)] = np.nanmean(test_scores) + else: + scores['{}_score'.format(score_name)] = 0 else: scores['train_{}_score'.format(score_name)] = train_scores.mean() + # Temporary fix to handle NaN values + if np.nanmean(train_scores)!=np.nan: + scores['train_{}_score'.format(score_name)] = np.nanmean( + train_scores) + else: + scores['train_{}_score'.format(score_name)] = 0 scores['{}_score'.format(score_name)] = test_scores.mean() + # Temporary fix to handle NaN values + if np.nanmean(test_scores)!=np.nan: + scores['{}_score'.format(score_name)] = np.nanmean(test_scores) + else: + scores['{}_score'.format(score_name)] = 0 # dump fitted module as pickle file export_model(tmpdir, _id, model, filename, target_name, mode, random_state) @@ -687,6 +765,8 @@ def plot_confusion_matrix( """ pred_y = np.empty(y.shape) cv = StratifiedKFold(n_splits=10) + # Temporary fix to handle NaN values + # cv = StratifiedKFold(n_splits=8) for cv_split, est in zip(cv.split(X, y), cv_scores['estimator']): train, test = cv_split pred_y[test] = est.predict(X[test]) @@ -980,11 +1060,12 @@ def plot_roc_curve(tmpdir, _id, X, y, cv_scores, figure_export): from scipy import interp from scipy.stats import sem, t cv = StratifiedKFold(n_splits=10) + # Temporary fix to handle NaN values + # cv = StratifiedKFold(n_splits=8) tprs = [] aucs = [] mean_fpr = np.linspace(0, 1, 100) - # print(cv_scores['train_roc_auc']) for cv_split, est in zip(cv.split(X, y), cv_scores['estimator']): train, test = cv_split try: @@ -998,8 +1079,16 @@ def plot_roc_curve(tmpdir, _id, X, y, cv_scores, figure_export): [list(est.classes_).index(c) for c in y[test]], dtype=np.int ) + # print("Each classes_encoded:", classes_encoded) fpr, tpr, thresholds = roc_curve(classes_encoded, probas_) + + # Temporary fix to handle NaN values + # When the given data is extremely unbalanced, as illustrated by the example where classes_encoded consists solely of the class 0, both true positives (TP) and false negatives (FN) are zero. Consequently, the true positive rate (TPR) is calculated as TPR = TP / (TP + FN), which results in an undefined value (NaN) due to division by zero. In the specific scenario provided, where roc_curve([0,0,0], [0,0.9,0]) is called, it highlights a situation with no positive instances present in the true labels. For purposes of data visualization or further analysis where a numerical value is required, this NaN value is replaced with 0 to indicate the absence of true positives under these conditions. + fpr = np.nan_to_num(fpr) + tpr = np.nan_to_num(tpr) + tprs.append(interp(mean_fpr, fpr, tpr)) + tprs[-1][0] = 0.0 roc_auc = auc(fpr, tpr) aucs.append(roc_auc) @@ -1047,6 +1136,7 @@ def plot_roc_curve(tmpdir, _id, X, y, cv_scores, figure_export): 'tpr': mean_tpr.tolist(), 'roc_auc_score': mean_auc } + print("roc_curve_dict:", roc_curve_dict) file_name = 'roc_curve' + '.json' save_json_fmt(outdir=tmpdir, _id=_id, @@ -1159,6 +1249,18 @@ def plot_learning_curve(tmpdir, _id, model, features, target, cv, return_times=T # replace nan with -1 test_scores = np.nan_to_num(test_scores, nan=-1) + + + + # temp solution for nan values + train_sizes = np.nan_to_num(train_sizes, nan=-1) + train_scores = np.nan_to_num(train_scores, nan=-1) + test_scores = np.nan_to_num(test_scores, nan=-1) + + print("train_sizes.tolist():", train_sizes.tolist()) + print("train_scores.tolist():", train_scores.tolist()) + print("test_scores.tolist():", test_scores.tolist()) + learning_curve_dict = { 'train_sizes': train_sizes.tolist(), 'train_scores': train_scores.tolist(), @@ -1240,13 +1342,11 @@ def plot_pca_3d(tmpdir, _id, features, target): # np.random.seed(5) # iris = datasets.load_iris() - # print(features) + X = np.array(features) y = np.array(target) y[y == -1] = 0 - # print(X) - # print(y) fig = plt.figure(1, figsize=(4, 3)) plt.clf()