Skip to content

Commit

Permalink
chore: make state transitions uncancelable
Browse files Browse the repository at this point in the history
Ensures updates to `KafkaConsumerActor`'s internal state are applied
together with the side-effects that support those updates by executing
both together inside an uncancelable block. These changes ensure that
changes made to the consumer are reflected in the internal state, and
changes to the internal state are made visible to clients.

In the simplest cases, a sequence of `Ref.update`/`modify` and
`Async.flatMap` was refactored to use `Ref.flatModify`, which offers the
desired guarantees. In other cases, the relevant actions were explicitly
wrapped in an `Async.uncancelable` block.

Where relevant, logging of state changes was treated as part of the
uncancelable block. This is to avoid silent state changes, which may be
hard to track down while debugging. Invocation of user-supplied
callbacks on changes to assigned partitions is left out of these blocks
as failures in those do not affect the internal state of the consumer.
  • Loading branch information
João Abecasis authored and biochimia committed Jan 3, 2025
1 parent 41be508 commit 7d227f5
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 98 deletions.
124 changes: 73 additions & 51 deletions modules/core/src/main/scala/fs2/kafka/KafkaConsumer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ object KafkaConsumer {
def storeFetch: F[Unit] =
actor
.ref
.modify { state =>
.flatModify { state =>
val (newState, oldFetches) =
state.withFetch(partition, streamId, callback)
newState ->
Expand All @@ -195,7 +195,6 @@ object KafkaConsumer {
logging.log(RevokedPreviousFetch(partition, streamId))
})
}
.flatten

def completeRevoked: F[Unit] =
callback((Chunk.empty, FetchCompletedReason.TopicPartitionRevoked))
Expand Down Expand Up @@ -268,25 +267,32 @@ object KafkaConsumer {
): OnRebalance[F] =
OnRebalance(
onRevoked = revoked => {
for {
finishers <- assignmentRef.modify(_.partition(entry => !revoked.contains(entry._1)))
_ <- finishers.toVector.traverse { case (_, finisher) => finisher.complete(()) }
} yield ()
assignmentRef.flatModify { assignment =>
val (newAssignment, finishers) =
assignment.partition(entry => !revoked.contains(entry._1))

(
newAssignment,
finishers.toVector.traverse_ { case (_, finisher) => finisher.complete(()) }
)
}
},
onAssigned = assignedPartitions => {
for {
assignment <- assignedPartitions
.toVector
.traverse { partition =>
Deferred[F, Unit].map(partition -> _)
}
.map(_.toMap)
_ <- assignmentRef.update(_ ++ assignment)
_ <- enqueueAssignment(
streamId = streamId,
assigned = assignment,
partitionsMapQueue = partitionsMapQueue
)
newAssignment <- assignedPartitions
.toVector
.traverse { partition =>
Deferred[F, Unit].map(partition -> _)
}
.map(_.toMap)
_ <- assignmentRef.flatModify { assignment =>
(assignment ++ newAssignment) ->
enqueueAssignment(
streamId = streamId,
assigned = newAssignment,
partitionsMapQueue = partitionsMapQueue
)
}
} yield ()
}
)
Expand Down Expand Up @@ -407,9 +413,9 @@ object KafkaConsumer {
.fold(actor.ref.updateAndGet(_.asStreaming)) { on =>
actor
.ref
.updateAndGet(_.withOnRebalance(on).asStreaming)
.flatTap { newState =>
logging.log(LogEntry.StoredOnRebalance(on, newState))
.flatModify { state =>
val newState = state.withOnRebalance(on).asStreaming
newState -> logging.log(LogEntry.StoredOnRebalance(on, newState)).as(newState)
}
}
.ensure(NotSubscribedException())(_.subscribed) >>
Expand All @@ -429,10 +435,16 @@ object KafkaConsumer {
OnRebalance(
onAssigned = assigned =>
initialAssignmentDone >>
assignmentRef.updateAndGet(_ ++ assigned).flatMap(updateQueue.offer),
assignmentRef.flatModify { oldAssignment =>
val newAssignment = oldAssignment ++ assigned
newAssignment -> updateQueue.offer(newAssignment)
},
onRevoked = revoked =>
initialAssignmentDone >>
assignmentRef.updateAndGet(_ -- revoked).flatMap(updateQueue.offer)
assignmentRef.flatModify { oldAssignment =>
val newAssignment = oldAssignment -- revoked
newAssignment -> updateQueue.offer(newAssignment)
}
)

Stream
Expand Down Expand Up @@ -514,15 +526,17 @@ object KafkaConsumer {

override def subscribe[G[_]](topics: G[String])(implicit G: Reducible[G]): F[Unit] =
withPermit {
withConsumer.blocking {
_.subscribe(
topics.toList.asJava,
actor.consumerRebalanceListener
)
} >> actor
.ref
.updateAndGet(_.asSubscribed)
.log(LogEntry.SubscribedTopics(topics.toNonEmptyList, _))
F.uncancelable { _ =>
withConsumer.blocking {
_.subscribe(
topics.toList.asJava,
actor.consumerRebalanceListener
)
} >> actor
.ref
.updateAndGet(_.asSubscribed)
.log(LogEntry.SubscribedTopics(topics.toNonEmptyList, _))
}
}

private def withPermit[A](fa: F[A]): F[A] =
Expand All @@ -535,36 +549,44 @@ object KafkaConsumer {

override def subscribe(regex: Regex): F[Unit] =
withPermit {
withConsumer.blocking {
_.subscribe(
regex.pattern,
actor.consumerRebalanceListener
)
} >>
actor.ref.updateAndGet(_.asSubscribed).log(LogEntry.SubscribedPattern(regex.pattern, _))
F.uncancelable { _ =>
withConsumer.blocking {
_.subscribe(
regex.pattern,
actor.consumerRebalanceListener
)
} >> actor
.ref
.updateAndGet(_.asSubscribed)
.log(LogEntry.SubscribedPattern(regex.pattern, _))
}
}

override def unsubscribe: F[Unit] =
withPermit {
withConsumer.blocking(_.unsubscribe()) >> actor
.ref
.updateAndGet(_.asUnsubscribed)
.log(LogEntry.Unsubscribed(_))
F.uncancelable { _ =>
withConsumer.blocking(_.unsubscribe()) >> actor
.ref
.updateAndGet(_.asUnsubscribed)
.log(LogEntry.Unsubscribed(_))
}
}

override def stopConsuming: F[Unit] =
stopConsumingDeferred.complete(()).attempt.void

override def assign(partitions: NonEmptySet[TopicPartition]): F[Unit] =
withPermit {
withConsumer.blocking {
_.assign(
partitions.toList.asJava
)
} >> actor
.ref
.updateAndGet(_.asSubscribed)
.log(LogEntry.ManuallyAssignedPartitions(partitions, _))
F.uncancelable { _ =>
withConsumer.blocking {
_.assign(
partitions.toList.asJava
)
} >> actor
.ref
.updateAndGet(_.asSubscribed)
.log(LogEntry.ManuallyAssignedPartitions(partitions, _))
}
}

override def assign(topic: String): F[Unit] =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,6 @@ final private[kafka] class KafkaConsumerActor[F[_], K, V](
jitter: Jitter[F]
) {

import logging.*

private[this] type ConsumerRecords =
Map[TopicPartition, NonEmptyVector[CommittableConsumerRecord[F, K, V]]]

Expand Down Expand Up @@ -91,17 +89,13 @@ final private[kafka] class KafkaConsumerActor[F[_], K, V](
.handleErrorWith(e => F.delay(callback(Left(e))))

private[this] def commit(request: Request.Commit[F]): F[Unit] =
ref
.modify { state =>
if (state.rebalancing) {
val newState = state.withPendingCommit(request)
(newState, Some(StoredPendingCommit(request, newState)))
} else (state, None)
}
.flatMap {
case Some(log) => logging.log(log)
case None => commitAsync(request.offsets, request.callback)
}
ref.flatModify { state =>
if (state.rebalancing) {
val newState = state.withPendingCommit(request)
(newState, logging.log(StoredPendingCommit(request, newState)))
} else
(state, commitAsync(request.offsets, request.callback))
}

private[this] def manualCommitSync(request: Request.ManualCommitSync[F]): F[Unit] = {
val commit =
Expand Down Expand Up @@ -142,11 +136,14 @@ final private[kafka] class KafkaConsumerActor[F[_], K, V](

private[this] def assigned(assigned: SortedSet[TopicPartition]): F[Unit] =
ref
.updateAndGet(_.withRebalancing(false))
.flatMap { state =>
log(AssignedPartitions(assigned, state)) >>
state.onRebalances.foldLeft(F.unit)(_ >> _.onAssigned(assigned))
.flatModify { state =>
val newState = state.withRebalancing(false)
(
newState,
logging.log(AssignedPartitions(assigned, state)).as(newState.onRebalances)
)
}
.flatMap(_.traverse_(_.onAssigned(assigned)))

private[this] def revoked(revoked: SortedSet[TopicPartition]): F[Unit] = {
def withState[A] = StateT.apply[Id, State[F, K, V], A](_)
Expand Down Expand Up @@ -198,7 +195,7 @@ final private[kafka] class KafkaConsumerActor[F[_], K, V](
}

ref
.modify { state =>
.flatModify { state =>
val withRebalancing = state.withRebalancing(true)

val fetches = withRebalancing.fetches.keySetStrict
Expand All @@ -214,24 +211,14 @@ final private[kafka] class KafkaConsumerActor[F[_], K, V](
completeWithRecords <- completeWithRecords(withRecords)
completeWithoutRecords <- completeWithoutRecords(withoutRecords)
removeRevokedRecords <- removeRevokedRecords(revokedNonFetches)
} yield RevokedResult(
logRevoked = logging.log(RevokedPartitions(revoked, withRebalancing)),
completeWithRecords = completeWithRecords,
completeWithoutRecords = completeWithoutRecords,
removeRevokedRecords = removeRevokedRecords,
onRebalances = withRebalancing.onRebalances
)).run(withRebalancing)
}
.flatMap { res =>
val onRevoked =
res.onRebalances.foldLeft(F.unit)(_ >> _.onRevoked(revoked))

res.logRevoked >>
res.completeWithRecords >>
res.completeWithoutRecords >>
res.removeRevokedRecords >>
onRevoked
} yield for {
_ <- logging.log(RevokedPartitions(revoked, withRebalancing))
_ <- completeWithRecords
_ <- completeWithoutRecords
_ <- removeRevokedRecords
} yield withRebalancing.onRebalances).run(withRebalancing)
}
.flatMap(_.traverse_(_.onRevoked(revoked)))
}

private[this] val offsetCommit: Map[TopicPartition, OffsetAndMetadata] => F[Unit] =
Expand Down Expand Up @@ -415,13 +402,18 @@ final private[kafka] class KafkaConsumerActor[F[_], K, V](
}) >> result.pendingCommits.traverse_(_.commit)
}
}

ref
.get
.flatMap { state =>
if (state.subscribed && state.streaming) {
val initialRebalancing = state.rebalancing
pollConsumer(state).flatMap(handlePoll(_, initialRebalancing))
} else F.unit
F.uncancelable { poll =>
val initialRebalancing = state.rebalancing
for {
records <- poll(pollConsumer(state))
_ <- handlePoll(records, initialRebalancing)
} yield ()
}
.whenA(state.subscribed && state.streaming)
}
}

Expand All @@ -434,14 +426,6 @@ final private[kafka] class KafkaConsumerActor[F[_], K, V](
case Request.WithPermit(fa, cb) => fa.attempt >>= cb
}

private[this] case class RevokedResult(
logRevoked: F[Unit],
completeWithRecords: F[Unit],
completeWithoutRecords: F[Unit],
removeRevokedRecords: F[Unit],
onRebalances: Chain[OnRebalance[F]]
)

sealed private[this] trait HandlePollResult {
def pendingCommits: Option[HandlePollResult.PendingCommits]
}
Expand Down

0 comments on commit 7d227f5

Please sign in to comment.