diff --git a/streamingpro-core/src/main/java/streaming/dsl/load/batch/ModelExplain.scala b/streamingpro-core/src/main/java/streaming/dsl/load/batch/ModelExplain.scala index 042665605..627840a54 100644 --- a/streamingpro-core/src/main/java/streaming/dsl/load/batch/ModelExplain.scala +++ b/streamingpro-core/src/main/java/streaming/dsl/load/batch/ModelExplain.scala @@ -20,10 +20,9 @@ package streaming.dsl.load.batch import org.apache.spark.sql.types.{StringType, StructField, StructType} import org.apache.spark.sql.{DataFrame, Row, SparkSession} -import tech.mlsql.common.utils.reflect.ClassPath import tech.mlsql.dsl.adaptor.MLMapping +import tech.mlsql.tool.HDFSOperatorV2 -import scala.collection.JavaConversions._ /** * Created by allwefantasy on 21/9/2018. @@ -219,7 +218,32 @@ class ModelExplain(format: String, path: String, option: Map[String, String])(sp if ((format == "model" && path == "explain")) option("path") else { - path + option.get("index") match { + // If user specifies index, try to find _model_n subdirectory + case Some(idx) => + val subDirs = HDFSOperatorV2.listFiles(path) + .filter( _.isDirectory ) + val _modelDirExists = subDirs.exists( _.getPath.getName.startsWith("_model_")) + val _modelIdxDirExists = subDirs.exists( s"_model_${idx}" == _.getPath.getName) + val metaDirExists = subDirs.exists( "meta" == _.getPath.getName) + val modelDirExists = subDirs.exists( "model" == _.getPath.getName) + if( _modelDirExists && _modelIdxDirExists) { + s"${path}/_model_${idx}" + } + else if( _modelDirExists && ! _modelIdxDirExists ) { + throw new RuntimeException(s"model directory with index ${idx} does not exist") + } + else if( metaDirExists && modelDirExists ) { + // `keepVersion`="false", index option is ignored. + path + } + else { + throw new RuntimeException(s"${path}/_model_${idx} does not exist") + } + + // If keepVersion is not enabled + case None => path + } } } diff --git a/streamingpro-core/src/main/java/tech/mlsql/dsl/adaptor/LoadAdaptor.scala b/streamingpro-core/src/main/java/tech/mlsql/dsl/adaptor/LoadAdaptor.scala index eb235760e..1be5a5ae4 100644 --- a/streamingpro-core/src/main/java/tech/mlsql/dsl/adaptor/LoadAdaptor.scala +++ b/streamingpro-core/src/main/java/tech/mlsql/dsl/adaptor/LoadAdaptor.scala @@ -136,10 +136,16 @@ class LoadProcessing(scriptSQLExecListener: ScriptSQLExecListener, // return the load table table }.getOrElse { - // calculate resource real absolute path + // path could be: + // 1) fileSystem path; code example: load modelExplain.`/tmp/model` where alg="RandomForest" as output; + // 2) ET name; code example: load modelExample.`JsonExpandExt` AS output_1; load modelParams.`JsonExpandExt` as output; + // For FileSystem path, pass the real path to ModelSelfExplain; for ET name pass original path val resourcePath = resourceRealPath(scriptSQLExecListener, option.get("owner"), path) - - table = ModelSelfExplain(format, cleanStr(path), option, sparkSession).isMatch.thenDo.orElse(() => { + val fsPathOrETName = format match { + case "modelExplain" => resourcePath + case _ => cleanStr(path) + } + table = ModelSelfExplain(format, fsPathOrETName, option, sparkSession).isMatch.thenDo.orElse(() => { reader.format(format).load(resourcePath) }).get } diff --git a/streamingpro-it/src/test/resources/sql/simple/23_random_forest_model_explain.mlsql b/streamingpro-it/src/test/resources/sql/simple/23_random_forest_model_explain.mlsql new file mode 100644 index 000000000..67ff7508b --- /dev/null +++ b/streamingpro-it/src/test/resources/sql/simple/23_random_forest_model_explain.mlsql @@ -0,0 +1,43 @@ +--%comparator=tech.mlsql.it.IgnoreResultComparator + +set jsonStr=''' +{"features":[5.1,3.5,1.4,0.2],"label":0.0}, +{"features":[5.1,3.5,1.4,0.2],"label":1.0} +{"features":[5.1,3.5,1.4,0.2],"label":0.0} +{"features":[4.4,2.9,1.4,0.2],"label":0.0} +{"features":[5.1,3.5,1.4,0.2],"label":1.0} +{"features":[5.1,3.5,1.4,0.2],"label":0.0} +{"features":[5.1,3.5,1.4,0.2],"label":0.0} +{"features":[4.7,3.2,1.3,0.2],"label":1.0} +{"features":[5.1,3.5,1.4,0.2],"label":0.0} +{"features":[5.1,3.5,1.4,0.2],"label":0.0} +'''; +load jsonStr.`jsonStr` as mock_data; + + +select vec_dense(features) as features, label as label from mock_data as mock_data_1; + + +train mock_data_1 as RandomForest.`/tmp/model` where +keepVersion="true" +and evaluateTable="mock_data_validate" + +and `fitParam.0.labelCol`="label" +and `fitParam.0.featuresCol`="features" +and `fitParam.0.maxDepth`="2" + +and `fitParam.1.featuresCol`="features" +and `fitParam.1.labelCol`="label" +and `fitParam.1.maxDepth`="10" +; + +load modelExplain.`/tmp/model/_model_0/` where alg="RandomForest" as output_1; + +select `name` from output_1 where name="uid" as result; + +!assert result ''':name=="uid"''' "RandomForest modelExplain should be successful"; + + +load modelExplain.`/tmp/model/_model_0/` where alg="RandomForest" and index = "0" as output_2; +select `name` from output_2 where name="uid" as result_2; +!assert result_2 ''':name=="uid"''' "RandomForest modelExplain should be successful"; diff --git a/streamingpro-it/src/test/resources/sql/simple/andie_huang_02_train_linear_regression_model.mlsql b/streamingpro-it/src/test/resources/sql/simple/andie_huang_02_train_linear_regression_model.mlsql index 1381a94fb..a6cbb2907 100644 --- a/streamingpro-it/src/test/resources/sql/simple/andie_huang_02_train_linear_regression_model.mlsql +++ b/streamingpro-it/src/test/resources/sql/simple/andie_huang_02_train_linear_regression_model.mlsql @@ -48,4 +48,11 @@ select pred_func(features) as predict_label, label from data1 as output; select name,value from model_result where name="status" as result; -- make sure status of all models are success. -!assert result ''':value=="success"''' "all model status should be success"; \ No newline at end of file +!assert result ''':value=="success"''' "all model status should be success"; + + +load modelExplain.`/tmp/linearregression` WHERE `alg`="LinearRegression" as lr_model_explain; + +select `name` from lr_model_explain where name="uid" as result; + +!assert result ''':name=="uid"''' "LinearRegression modelExplain should be successful"; \ No newline at end of file