From 23807efb6b00258e0901d5846b775cdc8e3a37fe Mon Sep 17 00:00:00 2001 From: WilliamZhu Date: Sat, 13 Apr 2019 21:56:37 +0800 Subject: [PATCH] fix sorl auth --- .../core/datasource/impl/MLSQLSolr.scala | 48 ++++++++++++++++++- .../test/datasource/MLSQLLoadStrSpec.scala | 1 + 2 files changed, 48 insertions(+), 1 deletion(-) diff --git a/streamingpro-mlsql/src/main/java/streaming/core/datasource/impl/MLSQLSolr.scala b/streamingpro-mlsql/src/main/java/streaming/core/datasource/impl/MLSQLSolr.scala index 74f48ee0b..a0bf22428 100644 --- a/streamingpro-mlsql/src/main/java/streaming/core/datasource/impl/MLSQLSolr.scala +++ b/streamingpro-mlsql/src/main/java/streaming/core/datasource/impl/MLSQLSolr.scala @@ -21,7 +21,7 @@ package streaming.core.datasource.impl import org.apache.spark.sql.{DataFrame, DataFrameReader, DataFrameWriter, Row} import streaming.core.datasource._ import streaming.dsl.mmlib.algs.param.{BaseParams, WowParams} -import streaming.dsl.ScriptSQLExec +import streaming.dsl.{ConnectMeta, DBMappingKey, ScriptSQLExec} class MLSQLSolr(override val uid: String) extends MLSQLBaseStreamSource with WowParams { @@ -69,6 +69,52 @@ class MLSQLSolr(override val uid: String) extends MLSQLBaseStreamSource with Wow } } + override def sourceInfo(config: DataAuthConfig): SourceInfo = { + + val Array(_dbname, _dbtable) = if (config.path.contains(dbSplitter)) { + config.path.split(dbSplitter, 2) + } else { + Array("", config.path) + } + + var table = _dbtable + var dbName = _dbname + + val newOptions = scala.collection.mutable.HashMap[String, String]() ++ config.config + ConnectMeta.options(DBMappingKey(shortFormat, _dbname)) match { + case Some(option) => + dbName = "" + newOptions ++= option + + table.split(dbSplitter) match { + case Array(_db, _table) => + dbName = _db + table = _table + case _ => + } + + case None => + //dbName = "" + } + + + newOptions.filter(f => f._1 == "collection").map { f => + if (f._2.contains(dbSplitter)) { + f._2.split(dbSplitter, 2) match { + case Array(_db, _table) => + dbName = _db + table = _table + case Array(_db) => + dbName = _db + } + } else { + dbName = f._2 + } + } + + SourceInfo(shortFormat, dbName, table) + } + override def register(): Unit = { DataSourceRegistry.register(MLSQLDataSourceKey(shortFormat, MLSQLSparkDataSourceType), this) DataSourceRegistry.register(MLSQLDataSourceKey(shortFormat, MLSQLSparkDataSourceType), this) diff --git a/streamingpro-mlsql/src/test/scala/streaming/test/datasource/MLSQLLoadStrSpec.scala b/streamingpro-mlsql/src/test/scala/streaming/test/datasource/MLSQLLoadStrSpec.scala index 6f6787775..19ff35c66 100644 --- a/streamingpro-mlsql/src/test/scala/streaming/test/datasource/MLSQLLoadStrSpec.scala +++ b/streamingpro-mlsql/src/test/scala/streaming/test/datasource/MLSQLLoadStrSpec.scala @@ -81,6 +81,7 @@ class MLSQLLoadStrSpec extends BasicSparkOperation with SpecFunctions with Basic withBatchContext(setupBatchContext(batchParams, "classpath:///test/empty.json")) { runtime: SparkRuntime => implicit val spark = runtime.sparkSession ShellCommand.exec("rm -rf /tmp/user/hive/warehouse/carbon_jack") + var ssel = createSSEL ScriptSQLExec.parse(