Skip to content

Commit

Permalink
Merge pull request #1794 from chncaesar/issue1649
Browse files Browse the repository at this point in the history
Add pathPrefix for modelExplain
  • Loading branch information
chncaesar authored Jul 18, 2022
2 parents 445b081 + c2eed39 commit a08f56a
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,9 @@ package streaming.dsl.load.batch

import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import tech.mlsql.common.utils.reflect.ClassPath
import tech.mlsql.dsl.adaptor.MLMapping
import tech.mlsql.tool.HDFSOperatorV2

import scala.collection.JavaConversions._

/**
* Created by allwefantasy on 21/9/2018.
Expand Down Expand Up @@ -219,7 +218,32 @@ class ModelExplain(format: String, path: String, option: Map[String, String])(sp
if ((format == "model" && path == "explain"))
option("path")
else {
path
option.get("index") match {
// If user specifies index, try to find _model_n subdirectory
case Some(idx) =>
val subDirs = HDFSOperatorV2.listFiles(path)
.filter( _.isDirectory )
val _modelDirExists = subDirs.exists( _.getPath.getName.startsWith("_model_"))
val _modelIdxDirExists = subDirs.exists( s"_model_${idx}" == _.getPath.getName)
val metaDirExists = subDirs.exists( "meta" == _.getPath.getName)
val modelDirExists = subDirs.exists( "model" == _.getPath.getName)
if( _modelDirExists && _modelIdxDirExists) {
s"${path}/_model_${idx}"
}
else if( _modelDirExists && ! _modelIdxDirExists ) {
throw new RuntimeException(s"model directory with index ${idx} does not exist")
}
else if( metaDirExists && modelDirExists ) {
// `keepVersion`="false", index option is ignored.
path
}
else {
throw new RuntimeException(s"${path}/_model_${idx} does not exist")
}

// If keepVersion is not enabled
case None => path
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,16 @@ class LoadProcessing(scriptSQLExecListener: ScriptSQLExecListener,
// return the load table
table
}.getOrElse {
// calculate resource real absolute path
// path could be:
// 1) fileSystem path; code example: load modelExplain.`/tmp/model` where alg="RandomForest" as output;
// 2) ET name; code example: load modelExample.`JsonExpandExt` AS output_1; load modelParams.`JsonExpandExt` as output;
// For FileSystem path, pass the real path to ModelSelfExplain; for ET name pass original path
val resourcePath = resourceRealPath(scriptSQLExecListener, option.get("owner"), path)

table = ModelSelfExplain(format, cleanStr(path), option, sparkSession).isMatch.thenDo.orElse(() => {
val fsPathOrETName = format match {
case "modelExplain" => resourcePath
case _ => cleanStr(path)
}
table = ModelSelfExplain(format, fsPathOrETName, option, sparkSession).isMatch.thenDo.orElse(() => {
reader.format(format).load(resourcePath)
}).get
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
--%comparator=tech.mlsql.it.IgnoreResultComparator

set jsonStr='''
{"features":[5.1,3.5,1.4,0.2],"label":0.0},
{"features":[5.1,3.5,1.4,0.2],"label":1.0}
{"features":[5.1,3.5,1.4,0.2],"label":0.0}
{"features":[4.4,2.9,1.4,0.2],"label":0.0}
{"features":[5.1,3.5,1.4,0.2],"label":1.0}
{"features":[5.1,3.5,1.4,0.2],"label":0.0}
{"features":[5.1,3.5,1.4,0.2],"label":0.0}
{"features":[4.7,3.2,1.3,0.2],"label":1.0}
{"features":[5.1,3.5,1.4,0.2],"label":0.0}
{"features":[5.1,3.5,1.4,0.2],"label":0.0}
''';
load jsonStr.`jsonStr` as mock_data;


select vec_dense(features) as features, label as label from mock_data as mock_data_1;


train mock_data_1 as RandomForest.`/tmp/model` where
keepVersion="true"
and evaluateTable="mock_data_validate"

and `fitParam.0.labelCol`="label"
and `fitParam.0.featuresCol`="features"
and `fitParam.0.maxDepth`="2"

and `fitParam.1.featuresCol`="features"
and `fitParam.1.labelCol`="label"
and `fitParam.1.maxDepth`="10"
;

load modelExplain.`/tmp/model/_model_0/` where alg="RandomForest" as output_1;

select `name` from output_1 where name="uid" as result;

!assert result ''':name=="uid"''' "RandomForest modelExplain should be successful";


load modelExplain.`/tmp/model/_model_0/` where alg="RandomForest" and index = "0" as output_2;
select `name` from output_2 where name="uid" as result_2;
!assert result_2 ''':name=="uid"''' "RandomForest modelExplain should be successful";
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,11 @@ select pred_func(features) as predict_label, label from data1 as output;

select name,value from model_result where name="status" as result;
-- make sure status of all models are success.
!assert result ''':value=="success"''' "all model status should be success";
!assert result ''':value=="success"''' "all model status should be success";


load modelExplain.`/tmp/linearregression` WHERE `alg`="LinearRegression" as lr_model_explain;

select `name` from lr_model_explain where name="uid" as result;

!assert result ''':name=="uid"''' "LinearRegression modelExplain should be successful";

0 comments on commit a08f56a

Please sign in to comment.