diff --git a/src/main/scala/org/apache/spark/sql/hbase/HBaseRelation.scala b/src/main/scala/org/apache/spark/sql/hbase/HBaseRelation.scala index b8fee42..c1363ed 100644 --- a/src/main/scala/org/apache/spark/sql/hbase/HBaseRelation.scala +++ b/src/main/scala/org/apache/spark/sql/hbase/HBaseRelation.scala @@ -815,11 +815,8 @@ private[hbase] case class HBaseRelation( def buildRow(projections: Seq[(Attribute, Int)], result: Result, - buffer: ListBuffer[HBaseRawType], - aBuffer: ArrayBuffer[Byte], row: MutableRow): Row = { assert(projections.size == row.length, "Projection size and row size mismatched") - // TODO: replaced with the new Key method val rowKeys = HBaseKVHelper.decodingRawKeyColumns(result.getRow, keyColumns) projections.foreach { p => diff --git a/src/main/scala/org/apache/spark/sql/hbase/HBaseSQLReaderRDD.scala b/src/main/scala/org/apache/spark/sql/hbase/HBaseSQLReaderRDD.scala index 7d59ab3..896b65b 100755 --- a/src/main/scala/org/apache/spark/sql/hbase/HBaseSQLReaderRDD.scala +++ b/src/main/scala/org/apache/spark/sql/hbase/HBaseSQLReaderRDD.scala @@ -54,10 +54,8 @@ class HBaseSQLReaderRDD( } private def createIterator(context: TaskContext, - scanner: ResultScanner, otherFilters: Option[Expression]): Iterator[Row] = { - val lBuffer = ListBuffer[HBaseRawType]() - val aBuffer = ArrayBuffer[Byte]() - + scanner: ResultScanner, + otherFilters: Option[Expression]): Iterator[Row] = { var finalOutput = output.distinct if (otherFilters.isDefined) { finalOutput = finalOutput.union(otherFilters.get.references.toSeq) @@ -95,7 +93,7 @@ class HBaseSQLReaderRDD( override def next(): Row = { if (hasNext) { gotNext = false - relation.buildRow(projections, result, lBuffer, aBuffer, row) + relation.buildRow(projections, result, row) } else { null } @@ -182,33 +180,33 @@ class HBaseSQLReaderRDD( case nkc => distinctProjectionList.exists(nkc.sqlName == _.name) } - var resultRows: Iterator[Row] = null - - for (range <- expandedCPRs) { + def generateGet(range: MDCriticalPointRange[_]): Get = { val rowKey = constructRowKey(range, isStart = true) val get = new Get(rowKey) for (nonKeyColumn <- nonKeyColumns) { get.addColumn(Bytes.toBytes(nonKeyColumn.family), Bytes.toBytes(nonKeyColumn.qualifier)) } + get + } + val predForEachRange: Seq[Expression] = expandedCPRs.map(range => { + gets.add(generateGet(range)) + range.lastRange.pred + }) + val resultsWithPred = relation.htable.get(gets).zip(predForEachRange).filter(!_._1.isEmpty) - gets.add(get) - val results = relation.htable.get(gets) - val predicate = range.lastRange.pred - - val lBuffer = ListBuffer[HBaseRawType]() - val aBuffer = ArrayBuffer[Byte]() - val row = new GenericMutableRow(output.size) - val projections = output.zipWithIndex - - resultRows = if (predicate == null) { - results.map(relation.buildRow(projections, _, lBuffer, aBuffer, row)).toIterator - } else { - val boundPredicate = BindReferences.bindReference(predicate, output) - results.map(relation.buildRow(projections, _, lBuffer, aBuffer, row)) - .filter(boundPredicate.eval(_).asInstanceOf[Boolean]).toIterator - } + def evalResultForBoundPredicate(input: Row, predicate: Expression): Boolean = { + val boundPredicate = BindReferences.bindReference(predicate, output) + boundPredicate.eval(input).asInstanceOf[Boolean] } - resultRows + val projections = output.zipWithIndex + val resultRows: Seq[Row] = for { + (result, predicate) <- resultsWithPred + row = new GenericMutableRow(output.size) + resultRow = relation.buildRow(projections, result, row) + if predicate == null || evalResultForBoundPredicate(resultRow, predicate) + } yield resultRow + + resultRows.toIterator } else { // isPointRanges is false diff --git a/src/test/scala/org/apache/spark/sql/hbase/HBaseAdvancedSQLQuerySuite.scala b/src/test/scala/org/apache/spark/sql/hbase/HBaseAdvancedSQLQuerySuite.scala index 3659d48..6b7ca53 100755 --- a/src/test/scala/org/apache/spark/sql/hbase/HBaseAdvancedSQLQuerySuite.scala +++ b/src/test/scala/org/apache/spark/sql/hbase/HBaseAdvancedSQLQuerySuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.hbase import org.apache.spark.sql.types._ import org.apache.spark.sql.{SQLConf, _} -class HBaseAdvancedSQLQuerySuite extends HBaseSplitedTestData { +class HBaseAdvancedSQLQuerySuite extends HBaseSplitTestData { import org.apache.spark.sql.hbase.TestHbase._ import org.apache.spark.sql.hbase.TestHbase.implicits._ diff --git a/src/test/scala/org/apache/spark/sql/hbase/HBaseBasicOperationSuite.scala b/src/test/scala/org/apache/spark/sql/hbase/HBaseBasicOperationSuite.scala index d56c6f7..b6e5eb9 100755 --- a/src/test/scala/org/apache/spark/sql/hbase/HBaseBasicOperationSuite.scala +++ b/src/test/scala/org/apache/spark/sql/hbase/HBaseBasicOperationSuite.scala @@ -21,7 +21,7 @@ package org.apache.spark.sql.hbase * Test insert / query against the table created by HBaseMainTest */ -class HBaseBasicOperationSuite extends HBaseSplitedTestData { +class HBaseBasicOperationSuite extends HBaseSplitTestData { import org.apache.spark.sql.hbase.TestHbase._ override def afterAll() = { @@ -43,22 +43,28 @@ class HBaseBasicOperationSuite extends HBaseSplitedTestData { column4=family2.qualifier2])""" ) - assert(sql( """SELECT * FROM tb0""").count() == 0) + assert(sql( """SELECT * FROM tb0""").collect().size == 0) sql( """INSERT INTO tb0 SELECT col4,col4,col6,col3 FROM ta""") - assert(sql( """SELECT * FROM tb0""").count() == 14) + assert(sql( """SELECT * FROM tb0""").collect().size == 14) sql( """DROP TABLE tb0""") } - test("Insert Into table 1") { + test("Insert and Query Single Row") { sql( """CREATE TABLE tb1 (column1 INTEGER, column2 STRING, - PRIMARY KEY(column2)) - MAPPED BY (ht1, COLS=[column1=cf.cq])""" + PRIMARY KEY(column1)) + MAPPED BY (ht1, COLS=[column2=cf.cq])""" ) - assert(sql( """SELECT * FROM tb1""").count() == 0) + assert(sql( """SELECT * FROM tb1""").collect().size == 0) sql( """INSERT INTO tb1 VALUES (1024, "abc")""") - assert(sql( """SELECT * FROM tb1""").count() == 1) + sql( """INSERT INTO tb1 VALUES (1028, "abd")""") + assert(sql( """SELECT * FROM tb1""").collect().size == 2) + assert( + sql( """SELECT * FROM tb1 WHERE (column1 = 1023 AND column2 ="abc")""").collect().size == 0) + assert(sql( + """SELECT * FROM tb1 WHERE (column1 = 1024) + |OR (column1 = 1028 AND column2 ="abd")""".stripMargin).collect().size == 2) sql( """DROP TABLE tb1""") } @@ -81,13 +87,13 @@ class HBaseBasicOperationSuite extends HBaseSplitedTestData { } test("Select test 1 (AND, OR)") { - assert(sql( """SELECT * FROM ta WHERE col7 = 255 OR col7 = 127""").count == 2) - assert(sql( """SELECT * FROM ta WHERE col7 < 0 AND col4 < -255""").count == 4) + assert(sql( """SELECT * FROM ta WHERE col7 = 255 OR col7 = 127""").collect().size == 2) + assert(sql( """SELECT * FROM ta WHERE col7 < 0 AND col4 < -255""").collect().size == 4) } test("Select test 2 (WHERE)") { assert(sql( """SELECT * FROM ta WHERE col7 > 128""").count() == 3) - assert(sql( """SELECT * FROM ta WHERE (col7 - 10 > 128) AND col1 = ' p255 '""").count() == 1) + assert(sql( """SELECT * FROM ta WHERE (col7 - 10 > 128) AND col1 = ' p255 '""").collect().size == 1) } test("Select test 3 (ORDER BY)") { @@ -100,10 +106,10 @@ class HBaseBasicOperationSuite extends HBaseSplitedTestData { } test("Select test 4 (join)") { - assert(sql( """SELECT ta.col2 FROM ta join tb on ta.col4=tb.col7""").count == 2) - assert(sql( """SELECT * FROM ta FULL OUTER JOIN tb WHERE tb.col7 = 1""").count == 14) - assert(sql( """SELECT * FROM ta LEFT JOIN tb WHERE tb.col7 = 1""").count == 14) - assert(sql( """SELECT * FROM ta RIGHT JOIN tb WHERE tb.col7 = 1""").count == 14) + assert(sql( """SELECT ta.col2 FROM ta join tb on ta.col4=tb.col7""").collect().size == 2) + assert(sql( """SELECT * FROM ta FULL OUTER JOIN tb WHERE tb.col7 = 1""").collect().size == 14) + assert(sql( """SELECT * FROM ta LEFT JOIN tb WHERE tb.col7 = 1""").collect().size == 14) + assert(sql( """SELECT * FROM ta RIGHT JOIN tb WHERE tb.col7 = 1""").collect().size == 14) } test("Alter Add column and Alter Drop column") { diff --git a/src/test/scala/org/apache/spark/sql/hbase/HBaseSQLQuerySuite.scala b/src/test/scala/org/apache/spark/sql/hbase/HBaseSQLQuerySuite.scala index ce57c8c..3960aa1 100644 --- a/src/test/scala/org/apache/spark/sql/hbase/HBaseSQLQuerySuite.scala +++ b/src/test/scala/org/apache/spark/sql/hbase/HBaseSQLQuerySuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.hbase.TestData._ import org.apache.spark.sql.types._ -class HBaseSQLQuerySuite extends HBaseSplitedTestData { +class HBaseSQLQuerySuite extends HBaseSplitTestData { // Make sure the tables are loaded. import org.apache.spark.sql.hbase.TestHbase._ import org.apache.spark.sql.hbase.TestHbase.implicits._ diff --git a/src/test/scala/org/apache/spark/sql/hbase/HBaseSplitedTestData.scala b/src/test/scala/org/apache/spark/sql/hbase/HBaseSplitTestData.scala similarity index 99% rename from src/test/scala/org/apache/spark/sql/hbase/HBaseSplitedTestData.scala rename to src/test/scala/org/apache/spark/sql/hbase/HBaseSplitTestData.scala index c0fcf8c..e88e9c4 100755 --- a/src/test/scala/org/apache/spark/sql/hbase/HBaseSplitedTestData.scala +++ b/src/test/scala/org/apache/spark/sql/hbase/HBaseSplitTestData.scala @@ -30,7 +30,7 @@ import org.apache.spark.sql.hbase.util.{DataTypeUtils, HBaseKVHelper, BytesUtils * HBaseMainTest * create HbTestTable and metadata table, and insert some data */ -class HBaseSplitedTestData extends HBaseIntegrationTestBase +class HBaseSplitTestData extends HBaseIntegrationTestBase { val TableName_a: String = "ta" val TableName_b: String = "tb"