Skip to content

Commit

Permalink
propagate correctly (#162)
Browse files Browse the repository at this point in the history
  • Loading branch information
chaoqin-li1123 authored and nlu90 committed Sep 11, 2023
1 parent ed90ada commit c5b8604
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,13 @@ class DefaultPulsarClientFactory extends PulsarClientFactory {

object PulsarClientFactory {
val PulsarClientFactoryClassOption = "org.apache.spark.sql.pulsar.PulsarClientFactoryClass"
def getOrCreate(sparkConf: SparkConf, params: ju.Map[String, Object]): PulsarClientImpl = {
getFactory(sparkConf).getOrCreate(params)
def getOrCreate(pulsarClientFactoryClassName: Option[String],
params: ju.Map[String, Object]): PulsarClientImpl = {
getFactory(pulsarClientFactoryClassName).getOrCreate(params)
}

private def getFactory(sparkConf: SparkConf): PulsarClientFactory = {
sparkConf.getOption(PulsarClientFactoryClassOption) match {
private def getFactory(pulsarClientFactoryClassName: Option[String]): PulsarClientFactory = {
pulsarClientFactoryClassName match {
case Some(factoryClassName) =>
Utils.classForName(factoryClassName).getConstructor()
.newInstance().asInstanceOf[PulsarClientFactory]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,8 @@ private[pulsar] case class PulsarHelper(
import scala.collection.JavaConverters._

protected var client: PulsarClientImpl =
PulsarClientFactory.getOrCreate(sparkContext.conf, clientConf)
PulsarClientFactory.getOrCreate(
sparkContext.conf.getOption(PulsarClientFactory.PulsarClientFactoryClassOption), clientConf)

private var topics: Seq[String] = _
private var topicPartitions: Seq[String] = _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ private[pulsar] object PulsarSinks extends Logging {

try {
PulsarClientFactory
.getOrCreate(SparkEnv.get.conf, clientConf)
.getOrCreate(None, clientConf)
.newProducer(schema)
.topic(topic)
.loadConf(producerConf)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,9 @@ private[pulsar] class PulsarSource(
pollTimeoutMs,
failOnDataLoss,
subscriptionNamePrefix,
jsonOptions)
jsonOptions,
sqlContext.sparkContext.conf
.getOption(PulsarClientFactory.PulsarClientFactoryClassOption))

logInfo(
"GetBatch generating RDD of offset range: " +
Expand Down
14 changes: 9 additions & 5 deletions src/main/scala/org/apache/spark/sql/pulsar/PulsarSourceRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ private[pulsar] abstract class PulsarSourceRDDBase(
pollTimeoutMs: Int,
failOnDataLoss: Boolean,
subscriptionNamePrefix: String,
jsonOptions: JSONOptionsInRead)
jsonOptions: JSONOptionsInRead,
pulsarClientFactoryClassName: Option[String])
extends RDD[InternalRow](sc, Nil) {

val reportDataLoss = reportDataLossFunc(failOnDataLoss)
Expand All @@ -59,7 +60,7 @@ private[pulsar] abstract class PulsarSourceRDDBase(
val schema: Schema[_] = SchemaUtils.getPSchema(schemaInfo.si)

lazy val reader = PulsarClientFactory
.getOrCreate(SparkEnv.get.conf, clientConf)
.getOrCreate(pulsarClientFactoryClassName, clientConf)
.newReader(schema)
.subscriptionRolePrefix(subscriptionNamePrefix)
.topic(topic)
Expand Down Expand Up @@ -170,7 +171,8 @@ private[pulsar] class PulsarSourceRDD(
pollTimeoutMs: Int,
failOnDataLoss: Boolean,
subscriptionNamePrefix: String,
jsonOptions: JSONOptionsInRead)
jsonOptions: JSONOptionsInRead,
pulsarClientFactoryClassName: Option[String])
extends PulsarSourceRDDBase(
sc,
schemaInfo,
Expand All @@ -180,7 +182,8 @@ private[pulsar] class PulsarSourceRDD(
pollTimeoutMs,
failOnDataLoss,
subscriptionNamePrefix,
jsonOptions) {
jsonOptions,
pulsarClientFactoryClassName) {

override def getPreferredLocations(split: Partition): Seq[String] = {
val part = split.asInstanceOf[PulsarSourceRDDPartition]
Expand Down Expand Up @@ -221,7 +224,8 @@ private[pulsar] class PulsarSourceRDD4Batch(
pollTimeoutMs,
failOnDataLoss,
subscriptionNamePrefix,
jsonOptions) {
jsonOptions,
None) {

override def compute(split: Partition, context: TaskContext): Iterator[InternalRow] = {

Expand Down

0 comments on commit c5b8604

Please sign in to comment.