-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add
BidirectionalAsyncStream
(#3)
* chore(BidrectionalSyncStream): too long line * feat: add `BidirectionalAsyncStream` * test: add tests for `BidirectionalAsyncStream`
- Loading branch information
1 parent
c686681
commit d86854a
Showing
5 changed files
with
600 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
// Copyright 2024-2024 Ruiyang Sun. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
import Dispatch | ||
import Foundation | ||
|
||
@available(macOS 10.15, *) | ||
internal actor AsyncSemphore { | ||
// MARK: Lifecycle | ||
|
||
deinit {} | ||
|
||
internal init(value: Int) { | ||
self.value = value | ||
} | ||
|
||
// MARK: Internal | ||
|
||
internal func wait() async { | ||
value -= 1 | ||
if value < 0 { | ||
_ = await withCheckedContinuation { continuation in | ||
let workItem = DispatchWorkItem { continuation.resume() } | ||
self.worksAndIDs.append((workItem, UUID())) | ||
} | ||
} | ||
} | ||
|
||
internal func wait(timeout: DispatchTime) async -> DispatchTimeoutResult { | ||
await withCheckedContinuation { continuation in | ||
value -= 1 | ||
if value >= 0 { | ||
continuation.resume(returning: .success) | ||
return | ||
} | ||
|
||
let id = UUID() | ||
let workItem = DispatchWorkItem { continuation.resume(returning: .success) } | ||
self.worksAndIDs.append((workItem, id)) | ||
|
||
queue.asyncAfter(deadline: timeout) { | ||
Task { | ||
if await self.removeWork(withID: id) { | ||
continuation.resume(returning: .timedOut) | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
internal func wait(wallTimeout: DispatchWallTime) async -> DispatchTimeoutResult { | ||
await withCheckedContinuation { continuation in | ||
value -= 1 | ||
if value >= 0 { | ||
continuation.resume(returning: .success) | ||
return | ||
} | ||
|
||
let id = UUID() | ||
let workItem = DispatchWorkItem { continuation.resume(returning: .success) } | ||
self.worksAndIDs.append((workItem, id)) | ||
|
||
queue.asyncAfter(wallDeadline: wallTimeout) { | ||
Task { | ||
if await self.removeWork(withID: id) { | ||
continuation.resume(returning: .timedOut) | ||
} | ||
} | ||
} | ||
} | ||
} | ||
|
||
internal func signal() async { | ||
value += 1 | ||
if let work = worksAndIDs.first { | ||
worksAndIDs.removeFirst() | ||
queue.sync(execute: work.work) | ||
} | ||
} | ||
|
||
// MARK: Private | ||
|
||
private var value: Int | ||
private var queue = DispatchQueue(label: "com.AsyncDispatchSemphore.\(UUID().uuidString)") | ||
private var worksAndIDs = [(work: DispatchWorkItem, id: UUID)]() | ||
|
||
private func removeWork(withID id: UUID) async -> Bool { | ||
if let index = worksAndIDs.firstIndex(where: { $0.id == id }) { | ||
worksAndIDs.remove(at: index) | ||
value += 1 | ||
return true | ||
} | ||
return false | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,242 @@ | ||
// Copyright 2024-2024 Ruiyang Sun. All Rights Reserved. | ||
// | ||
// Licensed under the Apache License, Version 2.0 (the "License"); | ||
// you may not use this file except in compliance with the License. | ||
// You may obtain a copy of the License at | ||
// | ||
// http://www.apache.org/licenses/LICENSE-2.0 | ||
// | ||
// Unless required by applicable law or agreed to in writing, software | ||
// distributed under the License is distributed on an "AS IS" BASIS, | ||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
// See the License for the specific language governing permissions and | ||
// limitations under the License. | ||
|
||
import Dispatch | ||
import Foundation | ||
|
||
// MARK: - BidirectionalAsyncStream | ||
|
||
/// A mechanism inspired by Python's generator to allow for bidirectional communication between two | ||
/// parties. One party can yield a value and the other party can send a value back. | ||
@available(macOS 10.15, *) | ||
public class BidirectionalAsyncStream<YieldT, SendT, ReturnT> { | ||
// MARK: Lifecycle | ||
|
||
/// Creates a new `BidirectionalSyncStream`. | ||
/// | ||
/// - Parameters: | ||
/// - YieldT: The type of the value to yield. | ||
/// - SendT: The type of the value to send. | ||
/// - ReturnT: The type of the value to return. | ||
/// - build: A async closure that takes a `Continuation` and returns `Void`. | ||
public init( | ||
_: YieldT.Type = YieldT.self, | ||
_: SendT.Type = SendT.self, | ||
_: ReturnT.Type = ReturnT.self, | ||
_ build: @escaping (Continuation) async -> Void | ||
) { | ||
self.build = build | ||
continuation = Continuation() | ||
} | ||
|
||
deinit {} | ||
|
||
// MARK: Public | ||
|
||
/// Advances the stream to the next value. In this stream, it is used to | ||
/// start the stream. | ||
/// | ||
/// - Returns: The next value in the stream. | ||
/// - Throws: `StopIteration` if the stream has finished. | ||
/// - Throws: `WrongStreamUse` if invalid interaction with the stream is detected. | ||
public func next() async throws -> YieldT { | ||
if case let .finished(value) = finished { | ||
throw StopIteration<ReturnT>(value: value) | ||
} | ||
if started { | ||
throw WrongStreamUse( | ||
message: "The BidirectionalSyncStream has already started, " + | ||
"Use send() instead of next() to continue the stream." | ||
) | ||
} | ||
await start() | ||
|
||
await continuation.yieldSemaphore.wait() | ||
switch continuation.state { | ||
case let .yielded(value): | ||
continuation.state = .waitingForSend | ||
return value | ||
|
||
case let .finished(value): | ||
finished = .finished(value) | ||
throw StopIteration(value: value) | ||
|
||
default: | ||
throw WrongStreamUse(message: "yield or return must be called in the continuation closure") | ||
} | ||
} | ||
|
||
/// Sends a value to the stream, and returns the next value. | ||
/// | ||
/// - Parameters: | ||
/// - element: The value to send. | ||
/// | ||
/// - Returns: The next value in the stream. | ||
/// | ||
/// - Throws: `StopIteration` if the stream has finished. | ||
/// - Throws: `WrongStreamUse` if invalid interaction with the stream is detected. | ||
/// | ||
/// - Note: This method can only be called after calling `next()`. | ||
public func send(_ element: SendT) async throws -> YieldT { | ||
guard started else { | ||
throw WrongStreamUse( | ||
message: "The BidirectionalSyncStream has not started yet, " + | ||
"Use next() to start the stream." | ||
) | ||
} | ||
|
||
if case let .finished(value) = finished { | ||
throw StopIteration<ReturnT>(value: value) | ||
} | ||
|
||
continuation.sendValue = element | ||
continuation.state = .sended(element) | ||
await continuation.sendSemaphore.signal() | ||
await continuation.yieldSemaphore.wait() | ||
switch continuation.state { | ||
case let .yielded(value): | ||
continuation.state = .waitingForSend | ||
return value | ||
|
||
case let .finished(value): | ||
finished = .finished(value) | ||
throw StopIteration(value: value) | ||
|
||
default: | ||
throw WrongStreamUse(message: "yield or return must be called in the continuation closure") | ||
} | ||
} | ||
|
||
// MARK: Internal | ||
|
||
internal enum State { | ||
case idle | ||
case yielded(YieldT) | ||
case waitingForSend | ||
case sended(SendT) | ||
case finished(ReturnT) | ||
} | ||
|
||
// MARK: Private | ||
|
||
private var started = false | ||
private var finished: State = .idle | ||
private var build: (Continuation) async -> Void | ||
private var continuation: Continuation | ||
private var queue = DispatchQueue(label: "com.BidirectionalAsyncStream.\(UUID().uuidString)") | ||
|
||
private func start() async { | ||
started = true | ||
Task { await build(continuation) } | ||
} | ||
} | ||
|
||
// MARK: BidirectionalAsyncStream.Continuation | ||
|
||
@available(macOS 10.15, *) | ||
public extension BidirectionalAsyncStream { | ||
/// A continuation of the `BidirectionalAsyncStream`. | ||
/// It is used to communicate between the two parties. | ||
class Continuation { | ||
// MARK: Lifecycle | ||
|
||
deinit {} | ||
|
||
// MARK: Public | ||
|
||
/// Yields a value to the stream and waits for a value to be sent back. | ||
/// | ||
/// - Parameters: | ||
/// - element: The value to yield. | ||
/// | ||
/// - Returns: The value sent back. | ||
@discardableResult | ||
public func yield(_ element: YieldT) async -> SendT { | ||
if finished { | ||
fatalError("The stream has finished. Cannot yield any more.") | ||
} | ||
|
||
state = .yielded(element) | ||
await yieldSemaphore.signal() | ||
await sendSemaphore.wait() | ||
return sendValue! | ||
} | ||
|
||
/// Returns a value to the stream and finishes the stream. | ||
/// This is the last call in the stream. | ||
public func `return`(_ element: ReturnT) async { | ||
if finished { | ||
fatalError("The stream has finished. Cannot return any more.") | ||
} | ||
|
||
finished = true | ||
state = .finished(element) | ||
await yieldSemaphore.signal() | ||
} | ||
|
||
// MARK: Internal | ||
|
||
internal var state: State = .idle | ||
internal var yieldSemaphore = AsyncSemphore(value: 0) | ||
internal var sendSemaphore = AsyncSemphore(value: 0) | ||
internal var sendValue: SendT? | ||
|
||
// MARK: Private | ||
|
||
private var finished = false | ||
} | ||
} | ||
|
||
@available(macOS 10.15, *) | ||
public extension BidirectionalAsyncStream { | ||
/// Converts the stream to a `SyncStream`. | ||
/// | ||
/// Only works when the `SendT` type is `NoneType`, and the `YieldT` type is the same as the `ReturnT` type. | ||
func toAsyncStream() async -> AsyncStream<YieldT> where SendT.Type == NoneType.Type, YieldT.Type == ReturnT.Type { | ||
AsyncStream<YieldT> { continuation in | ||
Task { | ||
do { | ||
let value = try await self.next() | ||
continuation.yield(value) | ||
while true { | ||
let value = try await self.send(NoneType()) | ||
continuation.yield(value) | ||
} | ||
} catch { | ||
if let value = (error as? StopIteration<ReturnT>)?.value { | ||
continuation.yield(value) | ||
} | ||
continuation.finish() | ||
} | ||
} | ||
} | ||
} | ||
|
||
/// Constructs an Bidrectional asynchronous stream from the Element Type | ||
/// | ||
/// - Returns: A tuple containing the stream and its continuation. The continuation | ||
/// should be passed to the producer while the stream should be passed to the consumer. | ||
static func makeStream( | ||
_: YieldT.Type = YieldT.self, | ||
_: SendT.Type = SendT.self, | ||
_: ReturnT.Type = ReturnT.self | ||
) -> ( | ||
stream: BidirectionalAsyncStream<YieldT, SendT, ReturnT>, | ||
continuation: BidirectionalAsyncStream<YieldT, SendT, ReturnT>.Continuation | ||
) { | ||
let stream = BidirectionalAsyncStream<YieldT, SendT, ReturnT> { _ in } | ||
let continuation = stream.continuation | ||
return (stream, continuation) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.