Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add BidirectionalAsyncStream #3

Merged
merged 3 commits into from
Apr 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions Sources/SyncStream/AsyncSemphore.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
// Copyright 2024-2024 Ruiyang Sun. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import Dispatch
import Foundation

@available(macOS 10.15, *)
internal actor AsyncSemphore {
// MARK: Lifecycle

deinit {}

internal init(value: Int) {
self.value = value
}

// MARK: Internal

internal func wait() async {
value -= 1
if value < 0 {
_ = await withCheckedContinuation { continuation in
let workItem = DispatchWorkItem { continuation.resume() }
self.worksAndIDs.append((workItem, UUID()))
}
}
}

internal func wait(timeout: DispatchTime) async -> DispatchTimeoutResult {
await withCheckedContinuation { continuation in
value -= 1
if value >= 0 {
continuation.resume(returning: .success)
return
}

let id = UUID()
let workItem = DispatchWorkItem { continuation.resume(returning: .success) }
self.worksAndIDs.append((workItem, id))

queue.asyncAfter(deadline: timeout) {
Task {
if await self.removeWork(withID: id) {
continuation.resume(returning: .timedOut)
}
}
}
}
}

internal func wait(wallTimeout: DispatchWallTime) async -> DispatchTimeoutResult {
await withCheckedContinuation { continuation in
value -= 1
if value >= 0 {
continuation.resume(returning: .success)
return
}

let id = UUID()
let workItem = DispatchWorkItem { continuation.resume(returning: .success) }
self.worksAndIDs.append((workItem, id))

queue.asyncAfter(wallDeadline: wallTimeout) {
Task {
if await self.removeWork(withID: id) {
continuation.resume(returning: .timedOut)
}
}
}
}
}

internal func signal() async {
value += 1
if let work = worksAndIDs.first {
worksAndIDs.removeFirst()
queue.sync(execute: work.work)
}
}

// MARK: Private

private var value: Int
private var queue = DispatchQueue(label: "com.AsyncDispatchSemphore.\(UUID().uuidString)")
private var worksAndIDs = [(work: DispatchWorkItem, id: UUID)]()

private func removeWork(withID id: UUID) async -> Bool {
if let index = worksAndIDs.firstIndex(where: { $0.id == id }) {
worksAndIDs.remove(at: index)
value += 1
return true
}
return false
}
}
242 changes: 242 additions & 0 deletions Sources/SyncStream/BidirectionalAsyncStream.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
// Copyright 2024-2024 Ruiyang Sun. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

import Dispatch
import Foundation

// MARK: - BidirectionalAsyncStream

/// A mechanism inspired by Python's generator to allow for bidirectional communication between two
/// parties. One party can yield a value and the other party can send a value back.
@available(macOS 10.15, *)
public class BidirectionalAsyncStream<YieldT, SendT, ReturnT> {
// MARK: Lifecycle

/// Creates a new `BidirectionalSyncStream`.
///
/// - Parameters:
/// - YieldT: The type of the value to yield.
/// - SendT: The type of the value to send.
/// - ReturnT: The type of the value to return.
/// - build: A async closure that takes a `Continuation` and returns `Void`.
public init(
_: YieldT.Type = YieldT.self,
_: SendT.Type = SendT.self,
_: ReturnT.Type = ReturnT.self,
_ build: @escaping (Continuation) async -> Void
) {
self.build = build
continuation = Continuation()
}

deinit {}

// MARK: Public

/// Advances the stream to the next value. In this stream, it is used to
/// start the stream.
///
/// - Returns: The next value in the stream.
/// - Throws: `StopIteration` if the stream has finished.
/// - Throws: `WrongStreamUse` if invalid interaction with the stream is detected.
public func next() async throws -> YieldT {
if case let .finished(value) = finished {
throw StopIteration<ReturnT>(value: value)
}
if started {
throw WrongStreamUse(
message: "The BidirectionalSyncStream has already started, " +
"Use send() instead of next() to continue the stream."
)
}
await start()

await continuation.yieldSemaphore.wait()
switch continuation.state {
case let .yielded(value):
continuation.state = .waitingForSend
return value

case let .finished(value):
finished = .finished(value)
throw StopIteration(value: value)

default:
throw WrongStreamUse(message: "yield or return must be called in the continuation closure")
}
}

/// Sends a value to the stream, and returns the next value.
///
/// - Parameters:
/// - element: The value to send.
///
/// - Returns: The next value in the stream.
///
/// - Throws: `StopIteration` if the stream has finished.
/// - Throws: `WrongStreamUse` if invalid interaction with the stream is detected.
///
/// - Note: This method can only be called after calling `next()`.
public func send(_ element: SendT) async throws -> YieldT {
guard started else {
throw WrongStreamUse(
message: "The BidirectionalSyncStream has not started yet, " +
"Use next() to start the stream."
)
}

if case let .finished(value) = finished {
throw StopIteration<ReturnT>(value: value)
}

continuation.sendValue = element
continuation.state = .sended(element)
await continuation.sendSemaphore.signal()
await continuation.yieldSemaphore.wait()
switch continuation.state {
case let .yielded(value):
continuation.state = .waitingForSend
return value

case let .finished(value):
finished = .finished(value)
throw StopIteration(value: value)

default:
throw WrongStreamUse(message: "yield or return must be called in the continuation closure")
}
}

// MARK: Internal

internal enum State {
case idle
case yielded(YieldT)
case waitingForSend
case sended(SendT)
case finished(ReturnT)
}

// MARK: Private

private var started = false
private var finished: State = .idle
private var build: (Continuation) async -> Void
private var continuation: Continuation
private var queue = DispatchQueue(label: "com.BidirectionalAsyncStream.\(UUID().uuidString)")

private func start() async {
started = true
Task { await build(continuation) }
}
}

// MARK: BidirectionalAsyncStream.Continuation

@available(macOS 10.15, *)
public extension BidirectionalAsyncStream {
/// A continuation of the `BidirectionalAsyncStream`.
/// It is used to communicate between the two parties.
class Continuation {
// MARK: Lifecycle

deinit {}

// MARK: Public

/// Yields a value to the stream and waits for a value to be sent back.
///
/// - Parameters:
/// - element: The value to yield.
///
/// - Returns: The value sent back.
@discardableResult
public func yield(_ element: YieldT) async -> SendT {
if finished {
fatalError("The stream has finished. Cannot yield any more.")
}

state = .yielded(element)
await yieldSemaphore.signal()
await sendSemaphore.wait()
return sendValue!
}

/// Returns a value to the stream and finishes the stream.
/// This is the last call in the stream.
public func `return`(_ element: ReturnT) async {
if finished {
fatalError("The stream has finished. Cannot return any more.")
}

finished = true
state = .finished(element)
await yieldSemaphore.signal()
}

// MARK: Internal

internal var state: State = .idle
internal var yieldSemaphore = AsyncSemphore(value: 0)
internal var sendSemaphore = AsyncSemphore(value: 0)
internal var sendValue: SendT?

// MARK: Private

private var finished = false
}
}

@available(macOS 10.15, *)
public extension BidirectionalAsyncStream {
/// Converts the stream to a `SyncStream`.
///
/// Only works when the `SendT` type is `NoneType`, and the `YieldT` type is the same as the `ReturnT` type.
func toAsyncStream() async -> AsyncStream<YieldT> where SendT.Type == NoneType.Type, YieldT.Type == ReturnT.Type {
AsyncStream<YieldT> { continuation in
Task {
do {
let value = try await self.next()
continuation.yield(value)
while true {
let value = try await self.send(NoneType())
continuation.yield(value)
}
} catch {
if let value = (error as? StopIteration<ReturnT>)?.value {
continuation.yield(value)
}
continuation.finish()
}
}
}
}

/// Constructs an Bidrectional asynchronous stream from the Element Type
///
/// - Returns: A tuple containing the stream and its continuation. The continuation
/// should be passed to the producer while the stream should be passed to the consumer.
static func makeStream(
_: YieldT.Type = YieldT.self,
_: SendT.Type = SendT.self,
_: ReturnT.Type = ReturnT.self
) -> (
stream: BidirectionalAsyncStream<YieldT, SendT, ReturnT>,
continuation: BidirectionalAsyncStream<YieldT, SendT, ReturnT>.Continuation
) {
let stream = BidirectionalAsyncStream<YieldT, SendT, ReturnT> { _ in }
let continuation = stream.continuation
return (stream, continuation)
}
}
5 changes: 4 additions & 1 deletion Sources/SyncStream/BidirectionalSyncStream.swift
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,10 @@ public extension BidirectionalSyncStream {
_: YieldT.Type = YieldT.self,
_: SendT.Type = SendT.self,
_: ReturnT.Type = ReturnT.self
) -> (stream: BidirectionalSyncStream<YieldT, SendT, ReturnT>, continuation: BidirectionalSyncStream<YieldT, SendT, ReturnT>.Continuation) {
) -> (
stream: BidirectionalSyncStream<YieldT, SendT, ReturnT>,
continuation: BidirectionalSyncStream<YieldT, SendT, ReturnT>.Continuation
) {
let stream = BidirectionalSyncStream { _ in }
let continuation = stream.continuation
return (stream, continuation)
Expand Down
Loading