From 0b7b565783e3bb0444c72291ba803102a25cf89f Mon Sep 17 00:00:00 2001 From: Alex Hoppen Date: Thu, 24 Oct 2024 16:22:41 -0700 Subject: [PATCH] Fix a race condition that caused `withTimeout` to not escalate the priority of the body rdar://137678566 --- Sources/SwiftExtensions/AsyncUtils.swift | 11 ++++++++--- Tests/SKSupportTests/AsyncUtilsTests.swift | 1 - 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/Sources/SwiftExtensions/AsyncUtils.swift b/Sources/SwiftExtensions/AsyncUtils.swift index 9e57635fd..1d0d4466e 100644 --- a/Sources/SwiftExtensions/AsyncUtils.swift +++ b/Sources/SwiftExtensions/AsyncUtils.swift @@ -176,9 +176,14 @@ package func withTimeout( _ duration: Duration, _ body: @escaping @Sendable () async throws -> T ) async throws -> T { + // Get the priority with which to launch the body task here so that we can pass the same priority as the initial + // priority to `withTaskPriorityChangedHandler`. Otherwise, we can get into a race condition where bodyTask gets + // launched with a low priority, then the priority gets elevated before we call with `withTaskPriorityChangedHandler`, + // we thus don't receive a `taskPriorityChanged` and hence never increase the priority of `bodyTask`. + let priority = Task.currentPriority var mutableTasks: [Task] = [] let stream = AsyncThrowingStream { continuation in - let bodyTask = Task { + let bodyTask = Task(priority: priority) { do { let result = try await body() continuation.yield(result) @@ -187,7 +192,7 @@ package func withTimeout( } } - let timeoutTask = Task { + let timeoutTask = Task(priority: priority) { try await Task.sleep(for: duration) continuation.yield(with: .failure(TimeoutError())) bodyTask.cancel() @@ -197,7 +202,7 @@ package func withTimeout( let tasks = mutableTasks - return try await withTaskPriorityChangedHandler { + return try await withTaskPriorityChangedHandler(initialPriority: priority) { for try await value in stream { return value } diff --git a/Tests/SKSupportTests/AsyncUtilsTests.swift b/Tests/SKSupportTests/AsyncUtilsTests.swift index 21a3a409f..41af3d94f 100644 --- a/Tests/SKSupportTests/AsyncUtilsTests.swift +++ b/Tests/SKSupportTests/AsyncUtilsTests.swift @@ -52,7 +52,6 @@ final class AsyncUtilsTests: XCTestCase { } func testWithTimeoutEscalatesPriority() async throws { - try XCTSkipIf(true, "Flakey test: rdar://137640122") let expectation = self.expectation(description: "Timeout started") let task = Task(priority: .background) { // We don't actually hit the timeout. It's just a large value.