diff --git a/Sources/FalClient/Client+Request.swift b/Sources/FalClient/Client+Request.swift index 1917894..4e6a248 100644 --- a/Sources/FalClient/Client+Request.swift +++ b/Sources/FalClient/Client+Request.swift @@ -57,6 +57,37 @@ extension Client { return data } + func fetchTemporaryAuthToken(for endpointId: String) async throws -> String { + let url = "https://rest.alpha.fal.ai/tokens/" + let body: Payload = try [ + "allowed_apps": [.string(appAlias(fromId: endpointId))], + "token_expiration": 120, + ] + + let response = try await sendRequest( + to: url, + input: body.json(), + options: .withMethod(.post) + ) + if let token = String(data: response, encoding: .utf8) { + return token.replacingOccurrences(of: "\"", with: "") + } else { + throw FalError.unauthorized(message: "Cannot generate token for \(endpointId)") + } + } + + func buildEndpointUrl(fromId endpointId: String, path: String? = nil, scheme: String = "https", queryParams: [String: String] = [:]) -> URL { + let appId = (try? ensureAppIdFormat(endpointId)) ?? endpointId + guard var components = URLComponents(string: buildUrl(fromId: appId, path: path)) else { + preconditionFailure("Invalid URL. This is unexpected and likely a problem in the client library.") + } + components.scheme = scheme + components.queryItems = queryParams.map { URLQueryItem(name: $0.key, value: $0.value) } + + // swiftlint:disable:next force_unwrapping + return components.url! + } + func checkResponseStatus(for response: URLResponse, withData data: Data) throws { guard response is HTTPURLResponse else { throw FalError.invalidResultFormat @@ -77,6 +108,7 @@ extension Client { var userAgent: String { let osVersion = ProcessInfo.processInfo.operatingSystemVersionString - return "fal.ai/swift-client 0.1.0 - \(osVersion)" + // TODO: figure out how to parametrize this + return "fal.ai/swift-client 0.6.0 - \(osVersion)" } } diff --git a/Sources/FalClient/Client.swift b/Sources/FalClient/Client.swift index b97a060..44dd59f 100644 --- a/Sources/FalClient/Client.swift +++ b/Sources/FalClient/Client.swift @@ -37,6 +37,8 @@ public protocol Client { var realtime: Realtime { get } + var streaming: Streaming { get } + var storage: Storage { get } func run(_ id: String, input: Payload?, options: RunOptions) async throws -> Payload @@ -60,6 +62,12 @@ public protocol Client { includeLogs: Bool, onQueueUpdate: OnQueueUpdate? ) async throws -> Payload + + func stream( + from endpointId: String, + input: Input, + timeout: DispatchTimeInterval + ) async throws -> FalStream where Input: Encodable, Output: Decodable } public extension Client { @@ -125,4 +133,12 @@ public extension Client { includeLogs: includeLogs, onQueueUpdate: onQueueUpdate) } + + func stream( + from endpointId: String, + input: Input, + timeout: DispatchTimeInterval = .seconds(60) + ) async throws -> FalStream where Input: Encodable, Output: Decodable { + try await streaming.stream(from: endpointId, input: input, timeout: timeout) + } } diff --git a/Sources/FalClient/FalClient.swift b/Sources/FalClient/FalClient.swift index e0ad6db..2832f30 100644 --- a/Sources/FalClient/FalClient.swift +++ b/Sources/FalClient/FalClient.swift @@ -27,6 +27,8 @@ public struct FalClient: Client { public var realtime: Realtime { RealtimeClient(client: self) } + public var streaming: Streaming { StreamingClient(client: self) } + public var storage: Storage { StorageClient(client: self) } public func run(_ app: String, input: Payload?, options: RunOptions) async throws -> Payload { diff --git a/Sources/FalClient/Payload.swift b/Sources/FalClient/Payload.swift index c348d1c..728b18d 100644 --- a/Sources/FalClient/Payload.swift +++ b/Sources/FalClient/Payload.swift @@ -268,3 +268,23 @@ extension Payload { } } } + +extension Encodable { + func asPayload() throws -> Payload { + if let payload = self as? Payload { + return payload + } + let data = try JSONEncoder().encode(self) + return try Payload.create(fromJSON: data) + } +} + +extension Payload { + func asType(_: T.Type) throws -> T { + if T.self == Payload.self { + return self as! T + } + let data = try JSONEncoder().encode(self) + return try JSONDecoder().decode(T.self, from: data) + } +} diff --git a/Sources/FalClient/Realtime.swift b/Sources/FalClient/Realtime.swift index 8769104..ec3acd0 100644 --- a/Sources/FalClient/Realtime.swift +++ b/Sources/FalClient/Realtime.swift @@ -126,23 +126,6 @@ let LegacyApps = [ "sd-turbo-real-time-high-fps-msgpack", ] -func buildRealtimeUrl(forApp app: String, token: String? = nil) -> URL { - // Some basic support for old ids, this should be removed during 1.0.0 release - // For full-support of old ids, users can point to version 0.4.x - let appAlias = (try? appAlias(fromId: app)) ?? app - let path = LegacyApps.contains(appAlias) || !app.contains("/") ? "/ws" : "/realtime" - guard var components = URLComponents(string: buildUrl(fromId: app, path: path)) else { - preconditionFailure("Invalid URL. This is unexpected and likely a problem in the client library.") - } - components.scheme = "wss" - - if let token { - components.queryItems = [URLQueryItem(name: "fal_jwt_token", value: token)] - } - // swiftlint:disable:next force_unwrapping - return components.url! -} - typealias RefreshTokenFunction = (String, (Result) -> Void) -> Void private let TokenExpirationInterval: DispatchTimeInterval = .minutes(1) @@ -205,10 +188,7 @@ class WebSocketConnection: NSObject, URLSessionWebSocketDelegate { return } - let url = buildRealtimeUrl( - forApp: app, - token: token - ) + let url = buildRealtimeUrl(token: token) let webSocketTask = session.webSocketTask(with: url) webSocketTask.delegate = self task = webSocketTask @@ -221,22 +201,9 @@ class WebSocketConnection: NSObject, URLSessionWebSocketDelegate { func refreshToken(_ app: String, completion: @escaping (Result) -> Void) { Task { - let url = "https://rest.alpha.fal.ai/tokens/" - let body: Payload = try [ - "allowed_apps": [.string(appAlias(fromId: app))], - "token_expiration": 300, - ] do { - let response = try await self.client.sendRequest( - to: url, - input: body.json(), - options: .withMethod(.post) - ) - if let token = String(data: response, encoding: .utf8) { - completion(.success(token.replacingOccurrences(of: "\"", with: ""))) - } else { - completion(.failure(FalRealtimeError.unauthorized)) - } + let token = try await self.client.fetchTemporaryAuthToken(for: app) + completion(.success(token.replacingOccurrences(of: "\"", with: ""))) } catch { completion(.failure(error)) } @@ -324,6 +291,18 @@ class WebSocketConnection: NSObject, URLSessionWebSocketDelegate { } task = nil } + + private func buildRealtimeUrl(token: String? = nil) -> URL { + // Some basic support for old ids, this should be removed during 1.0.0 release + // For full-support of old ids, users can point to version 0.4.x + let appAlias = (try? appAlias(fromId: app)) ?? app + let path = LegacyApps.contains(appAlias) || !app.contains("/") ? "/ws" : "/realtime" + var queryParams: [String: String] = [:] + if let token { + queryParams["fal_jwt_token"] = token + } + return client.buildEndpointUrl(fromId: app, path: path, scheme: "wss", queryParams: queryParams) + } } var connectionPool: [String: WebSocketConnection] = [:] diff --git a/Sources/FalClient/Streaming.swift b/Sources/FalClient/Streaming.swift new file mode 100644 index 0000000..97a7429 --- /dev/null +++ b/Sources/FalClient/Streaming.swift @@ -0,0 +1,185 @@ +import Combine +import Foundation + +public class FalStream: AsyncSequence { + public typealias Element = Output + + private let url: URL + private let input: Input + private let timeout: DispatchTimeInterval + + private let subject = PassthroughSubject() + private var buffer: [Output] = [] + private var currentData: Output? + private var lastEventTimestamp: Date = .init() + private var streamClosed = false + private var doneFuture: Future? = nil + + private var cancellables: Set = [] + + public init(url: URL, input: Input, timeout: DispatchTimeInterval) { + self.url = url + self.input = input + self.timeout = timeout + } + + public var publisher: AnyPublisher { + subject.eraseToAnyPublisher() + } + + func start() { + doneFuture = Future { promise in + self.subject + .last() + .sink( + receiveCompletion: { completion in + switch completion { + case .finished: + if let lastValue = self.currentData { + promise(.success(lastValue)) + } else { + promise(.failure(StreamingApiError.emptyResponse)) + } + case let .failure(error): + promise(.failure(error)) + } + }, + receiveValue: { _ in } + ) + .store(in: &self.cancellables) + } + + Task { + var request = URLRequest(url: url) + request.httpMethod = "POST" + request.addValue("text/event-stream", forHTTPHeaderField: "Accept") + request.addValue("application/json", forHTTPHeaderField: "Content-Type") + request.addValue("Keep-Alive", forHTTPHeaderField: "Connection") + + do { + request.httpBody = try JSONEncoder().encode(input) + let (data, _) = try await URLSession.shared.bytes(for: request) + + for try await content in data.lines { + // NOTE: naive approach that relies on each chunk to be a complete event + // revisit this in case endpoints start to handle SSE differently + if content.starts(with: "data:") { + let payloadData = content.dropFirst("data:".count) + .trimmingCharacters(in: .whitespacesAndNewlines) + .data(using: .utf8) ?? Data() + let eventData = try JSONDecoder().decode(Output.self, from: payloadData) + self.currentData = eventData + subject.send(eventData) + } + } + subject.send(completion: .finished) + } catch { + subject.send(completion: .failure(error)) + return + } + } + } + + private func checkTimeout() { + let currentTime = Date() + let timeInterval = currentTime.timeIntervalSince(lastEventTimestamp) + switch timeout { + case let .seconds(seconds): + if timeInterval > TimeInterval(seconds) { + handleError(StreamingApiError.timeout) + } + case let .milliseconds(milliseconds): + if timeInterval > TimeInterval(milliseconds) / 1000.0 { + handleError(StreamingApiError.timeout) + } + default: + break + } + } + + private func handleError(_ error: Error) { + streamClosed = true + subject.send(completion: .failure(error)) + } + + public func makeAsyncIterator() -> AsyncThrowingStream.AsyncIterator { + AsyncThrowingStream { continuation in + self.subject.sink( + receiveCompletion: { completion in + switch completion { + case .finished: + continuation.finish() + case let .failure(error): + continuation.finish(throwing: error) + } + }, + receiveValue: { value in + continuation.yield(value) + } + ).store(in: &self.cancellables) + }.makeAsyncIterator() + } + + public func done() async throws -> Output { + guard let doneFuture else { + throw StreamingApiError.invalidState + } + return try await doneFuture.value + } +} + +public typealias UntypedFalStream = FalStream + +enum StreamingApiError: Error { + case invalidState + case invalidResponse + case httpError(statusCode: Int) + case emptyResponse + case timeout +} + +public protocol Streaming { + func stream( + from endpointId: String, + input: Input, + timeout: DispatchTimeInterval + ) async throws -> FalStream +} + +public extension Streaming { + func stream( + from endpointId: String, + input: Input, + timeout: DispatchTimeInterval = .seconds(60) + ) async throws -> FalStream { + try await stream(from: endpointId, input: input, timeout: timeout) + } +} + +public struct StreamingClient: Streaming { + public let client: Client + + public func stream( + from endpointId: String, + input: Input, + timeout: DispatchTimeInterval + ) async throws -> FalStream where Input: Codable, Output: Codable { + let token = try await client.fetchTemporaryAuthToken(for: endpointId) + let url = client.buildEndpointUrl(fromId: endpointId, path: "/stream", queryParams: [ + "fal_jwt_token": token, + ]) + + // TODO: improve auto-upload handling across different APIs + var inputPayload = input is EmptyInput ? nil : try input.asPayload() + if let storage = client.storage as? StorageClient, + inputPayload != nil, + inputPayload.hasBinaryData + { + inputPayload = try await storage.autoUpload(input: inputPayload) + } + + let stream: FalStream = try FalStream(url: url, input: inputPayload.asType(Input.self), timeout: timeout) + stream.start() + return stream + } +} diff --git a/Sources/FalClient/Utility.swift b/Sources/FalClient/Utility.swift index e4da1c3..572a53d 100644 --- a/Sources/FalClient/Utility.swift +++ b/Sources/FalClient/Utility.swift @@ -28,10 +28,13 @@ func appAlias(fromId id: String) throws -> String { try AppId.parse(id: id).appAlias } +let ReservedNamescapes = ["workflows", "comfy"] + struct AppId { let ownerId: String let appAlias: String let path: String? + let namespace: String? static func parse(id: String) throws -> Self { let appId = try ensureAppIdFormat(id) @@ -39,10 +42,18 @@ struct AppId { guard parts.count > 1 else { throw FalError.invalidAppId(id: id) } + + var namespace: String? = nil + if ReservedNamescapes.contains(parts[0]) { + namespace = parts[0] + } + + let startIndex = namespace != nil ? 1 : 0 return Self( - ownerId: parts[0], - appAlias: parts[1], - path: parts.endIndex > 2 ? parts.dropFirst(2).joined(separator: "/") : nil + ownerId: parts[startIndex], + appAlias: parts[startIndex + 1], + path: parts.endIndex > startIndex + 2 ? parts.dropFirst(startIndex + 2).joined(separator: "/") : nil, + namespace: namespace ) } } diff --git a/Sources/FalClient/Workflow.swift b/Sources/FalClient/Workflow.swift new file mode 100644 index 0000000..d934b92 --- /dev/null +++ b/Sources/FalClient/Workflow.swift @@ -0,0 +1,65 @@ + +public enum WorkflowEventType: String, Codable { + case submit + case completion + case output + case error +} + +protocol WorkflowEvent: Codable { + var type: WorkflowEventType { get } + var nodeId: String { get } +} + +struct WorkflowSubmitEvent: WorkflowEvent { + let type: WorkflowEventType = .submit + let nodeId: String + let appId: String + let requestId: String + + enum CodingKeys: String { + case type + case nodeId = "node_id" + case appId = "app_id" + case requestId = "request_id" + } +} + +struct WorkflowOutputEvent: WorkflowEvent { + let type: WorkflowEventType = .output + let nodeId: String + + enum CodingKeys: String { + case type + case nodeId = "node_id" + } +} + +struct WorkflowCompletionEvent: WorkflowEvent { + let type: WorkflowEventType = .completion + let nodeId: String + let appId: String + let output: Payload + + enum CodingKeys: String { + case type + case nodeId = "node_id" + case appId = "app_id" + case output + } +} + +struct WorkflowErrorEvent: WorkflowEvent { + let type: WorkflowEventType = .error + let nodeId: String + let message: String + // TODO: decode the underlying error to a more specific type + let error: Payload +} + +public enum WorkflowEventData { + case submit(WorkflowSubmitEvent) + case completion(WorkflowCompletionEvent) + case output(WorkflowOutputEvent) + case error(WorkflowErrorEvent) +} diff --git a/Tests/FalClientTests/UtilitySpec.swift b/Tests/FalClientTests/UtilitySpec.swift index e32705a..be5de45 100644 --- a/Tests/FalClientTests/UtilitySpec.swift +++ b/Tests/FalClientTests/UtilitySpec.swift @@ -29,6 +29,20 @@ class UtilitySpec: QuickSpec { expect(appId.appAlias).to(equal("fast-sdxl")) expect(appId.path).to(equal("image-to-image")) } + it ("should parse an id with a namespace") { + let appId = try AppId.parse(id: "workflows/fal-ai/fast-sdxl") + expect(appId.ownerId).to(equal("fal-ai")) + expect(appId.appAlias).to(equal("fast-sdxl")) + expect(appId.path).to(beNil()) + expect(appId.namespace).to(equal("workflows")) + } + it ("should parse an id with a namespace and a path") { + let appId = try AppId.parse(id: "comfy/fal-ai/fast-sdxl/image-to-image") + expect(appId.ownerId).to(equal("fal-ai")) + expect(appId.appAlias).to(equal("fast-sdxl")) + expect(appId.path).to(equal("image-to-image")) + expect(appId.namespace).to(equal("comfy")) + } } } }