Skip to content

Commit

Permalink
Merge branch 'feat/assistants-api-4' of github.com:bwhtmn/OpenAI into…
Browse files Browse the repository at this point in the history
… feat/assistants-api-3

# Conflicts:
#	Demo/DemoChat/Sources/ChatStore.swift
#	Sources/OpenAI/Public/Models/ChatQuery.swift
#	Sources/OpenAI/Public/Models/ChatResult.swift
#	Sources/OpenAI/Public/Models/ImagesQuery.swift
#	Sources/OpenAI/Public/Models/ThreadsQuery.swift
#	Tests/OpenAITests/OpenAITests.swift
#	Tests/OpenAITests/OpenAITestsCombine.swift
#	Tests/OpenAITests/OpenAITestsDecoder.swift
  • Loading branch information
cdillard committed Feb 22, 2024
2 parents fa6d830 + a0d1126 commit 1f74731
Show file tree
Hide file tree
Showing 31 changed files with 1,515 additions and 396 deletions.
37 changes: 23 additions & 14 deletions Demo/DemoChat/Sources/AssistantStore.swift
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ public final class AssistantStore: ObservableObject {
// MARK: Models

@MainActor
func createAssistant(name: String, description: String, instructions: String, codeInterpreter: Bool, retrievel: Bool, fileIds: [String]? = nil) async -> String? {
func createAssistant(name: String, description: String, instructions: String, codeInterpreter: Bool, retrieval: Bool, functions: [FunctionDeclaration], fileIds: [String]? = nil) async -> String? {
do {
let tools = createToolsArray(codeInterpreter: codeInterpreter, retrieval: retrievel)
let tools = createToolsArray(codeInterpreter: codeInterpreter, retrieval: retrieval, functions: functions)
let query = AssistantsQuery(model: Model.gpt4_1106_preview, name: name, description: description, instructions: instructions, tools:tools, fileIds: fileIds)
let response = try await openAIClient.assistants(query: query, method: "POST", after: nil)
let response = try await openAIClient.assistantCreate(query: query)

// Refresh assistants with one just created (or modified)
let _ = await getAssistants()
Expand All @@ -47,11 +47,11 @@ public final class AssistantStore: ObservableObject {
}

@MainActor
func modifyAssistant(asstId: String, name: String, description: String, instructions: String, codeInterpreter: Bool, retrievel: Bool, fileIds: [String]? = nil) async -> String? {
func modifyAssistant(asstId: String, name: String, description: String, instructions: String, codeInterpreter: Bool, retrieval: Bool, functions: [FunctionDeclaration], fileIds: [String]? = nil) async -> String? {
do {
let tools = createToolsArray(codeInterpreter: codeInterpreter, retrieval: retrievel)
let tools = createToolsArray(codeInterpreter: codeInterpreter, retrieval: retrieval, functions: functions)
let query = AssistantsQuery(model: Model.gpt4_1106_preview, name: name, description: description, instructions: instructions, tools:tools, fileIds: fileIds)
let response = try await openAIClient.assistantModify(query: query, asstId: asstId)
let response = try await openAIClient.assistantModify(query: query, assistantId: asstId)

// Returns assistantId
return response.id
Expand All @@ -66,15 +66,24 @@ public final class AssistantStore: ObservableObject {
@MainActor
func getAssistants(limit: Int = 20, after: String? = nil) async -> [Assistant] {
do {
let response = try await openAIClient.assistants(query: nil, method: "GET", after: after)
let response = try await openAIClient.assistants(after: after)

var assistants = [Assistant]()
for result in response.data ?? [] {
let codeInterpreter = result.tools?.filter { $0.toolType == "code_interpreter" }.first != nil
let retrieval = result.tools?.filter { $0.toolType == "retrieval" }.first != nil
let tools = result.tools ?? []
let codeInterpreter = tools.contains { $0 == .codeInterpreter }
let retrieval = tools.contains { $0 == .retrieval }
let functions = tools.compactMap {
switch $0 {
case let .function(declaration):
return declaration
default:
return nil
}
}
let fileIds = result.fileIds ?? []

assistants.append(Assistant(id: result.id, name: result.name, description: result.description, instructions: result.instructions, codeInterpreter: codeInterpreter, retrieval: retrieval, fileIds: fileIds))
assistants.append(Assistant(id: result.id, name: result.name ?? "", description: result.description, instructions: result.instructions, codeInterpreter: codeInterpreter, retrieval: retrieval, fileIds: fileIds, functions: functions))
}
if after == nil {
availableAssistants = assistants
Expand Down Expand Up @@ -112,14 +121,14 @@ public final class AssistantStore: ObservableObject {
}
}

func createToolsArray(codeInterpreter: Bool, retrieval: Bool) -> [Tool] {
func createToolsArray(codeInterpreter: Bool, retrieval: Bool, functions: [FunctionDeclaration]) -> [Tool] {
var tools = [Tool]()
if codeInterpreter {
tools.append(Tool(toolType: "code_interpreter"))
tools.append(.codeInterpreter)
}
if retrieval {
tools.append(Tool(toolType: "retrieval"))
tools.append(.retrieval)
}
return tools
return tools + functions.map { .function($0) }
}
}
121 changes: 78 additions & 43 deletions Demo/DemoChat/Sources/ChatStore.swift
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ public final class ChatStore: ObservableObject {
guard let currentThreadId else { return print("No thread to add message to.")}

let _ = try await openAIClient.threadsAddMessage(threadId: currentThreadId,
query: ThreadAddMessageQuery(role: message.role.rawValue, content: message.content))
query: MessageQuery(role: message.role, content: message.content))

guard let currentAssistantId = conversations[conversationIndex].assistantId else { return print("No assistant selected.")}

Expand Down Expand Up @@ -249,19 +249,19 @@ public final class ChatStore: ObservableObject {
let result = try await openAIClient.runRetrieve(threadId: currentThreadId ?? "", runId: currentRunId ?? "")

// TESTING RETRIEVAL OF RUN STEPS
handleRunRetrieveSteps()
try await handleRunRetrieveSteps()

switch result.status {
// Get threadsMesages.
case "completed":
case .completed:
handleCompleted()
break
case "failed":
case .failed:
// Handle more gracefully with a popup dialog or failure indicator
await MainActor.run {
self.stopPolling()
}
break
case .requiresAction:
try await handleRequiresAction(result)
default:
// Handle additional statuses "requires_action", "queued" ?, "expired", "cancelled"
// https://platform.openai.com/docs/assistants/how-it-works/runs-and-run-steps
Expand Down Expand Up @@ -293,7 +293,7 @@ public final class ChatStore: ObservableObject {
for innerItem in item.content {
let message = Message(
id: item.id,
role: ChatQuery.ChatCompletionMessageParam.Role(rawValue: role) ?? .user,
role: role,
content: innerItem.text?.value ?? "",
createdAt: Date(),
isLocal: false // Messages from the server are not local
Expand All @@ -314,54 +314,89 @@ public final class ChatStore: ObservableObject {
}
}

// Store the function call as a message and submit tool outputs with a simple done message.
private func handleRequiresAction(_ result: RunResult) async throws {
guard let currentThreadId, let currentRunId else {
return
}

guard let toolCalls = result.requiredAction?.submitToolOutputs.toolCalls else {
return
}

var toolOutputs = [RunToolOutputsQuery.ToolOutput]()

for toolCall in toolCalls {
let msgContent = "function\nname: \(toolCall.function.name ?? "")\nargs: \(toolCall.function.arguments ?? "{}")"

let runStepMessage = Message(
id: toolCall.id,
role: .assistant,
content: msgContent,
createdAt: Date(),
isRunStep: true
)
await addOrUpdateRunStepMessage(runStepMessage)

// Just return a generic "Done" output for now
toolOutputs.append(.init(toolCallId: toolCall.id, output: "Done"))
}

let query = RunToolOutputsQuery(toolOutputs: toolOutputs)
_ = try await openAIClient.runSubmitToolOutputs(threadId: currentThreadId, runId: currentRunId, query: query)
}

// The run retrieval steps are fetched in a separate task. This request is fetched, checking for new run steps, each time the run is fetched.
private func handleRunRetrieveSteps() {
Task {
guard let conversationIndex = conversations.firstIndex(where: { $0.id == currentConversationId }) else {
return
}
var before: String?
private func handleRunRetrieveSteps() async throws {
var before: String?
// if let lastRunStepMessage = self.conversations[conversationIndex].messages.last(where: { $0.isRunStep == true }) {
// before = lastRunStepMessage.id
// }

let stepsResult = try await openAIClient.runRetrieveSteps(threadId: currentThreadId ?? "", runId: currentRunId ?? "", before: before)
let stepsResult = try await openAIClient.runRetrieveSteps(threadId: currentThreadId ?? "", runId: currentRunId ?? "", before: before)

for item in stepsResult.data.reversed() {
let toolCalls = item.stepDetails.toolCalls?.reversed() ?? []
for item in stepsResult.data.reversed() {
let toolCalls = item.stepDetails.toolCalls?.reversed() ?? []

for step in toolCalls {
// TODO: Depending on the type of tool tha is used we can add additional information here
// ie: if its a retrieval: add file information, code_interpreter: add inputs and outputs info, or function: add arguemts and additional info.
let msgContent: String
switch step.type {
case "retrieval":
msgContent = "RUN STEP: \(step.type)"
for step in toolCalls {
// TODO: Depending on the type of tool tha is used we can add additional information here
// ie: if its a retrieval: add file information, code_interpreter: add inputs and outputs info, or function: add arguemts and additional info.
let msgContent: String
switch step.type {
case .retrieval:
msgContent = "RUN STEP: \(step.type)"

case "code_interpreter":
msgContent = "code_interpreter\ninput:\n\(step.code?.input ?? "")\noutputs: \(step.code?.outputs?.first?.logs ?? "")"
case .codeInterpreter:
let code = step.codeInterpreter
msgContent = "code_interpreter\ninput:\n\(code?.input ?? "")\noutputs: \(code?.outputs?.first?.logs ?? "")"

default:
msgContent = "RUN STEP: \(step.type)"
case .function:
msgContent = "function\nname: \(step.function?.name ?? "")\nargs: \(step.function?.arguments ?? "{}")"

}
let runStepMessage = Message(
id: step.id,
role: .assistant,
content: msgContent,
createdAt: Date(),
isRunStep: true
)
await MainActor.run {
if let localMessageIndex = self.conversations[conversationIndex].messages.firstIndex(where: { $0.isRunStep == true && $0.id == step.id }) {
self.conversations[conversationIndex].messages[localMessageIndex] = runStepMessage
}
else {
self.conversations[conversationIndex].messages.append(runStepMessage)
}
}
}
let runStepMessage = Message(
id: step.id,
role: .assistant,
content: msgContent,
createdAt: Date(),
isRunStep: true
)
await addOrUpdateRunStepMessage(runStepMessage)
}
}
}

@MainActor
private func addOrUpdateRunStepMessage(_ message: Message) async {
guard let conversationIndex = conversations.firstIndex(where: { $0.id == currentConversationId }) else {
return
}

if let localMessageIndex = conversations[conversationIndex].messages.firstIndex(where: { $0.isRunStep == true && $0.id == message.id }) {
conversations[conversationIndex].messages[localMessageIndex] = message
}
else {
conversations[conversationIndex].messages.append(message)
}
}
}
13 changes: 12 additions & 1 deletion Demo/DemoChat/Sources/Models/Assistant.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,18 @@
//

import Foundation
import OpenAI

struct Assistant: Hashable {
init(id: String, name: String, description: String? = nil, instructions: String? = nil, codeInterpreter: Bool, retrieval: Bool, fileIds: [String]? = nil) {
init(id: String, name: String, description: String? = nil, instructions: String? = nil, codeInterpreter: Bool, retrieval: Bool, fileIds: [String]? = nil, functions: [FunctionDeclaration] = []) {
self.id = id
self.name = name
self.description = description
self.instructions = instructions
self.codeInterpreter = codeInterpreter
self.retrieval = retrieval
self.fileIds = fileIds
self.functions = functions
}

typealias ID = String
Expand All @@ -27,7 +29,16 @@ struct Assistant: Hashable {
let fileIds: [String]?
var codeInterpreter: Bool
var retrieval: Bool
var functions: [FunctionDeclaration]
}


extension Assistant: Equatable, Identifiable {}

extension FunctionDeclaration: Hashable {
public func hash(into hasher: inout Hasher) {
hasher.combine(name)
hasher.combine(description)
hasher.combine(parameters)
}
}
Loading

0 comments on commit 1f74731

Please sign in to comment.