diff --git a/README.md b/README.md
index fd4b941f..d2f29020 100644
--- a/README.md
+++ b/README.md
@@ -325,9 +325,77 @@ You can use `org.apache.spark.sql.pulsar.JsonUtils.topicOffsets(Map[String, Mess
This may cause a false alarm. You can set it to `false` when it doesn't work as you expected.
A batch query always fails if it fails to read any data from the provided offsets due to data loss.
+
+
+
+`maxEntriesPerTrigger`
+ |
+
+Number of entries to include in a single micro-batch during
+streaming.
+ |
+-1 |
+Streaming query |
+This parameter controls how many Pulsar entries are read by
+the connector from the topic backlog at once. If the topic
+backlog is considerably high, users can use this parameter
+to limit the size of the micro-batch. If multiple topics are read,
+this parameter controls the complete number of entries fetched from
+all of them.
+
+*Note:* Entries might contain multiple messages. The default value of `-1` means that the
+complete backlog is read at once. |
+
+
+
+
+`forwardStrategy`
+ |
+
+`simple`, `large-first` or `proportional`
+ |
+`simple` |
+Streaming query |
+If `maxEntriesPerTrigger` is set, this parameter controls
+which forwarding strategy is in use during the read of multiple
+topics.
+
+`simple` just divides the allowed number of entries equally
+between all topics, regardless of their backlog size
+
+
+`large-first` will load the largest topic backlogs first,
+as the maximum number of allowed entries allows
+
+
+`proportional` will forward all topics proportional to the
+topic backlog/overall backlog ratio
+
+ |
+
+
+
+`ensureEntriesPerTopic`
+ |
+Number to forward each topic with during a micro-batch. |
+0 |
+Streaming query |
+If multiple topics are read, and the maximum number of
+entries is also specified, always forward all topics with the
+amount of entries specified here. Using this, users can ensure that topics
+with considerably smaller backlogs than others are also forwarded
+and read. Note that:
+If this number is higher than the maximum allowed entries divided
+by the number of topics, then this value is taken into account, overriding
+the maximum number of entries per micro-batch.
+
+This parameter has an effect only for forwarding strategies
+`large-first` and `proportional`.
+ |
+
#### Authentication
diff --git a/src/main/scala/org/apache/spark/sql/pulsar/PulsarMetadataReader.scala b/src/main/scala/org/apache/spark/sql/pulsar/PulsarMetadataReader.scala
index 91e70f6f..df0df3d5 100644
--- a/src/main/scala/org/apache/spark/sql/pulsar/PulsarMetadataReader.scala
+++ b/src/main/scala/org/apache/spark/sql/pulsar/PulsarMetadataReader.scala
@@ -15,18 +15,20 @@ package org.apache.spark.sql.pulsar
import java.{util => ju}
import java.io.Closeable
-import java.util.{Optional, UUID}
+import java.util.Optional
import java.util.concurrent.TimeUnit
import java.util.regex.Pattern
import org.apache.pulsar.client.admin.{PulsarAdmin, PulsarAdminException}
-import org.apache.pulsar.client.api.{Message, MessageId, PulsarClient, SubscriptionInitialPosition, SubscriptionType}
+import org.apache.pulsar.client.api.{Message, MessageId, PulsarClient}
import org.apache.pulsar.client.impl.schema.BytesSchema
+import org.apache.pulsar.client.internal.DefaultImplementation
import org.apache.pulsar.common.naming.TopicName
import org.apache.pulsar.common.schema.SchemaInfo
import org.apache.spark.internal.Logging
-import org.apache.spark.sql.pulsar.PulsarOptions.{AUTH_PARAMS, AUTH_PLUGIN_CLASS_NAME, TLS_ALLOW_INSECURE_CONNECTION, TLS_HOSTNAME_VERIFICATION_ENABLE, TLS_TRUST_CERTS_FILE_PATH, TOPIC_OPTION_KEYS}
+import org.apache.spark.sql.pulsar.PulsarOptions._
+import org.apache.spark.sql.pulsar.topicinternalstats.forward._
import org.apache.spark.sql.types.StructType
/**
@@ -205,6 +207,82 @@ private[pulsar] case class PulsarMetadataReader(
}.toMap)
}
+
+ def forwardOffset(actualOffset: Map[String, MessageId],
+ strategy: String,
+ numberOfEntriesToForward: Long,
+ ensureEntriesPerTopic: Long): SpecificPulsarOffset = {
+ getTopicPartitions()
+
+ // Collect internal stats for all topics
+ val topicStats = topicPartitions.map( topic => {
+ val internalStats = admin.topics().getInternalStats(topic)
+ val topicActualMessageId = actualOffset.getOrElse(topic, MessageId.earliest)
+ topic -> TopicState(internalStats,
+ PulsarSourceUtils.getLedgerId(topicActualMessageId),
+ PulsarSourceUtils.getEntryId(topicActualMessageId))
+ } ).toMap
+
+ val forwarder = strategy match {
+ case PulsarOptions.ProportionalForwardStrategy =>
+ new ProportionalForwardStrategy(numberOfEntriesToForward, ensureEntriesPerTopic)
+ case PulsarOptions.LargeFirstForwardStrategy =>
+ new LargeFirstForwardStrategy(numberOfEntriesToForward, ensureEntriesPerTopic)
+ case _ =>
+ new LinearForwardStrategy(numberOfEntriesToForward)
+ }
+
+ SpecificPulsarOffset(topicPartitions.map { topic =>
+ topic -> PulsarSourceUtils.seekableLatestMid {
+ // Fetch actual offset for topic
+ val topicActualMessageId = actualOffset.getOrElse(topic, MessageId.earliest)
+ try {
+ // Get the actual ledger
+ val actualLedgerId = PulsarSourceUtils.getLedgerId(topicActualMessageId)
+ // Get the actual entry ID
+ val actualEntryId = PulsarSourceUtils.getEntryId(topicActualMessageId)
+ // Get the partition index
+ val partitionIndex = PulsarSourceUtils.getPartitionIndex(topicActualMessageId)
+ // Cache topic internal stats
+ val internalStats = topicStats.get(topic).get.internalStat
+ // Calculate the amount of messages we will pull in
+ val numberOfEntriesPerTopic = forwarder.forward(topicStats)(topic)
+ // Get a future message ID which corresponds
+ // to the maximum number of messages
+ val (nextLedgerId, nextEntryId) = TopicInternalStatsUtils.forwardMessageId(
+ internalStats,
+ actualLedgerId,
+ actualEntryId,
+ numberOfEntriesPerTopic)
+ // Build a message id
+ val forwardedMessageId =
+ DefaultImplementation.newMessageId(nextLedgerId, nextEntryId, partitionIndex)
+ // Log state
+ val forwardedEntry = TopicInternalStatsUtils.numOfEntriesUntil(
+ internalStats, nextLedgerId, nextEntryId)
+ val entryCount = internalStats.numberOfEntries
+ val progress = f"${forwardedEntry.toFloat / entryCount.toFloat}%1.3f"
+ val logMessage = s"Pulsar Connector forward on topic. " +
+ s"[$numberOfEntriesPerTopic/$numberOfEntriesToForward]" +
+ s"${topic.reverse.take(30).reverse} $topicActualMessageId -> " +
+ s"$forwardedMessageId ($forwardedEntry/$entryCount) [$progress]"
+ log.debug(logMessage)
+ // Return the message ID
+ forwardedMessageId
+ } catch {
+ case e: PulsarAdminException if e.getStatusCode == 404 =>
+ MessageId.earliest
+ case e: Throwable =>
+ throw new RuntimeException(
+ s"Failed to get forwarded messageId for ${TopicName.get(topic).toString} " +
+ s"(tried to forward ${forwarder.forward(topicStats)(topic)} messages " +
+ s"starting from `$topicActualMessageId` using strategy $strategy)", e)
+ }
+
+ }
+ }.toMap)
+ }
+
def fetchLatestOffsetForTopic(topic: String): MessageId = {
PulsarSourceUtils.seekableLatestMid( try {
admin.topics().getLastMessageId(topic)
diff --git a/src/main/scala/org/apache/spark/sql/pulsar/PulsarOptions.scala b/src/main/scala/org/apache/spark/sql/pulsar/PulsarOptions.scala
index c2b56a7d..164816c6 100644
--- a/src/main/scala/org/apache/spark/sql/pulsar/PulsarOptions.scala
+++ b/src/main/scala/org/apache/spark/sql/pulsar/PulsarOptions.scala
@@ -31,6 +31,12 @@ private[pulsar] object PulsarOptions {
val TOPIC_MULTI = "topics"
val TOPIC_PATTERN = "topicspattern"
+ val MaxEntriesPerTrigger = "maxentriespertrigger"
+ val EnsureEntriesPerTopic = "ensureentriespertopic"
+ val ForwardStrategy = "forwardstrategy"
+ val ProportionalForwardStrategy = "proportional"
+ val LargeFirstForwardStrategy = "large-first"
+
val PARTITION_SUFFIX = TopicName.PARTITIONED_TOPIC_SUFFIX
val TOPIC_OPTION_KEYS = Set(
diff --git a/src/main/scala/org/apache/spark/sql/pulsar/PulsarProvider.scala b/src/main/scala/org/apache/spark/sql/pulsar/PulsarProvider.scala
index 0b5c84bf..e833f811 100644
--- a/src/main/scala/org/apache/spark/sql/pulsar/PulsarProvider.scala
+++ b/src/main/scala/org/apache/spark/sql/pulsar/PulsarProvider.scala
@@ -110,7 +110,10 @@ private[pulsar] class PulsarProvider
pollTimeoutMs(caseInsensitiveParams),
failOnDataLoss(caseInsensitiveParams),
subscriptionNamePrefix,
- jsonOptions
+ jsonOptions,
+ maxEntriesPerTrigger(caseInsensitiveParams),
+ minEntriesPerTopic(caseInsensitiveParams),
+ forwardStrategy(caseInsensitiveParams)
)
}
@@ -365,6 +368,15 @@ private[pulsar] object PulsarProvider extends Logging {
(SparkEnv.get.conf.getTimeAsSeconds("spark.network.timeout", "120s") * 1000).toString)
.toInt
+ private def maxEntriesPerTrigger(caseInsensitiveParams: Map[String, String]): Long =
+ caseInsensitiveParams.getOrElse(MaxEntriesPerTrigger, "-1").toLong
+
+ private def minEntriesPerTopic(caseInsensitiveParams: Map[String, String]): Long =
+ caseInsensitiveParams.getOrElse(EnsureEntriesPerTopic, "0").toLong
+
+ private def forwardStrategy(caseInsensitiveParams: Map[String, String]): String =
+ caseInsensitiveParams.getOrElse(ForwardStrategy, "simple")
+
private def validateGeneralOptions(
caseInsensitiveParams: Map[String, String]): Map[String, String] = {
if (!caseInsensitiveParams.contains(SERVICE_URL_OPTION_KEY)) {
diff --git a/src/main/scala/org/apache/spark/sql/pulsar/PulsarSource.scala b/src/main/scala/org/apache/spark/sql/pulsar/PulsarSource.scala
index cbc2544b..98378f69 100644
--- a/src/main/scala/org/apache/spark/sql/pulsar/PulsarSource.scala
+++ b/src/main/scala/org/apache/spark/sql/pulsar/PulsarSource.scala
@@ -36,7 +36,10 @@ private[pulsar] class PulsarSource(
pollTimeoutMs: Int,
failOnDataLoss: Boolean,
subscriptionNamePrefix: String,
- jsonOptions: JSONOptionsInRead)
+ jsonOptions: JSONOptionsInRead,
+ maxEntriesPerTrigger: Long,
+ ensureEntriesPerTopic: Long,
+ forwardStrategy: String)
extends Source
with Logging {
@@ -63,12 +66,21 @@ private[pulsar] class PulsarSource(
override def schema(): StructType = SchemaUtils.pulsarSourceSchema(pulsarSchema)
override def getOffset: Option[Offset] = {
- // Make sure initialTopicOffsets is initialized
initialTopicOffsets
- val latest = metadataReader.fetchLatestOffsets()
- currentTopicOffsets = Some(latest.topicOffsets)
- logDebug(s"GetOffset: ${latest.topicOffsets.toSeq.map(_.toString).sorted}")
- Some(latest.asInstanceOf[Offset])
+ val nextOffsets = if (maxEntriesPerTrigger == -1) {
+ metadataReader.fetchLatestOffsets()
+ } else {
+ currentTopicOffsets match {
+ case Some(value) =>
+ metadataReader.forwardOffset(value,
+ forwardStrategy, maxEntriesPerTrigger, ensureEntriesPerTopic)
+ case _ =>
+ metadataReader.forwardOffset(initialTopicOffsets.topicOffsets,
+ forwardStrategy, maxEntriesPerTrigger, ensureEntriesPerTopic)
+ }
+ }
+ logDebug(s"GetOffset: ${nextOffsets.topicOffsets.toSeq.map(_.toString).sorted}")
+ Some(nextOffsets.asInstanceOf[Offset])
}
override def getBatch(start: Option[Offset], end: Offset): DataFrame = {
@@ -78,9 +90,7 @@ private[pulsar] class PulsarSource(
logInfo(s"getBatch called with start = $start, end = $end")
val endTopicOffsets = SpecificPulsarOffset.getTopicOffsets(end)
- if (currentTopicOffsets.isEmpty) {
- currentTopicOffsets = Some(endTopicOffsets)
- }
+ currentTopicOffsets = Some(endTopicOffsets)
if (start.isDefined && start.get == end) {
return sqlContext.internalCreateDataFrame(
diff --git a/src/main/scala/org/apache/spark/sql/pulsar/PulsarSources.scala b/src/main/scala/org/apache/spark/sql/pulsar/PulsarSources.scala
index 12857b43..75bbb57d 100644
--- a/src/main/scala/org/apache/spark/sql/pulsar/PulsarSources.scala
+++ b/src/main/scala/org/apache/spark/sql/pulsar/PulsarSources.scala
@@ -120,6 +120,36 @@ private[pulsar] object PulsarSourceUtils extends Logging {
}
}
+ def getLedgerId(mid: MessageId): Long = {
+ mid match {
+ case bmid: BatchMessageIdImpl =>
+ bmid.getLedgerId
+ case midi: MessageIdImpl => midi.getLedgerId
+ case t: TopicMessageIdImpl => getLedgerId(t.getInnerMessageId)
+ case up: UserProvidedMessageId => up.getLedgerId
+ }
+ }
+
+ def getEntryId(mid: MessageId): Long = {
+ mid match {
+ case bmid: BatchMessageIdImpl =>
+ bmid.getEntryId
+ case midi: MessageIdImpl => midi.getEntryId
+ case t: TopicMessageIdImpl => getEntryId(t.getInnerMessageId)
+ case up: UserProvidedMessageId => up.getEntryId
+ }
+ }
+
+ def getPartitionIndex(mid: MessageId): Int = {
+ mid match {
+ case bmid: BatchMessageIdImpl =>
+ bmid.getPartitionIndex
+ case midi: MessageIdImpl => midi.getPartitionIndex
+ case t: TopicMessageIdImpl => getPartitionIndex(t.getInnerMessageId)
+ case up: UserProvidedMessageId => up.getPartitionIndex
+ }
+ }
+
def seekableLatestMid(mid: MessageId): MessageId = {
if (messageExists(mid)) mid else MessageId.earliest
}
diff --git a/src/main/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/ForwardStrategy.scala b/src/main/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/ForwardStrategy.scala
new file mode 100644
index 00000000..ec75ba7b
--- /dev/null
+++ b/src/main/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/ForwardStrategy.scala
@@ -0,0 +1,24 @@
+/**
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.pulsar.topicinternalstats.forward
+
+import org.apache.pulsar.common.policies.data.PersistentTopicInternalStats
+
+trait ForwardStrategy {
+ def forward(topics: Map[String, TopicState]): Map[String, Long]
+}
+
+case class TopicState(internalStat: PersistentTopicInternalStats,
+ actualLedgerId: Long,
+ actualEntryId: Long)
diff --git a/src/main/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/LargeFirstForwardStrategy.scala b/src/main/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/LargeFirstForwardStrategy.scala
new file mode 100644
index 00000000..ed39885b
--- /dev/null
+++ b/src/main/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/LargeFirstForwardStrategy.scala
@@ -0,0 +1,96 @@
+/**
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.pulsar.topicinternalstats.forward
+
+/**
+ * Forward strategy which sorts the topics by their backlog size starting
+ * with the largest, and forwards topics starting from the beginning of
+ * this list as the maximum entries parameter allows (taking into account
+ * the number entries that need to be added anyway if
+ *
+ * @param additionalEntriesPerTopic is set).
+ *
+ * If the maximum entries to forward is `100`, topics will be forwarded
+ * like this (provided there is no minimum entry number specified:
+ * | topic name | backlog size | forward amount |
+ * |------------|--------------|----------------|
+ * | topic-1 | 60 | 60 |
+ * | topic-2 | 50 | 40 |
+ * | topic-3 | 40 | 0 |
+ *
+ * If @param ensureEntriesPerTopic is specified, then every topic will be
+ * forwarded by that value in addition to this (taking the backlog size of
+ * the topic into account so that bandwidth is not wasted). Given maximum
+ * entries is `100`, minimum entries is `10`, topics will be forwarded like
+ * this:
+ *
+ * | topic name | backlog size | forward amount |
+ * |------------|--------------|----------------|
+ * | topic-1 | 60 | 10 + 50 = 60 |
+ * | topic-2 | 50 | 10 + 30 = 30 |
+ * | topic-3 | 40 | 10 + 0 = 10 |
+ * @param maxEntriesAltogetherToForward Maximum entries in all topics to forward.
+ * Individual topics forward values will sum
+ * up to this value.
+ * @param ensureEntriesPerTopic All topics will be forwarded by this value. The goal
+ * of this parameter is to ensure that topics with a very
+ * small backlog are also forwarded with a given minimal
+ * value. Has a higher precedence than
+ * @param maxEntriesAltogetherToForward.
+ */
+class LargeFirstForwardStrategy(maxEntriesAltogetherToForward: Long,
+ ensureEntriesPerTopic: Long) extends ForwardStrategy {
+ override def forward(topics: Map[String, TopicState]): Map[String, Long] = {
+
+ // calculate all remaining entries per topic, ordering them by remaining entry count
+ // in a reverse order
+ val topicBacklogs = topics
+ .map{
+ case(topicName, topicStat) =>
+ val internalStat = topicStat.internalStat
+ val ledgerId = topicStat.actualLedgerId
+ val entryId = topicStat.actualEntryId
+ (topicName, TopicInternalStatsUtils.numOfEntriesAfter(internalStat, ledgerId, entryId))
+ }
+ .toList
+ .sortBy{ case(_, numOfEntriesAfterPosition) => numOfEntriesAfterPosition }
+ .reverse
+
+ // calculate quota based on the ensured entry count
+ // this will be distributed between individual topics
+ var quota = Math.max(maxEntriesAltogetherToForward - ensureEntriesPerTopic * topics.size, 0)
+
+ val result = for ((topic, topicBacklogSize) <- topicBacklogs) yield {
+ // try to increase topic by this number
+ // - if we have already ran out of quota, do not move topic
+ // - if we do not have enough quota, proceed with the quota (exhaust it completely)
+ // - if we have enough quota, exhaust all topic content (and decrease it later)
+ // - take the number of ensured entries into account when calculating quota
+ val forwardTopicBy = if (quota > 0) {
+ Math.min(quota, topicBacklogSize - ensureEntriesPerTopic)
+ } else {
+ 0
+ }
+ // calculate forward position for a topic, make sure that it is
+ // always increased by the configured ensure entry count
+ val resultEntry = topic -> (ensureEntriesPerTopic + forwardTopicBy)
+ // decrease the overall quota separately
+ quota -= (topicBacklogSize - ensureEntriesPerTopic)
+ // return already calculated forward position
+ resultEntry
+ }
+
+ result.toMap
+ }
+}
diff --git a/src/main/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/LinearForwardStrategy.scala b/src/main/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/LinearForwardStrategy.scala
new file mode 100644
index 00000000..5dc994a0
--- /dev/null
+++ b/src/main/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/LinearForwardStrategy.scala
@@ -0,0 +1,40 @@
+/**
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.pulsar.topicinternalstats.forward
+
+/**
+ * Simple forward strategy, which forwards every topic evenly, not
+ * taking actual backlog sizes into account. Might waste bandwidth
+ * when the backlog of the topic is smaller than the calculated value
+ * for that topic.
+ *
+ * If the maximum entries to forward is `150`, topics will be forwarded
+ * like this (provided there is no minimum entry number specified:
+ * | topic name | backlog size | forward amount |
+ * |------------|--------------|----------------|
+ * | topic-1 | 60 | 50 |
+ * | topic-2 | 50 | 50 |
+ * | topic-3 | 40 | 50 |
+ *
+ * @param maxEntriesAltogetherToForward Maximum entries in all topics to
+ * forward. Will forward every topic
+ * by dividing this with the number of
+ * topics.
+ */
+class LinearForwardStrategy(maxEntriesAltogetherToForward: Long) extends ForwardStrategy {
+ override def forward(topics: Map[String, TopicState]): Map[String, Long] =
+ topics
+ .map{ case (topicName, _) =>
+ topicName -> (maxEntriesAltogetherToForward / topics.size) }
+}
diff --git a/src/main/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/ProportionalForwardStrategy.scala b/src/main/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/ProportionalForwardStrategy.scala
new file mode 100644
index 00000000..b52f7049
--- /dev/null
+++ b/src/main/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/ProportionalForwardStrategy.scala
@@ -0,0 +1,92 @@
+/**
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.pulsar.topicinternalstats.forward
+
+/**
+ * This forward strategy will forward individual topic backlogs based on
+ * their size proportional to the size of the overall backlog (considering
+ * all topics).
+ *
+ * If the maximum entries to forward is `100`, topics will be forwarded
+ * like this (provided there is no minimum entry number specified:
+ * | topic name | backlog size | forward amount |
+ * |------------|--------------|--------------------------|
+ * |topic-1 | 60 | 100*(60/(60+50+40)) = 40 |
+ * |topic-2 | 50 | 100*(50/(60+50+40)) = 33 |
+ * |topic-3 | 40 | 100*(40/(60+50+40)) = 27 |
+ *
+ * If @param ensureEntriesPerTopic is specified, then every topic will be
+ * forwarded by that value in addition to this (taking the backlog size of
+ * the topic into account so that bandwidth is not wasted).
+ * Given maximum entries is `100`, minimum entries is `10`, topics will be
+ * forwarded like this:
+ *
+ * | topic name | backlog size | forward amount |
+ * |------------|--------------|----------------------------|
+ * |topic-1 | 60 | 10+70*(60/(60+50+40)) = 38 |
+ * |topic-2 | 50 | 10+70*(50/(60+50+40)) = 33 |
+ * |topic-3 | 40 | 10+70*(40/(60+50+40)) = 29 |
+ *
+ * @param maxEntriesAltogetherToForward Maximum entries in all topics to forward.
+ * Individual topics forward values will sum
+ * up to this value.
+ * @param ensureEntriesPerTopic All topics will be forwarded by this value. The goal
+ * of this parameter is to ensure that topics with a very
+ * small backlog are also forwarded with a given minimal
+ * value. Has a higher precedence than
+ * @param maxEntriesAltogetherToForward.
+ */
+class ProportionalForwardStrategy(maxEntriesAltogetherToForward: Long,
+ ensureEntriesPerTopic: Long) extends ForwardStrategy {
+ override def forward(topics: Map[String, TopicState]): Map[String, Long] = {
+ // calculate all remaining entries per topic
+ val topicBacklogs = topics
+ .map{
+ case (topicName, topicStat) =>
+ val internalStat = topicStat.internalStat
+ val ledgerId = topicStat.actualLedgerId
+ val entryId = topicStat.actualEntryId
+ (topicName, TopicInternalStatsUtils.numOfEntriesAfter(internalStat, ledgerId, entryId))
+ }
+ .toList
+
+ // this is the size of the complete backlog (the sum of all individual topic
+ // backlogs)
+ val completeBacklogSize = topicBacklogs
+ .map{ case (_, topicBacklogSize) => topicBacklogSize }
+ .sum
+
+ // calculate quota based on the ensured entry count
+ // this will be distributed between individual topics
+ val quota = Math.max(maxEntriesAltogetherToForward - ensureEntriesPerTopic * topics.size, 0)
+
+ topicBacklogs.map {
+ case (topicName: String, backLog: Long) =>
+ // when calculating the coefficient, do not take the number of additional entries into
+ // account (that we will add anyway)
+ val topicBacklogCoefficient = if (completeBacklogSize == 0) {
+ 0.0 // do not forward if there is no backlog
+ } else {
+ // take the ensured entries into account when calculating
+ // backlog coefficient
+ val backlogWithoutAdditionalEntries =
+ Math.max(backLog - ensureEntriesPerTopic, 0).toFloat
+ val completeBacklogWithoutAdditionalEntries =
+ (completeBacklogSize - ensureEntriesPerTopic * topics.size).toFloat
+ backlogWithoutAdditionalEntries / completeBacklogWithoutAdditionalEntries
+ }
+ topicName -> (ensureEntriesPerTopic + (quota * topicBacklogCoefficient).toLong)
+ }.toMap
+ }
+}
diff --git a/src/main/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/TopicInternalStatsUtils.scala b/src/main/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/TopicInternalStatsUtils.scala
new file mode 100644
index 00000000..19761e01
--- /dev/null
+++ b/src/main/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/TopicInternalStatsUtils.scala
@@ -0,0 +1,118 @@
+/**
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.pulsar.topicinternalstats.forward
+
+import scala.collection.JavaConverters.asScalaBufferConverter
+
+import org.apache.pulsar.common.policies.data.PersistentTopicInternalStats
+
+object TopicInternalStatsUtils {
+
+ def forwardMessageId(stats: PersistentTopicInternalStats,
+ startLedgerId: Long,
+ startEntryId: Long,
+ forwardByEntryCount: Long): (Long, Long) = {
+ val ledgers = fixLastLedgerInInternalStat(stats).ledgers.asScala.toList
+ if (ledgers.isEmpty) {
+ // If there are no ledger info, stay at current ID
+ (startLedgerId, startEntryId)
+ } else {
+ // Find the start ledger and entry ID
+ var actualLedgerIndex = if (ledgers.exists(_.ledgerId == startLedgerId)) {
+ ledgers.indexWhere(_.ledgerId == startLedgerId)
+ } else if (startLedgerId == -1) {
+ 0
+ } else {
+ ledgers.size - 1
+ }
+
+ var actualEntryId = Math.min(Math.max(startEntryId, 0), ledgers(actualLedgerIndex).entries)
+ var entriesToSkip = forwardByEntryCount
+
+ while (entriesToSkip > 0) {
+ val currentLedger = ledgers(actualLedgerIndex)
+ val remainingElementsInCurrentLedger = currentLedger.entries - actualEntryId
+
+ if (entriesToSkip <= remainingElementsInCurrentLedger) {
+ actualEntryId += entriesToSkip
+ entriesToSkip = 0
+ } else if ((remainingElementsInCurrentLedger < entriesToSkip)
+ && (actualLedgerIndex < (ledgers.size-1))) {
+ // Moving onto the next ledger
+ entriesToSkip -= remainingElementsInCurrentLedger
+ actualLedgerIndex += 1
+ actualEntryId = 0
+ } else {
+ // This is the last ledger
+ val entriesInLastLedger = ledgers(actualLedgerIndex).entries
+ actualEntryId = Math.min(entriesToSkip + actualEntryId, entriesInLastLedger)
+ entriesToSkip = 0
+ }
+ }
+
+ (ledgers(actualLedgerIndex).ledgerId, actualEntryId)
+ }
+ }
+
+ def numOfEntriesUntil(stats: PersistentTopicInternalStats,
+ ledgerId: Long,
+ entryId: Long): Long = {
+ val ledgers = fixLastLedgerInInternalStat(stats).ledgers.asScala
+ if (ledgers.isEmpty) {
+ 0
+ } else {
+ val ledgersBeforeStartLedger = fixLastLedgerInInternalStat(stats).ledgers
+ .asScala
+ .filter(_.ledgerId < ledgerId)
+ val boundedEntryId = if (ledgersBeforeStartLedger.isEmpty) {
+ Math.max(entryId, 0)
+ } else {
+ Math.min(Math.max(entryId, 0), ledgersBeforeStartLedger.last.entries)
+ }
+ boundedEntryId + ledgersBeforeStartLedger.map(_.entries).sum
+ }
+ }
+
+ def numOfEntriesAfter(stats: PersistentTopicInternalStats,
+ ledgerId: Long,
+ entryId: Long): Long = {
+ val ledgers = fixLastLedgerInInternalStat(stats).ledgers.asScala
+ if (ledgers.isEmpty) {
+ 0
+ } else {
+ val entryCountIncludingCurrentLedger = fixLastLedgerInInternalStat(stats).ledgers
+ .asScala
+ .filter(_.ledgerId >= ledgerId)
+ val boundedEntryId = if (entryCountIncludingCurrentLedger.isEmpty) {
+ Math.max(entryId, 0)
+ } else {
+ Math.min(Math.max(entryId, 0), entryCountIncludingCurrentLedger.last.entries)
+ }
+ entryCountIncludingCurrentLedger.map(_.entries).sum - boundedEntryId
+ }
+ }
+
+ private def fixLastLedgerInInternalStat(
+ stats: PersistentTopicInternalStats): PersistentTopicInternalStats = {
+ if (stats.ledgers.isEmpty) {
+ stats
+ } else {
+ val lastLedgerInfo = stats.ledgers.get(stats.ledgers.size() - 1)
+ lastLedgerInfo.entries = stats.currentLedgerEntries
+ stats.ledgers.set(stats.ledgers.size() - 1, lastLedgerInfo)
+ stats
+ }
+ }
+
+}
diff --git a/src/test/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/LargeFirstForwardStrategySuite.scala b/src/test/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/LargeFirstForwardStrategySuite.scala
new file mode 100644
index 00000000..993b383a
--- /dev/null
+++ b/src/test/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/LargeFirstForwardStrategySuite.scala
@@ -0,0 +1,205 @@
+/**
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.pulsar.topicinternalstats.forward
+
+import TopicStateFixture.{createLedgerInfo, _}
+import org.apache.spark.SparkFunSuite
+
+class LargeFirstForwardStrategySuite extends SparkFunSuite {
+
+ test("forward empty topics") {
+ val fakeState = Map( "topic1" ->
+ createTopicState(
+ createPersistentTopicInternalStat(),
+ 0, 0
+ ))
+ val testForwarder = new LargeFirstForwardStrategy(10, 0)
+ val result = testForwarder.forward(fakeState)
+
+ assert(result.size == 1)
+ assert(result("topic1") == 0)
+ }
+
+ test("forward a single topic with a single ledger") {
+ val fakeState = Map( "topic1" ->
+ createTopicState(
+ createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 200)
+ ),
+ 0, 0
+ ))
+ val testForwarder = new LargeFirstForwardStrategy(10, 0)
+ val result = testForwarder.forward(fakeState)
+
+ assert(result.size == 1)
+ assert(result("topic1") == 10)
+ }
+
+ test("forward a single topic with multiple ledgers") {
+ val fakeState = Map( "topic1" ->
+ createTopicState(
+ createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 200),
+ createLedgerInfo(2000, 200)
+ ),
+ 0, 0
+ ))
+ val testForwarder = new LargeFirstForwardStrategy(350, 0)
+ val result = testForwarder.forward(fakeState)
+
+ assert(result.size == 1)
+ assert(result("topic1") == 350)
+ }
+
+ test("forward a single topic with the biggest backlog") {
+ val fakeState = Map(
+ "topic1" -> createTopicState(
+ createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 200),
+ ),
+ 0, 0
+ ),
+ "topic2" -> createTopicState(
+ createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 400),
+ ),
+ 0, 0
+ ),
+ "topic3" -> createTopicState(
+ createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 600),
+ ),
+ 0, 0
+ ))
+ val testForwarder = new LargeFirstForwardStrategy(15, 0)
+ val result = testForwarder.forward(fakeState)
+
+ assert(result.size == 3)
+ assert(result("topic3") == 15)
+ assert(result("topic2") == 0)
+ assert(result("topic1") == 0)
+ }
+
+ test("forward multiple topics if the backlog is small enough") {
+ val fakeState = Map(
+ "topic1" -> createTopicState(
+ createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 20),
+ ),
+ 0, 0
+ ),
+ "topic2" -> createTopicState(
+ createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 40),
+ ),
+ 0, 0
+ ),
+ "topic3" -> createTopicState(
+ createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 60),
+ ),
+ 0, 0
+ ))
+ val testForwarder = new LargeFirstForwardStrategy(100, 0)
+ val result = testForwarder.forward(fakeState)
+
+ assert(result.size == 3)
+ assert(result("topic3") == 60)
+ assert(result("topic2") == 40)
+ assert(result("topic1") == 0)
+ }
+
+ test("forward by additional entries regardless of backlog size") {
+ val maxEntries = 130
+ val additionalEntries = 10
+ val topic1Backlog = 80
+ val topic2Backlog = 60
+ val topic3Backlog = 40
+ val fakeState = Map(
+ "topic1" -> createTopicState(
+ createPersistentTopicInternalStat(
+ createLedgerInfo(1000, topic1Backlog),
+ ),
+ 0, 0
+ ),
+ "topic2" -> createTopicState(
+ createPersistentTopicInternalStat(
+ createLedgerInfo(1000, topic2Backlog),
+ ),
+ 0, 0
+ ),
+ "topic3" -> createTopicState(
+ createPersistentTopicInternalStat(
+ createLedgerInfo(1000, topic3Backlog),
+ ),
+ 0, 0
+ ))
+ val testForwarder = new LargeFirstForwardStrategy(maxEntries, additionalEntries)
+ val result = testForwarder.forward(fakeState)
+
+ assert(result.size == 3)
+
+ assert(result("topic1") >= additionalEntries)
+ assert(result("topic2") >= additionalEntries)
+ assert(result("topic3") == additionalEntries)
+
+ }
+
+ test("additional entries to forward has a higher precedence than max allowed entries") {
+ val fakeState = Map( "topic1" ->
+ createTopicState(
+ createPersistentTopicInternalStat(),
+ 0, 0
+ ))
+
+ val testForwarder = new LargeFirstForwardStrategy(10, 20)
+ val result = testForwarder.forward(fakeState)
+
+ assert(result("topic1") == 20)
+ }
+
+ test("forward from the middle of the first topic ledger") {
+ val fakeState = Map( "topic1" ->
+ createTopicState(
+ createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 200)
+ ),
+ 1000, 20
+ ))
+ val testForwarder = new LargeFirstForwardStrategy(80, 0)
+ val result = testForwarder.forward(fakeState)
+
+ assert(result.size == 1)
+ assert(result("topic1") == 80)
+ }
+
+ test("forward from the middle of the last topic ledger") {
+ val fakeState = Map( "topic1" ->
+ createTopicState(
+ createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 200),
+ createLedgerInfo(2000, 200),
+ createLedgerInfo(3000, 200)
+ ),
+ 3000, 20
+ ))
+ val testForwarder = new LargeFirstForwardStrategy(80, 0)
+ val result = testForwarder.forward(fakeState)
+
+ assert(result.size == 1)
+ assert(result("topic1") == 80)
+ }
+
+}
+
diff --git a/src/test/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/LinearForwardStrategySuite.scala b/src/test/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/LinearForwardStrategySuite.scala
new file mode 100644
index 00000000..087356f8
--- /dev/null
+++ b/src/test/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/LinearForwardStrategySuite.scala
@@ -0,0 +1,128 @@
+/**
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.pulsar.topicinternalstats.forward
+
+import TopicStateFixture._
+
+import org.apache.spark.SparkFunSuite
+
+class LinearForwardStrategySuite extends SparkFunSuite {
+
+ test("forward empty topics") {
+ val fakeState = Map( "topic1" ->
+ createTopicState(
+ createPersistentTopicInternalStat(),
+ 0, 0
+ ))
+ val testForwarder = new LinearForwardStrategy(10)
+ val result = testForwarder.forward(fakeState)
+
+ assert(result.size == 1)
+ assert(result("topic1") == 10)
+ }
+
+ test("forward a single topic with a single ledger") {
+ val fakeState = Map( "topic1" ->
+ createTopicState(
+ createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 200)
+ ),
+ 0, 0
+ ))
+ val testForwarder = new LinearForwardStrategy(10)
+ val result = testForwarder.forward(fakeState)
+
+ assert(result.size == 1)
+ assert(result("topic1") == 10)
+ }
+
+ test("forward a single topic with multiple ledgers") {
+ val fakeState = Map( "topic1" ->
+ createTopicState(
+ createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 200),
+ createLedgerInfo(2000, 200)
+ ),
+ 0, 0
+ ))
+ val testForwarder = new LinearForwardStrategy(350)
+ val result = testForwarder.forward(fakeState)
+
+ assert(result.size == 1)
+ assert(result("topic1") == 350)
+ }
+
+ test("forward multiple topics with single ledger") {
+ val fakeState = Map(
+ "topic1" -> createTopicState(
+ createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 200),
+ ),
+ 0, 0
+ ),
+ "topic2" -> createTopicState(
+ createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 200),
+ ),
+ 0, 0
+ ),
+ "topic3" -> createTopicState(
+ createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 200),
+ ),
+ 0, 0
+ ))
+ val testForwarder = new LinearForwardStrategy(15)
+ val result = testForwarder.forward(fakeState)
+
+ assert(result.size == 3)
+ assert(result("topic1") == 5)
+ assert(result("topic2") == 5)
+ assert(result("topic3") == 5)
+ }
+
+ test("forward from the middle of the first topic ledger") {
+ val fakeState = Map( "topic1" ->
+ createTopicState(
+ createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 200)
+ ),
+ 1000, 20
+ ))
+ val testForwarder = new LinearForwardStrategy(80)
+ val result = testForwarder.forward(fakeState)
+
+ assert(result.size == 1)
+ assert(result("topic1") == 80)
+ }
+
+ test("forward from the middle of the last topic ledger") {
+ val fakeState = Map( "topic1" ->
+ createTopicState(
+ createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 200),
+ createLedgerInfo(2000, 200),
+ createLedgerInfo(3000, 200)
+ ),
+ 3000, 20
+ ))
+ val testForwarder = new LinearForwardStrategy(80)
+ val result = testForwarder.forward(fakeState)
+
+ assert(result.size == 1)
+ assert(result("topic1") == 80)
+ }
+
+}
+
diff --git a/src/test/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/ProportionalForwardStrategySuite.scala b/src/test/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/ProportionalForwardStrategySuite.scala
new file mode 100644
index 00000000..d811871d
--- /dev/null
+++ b/src/test/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/ProportionalForwardStrategySuite.scala
@@ -0,0 +1,238 @@
+/**
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.pulsar.topicinternalstats.forward
+
+import TopicStateFixture._
+
+import org.apache.spark.SparkFunSuite
+
+class ProportionalForwardStrategySuite extends SparkFunSuite {
+
+ test("forward empty topics") {
+ val fakeState = Map( "topic1" ->
+ createTopicState(
+ createPersistentTopicInternalStat(),
+ 0, 0
+ ))
+ val testForwarder = new ProportionalForwardStrategy(10, 0)
+ val result = testForwarder.forward(fakeState)
+
+ assert(result.size == 1)
+ assert(result("topic1") == 0)
+ }
+
+ test("forward a single topic with a single ledger") {
+ val fakeState = Map( "topic1" ->
+ createTopicState(
+ createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 200)
+ ),
+ 0, 0
+ ))
+ val testForwarder = new ProportionalForwardStrategy(10, 0)
+ val result = testForwarder.forward(fakeState)
+
+ assert(result.size == 1)
+ assert(result("topic1") == 10)
+ }
+
+ test("forward a single topic with multiple ledgers") {
+ val fakeState = Map( "topic1" ->
+ createTopicState(
+ createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 200),
+ createLedgerInfo(2000, 200)
+ ),
+ 0, 0
+ ))
+ val testForwarder = new ProportionalForwardStrategy(350, 0)
+ val result = testForwarder.forward(fakeState)
+
+ assert(result.size == 1)
+ assert(result("topic1") == 350)
+ }
+
+ test("forward a single topic with the biggest backlog") {
+ val maxEntries = 12
+ val fakeState = Map(
+ "topic1" -> createTopicState(
+ createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 200),
+ ),
+ 0, 0
+ ),
+ "topic2" -> createTopicState(
+ createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 400),
+ ),
+ 0, 0
+ ),
+ "topic3" -> createTopicState(
+ createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 600),
+ ),
+ 0, 0
+ ))
+ val testForwarder = new ProportionalForwardStrategy(maxEntries, 0)
+ val result = testForwarder.forward(fakeState)
+
+ assert(result.size == 3)
+ assert(result("topic1") == (maxEntries.toFloat / 6.0).toInt)
+ assert(result("topic2") == (maxEntries.toFloat / 3.0).toInt)
+ assert(result("topic3") == (maxEntries.toFloat / 2.0).toInt)
+ }
+
+ test("forward multiple topics at the same time") {
+ val fakeState = Map(
+ "topic1" -> createTopicState(
+ createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 20),
+ ),
+ 0, 0
+ ),
+ "topic2" -> createTopicState(
+ createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 40),
+ ),
+ 0, 0
+ ),
+ "topic3" -> createTopicState(
+ createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 60),
+ ),
+ 0, 0
+ ))
+ val testForwarder = new ProportionalForwardStrategy(100, 0)
+ val result = testForwarder.forward(fakeState)
+
+ assert(result.size == 3)
+ assert(result("topic3") > 0)
+ assert(result("topic2") > 0)
+ assert(result("topic1") > 0)
+ }
+
+ test("forward by additional entries regardless of backlog size") {
+ val maxEntries = 50
+ val additionalEntries = 10
+ val topic1Backlog = 10000
+ val topic2Backlog = 20000
+ val topic3Backlog = 10
+ val fakeState = Map(
+ "topic1" -> createTopicState(
+ createPersistentTopicInternalStat(
+ createLedgerInfo(1000, topic1Backlog),
+ ),
+ 0, 0
+ ),
+ "topic2" -> createTopicState(
+ createPersistentTopicInternalStat(
+ createLedgerInfo(1000, topic2Backlog),
+ ),
+ 0, 0
+ ),
+ "topic3" -> createTopicState(
+ createPersistentTopicInternalStat(
+ createLedgerInfo(1000, topic3Backlog),
+ ),
+ 0, 0
+ ))
+ val testForwarder = new ProportionalForwardStrategy(maxEntries, additionalEntries)
+ val result = testForwarder.forward(fakeState)
+
+ assert(result.size == 3)
+
+ assert(result("topic1") >= additionalEntries)
+ assert(result("topic2") >= additionalEntries)
+ assert(result("topic3") == additionalEntries)
+
+ }
+
+ test("additional entries to forward has a higher precedence than topic backlog size") {
+ val fakeState = Map( "topic1" ->
+ createTopicState(
+ createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 10)
+ ),
+ 0, 0
+ ))
+
+ val testForwarder = new ProportionalForwardStrategy(10, 20)
+ val result = testForwarder.forward(fakeState)
+
+ assert(result("topic1") == 20)
+ }
+
+ test("take the additional entries into account when calculating individual topic forward ratio") {
+ val fakeState = Map(
+ "topic1" -> createTopicState(
+ createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 1000),
+ ),
+ 0, 0
+ ),
+ "topic2" -> createTopicState(
+ createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 2000),
+ ),
+ 0, 0
+ ))
+ val numberOfFakeTopics = fakeState.size
+ val ensureAdditionalEntriesPerTopic = 500
+ val entriesOnTopOfAdditionalEntries = 100
+ val maxEntries = entriesOnTopOfAdditionalEntries + ensureAdditionalEntriesPerTopic * numberOfFakeTopics
+
+ val testForwarder = new ProportionalForwardStrategy(maxEntries, ensureAdditionalEntriesPerTopic)
+ val result = testForwarder.forward(fakeState)
+
+ assert(result("topic1") ==
+ (entriesOnTopOfAdditionalEntries.toFloat / 4.0).toInt
+ + ensureAdditionalEntriesPerTopic)
+ assert(result("topic2") ==
+ (entriesOnTopOfAdditionalEntries.toFloat * 3.0 / 4.0).toInt
+ + ensureAdditionalEntriesPerTopic)
+ }
+
+ test("forward from the middle of the first topic ledger") {
+ val fakeState = Map( "topic1" ->
+ createTopicState(
+ createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 200)
+ ),
+ 1000, 20
+ ))
+ val testForwarder = new ProportionalForwardStrategy(80, 0)
+ val result = testForwarder.forward(fakeState)
+
+ assert(result.size == 1)
+ assert(result("topic1") == 80)
+ }
+
+ test("forward from the middle of the last topic ledger") {
+ val fakeState = Map( "topic1" ->
+ createTopicState(
+ createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 200),
+ createLedgerInfo(2000, 200),
+ createLedgerInfo(3000, 200)
+ ),
+ 3000, 20
+ ))
+ val testForwarder = new ProportionalForwardStrategy(80, 0)
+ val result = testForwarder.forward(fakeState)
+
+ assert(result.size == 1)
+ assert(result("topic1") == 80)
+ }
+
+}
diff --git a/src/test/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/TopicInternalStatsUtilsSuite.scala b/src/test/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/TopicInternalStatsUtilsSuite.scala
new file mode 100644
index 00000000..8573c125
--- /dev/null
+++ b/src/test/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/TopicInternalStatsUtilsSuite.scala
@@ -0,0 +1,374 @@
+/**
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.pulsar.topicinternalstats.forward
+
+import TopicStateFixture._
+
+import org.apache.spark.SparkFunSuite
+
+class TopicInternalStatsUtilsSuite extends SparkFunSuite {
+
+ test("forward empty ledger") {
+ val fakeStats = createPersistentTopicInternalStat()
+ val (nextLedgerId, nextEntryId) =
+ TopicInternalStatsUtils.forwardMessageId(fakeStats, 0, 0, 10)
+
+ assert(nextLedgerId == 0)
+ assert(nextEntryId == 0)
+ }
+
+ test("forward within a single ledger") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 500)
+ )
+ val (nextLedgerId, nextEntryId) =
+ TopicInternalStatsUtils.forwardMessageId(fakeStats, 1000, 0, 10)
+
+ assert(nextLedgerId == 1000)
+ assert(nextEntryId == 10)
+ }
+
+ test("forward within a single ledger starting from the middle") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 500)
+ )
+ val (nextLedgerId, nextEntryId) =
+ TopicInternalStatsUtils.forwardMessageId(fakeStats, 1000, 25, 10)
+
+ assert(nextLedgerId == 1000)
+ assert(nextEntryId == 35)
+ }
+
+ test("forward to the next ledger") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50)
+ )
+ val (nextLedgerId, nextEntryId) =
+ TopicInternalStatsUtils.forwardMessageId(fakeStats, 1000, 25, 50)
+
+ assert(nextLedgerId == 2000)
+ assert(nextEntryId == 25)
+ }
+
+ test("skip over a ledger if needed") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val (nextLedgerId, nextEntryId) =
+ TopicInternalStatsUtils.forwardMessageId(fakeStats, 1000, 25, 100)
+
+ assert(nextLedgerId == 3000)
+ assert(nextEntryId == 25)
+ }
+
+ test("forward to the end of the topic if too many entries need " +
+ "to be skipped with a single ledger") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ )
+ val (nextLedgerId, nextEntryId) =
+ TopicInternalStatsUtils.forwardMessageId(fakeStats, 1000, 25, 600)
+
+ assert(nextLedgerId == 1000)
+ assert(nextEntryId == 50)
+ }
+
+ test("forward to the end of the topic if too many entries need " +
+ "to be skipped with multiple ledgers") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val (nextLedgerId, nextEntryId) =
+ TopicInternalStatsUtils.forwardMessageId(fakeStats, 1000, 25, 600)
+
+ assert(nextLedgerId == 3000)
+ assert(nextEntryId == 50)
+ }
+
+ test("forward with zero elements shall give you back what was given") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val (nextLedgerId, nextEntryId) =
+ TopicInternalStatsUtils.forwardMessageId(fakeStats, 2000, 25, 0)
+
+ assert(nextLedgerId == 2000)
+ assert(nextEntryId == 25)
+ }
+
+ test("forward from beginning of the topic") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val (nextLedgerId, nextEntryId) =
+ TopicInternalStatsUtils.forwardMessageId(fakeStats, -1, -1, 125)
+
+ assert(nextLedgerId == 3000)
+ assert(nextEntryId == 25)
+ }
+
+ test("forward from non-existent ledger id shall forward from the last ledger instead") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val (nextLedgerId, nextEntryId) =
+ TopicInternalStatsUtils.forwardMessageId(fakeStats, 6000, 0, 25)
+
+ assert(nextLedgerId == 3000)
+ assert(nextEntryId == 25)
+ }
+
+ test("forward from non-existent entry id shall forward from end of ledger instead") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val (nextLedgerId, nextEntryId) =
+ TopicInternalStatsUtils.forwardMessageId(fakeStats, 1000, 250, 25)
+
+ assert(nextLedgerId == 2000)
+ assert(nextEntryId == 25)
+ }
+
+ test("forwarded entry id shall never be less than current entry id") {
+ val startEntryID = 200
+ val ledgerID = 1000
+ val entriesInLedger = 205
+ val forwardByEntries = 50
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(ledgerID, entriesInLedger)
+ )
+ val (nextLedgerId, nextEntryId) =
+ TopicInternalStatsUtils.forwardMessageId(fakeStats, ledgerID, startEntryID, forwardByEntries)
+ assert(nextLedgerId == ledgerID)
+ assert(nextEntryId > startEntryID)
+ }
+
+ test("number of entries until shall work with empty input") {
+ val fakeStats = createPersistentTopicInternalStat()
+ val result =
+ TopicInternalStatsUtils.numOfEntriesUntil(fakeStats, -1, -1)
+
+ assert(result == 0)
+ }
+
+ test("number of entries until with single ledger") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ )
+ val result =
+ TopicInternalStatsUtils.numOfEntriesUntil(fakeStats, 1000, 25)
+
+ assert(result == 25)
+ }
+
+ test("number of entries until with multiple ledgers") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val result =
+ TopicInternalStatsUtils.numOfEntriesUntil(fakeStats, 2000, 25)
+
+ assert(result == 75)
+ }
+
+ test("number of entries until beginning of topic") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val result =
+ TopicInternalStatsUtils.numOfEntriesUntil(fakeStats, -1, -1)
+
+ assert(result == 0)
+ }
+
+ test("number of entries until end of topic") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val result =
+ TopicInternalStatsUtils.numOfEntriesUntil(fakeStats, 3000, 50)
+
+ assert(result == 150)
+ }
+
+ test("number of entries until with ledger id below boundary") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val result =
+ TopicInternalStatsUtils.numOfEntriesUntil(fakeStats, -2, 0)
+
+ assert(result == 0)
+ }
+
+ test("number of entries until with entry id below boundary") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val result =
+ TopicInternalStatsUtils.numOfEntriesUntil(fakeStats, 2000, -2)
+
+ assert(result == 50)
+
+ }
+
+ test("number of entries until with ledger id above boundary") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val result =
+ TopicInternalStatsUtils.numOfEntriesUntil(fakeStats, 6000, 0)
+
+ assert(result == 150)
+ }
+
+ test("number of entries until with entry id above boundary") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val result =
+ TopicInternalStatsUtils.numOfEntriesUntil(fakeStats, 2000, 200)
+
+ assert(result == 100)
+ }
+
+ test("number of entries after shall work with empty input") {
+ val fakeStats = createPersistentTopicInternalStat()
+ val result =
+ TopicInternalStatsUtils.numOfEntriesAfter(fakeStats, -1, -1)
+
+ assert(result == 0)
+ }
+
+ test("number of entries after with single ledger") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ )
+ val result =
+ TopicInternalStatsUtils.numOfEntriesAfter(fakeStats, 1000, 20)
+
+ assert(result == 30)
+ }
+
+ test("number of entries after with multiple ledgers") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val result =
+ TopicInternalStatsUtils.numOfEntriesAfter(fakeStats, 1000, 20)
+
+ assert(result == 130)
+ }
+
+ test("number of entries after beginning of topic") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val result =
+ TopicInternalStatsUtils.numOfEntriesAfter(fakeStats, -1, -1)
+
+ assert(result == 150)
+ }
+
+ test("number of entries after end of topic") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val result =
+ TopicInternalStatsUtils.numOfEntriesAfter(fakeStats, 3000, 50)
+
+ assert(result == 0)
+ }
+
+ test("number of entries after with ledger id below boundary") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val result =
+ TopicInternalStatsUtils.numOfEntriesAfter(fakeStats, -2, 0)
+
+ assert(result == 150)
+ }
+
+ test("number of entries after with entry id below boundary") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val result =
+ TopicInternalStatsUtils.numOfEntriesAfter(fakeStats, 2000, -2)
+
+ assert(result == 100)
+ }
+
+ test("number of entries after with ledger id above boundary") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val result =
+ TopicInternalStatsUtils.numOfEntriesAfter(fakeStats, 6000, 0)
+
+ assert(result == 0)
+ }
+
+ test("number of entries after with entry id above boundary") {
+ val fakeStats = createPersistentTopicInternalStat(
+ createLedgerInfo(1000, 50),
+ createLedgerInfo(2000, 50),
+ createLedgerInfo(3000, 50),
+ )
+ val result =
+ TopicInternalStatsUtils.numOfEntriesAfter(fakeStats, 2000, 200)
+
+ assert(result == 50)
+ }
+}
diff --git a/src/test/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/TopicStateTestFixture.scala b/src/test/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/TopicStateTestFixture.scala
new file mode 100644
index 00000000..ff4d4ff6
--- /dev/null
+++ b/src/test/scala/org/apache/spark/sql/pulsar/topicinternalstats/forward/TopicStateTestFixture.scala
@@ -0,0 +1,57 @@
+/**
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.pulsar.topicinternalstats.forward
+
+import java.util
+
+import org.apache.pulsar.common.policies.data.PersistentTopicInternalStats
+import org.apache.pulsar.common.policies.data.PersistentTopicInternalStats.LedgerInfo
+
+object TopicStateFixture {
+
+ def createTopicState(topicInternalStats: PersistentTopicInternalStats,
+ ledgerId: Long,
+ entryId: Long): TopicState = {
+ TopicState(topicInternalStats, ledgerId, entryId)
+ }
+
+ def createPersistentTopicInternalStat(ledgers: LedgerInfo*): PersistentTopicInternalStats = {
+ val result = new PersistentTopicInternalStats()
+
+ result.currentLedgerEntries = if (ledgers.isEmpty) {
+ 0
+ } else {
+ ledgers.last.entries
+ }
+
+ if (!ledgers.isEmpty) {
+ // simulating a bug in the Pulsar Admin interface
+ // (the last ledger in the list of ledgers has 0
+ // as entry count instead of the current entry
+ // count)
+ val modifiedLastEntryId = ledgers.last
+ modifiedLastEntryId.entries = 0
+ }
+ result.ledgers = util.Arrays.asList(ledgers: _*)
+ result
+ }
+
+ def createLedgerInfo(ledgerId: Long, entries: Long): LedgerInfo = {
+ val result = new LedgerInfo()
+ result.ledgerId = ledgerId
+ result.entries = entries
+ result
+ }
+}
+