-
Notifications
You must be signed in to change notification settings - Fork 88
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Code base refactor - Discussion #468
Comments
To design PLP system inspired by mlr3, we can organize PLP components into mlr3 building blocks:
Forcing our existing PLP functions into mlr3 building blocks could look something like this: PatientLevelPrediction: classDiagram
class PatientLevelPrediction
PatientLevelPrediction --> HelperFunctions
PatientLevelPrediction --> Fit
PatientLevelPrediction --> Logging
PatientLevelPrediction --> ParamChecks
PatientLevelPrediction --> DatabaseMigration
PatientLevelPrediction --> RunMultiplePlp
PatientLevelPrediction --> RunPlp
PatientLevelPrediction --> RunPlpHelpers
PatientLevelPrediction --> SaveLoadPlp
PatientLevelPrediction --> LearningCurve
class HelperFunctions {
+configurePython()
+createTempModelLoc()
+cut2()
+ensure_installed()
+getOs()
+is_installed()
+listAppend()
+nrow()
+nrow.default()
+nrow.tbl()
+removeInvalidString()
+setPythonEnvironment()
}
class Fit {
+fitPlp()
}
class Logging {
+checkFileExists()
+closeLog()
+createLog()
+createLogSettings()
}
class ParamChecks {
+checkBoolean()
+checkHigher()
+checkHigherEqual()
+checkInStringVector()
+checkIsClass()
+checkLower()
+checkLowerEqual()
+checkNotNull()
}
class DatabaseMigration {
+getDataMigrator()
+migrateDataModel()
}
class RunMultiplePlp {
+convertToJson()
+createModelDesign()
+loadPlpAnalysesJson()
+runMultiplePlp()
+savePlpAnalysesJson()
+validateMultiplePlp()
}
class RunPlp {
+runPlp()
}
class RunPlpHelpers {
+checkInputs()
+createDefaultExecuteSettings()
+createExecuteSettings()
+printHeader()
}
class SaveLoadPlp {
+applyMinCellCount()
+extractDatabaseToCsv()
+getPlpSensitiveColumns()
+loadPlpData()
+loadPlpModel()
+loadPlpResult()
+loadPlpShareable()
+loadPrediction()
+removeCellCount()
+removeList()
+saveModelPart()
+savePlpData()
+savePlpModel()
+savePlpResult()
+savePlpShareable()
+savePrediction()
}
class LearningCurve {
+createLearningCurve()
+getTrainFractions()
+lcWrapper()
+learningCurveHelper()
+plotLearningCurve()
}
Data: classDiagram
class Data
Data --> ExtractData
Data --> Simulation
Data --> PreprocessingData
Data --> FeatureEngineering
Data --> FeatureImportance
Data --> Formatting
Data --> AdditionalCovariates
Data --> AndromedaHelperFunctions
class ExtractData {
+createDatabaseDetails()
+createRestrictPlpDataSettings()
+getPlpData()
+print.plpData()
+print.summary.plpData()
+summary.plpData()
}
class Simulation {
+simulatePlpData()
}
class PreprocessingData {
+createPreprocessSettings()
+preprocessData()
}
class FeatureEngineering {
+calculateStratifiedMeans()
+createFeatureEngineeringSettings()
+createRandomForestFeatureSelection()
+createSplineSettings()
+createStratifiedImputationSettings()
+createUnivariateFeatureSelection()
+featureEngineer()
+imputeMissingMeans()
+randomForestFeatureSelection()
+splineCovariates()
+splineMap()
+stratifiedImputeCovariates()
+univariateFeatureSelection()
}
class FeatureImportance {
+permute()
+permutePerf()
+pfi()
}
class Formatting {
+checkRam()
+MapIds()
+toSparseM()
}
class AdditionalCovariates {
+createCohortCovariateSettings()
+getCohortCovariateData()
}
class AndromedaHelperFunctions {
+batchRestrict()
+calculatePrevs()
+limitCovariatesToPopulation()
}
Resample: classDiagram
class Resample
class Sampling {
+createSampleSettings()
+overSampleData()
+sameData()
+sampleData()
+underSampleData()
}
class DataSplitting {
+checkInputsSplit()
+createDefaultSplitSetting()
+dataSummary()
+randomSplitter()
+splitData()
+subjectSplitter()
+timeSplitter()
}
Resample --> Sampling
Resample --> DataSplitting
Task: classDiagram
class Task
Task --> PopulationSettings
Task --> DiagnosePlp
class PopulationSettings {
+createStudyPopulation()
+createStudyPopulationSettings()
+getCounts()
+getCounts2()
}
class DiagnosePlp {
+cos_sim()
+diagnoseMultiplePlp()
+diagnosePlp()
+getDiagnostic()
+getMaxEndDaysFromCovariates()
+getOutcomeSummary()
+probastDesign()
+probastOutcome()
+probastParticipants()
+probastPredictors()
}
Learner: classDiagram
class Learner
Learner --> SklearnToJson
Learner --> SklearnClassifierSettings
Learner --> SklearnClassifierHelpers
Learner --> SklearnClassifier
Learner --> RClassifier
Learner --> KNN
Learner --> LightGBM
Learner --> GradientBoostingMachine
Learner --> CyclopsModels
Learner --> CyclopsSettings
class SklearnToJson {
+deSerializeAdaboost()
+deSerializeCsrMatrix()
+deSerializeDecisionTree()
+deSerializeMlp()
+deSerializeNaiveBayes()
+deSerializeRandomForest()
+deSerializeSVM()
+deSerializeTree()
+serializeAdaboost()
+serializeCsrMatrix()
+serializeDecisionTree()
+serializeMLP()
+serializeNaiveBayes()
+serializeRandomForest()
+serializeSVM()
+serializeTree()
+sklearnFromJson()
+sklearnToJson()
}
class SklearnClassifierSettings {
+AdaBoostClassifierInputs()
+DecisionTreeClassifierInputs()
+GaussianNBInputs()
+MLPClassifierInputs()
+RandomForestClassifierInputs()
+setAdaBoost()
+setDecisionTree()
+setMLP()
+setNaiveBayes()
+setRandomForest()
+setSVM()
+SVCInputs()
}
class SklearnClassifierHelpers {
+listCartesian()
}
class SklearnClassifier {
+checkPySettings()
+computeGridPerformance()
+fitPythonModel()
+fitSklearn()
+gridCvPython()
+predictPythonSklearn()
+predictValues()
}
class RClassifier {
+applyCrossValidationInR()
+fitRclassifier()
}
class KNN {
+fitKNN()
+predictKnn()
+setKNN()
}
class LightGBM {
+fitLightGBM()
+predictLightGBM()
+setLightGBM()
+varImpLightGBM()
}
class GradientBoostingMachine {
+fitXgboost()
+predictXgboost()
+setGradientBoostingMachine()
+varImpXgboost()
}
class CyclopsModels {
+createCyclopsModel()
+filterCovariateIds()
+fitCyclopsModel()
+getCV()
+getVariableImportance()
+modelTypeToCyclopsModelType()
+predictCyclops()
+predictCyclopsType()
+reparamTransferCoefs()
}
class CyclopsSettings {
+setCoxModel()
+setIterativeHardThresholding()
+setLassoLogisticRegression()
}
Measure: classDiagram
class Measure
Measure --> ViewShinyPlp
Measure --> uploadToDatabasePerformance
Measure --> uploadToDatabase
Measure --> uploadToDatabaseDiagnostics
Measure --> uploadToDatabaseModelDesign
Measure --> ThresholdSummary
Measure --> PredictionDistribution
Measure --> Plotting
Measure --> CovariateSummary
Measure --> EvaluatePlp
Measure --> EvaluationSummary
Measure --> DemographicSummary
Measure --> CalibrationSummary
Measure --> ImportFromCsv
class ViewShinyPlp {
+viewDatabaseResultPlp()
+viewMultiplePlp()
+viewPlp()
+viewPlps()
}
class uploadToDatabasePerformance {
+addAttrition()
+addCalibrationSummary()
+addCovariateSummary()
+addDemographicSummary()
+addEvaluation()
+addEvaluationStatistics()
+addPerformance()
+addPredictionDistribution()
+addThresholdSummary()
+checkResultExists()
+getColumnNames()
+insertPerformanceInDatabase()
}
class uploadToDatabase {
+addCohort()
+addDatabase()
+addModel()
+addMultipleRunPlpToDatabase()
+addRunPlpToDatabase()
+checkJson()
+checkTable()
+cleanNum()
+createDatabaseList()
+createDatabaseSchemaSettings()
+createPlpResultTables()
+deleteTables()
+enc()
+getCohortDef()
+getPlpResultTables()
+getResultLocations()
+insertModelInDatabase()
+insertResultsToSqlite()
+insertRunPlpToSqlite()
}
class uploadToDatabaseDiagnostics {
+addDiagnosePlpToDatabase()
+addDiagnostic()
+addMultipleDiagnosePlpToDatabase()
+addResultTable()
+insertDiagnosisToDatabase()
}
class uploadToDatabaseModelDesign {
+addCovariateSetting()
+addFESetting()
+addModelDesign()
+addModelSetting()
+addPlpDataSetting()
+addPopulationSetting()
+addSampleSetting()
+addSplitSettings()
+addTar()
+addTidySetting()
+insertModelDesignInDatabase()
+insertModelDesignSettings()
+orderJson()
}
class ThresholdSummary {
+accuracy()
+checkToByTwoTableInputs()
+diagnosticOddsRatio()
+f1Score()
+falseDiscoveryRate()
+falseNegativeRate()
+falseOmissionRate()
+falsePositiveRate()
+getThresholdSummary()
+getThresholdSummary_binary()
+getThresholdSummary_survival()
+negativeLikelihoodRatio()
+negativePredictiveValue()
+positiveLikelihoodRatio()
+positivePredictiveValue()
+sensitivity()
+specificity()
+stdca()
}
class PredictionDistribution {
+getPredictionDistribution()
+getPredictionDistribution_binary()
+getPredictionDistribution_survival()
}
class Plotting {
+outcomeSurvivalPlot()
+plotDemographicSummary()
+plotF1Measure()
+plotGeneralizability()
+plotPlp()
+plotPrecisionRecall()
+plotPredictedPDF()
+plotPredictionDistribution()
+plotPreferencePDF()
+plotSmoothCalibration()
+plotSmoothCalibrationLoess()
+plotSmoothCalibrationRcs()
+plotSparseCalibration()
+plotSparseCalibration2()
+plotSparseRoc()
+plotVariableScatterplot()
}
class CovariateSummary {
+aggregateCovariateSummaries()
+covariateSummary()
+covariateSummarySubset()
+createCovariateSubsets()
+getCovariatesForGroup()
}
class EvaluatePlp {
+evaluatePlp()
+modelBasedConcordance()
}
class EvaluationSummary {
+aucWithCi()
+aucWithoutCi()
+averagePrecision()
+brierScore()
+calculateEStatisticsBinary()
+calibrationInLarge()
+calibrationInLargeIntercept()
+calibrationLine()
+calibrationWeak()
+computeAuc()
+getEvaluationStatistics()
+getEvaluationStatistics_binary()
+getEvaluationStatistics_survival()
+ici()
}
class DemographicSummary {
+getDemographicSummary()
+getDemographicSummary_binary()
+getDemographicSummary_survival()
}
class CalibrationSummary {
+getCalibrationSummary()
+getCalibrationSummary_binary()
+getCalibrationSummary_survival()
}
class ImportFromCsv {
+extractCohortDefinitionsCSV()
+extractDatabaseListCSV()
+extractDiagnosticFromCsv()
+extractObjectFromCsv()
+getModelDesignCsv()
+getModelDesignSettingTable()
+getPerformanceEvaluationCsv()
+getTableNamesPlp()
+insertCsvToDatabase()
}
Prediction: classDiagram
class Prediction
Prediction --> ExternalValidatePlp
Prediction --> Recalibration
Prediction --> Predict
class ExternalValidatePlp {
+createValidationDesign()
+createValidationSettings()
+externalValidateDbPlp()
+externalValidatePlp()
+validateExternal()
+validateModel()
}
class Recalibration {
+inverseLog()
+logFunct()
+recalibratePlp()
+recalibratePlpRefit()
+recalibrationInTheLarge()
+weakRecalibration()
}
class Predict {
+applyFeatureengineering()
+applyTidyCovariateData()
+predictPlp()
}
|
For information. TidyModels uses Parsnip to provide model interfaces. They describe their design here: https://github.com/tidymodels/parsnip/tree/main/R#readme They seem to be using function calls although it is a bit complicated. |
A place to discuss the refactor of PLP and get an overview of the current and options for a prospective code base. Currently the project is file-organized and function-based. Below is a "class" diagram of all files and functions in the R folder.
Related resources:
Draft PR for new model API: #462
The text was updated successfully, but these errors were encountered: