diff --git a/streamingpro-core/src/main/java/streaming/dsl/load/batch/ModelExplain.scala b/streamingpro-core/src/main/java/streaming/dsl/load/batch/ModelExplain.scala index 1a5862837..627840a54 100644 --- a/streamingpro-core/src/main/java/streaming/dsl/load/batch/ModelExplain.scala +++ b/streamingpro-core/src/main/java/streaming/dsl/load/batch/ModelExplain.scala @@ -21,6 +21,7 @@ package streaming.dsl.load.batch import org.apache.spark.sql.types.{StringType, StructField, StructType} import org.apache.spark.sql.{DataFrame, Row, SparkSession} import tech.mlsql.dsl.adaptor.MLMapping +import tech.mlsql.tool.HDFSOperatorV2 /** @@ -217,7 +218,32 @@ class ModelExplain(format: String, path: String, option: Map[String, String])(sp if ((format == "model" && path == "explain")) option("path") else { - path + option.get("index") match { + // If user specifies index, try to find _model_n subdirectory + case Some(idx) => + val subDirs = HDFSOperatorV2.listFiles(path) + .filter( _.isDirectory ) + val _modelDirExists = subDirs.exists( _.getPath.getName.startsWith("_model_")) + val _modelIdxDirExists = subDirs.exists( s"_model_${idx}" == _.getPath.getName) + val metaDirExists = subDirs.exists( "meta" == _.getPath.getName) + val modelDirExists = subDirs.exists( "model" == _.getPath.getName) + if( _modelDirExists && _modelIdxDirExists) { + s"${path}/_model_${idx}" + } + else if( _modelDirExists && ! _modelIdxDirExists ) { + throw new RuntimeException(s"model directory with index ${idx} does not exist") + } + else if( metaDirExists && modelDirExists ) { + // `keepVersion`="false", index option is ignored. + path + } + else { + throw new RuntimeException(s"${path}/_model_${idx} does not exist") + } + + // If keepVersion is not enabled + case None => path + } } } diff --git a/streamingpro-it/src/test/resources/sql/simple/23_random_forest_model_explain.mlsql b/streamingpro-it/src/test/resources/sql/simple/23_random_forest_model_explain.mlsql index d7dc057ac..67ff7508b 100644 --- a/streamingpro-it/src/test/resources/sql/simple/23_random_forest_model_explain.mlsql +++ b/streamingpro-it/src/test/resources/sql/simple/23_random_forest_model_explain.mlsql @@ -35,4 +35,9 @@ load modelExplain.`/tmp/model/_model_0/` where alg="RandomForest" as output_1; select `name` from output_1 where name="uid" as result; -!assert result ''':name=="uid"''' "RandomForest modelExplain should be successful"; \ No newline at end of file +!assert result ''':name=="uid"''' "RandomForest modelExplain should be successful"; + + +load modelExplain.`/tmp/model/_model_0/` where alg="RandomForest" and index = "0" as output_2; +select `name` from output_2 where name="uid" as result_2; +!assert result_2 ''':name=="uid"''' "RandomForest modelExplain should be successful";