Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add streaming support #19

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion Sources/FalClient/Client+Request.swift
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,37 @@ extension Client {
return data
}

func fetchTemporaryAuthToken(for endpointId: String) async throws -> String {
let url = "https://rest.alpha.fal.ai/tokens/"
let body: Payload = try [
"allowed_apps": [.string(appAlias(fromId: endpointId))],
"token_expiration": 120,
]

let response = try await sendRequest(
to: url,
input: body.json(),
options: .withMethod(.post)
)
if let token = String(data: response, encoding: .utf8) {
return token.replacingOccurrences(of: "\"", with: "")
} else {
throw FalError.unauthorized(message: "Cannot generate token for \(endpointId)")
}
}

func buildEndpointUrl(fromId endpointId: String, path: String? = nil, scheme: String = "https", queryParams: [String: String] = [:]) -> URL {
let appId = (try? ensureAppIdFormat(endpointId)) ?? endpointId
guard var components = URLComponents(string: buildUrl(fromId: appId, path: path)) else {
preconditionFailure("Invalid URL. This is unexpected and likely a problem in the client library.")
}
components.scheme = scheme
components.queryItems = queryParams.map { URLQueryItem(name: $0.key, value: $0.value) }

// swiftlint:disable:next force_unwrapping
return components.url!
}

func checkResponseStatus(for response: URLResponse, withData data: Data) throws {
guard response is HTTPURLResponse else {
throw FalError.invalidResultFormat
Expand All @@ -77,6 +108,7 @@ extension Client {

var userAgent: String {
let osVersion = ProcessInfo.processInfo.operatingSystemVersionString
return "fal.ai/swift-client 0.1.0 - \(osVersion)"
// TODO: figure out how to parametrize this
return "fal.ai/swift-client 0.6.0 - \(osVersion)"
}
}
16 changes: 16 additions & 0 deletions Sources/FalClient/Client.swift
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ public protocol Client {

var realtime: Realtime { get }

var streaming: Streaming { get }

var storage: Storage { get }

func run(_ id: String, input: Payload?, options: RunOptions) async throws -> Payload
Expand All @@ -60,6 +62,12 @@ public protocol Client {
includeLogs: Bool,
onQueueUpdate: OnQueueUpdate?
) async throws -> Payload

func stream<Input, Output>(
from endpointId: String,
input: Input,
timeout: DispatchTimeInterval
) async throws -> FalStream<Input, Output> where Input: Encodable, Output: Decodable
}

public extension Client {
Expand Down Expand Up @@ -125,4 +133,12 @@ public extension Client {
includeLogs: includeLogs,
onQueueUpdate: onQueueUpdate)
}

func stream<Input, Output>(
from endpointId: String,
input: Input,
timeout: DispatchTimeInterval = .seconds(60)
) async throws -> FalStream<Input, Output> where Input: Encodable, Output: Decodable {
try await streaming.stream(from: endpointId, input: input, timeout: timeout)
}
}
2 changes: 2 additions & 0 deletions Sources/FalClient/FalClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ public struct FalClient: Client {

public var realtime: Realtime { RealtimeClient(client: self) }

public var streaming: Streaming { StreamingClient(client: self) }

public var storage: Storage { StorageClient(client: self) }

public func run(_ app: String, input: Payload?, options: RunOptions) async throws -> Payload {
Expand Down
20 changes: 20 additions & 0 deletions Sources/FalClient/Payload.swift
Original file line number Diff line number Diff line change
Expand Up @@ -268,3 +268,23 @@ extension Payload {
}
}
}

extension Encodable {
func asPayload() throws -> Payload {
if let payload = self as? Payload {
return payload
}
let data = try JSONEncoder().encode(self)
return try Payload.create(fromJSON: data)
}
}

extension Payload {
func asType<T: Decodable>(_: T.Type) throws -> T {
if T.self == Payload.self {
return self as! T
}
let data = try JSONEncoder().encode(self)
return try JSONDecoder().decode(T.self, from: data)
}
}
51 changes: 15 additions & 36 deletions Sources/FalClient/Realtime.swift
Original file line number Diff line number Diff line change
Expand Up @@ -126,23 +126,6 @@ let LegacyApps = [
"sd-turbo-real-time-high-fps-msgpack",
]

func buildRealtimeUrl(forApp app: String, token: String? = nil) -> URL {
// Some basic support for old ids, this should be removed during 1.0.0 release
// For full-support of old ids, users can point to version 0.4.x
let appAlias = (try? appAlias(fromId: app)) ?? app
let path = LegacyApps.contains(appAlias) || !app.contains("/") ? "/ws" : "/realtime"
guard var components = URLComponents(string: buildUrl(fromId: app, path: path)) else {
preconditionFailure("Invalid URL. This is unexpected and likely a problem in the client library.")
}
components.scheme = "wss"

if let token {
components.queryItems = [URLQueryItem(name: "fal_jwt_token", value: token)]
}
// swiftlint:disable:next force_unwrapping
return components.url!
}

typealias RefreshTokenFunction = (String, (Result<String, Error>) -> Void) -> Void

private let TokenExpirationInterval: DispatchTimeInterval = .minutes(1)
Expand Down Expand Up @@ -205,10 +188,7 @@ class WebSocketConnection: NSObject, URLSessionWebSocketDelegate {
return
}

let url = buildRealtimeUrl(
forApp: app,
token: token
)
let url = buildRealtimeUrl(token: token)
let webSocketTask = session.webSocketTask(with: url)
webSocketTask.delegate = self
task = webSocketTask
Expand All @@ -221,22 +201,9 @@ class WebSocketConnection: NSObject, URLSessionWebSocketDelegate {

func refreshToken(_ app: String, completion: @escaping (Result<String, Error>) -> Void) {
Task {
let url = "https://rest.alpha.fal.ai/tokens/"
let body: Payload = try [
"allowed_apps": [.string(appAlias(fromId: app))],
"token_expiration": 300,
]
do {
let response = try await self.client.sendRequest(
to: url,
input: body.json(),
options: .withMethod(.post)
)
if let token = String(data: response, encoding: .utf8) {
completion(.success(token.replacingOccurrences(of: "\"", with: "")))
} else {
completion(.failure(FalRealtimeError.unauthorized))
}
let token = try await self.client.fetchTemporaryAuthToken(for: app)
completion(.success(token.replacingOccurrences(of: "\"", with: "")))
} catch {
completion(.failure(error))
}
Expand Down Expand Up @@ -324,6 +291,18 @@ class WebSocketConnection: NSObject, URLSessionWebSocketDelegate {
}
task = nil
}

private func buildRealtimeUrl(token: String? = nil) -> URL {
// Some basic support for old ids, this should be removed during 1.0.0 release
// For full-support of old ids, users can point to version 0.4.x
let appAlias = (try? appAlias(fromId: app)) ?? app
let path = LegacyApps.contains(appAlias) || !app.contains("/") ? "/ws" : "/realtime"
var queryParams: [String: String] = [:]
if let token {
queryParams["fal_jwt_token"] = token
}
return client.buildEndpointUrl(fromId: app, path: path, scheme: "wss", queryParams: queryParams)
}
}

var connectionPool: [String: WebSocketConnection] = [:]
Expand Down
185 changes: 185 additions & 0 deletions Sources/FalClient/Streaming.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
import Combine
import Foundation

public class FalStream<Input: Codable, Output: Codable>: AsyncSequence {
public typealias Element = Output

private let url: URL
private let input: Input
private let timeout: DispatchTimeInterval

private let subject = PassthroughSubject<Element, Error>()
private var buffer: [Output] = []
private var currentData: Output?
private var lastEventTimestamp: Date = .init()
private var streamClosed = false
private var doneFuture: Future<Output, Error>? = nil

private var cancellables: Set<AnyCancellable> = []

public init(url: URL, input: Input, timeout: DispatchTimeInterval) {
self.url = url
self.input = input
self.timeout = timeout
}

public var publisher: AnyPublisher<Element, Error> {
subject.eraseToAnyPublisher()
}

func start() {
doneFuture = Future { promise in
self.subject
.last()
.sink(
receiveCompletion: { completion in
switch completion {
case .finished:
if let lastValue = self.currentData {
promise(.success(lastValue))
} else {
promise(.failure(StreamingApiError.emptyResponse))
}
case let .failure(error):
promise(.failure(error))
}
},
receiveValue: { _ in }
)
.store(in: &self.cancellables)
}

Task {
var request = URLRequest(url: url)
request.httpMethod = "POST"
request.addValue("text/event-stream", forHTTPHeaderField: "Accept")
request.addValue("application/json", forHTTPHeaderField: "Content-Type")
request.addValue("Keep-Alive", forHTTPHeaderField: "Connection")

do {
request.httpBody = try JSONEncoder().encode(input)
let (data, _) = try await URLSession.shared.bytes(for: request)

for try await content in data.lines {
// NOTE: naive approach that relies on each chunk to be a complete event
// revisit this in case endpoints start to handle SSE differently
if content.starts(with: "data:") {
let payloadData = content.dropFirst("data:".count)
.trimmingCharacters(in: .whitespacesAndNewlines)
.data(using: .utf8) ?? Data()
let eventData = try JSONDecoder().decode(Output.self, from: payloadData)
self.currentData = eventData
subject.send(eventData)
}
}
subject.send(completion: .finished)
} catch {
subject.send(completion: .failure(error))
return
}
}
}

private func checkTimeout() {
let currentTime = Date()
let timeInterval = currentTime.timeIntervalSince(lastEventTimestamp)
switch timeout {
case let .seconds(seconds):
if timeInterval > TimeInterval(seconds) {
handleError(StreamingApiError.timeout)
}
case let .milliseconds(milliseconds):
if timeInterval > TimeInterval(milliseconds) / 1000.0 {
handleError(StreamingApiError.timeout)
}
default:
break
}
}

private func handleError(_ error: Error) {
streamClosed = true
subject.send(completion: .failure(error))
}

public func makeAsyncIterator() -> AsyncThrowingStream<Element, Error>.AsyncIterator {
AsyncThrowingStream { continuation in
self.subject.sink(
receiveCompletion: { completion in
switch completion {
case .finished:
continuation.finish()
case let .failure(error):
continuation.finish(throwing: error)
}
},
receiveValue: { value in
continuation.yield(value)
}
).store(in: &self.cancellables)
}.makeAsyncIterator()
}

public func done() async throws -> Output {
guard let doneFuture else {
throw StreamingApiError.invalidState
}
return try await doneFuture.value
}
}

public typealias UntypedFalStream = FalStream<Payload, Payload>

enum StreamingApiError: Error {
case invalidState
case invalidResponse
case httpError(statusCode: Int)
case emptyResponse
case timeout
}

public protocol Streaming {
func stream<Input: Codable, Output: Codable>(
from endpointId: String,
input: Input,
timeout: DispatchTimeInterval
) async throws -> FalStream<Input, Output>
}

public extension Streaming {
func stream<Input: Codable, Output: Codable>(
from endpointId: String,
input: Input,
timeout: DispatchTimeInterval = .seconds(60)
) async throws -> FalStream<Input, Output> {
try await stream(from: endpointId, input: input, timeout: timeout)
}
}

public struct StreamingClient: Streaming {
public let client: Client

public func stream<Input, Output>(
from endpointId: String,
input: Input,
timeout: DispatchTimeInterval
) async throws -> FalStream<Input, Output> where Input: Codable, Output: Codable {
let token = try await client.fetchTemporaryAuthToken(for: endpointId)
let url = client.buildEndpointUrl(fromId: endpointId, path: "/stream", queryParams: [
"fal_jwt_token": token,
])

// TODO: improve auto-upload handling across different APIs
var inputPayload = input is EmptyInput ? nil : try input.asPayload()
if let storage = client.storage as? StorageClient,
inputPayload != nil,
inputPayload.hasBinaryData
{
inputPayload = try await storage.autoUpload(input: inputPayload)
}

let stream: FalStream<Input, Output> = try FalStream(url: url, input: inputPayload.asType(Input.self), timeout: timeout)
stream.start()
return stream
}
}
Loading
Loading