Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ClassTags to types that depend on Spark's Serializer. #334

Open
wants to merge 14 commits into
base: spark-1.0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions project/SharkBuild.scala
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,12 @@ object SharkBuild extends Build {
// See https://code.google.com/p/guava-libraries/issues/detail?id=1095
"com.google.code.findbugs" % "jsr305" % "1.3.+",

// sbt fails down download the javax.servlet artifacts from jetty 8.1:
// http://mvnrepository.com/artifact/org.eclipse.jetty.orbit/javax.servlet/3.0.0.v201112011016
// which may be due to the use of the orbit extension. So, we manually include servlet api
// from a separate source.
"org.mortbay.jetty" % "servlet-api" % "3.0.20100224",

// Hive unit test requirements. These are used by Hadoop to run the tests, but not necessary
// in usual Shark runs.
"commons-io" % "commons-io" % "2.1",
Expand Down
4 changes: 3 additions & 1 deletion src/main/scala/shark/SharkCliDriver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,9 @@ object SharkCliDriver {
val cli = new SharkCliDriver(reloadRdds)
cli.setHiveVariables(oproc.getHiveVariables())

SharkEnv.fixUncompatibleConf(conf)
if (!ss.isRemoteMode) {
SharkEnv.fixUncompatibleConf(conf)
}

// Execute -i init files (always in silent mode)
cli.processInitFiles(ss)
Expand Down
6 changes: 3 additions & 3 deletions src/main/scala/shark/execution/LateralViewJoinOperator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.nio.ByteBuffer

import scala.collection.mutable.ArrayBuffer
import scala.collection.JavaConversions._
import scala.reflect.BeanProperty
import scala.reflect.{BeanProperty, ClassTag}

import org.apache.commons.codec.binary.Base64
import org.apache.hadoop.hive.ql.exec.{ExprNodeEvaluator, ExprNodeEvaluatorFactory}
Expand Down Expand Up @@ -174,12 +174,12 @@ object KryoSerializerToString {

@transient val kryoSer = new SparkKryoSerializer(SparkEnv.get.conf)

def serialize[T](o: T): String = {
def serialize[T: ClassTag](o: T): String = {
val bytes = kryoSer.newInstance().serialize(o).array()
new String(Base64.encodeBase64(bytes))
}

def deserialize[T](byteString: String): T = {
def deserialize[T: ClassTag](byteString: String): T = {
val bytes = Base64.decodeBase64(byteString.getBytes())
kryoSer.newInstance().deserialize[T](ByteBuffer.wrap(bytes))
}
Expand Down
9 changes: 6 additions & 3 deletions src/main/scala/shark/execution/MapSplitPruning.scala
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ object MapSplitPruning {
true
}

case _: GenericUDFIn =>
case _: GenericUDFIn if e.children(0).isInstanceOf[ExprNodeColumnEvaluator] =>
testInPredicate(
s,
e.children(0).asInstanceOf[ExprNodeColumnEvaluator],
Expand All @@ -91,10 +91,13 @@ object MapSplitPruning {
val columnStats = s.stats(field.fieldID)

if (columnStats != null) {
expEvals.exists {
e =>
expEvals.exists { e =>
if (e.isInstanceOf[ExprNodeConstantEvaluator]) {
val constEval = e.asInstanceOf[ExprNodeConstantEvaluator]
columnStats := constEval.expr.getValue()
} else {
true
}
}
} else {
// If there is no stats on the column, don't prune.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,20 @@ package shark.execution.serialization

import java.nio.ByteBuffer

import scala.reflect.ClassTag

import org.apache.spark.SparkEnv
import org.apache.spark.serializer.{JavaSerializer => SparkJavaSerializer}


object JavaSerializer {
@transient val ser = new SparkJavaSerializer(SparkEnv.get.conf)

def serialize[T](o: T): Array[Byte] = {
def serialize[T: ClassTag](o: T): Array[Byte] = {
ser.newInstance().serialize(o).array()
}

def deserialize[T](bytes: Array[Byte]): T = {
def deserialize[T: ClassTag](bytes: Array[Byte]): T = {
ser.newInstance().deserialize[T](ByteBuffer.wrap(bytes))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@

package shark.execution.serialization

import scala.reflect.ClassTag

/**
* A wrapper around some unserializable objects that make them both Java
* serializable. Internally, Kryo is used for serialization.
*
* Use KryoSerializationWrapper(value) to create a wrapper.
*/
class KryoSerializationWrapper[T] extends Serializable {
class KryoSerializationWrapper[T: ClassTag] extends Serializable {

@transient var value: T = _

Expand Down Expand Up @@ -54,7 +56,7 @@ class KryoSerializationWrapper[T] extends Serializable {


object KryoSerializationWrapper {
def apply[T](value: T): KryoSerializationWrapper[T] = {
def apply[T: ClassTag](value: T): KryoSerializationWrapper[T] = {
val wrapper = new KryoSerializationWrapper[T]
wrapper.value = value
wrapper
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ package shark.execution.serialization

import java.nio.ByteBuffer

import scala.reflect.ClassTag

import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.serializer.{KryoSerializer => SparkKryoSerializer}

Expand All @@ -36,11 +38,11 @@ object KryoSerializer {
new SparkKryoSerializer(sparkConf)
}

def serialize[T](o: T): Array[Byte] = {
def serialize[T: ClassTag](o: T): Array[Byte] = {
ser.newInstance().serialize(o).array()
}

def deserialize[T](bytes: Array[Byte]): T = {
def deserialize[T: ClassTag](bytes: Array[Byte]): T = {
ser.newInstance().deserialize[T](ByteBuffer.wrap(bytes))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package shark.execution.serialization
import java.io.{InputStream, OutputStream}
import java.nio.ByteBuffer

import scala.reflect.ClassTag

import org.apache.hadoop.io.BytesWritable

import org.apache.spark.SparkConf
Expand Down Expand Up @@ -60,11 +62,11 @@ class ShuffleSerializer(conf: SparkConf) extends Serializer with Serializable {

class ShuffleSerializerInstance extends SerializerInstance with Serializable {

override def serialize[T](t: T): ByteBuffer = throw new UnsupportedOperationException
override def serialize[T: ClassTag](t: T): ByteBuffer = throw new UnsupportedOperationException

override def deserialize[T](bytes: ByteBuffer): T = throw new UnsupportedOperationException
override def deserialize[T: ClassTag](bytes: ByteBuffer): T = throw new UnsupportedOperationException

override def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T =
override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T =
throw new UnsupportedOperationException

override def serializeStream(s: OutputStream): SerializationStream = {
Expand All @@ -79,7 +81,7 @@ class ShuffleSerializerInstance extends SerializerInstance with Serializable {

class ShuffleSerializationStream(stream: OutputStream) extends SerializationStream with Serializable {

override def writeObject[T](t: T): SerializationStream = {
override def writeObject[T: ClassTag](t: T): SerializationStream = {
// On the write-side, the ReduceKey should be of type ReduceKeyMapSide.
val (key, value) = t.asInstanceOf[(ReduceKey, BytesWritable)]
writeUnsignedVarInt(key.length)
Expand Down Expand Up @@ -110,7 +112,7 @@ class ShuffleSerializationStream(stream: OutputStream) extends SerializationStre

class ShuffleDeserializationStream(stream: InputStream) extends DeserializationStream with Serializable {

override def readObject[T](): T = {
override def readObject[T: ClassTag](): T = {
// Return type is (ReduceKeyReduceSide, Array[Byte])
val keyLen = readUnsignedVarInt()
if (keyLen < 0) {
Expand Down
1 change: 1 addition & 0 deletions src/main/scala/shark/memstore2/TableRecovery.scala
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ object TableRecovery extends LogHelper {
logInfo(logMessage)
}
val cmd = QueryRewriteUtils.cacheToAlterTable("CACHE %s".format(tableName))
cmdRunner(s"use $databaseName")
cmdRunner(cmd)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class RLDecoder[V](buffer: ByteBuffer, columnType: ColumnType[_, V]) extends Ite
private var _count: Int = 0
private val _current: V = columnType.newWritable()

override def hasNext = buffer.hasRemaining()
override def hasNext = _count < _run || buffer.hasRemaining()

override def next(): V = {
if (_count == _run) {
Expand Down
60 changes: 60 additions & 0 deletions src/main/scala/shark/optimizer/SharkMapJoinProcessor.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright (C) 2012 The Regents of The University California.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package shark.optimizer

import java.util.{LinkedHashMap => JavaLinkedHashMap}

import org.apache.hadoop.hive.ql.exec.{MapJoinOperator, JoinOperator, Operator}
import org.apache.hadoop.hive.ql.optimizer.MapJoinProcessor
import org.apache.hadoop.hive.ql.parse.{ParseContext, QBJoinTree, OpParseContext}
import org.apache.hadoop.hive.ql.plan.OperatorDesc
import org.apache.hadoop.hive.conf.HiveConf

class SharkMapJoinProcessor extends MapJoinProcessor {

/**
* Override generateMapJoinOperator to bypass the step of validating Map Join hints int Hive.
*/
override def generateMapJoinOperator(
pctx: ParseContext,
op: JoinOperator,
joinTree: QBJoinTree,
mapJoinPos: Int): MapJoinOperator = {
val hiveConf: HiveConf = pctx.getConf
val noCheckOuterJoin: Boolean =
HiveConf.getBoolVar(hiveConf, HiveConf.ConfVars.HIVEOPTSORTMERGEBUCKETMAPJOIN) &&
HiveConf.getBoolVar(hiveConf, HiveConf.ConfVars.HIVEOPTBUCKETMAPJOIN)

val opParseCtxMap: JavaLinkedHashMap[Operator[_ <: OperatorDesc], OpParseContext] =
pctx.getOpParseCtx

// Explicitly set validateMapJoinTree to false to bypass the step of validating
// Map Join hints in Hive.
val validateMapJoinTree = false
val mapJoinOp: MapJoinOperator =
MapJoinProcessor.convertMapJoin(
opParseCtxMap, op, joinTree, mapJoinPos, noCheckOuterJoin, validateMapJoinTree)

// Hive originally uses genSelectPlan to insert an dummy select after the MapJoinOperator.
// We should not need this step.
// create a dummy select to select all columns
// MapJoinProcessor.genSelectPlan(pctx, mapJoinOp)

return mapJoinOp
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (C) 2012 The Regents of The University California.
* Copyright (C) 2012 The Regents of The University California.
* All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -15,14 +15,15 @@
* limitations under the License.
*/

package shark
package shark.optimizer

import java.util.{List => JavaList}

import org.apache.hadoop.hive.ql.optimizer.JoinReorder
import org.apache.hadoop.hive.ql.optimizer.{Optimizer => HiveOptimizer,
SimpleFetchOptimizer, Transform}
import org.apache.hadoop.hive.ql.parse.{ParseContext}
SimpleFetchOptimizer, Transform, MapJoinProcessor => HiveMapJoinProcessor}
import org.apache.hadoop.hive.ql.parse.ParseContext
import shark.LogHelper

class SharkOptimizer extends HiveOptimizer with LogHelper {

Expand All @@ -49,6 +50,13 @@ class SharkOptimizer extends HiveOptimizer with LogHelper {
transformation match {
case _: SimpleFetchOptimizer => {}
case _: JoinReorder => {}
case _: HiveMapJoinProcessor => {
// Use SharkMapJoinProcessor to bypass the step of validating Map Join hints
// in Hive. So, we can use hints to mark tables that will be considered as small
// tables (like Hive 0.9).
val sharkMapJoinProcessor = new SharkMapJoinProcessor
pctx = sharkMapJoinProcessor.transform(pctx)
}
case _ => {
pctx = transformation.transform(pctx)
}
Expand Down
3 changes: 2 additions & 1 deletion src/main/scala/shark/parse/SharkSemanticAnalyzer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,12 @@ import org.apache.hadoop.hive.ql.parse._
import org.apache.hadoop.hive.ql.plan._
import org.apache.hadoop.hive.ql.session.SessionState

import shark.{LogHelper, SharkConfVars, SharkOptimizer}
import shark.{LogHelper, SharkConfVars}
import shark.execution.{HiveDesc, Operator, OperatorFactory, ReduceSinkOperator}
import shark.execution.{SharkDDLWork, SparkLoadWork, SparkWork, TerminalOperator}
import shark.memstore2.{CacheType, LazySimpleSerDeWrapper, MemoryMetadataManager}
import shark.memstore2.SharkTblProperties
import shark.optimizer.SharkOptimizer


/**
Expand Down
5 changes: 5 additions & 0 deletions src/test/scala/shark/SQLSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -718,6 +718,11 @@ class SQLSuite extends FunSuite {
where year(from_unixtime(k)) between "2013" and "2014" """, Array[String]("0"))
}

test("map pruning with functions in in clause") {
expectSql("""select count(*) from mapsplitfunc_cached
where year(from_unixtime(k)) in ("2013", concat("201", "4")) """, Array[String]("0"))
}

//////////////////////////////////////////////////////////////////////////////
// SharkContext APIs (e.g. sql2rdd, sql)
//////////////////////////////////////////////////////////////////////////////
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class CompressedColumnIteratorSuite extends FunSuite {
}

l.foreach { x =>
assert(iter.hasNext)
iter.next()
assert(t.get(iter.current, oi) === x)
}
Expand Down