Skip to content

Commit

Permalink
Support multi-directory modelExplain
Browse files Browse the repository at this point in the history
  • Loading branch information
chncaesar committed Jul 16, 2022
1 parent 618788c commit c2eed39
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package streaming.dsl.load.batch
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import tech.mlsql.dsl.adaptor.MLMapping
import tech.mlsql.tool.HDFSOperatorV2


/**
Expand Down Expand Up @@ -217,7 +218,32 @@ class ModelExplain(format: String, path: String, option: Map[String, String])(sp
if ((format == "model" && path == "explain"))
option("path")
else {
path
option.get("index") match {
// If user specifies index, try to find _model_n subdirectory
case Some(idx) =>
val subDirs = HDFSOperatorV2.listFiles(path)
.filter( _.isDirectory )
val _modelDirExists = subDirs.exists( _.getPath.getName.startsWith("_model_"))
val _modelIdxDirExists = subDirs.exists( s"_model_${idx}" == _.getPath.getName)
val metaDirExists = subDirs.exists( "meta" == _.getPath.getName)
val modelDirExists = subDirs.exists( "model" == _.getPath.getName)
if( _modelDirExists && _modelIdxDirExists) {
s"${path}/_model_${idx}"
}
else if( _modelDirExists && ! _modelIdxDirExists ) {
throw new RuntimeException(s"model directory with index ${idx} does not exist")
}
else if( metaDirExists && modelDirExists ) {
// `keepVersion`="false", index option is ignored.
path
}
else {
throw new RuntimeException(s"${path}/_model_${idx} does not exist")
}

// If keepVersion is not enabled
case None => path
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,9 @@ load modelExplain.`/tmp/model/_model_0/` where alg="RandomForest" as output_1;

select `name` from output_1 where name="uid" as result;

!assert result ''':name=="uid"''' "RandomForest modelExplain should be successful";
!assert result ''':name=="uid"''' "RandomForest modelExplain should be successful";


load modelExplain.`/tmp/model/_model_0/` where alg="RandomForest" and index = "0" as output_2;
select `name` from output_2 where name="uid" as result_2;
!assert result_2 ''':name=="uid"''' "RandomForest modelExplain should be successful";

0 comments on commit c2eed39

Please sign in to comment.