diff --git a/Sources/FalClient/Client+Request.swift b/Sources/FalClient/Client+Request.swift
index 1917894..4e6a248 100644
--- a/Sources/FalClient/Client+Request.swift
+++ b/Sources/FalClient/Client+Request.swift
@@ -57,6 +57,37 @@ extension Client {
return data
}
+ func fetchTemporaryAuthToken(for endpointId: String) async throws -> String {
+ let url = "https://rest.alpha.fal.ai/tokens/"
+ let body: Payload = try [
+ "allowed_apps": [.string(appAlias(fromId: endpointId))],
+ "token_expiration": 120,
+ ]
+
+ let response = try await sendRequest(
+ to: url,
+ input: body.json(),
+ options: .withMethod(.post)
+ )
+ if let token = String(data: response, encoding: .utf8) {
+ return token.replacingOccurrences(of: "\"", with: "")
+ } else {
+ throw FalError.unauthorized(message: "Cannot generate token for \(endpointId)")
+ }
+ }
+
+ func buildEndpointUrl(fromId endpointId: String, path: String? = nil, scheme: String = "https", queryParams: [String: String] = [:]) -> URL {
+ let appId = (try? ensureAppIdFormat(endpointId)) ?? endpointId
+ guard var components = URLComponents(string: buildUrl(fromId: appId, path: path)) else {
+ preconditionFailure("Invalid URL. This is unexpected and likely a problem in the client library.")
+ }
+ components.scheme = scheme
+ components.queryItems = queryParams.map { URLQueryItem(name: $0.key, value: $0.value) }
+
+ // swiftlint:disable:next force_unwrapping
+ return components.url!
+ }
+
func checkResponseStatus(for response: URLResponse, withData data: Data) throws {
guard response is HTTPURLResponse else {
throw FalError.invalidResultFormat
@@ -77,6 +108,7 @@ extension Client {
var userAgent: String {
let osVersion = ProcessInfo.processInfo.operatingSystemVersionString
- return "fal.ai/swift-client 0.1.0 - \(osVersion)"
+ // TODO: figure out how to parametrize this
+ return "fal.ai/swift-client 0.6.0 - \(osVersion)"
}
}
diff --git a/Sources/FalClient/Client.swift b/Sources/FalClient/Client.swift
index b97a060..44dd59f 100644
--- a/Sources/FalClient/Client.swift
+++ b/Sources/FalClient/Client.swift
@@ -37,6 +37,8 @@ public protocol Client {
var realtime: Realtime { get }
+ var streaming: Streaming { get }
+
var storage: Storage { get }
func run(_ id: String, input: Payload?, options: RunOptions) async throws -> Payload
@@ -60,6 +62,12 @@ public protocol Client {
includeLogs: Bool,
onQueueUpdate: OnQueueUpdate?
) async throws -> Payload
+
+ func stream(
+ from endpointId: String,
+ input: Input,
+ timeout: DispatchTimeInterval
+ ) async throws -> FalStream where Input: Encodable, Output: Decodable
}
public extension Client {
@@ -125,4 +133,12 @@ public extension Client {
includeLogs: includeLogs,
onQueueUpdate: onQueueUpdate)
}
+
+ func stream(
+ from endpointId: String,
+ input: Input,
+ timeout: DispatchTimeInterval = .seconds(60)
+ ) async throws -> FalStream where Input: Encodable, Output: Decodable {
+ try await streaming.stream(from: endpointId, input: input, timeout: timeout)
+ }
}
diff --git a/Sources/FalClient/FalClient.swift b/Sources/FalClient/FalClient.swift
index e0ad6db..2832f30 100644
--- a/Sources/FalClient/FalClient.swift
+++ b/Sources/FalClient/FalClient.swift
@@ -27,6 +27,8 @@ public struct FalClient: Client {
public var realtime: Realtime { RealtimeClient(client: self) }
+ public var streaming: Streaming { StreamingClient(client: self) }
+
public var storage: Storage { StorageClient(client: self) }
public func run(_ app: String, input: Payload?, options: RunOptions) async throws -> Payload {
diff --git a/Sources/FalClient/Payload.swift b/Sources/FalClient/Payload.swift
index c348d1c..728b18d 100644
--- a/Sources/FalClient/Payload.swift
+++ b/Sources/FalClient/Payload.swift
@@ -268,3 +268,23 @@ extension Payload {
}
}
}
+
+extension Encodable {
+ func asPayload() throws -> Payload {
+ if let payload = self as? Payload {
+ return payload
+ }
+ let data = try JSONEncoder().encode(self)
+ return try Payload.create(fromJSON: data)
+ }
+}
+
+extension Payload {
+ func asType(_: T.Type) throws -> T {
+ if T.self == Payload.self {
+ return self as! T
+ }
+ let data = try JSONEncoder().encode(self)
+ return try JSONDecoder().decode(T.self, from: data)
+ }
+}
diff --git a/Sources/FalClient/Realtime.swift b/Sources/FalClient/Realtime.swift
index 8769104..ec3acd0 100644
--- a/Sources/FalClient/Realtime.swift
+++ b/Sources/FalClient/Realtime.swift
@@ -126,23 +126,6 @@ let LegacyApps = [
"sd-turbo-real-time-high-fps-msgpack",
]
-func buildRealtimeUrl(forApp app: String, token: String? = nil) -> URL {
- // Some basic support for old ids, this should be removed during 1.0.0 release
- // For full-support of old ids, users can point to version 0.4.x
- let appAlias = (try? appAlias(fromId: app)) ?? app
- let path = LegacyApps.contains(appAlias) || !app.contains("/") ? "/ws" : "/realtime"
- guard var components = URLComponents(string: buildUrl(fromId: app, path: path)) else {
- preconditionFailure("Invalid URL. This is unexpected and likely a problem in the client library.")
- }
- components.scheme = "wss"
-
- if let token {
- components.queryItems = [URLQueryItem(name: "fal_jwt_token", value: token)]
- }
- // swiftlint:disable:next force_unwrapping
- return components.url!
-}
-
typealias RefreshTokenFunction = (String, (Result) -> Void) -> Void
private let TokenExpirationInterval: DispatchTimeInterval = .minutes(1)
@@ -205,10 +188,7 @@ class WebSocketConnection: NSObject, URLSessionWebSocketDelegate {
return
}
- let url = buildRealtimeUrl(
- forApp: app,
- token: token
- )
+ let url = buildRealtimeUrl(token: token)
let webSocketTask = session.webSocketTask(with: url)
webSocketTask.delegate = self
task = webSocketTask
@@ -221,22 +201,9 @@ class WebSocketConnection: NSObject, URLSessionWebSocketDelegate {
func refreshToken(_ app: String, completion: @escaping (Result) -> Void) {
Task {
- let url = "https://rest.alpha.fal.ai/tokens/"
- let body: Payload = try [
- "allowed_apps": [.string(appAlias(fromId: app))],
- "token_expiration": 300,
- ]
do {
- let response = try await self.client.sendRequest(
- to: url,
- input: body.json(),
- options: .withMethod(.post)
- )
- if let token = String(data: response, encoding: .utf8) {
- completion(.success(token.replacingOccurrences(of: "\"", with: "")))
- } else {
- completion(.failure(FalRealtimeError.unauthorized))
- }
+ let token = try await self.client.fetchTemporaryAuthToken(for: app)
+ completion(.success(token.replacingOccurrences(of: "\"", with: "")))
} catch {
completion(.failure(error))
}
@@ -324,6 +291,18 @@ class WebSocketConnection: NSObject, URLSessionWebSocketDelegate {
}
task = nil
}
+
+ private func buildRealtimeUrl(token: String? = nil) -> URL {
+ // Some basic support for old ids, this should be removed during 1.0.0 release
+ // For full-support of old ids, users can point to version 0.4.x
+ let appAlias = (try? appAlias(fromId: app)) ?? app
+ let path = LegacyApps.contains(appAlias) || !app.contains("/") ? "/ws" : "/realtime"
+ var queryParams: [String: String] = [:]
+ if let token {
+ queryParams["fal_jwt_token"] = token
+ }
+ return client.buildEndpointUrl(fromId: app, path: path, scheme: "wss", queryParams: queryParams)
+ }
}
var connectionPool: [String: WebSocketConnection] = [:]
diff --git a/Sources/FalClient/Streaming.swift b/Sources/FalClient/Streaming.swift
new file mode 100644
index 0000000..97a7429
--- /dev/null
+++ b/Sources/FalClient/Streaming.swift
@@ -0,0 +1,185 @@
+import Combine
+import Foundation
+
+public class FalStream: AsyncSequence {
+ public typealias Element = Output
+
+ private let url: URL
+ private let input: Input
+ private let timeout: DispatchTimeInterval
+
+ private let subject = PassthroughSubject()
+ private var buffer: [Output] = []
+ private var currentData: Output?
+ private var lastEventTimestamp: Date = .init()
+ private var streamClosed = false
+ private var doneFuture: Future