Skip to content

Commit

Permalink
feat: add token count support (#85)
Browse files Browse the repository at this point in the history
  • Loading branch information
FelixKuhnAnsys authored Oct 10, 2024
1 parent 27f633c commit 043ae70
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 28 deletions.
5 changes: 3 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@ go 1.22.0
toolchain go1.22.5

require (
github.com/ansys/allie-sharedtypes v0.0.0-20240926111235-0e4d60bb076e
github.com/ansys/allie-sharedtypes v0.0.0-20241008142654-5943fea15868
github.com/google/go-github/v56 v56.0.0
github.com/google/uuid v1.6.0
github.com/pandodao/tokenizer-go v0.2.0
github.com/tiktoken-go/tokenizer v0.2.0
github.com/tmc/langchaingo v0.1.12
golang.org/x/oauth2 v0.22.0
google.golang.org/grpc v1.67.0
google.golang.org/grpc v1.67.1
nhooyr.io/websocket v1.8.17
)

Expand Down
10 changes: 6 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ github.com/PuerkitoBio/goquery v1.9.2 h1:4/wZksC3KgkQw7SQgkKotmKljk0M6V8TUvA8Wb4
github.com/PuerkitoBio/goquery v1.9.2/go.mod h1:GHPCaP0ODyyxqcNoFGYlAprUFH81NuRPd0GX3Zu2Mvk=
github.com/andybalholm/cascadia v1.3.2 h1:3Xi6Dw5lHF15JtdcmAHD3i1+T8plmv7BQ/nsViSLyss=
github.com/andybalholm/cascadia v1.3.2/go.mod h1:7gtRlve5FxPPgIgX36uWBX58OdBsSS6lUvCFb+h7KvU=
github.com/ansys/allie-sharedtypes v0.0.0-20240926111235-0e4d60bb076e h1:3wKulKpODo+jhCIWmyifD/KwNdDw0SYKQToU9WxiUwY=
github.com/ansys/allie-sharedtypes v0.0.0-20240926111235-0e4d60bb076e/go.mod h1:puqr6W4OTBvupiDiyhURguCHH9cyATSQk0tWimLSIvM=
github.com/ansys/allie-sharedtypes v0.0.0-20241008142654-5943fea15868 h1:ECGCl5gelLGjfdFB1XKcRFV1/b6y1lsSDhfNA+NWli8=
github.com/ansys/allie-sharedtypes v0.0.0-20241008142654-5943fea15868/go.mod h1:zHdqbofXR7x2yDBB8r0KIb/foEx9UsNnJd9yEJ1wE50=
github.com/aymerick/douceur v0.2.0 h1:Mv+mAeH1Q+n9Fr+oyamOlAkUNPWPlA8PPGR0QAaYuPk=
github.com/aymerick/douceur v0.2.0/go.mod h1:wlT5vV2O3h55X9m7iVYN0TBM0NH/MmbLnd30/FjWUq4=
github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4=
Expand Down Expand Up @@ -73,6 +73,8 @@ github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsT
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/texttheater/golang-levenshtein v1.0.1 h1:+cRNoVrfiwufQPhoMzB6N0Yf/Mqajr6t1lOv8GyGE2U=
github.com/texttheater/golang-levenshtein v1.0.1/go.mod h1:PYAKrbF5sAiq9wd+H82hs7gNaen0CplQ9uvm6+enD/8=
github.com/tiktoken-go/tokenizer v0.2.0 h1:MqBlDeE5LRIEpapZk5s7COS9taGtRRIwM8bPxq13rI8=
github.com/tiktoken-go/tokenizer v0.2.0/go.mod h1:7SZW3pZUKWLJRilTvWCa86TOVIiiJhYj3FQ5V3alWcg=
github.com/tmc/langchaingo v0.1.12 h1:yXwSu54f3b1IKw0jJ5/DWu+qFVH1NBblwC0xddBzGJE=
github.com/tmc/langchaingo v0.1.12/go.mod h1:cd62xD6h+ouk8k/QQFhOsjRYBSA1JJ5UVKXSIgm7Ni4=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
Expand Down Expand Up @@ -145,8 +147,8 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240814211410-ddb44dafa142 h1:e7S5W7MGGLaSu8j3YjdezkZ+m1/Nm0uRVRMEMGk26Xs=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240814211410-ddb44dafa142/go.mod h1:UqMtugtsSgubUsoxbuAoiCXvqvErP7Gf0so0mK9tHxU=
google.golang.org/grpc v1.67.0 h1:IdH9y6PF5MPSdAntIcpjQ+tXO41pcQsfZV2RxtQgVcw=
google.golang.org/grpc v1.67.0/go.mod h1:1gLDyUQU7CTLJI90u3nXZ9ekeghjeM7pTDZlqFNg2AA=
google.golang.org/grpc v1.67.1 h1:zWnc1Vrcno+lHZCOofnIMvycFcc0QRGIzm9dhnDX68E=
google.golang.org/grpc v1.67.1/go.mod h1:1gLDyUQU7CTLJI90u3nXZ9ekeghjeM7pTDZlqFNg2AA=
google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg=
google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
Expand Down
28 changes: 24 additions & 4 deletions pkg/externalfunctions/ansysgpt.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ func AnsysGPTPerformLLMRequest(finalQuery string, history []sharedtypes.Historic
streamChannel := make(chan string, 400)

// Start a goroutine to transfer the data from the response channel to the stream channel
go transferDatafromResponseToStreamChannel(&responseChannel, &streamChannel, false)
go transferDatafromResponseToStreamChannel(&responseChannel, &streamChannel, false, false, 0, 0, "")

// Return the stream channel
return "", &streamChannel
Expand Down Expand Up @@ -510,7 +510,7 @@ func AnsysGPTGetSystemPrompt(query string, prohibitedWords []string, template st
//
// Returns:
// - rephrasedQuery: the rephrased query
func AisPerformLLMRephraseRequest(systemTemplate string, userTemplate string, query string, history []sharedtypes.HistoricMessage) (rephrasedQuery string) {
func AisPerformLLMRephraseRequest(systemTemplate string, userTemplate string, query string, history []sharedtypes.HistoricMessage, tokenCountModelName string) (rephrasedQuery string, inputTokenCount int, outputTokenCount int) {
logging.Log.Debugf(internalstates.Ctx, "Performing LLM rephrase request")

// create "chat_history" string
Expand Down Expand Up @@ -549,9 +549,19 @@ func AisPerformLLMRephraseRequest(systemTemplate string, userTemplate string, qu
panic(err)
}

// calculate input and output token count
inputTokenCount, err = openAiTokenCount(tokenCountModelName, userPrompt+systemPrompt)
if err != nil {
panic(err)
}
outputTokenCount, err = openAiTokenCount(tokenCountModelName, rephrasedQuery)
if err != nil {
panic(err)
}

logging.Log.Debugf(internalstates.Ctx, "Rephrased query: %v", rephrasedQuery)

return rephrasedQuery
return rephrasedQuery, inputTokenCount, outputTokenCount
}

// AisReturnIndexList returns the index list for AIS
Expand Down Expand Up @@ -675,6 +685,9 @@ func AisPerformLLMFinalRequest(systemTemplate string,
prohibitedWords []string,
errorList1 []string,
errorList2 []string,
previousInputTokenCount int,
previousOutputTokenCount int,
tokenCountModelName string,
isStream bool) (message string, stream *chan string) {

logging.Log.Debugf(internalstates.Ctx, "Performing LLM final request")
Expand Down Expand Up @@ -751,8 +764,15 @@ func AisPerformLLMFinalRequest(systemTemplate string,
// Create a stream channel
streamChannel := make(chan string, 400)

// calculate input token count
inputTokenCount, err := openAiTokenCount(tokenCountModelName, userPrompt+systemPrompt)
if err != nil {
panic(err)
}
totalInputTokenCount := previousInputTokenCount + inputTokenCount

// Start a goroutine to transfer the data from the response channel to the stream channel.
go transferDatafromResponseToStreamChannel(&responseChannel, &streamChannel, false)
go transferDatafromResponseToStreamChannel(&responseChannel, &streamChannel, false, true, totalInputTokenCount, previousOutputTokenCount, tokenCountModelName)

return "", &streamChannel
}
6 changes: 3 additions & 3 deletions pkg/externalfunctions/llmhandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ func PerformGeneralRequest(input string, history []sharedtypes.HistoricMessage,
streamChannel := make(chan string, 400)

// Start a goroutine to transfer the data from the response channel to the stream channel
go transferDatafromResponseToStreamChannel(&responseChannel, &streamChannel, false)
go transferDatafromResponseToStreamChannel(&responseChannel, &streamChannel, false, false, 0, 0, "")

// Return the stream channel
return "", &streamChannel
Expand Down Expand Up @@ -314,7 +314,7 @@ func PerformGeneralRequestSpecificModel(input string, history []sharedtypes.Hist
streamChannel := make(chan string, 400)

// Start a goroutine to transfer the data from the response channel to the stream channel
go transferDatafromResponseToStreamChannel(&responseChannel, &streamChannel, false)
go transferDatafromResponseToStreamChannel(&responseChannel, &streamChannel, false, false, 0, 0, "")

// Return the stream channel
return "", &streamChannel
Expand Down Expand Up @@ -370,7 +370,7 @@ func PerformCodeLLMRequest(input string, history []sharedtypes.HistoricMessage,
streamChannel := make(chan string, 400)

// Start a goroutine to transfer the data from the response channel to the stream channel
go transferDatafromResponseToStreamChannel(&responseChannel, &streamChannel, validateCode)
go transferDatafromResponseToStreamChannel(&responseChannel, &streamChannel, validateCode, false, 0, 0, "")

// Return the stream channel
return "", &streamChannel
Expand Down
87 changes: 76 additions & 11 deletions pkg/externalfunctions/privatefunctions.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import (
"github.com/ansys/allie-sharedtypes/pkg/sharedtypes"
"github.com/google/go-github/v56/github"
"github.com/google/uuid"
"github.com/tiktoken-go/tokenizer"
"golang.org/x/oauth2"
"nhooyr.io/websocket"
)
Expand All @@ -34,13 +35,28 @@ import (
// - responseChannel: the response channel
// - streamChannel: the stream channel
// - validateCode: the flag to indicate whether the code should be validated
func transferDatafromResponseToStreamChannel(responseChannel *chan sharedtypes.HandlerResponse, streamChannel *chan string, validateCode bool) {
func transferDatafromResponseToStreamChannel(
responseChannel *chan sharedtypes.HandlerResponse,
streamChannel *chan string,
validateCode bool,
sendTokenCount bool,
previousInputTokenCount int,
previousOutputTokenCount int,
tokenCountModelName string) {

// Defer the closing of the channels
defer close(*responseChannel)
defer close(*streamChannel)

// Loop through the response channel
responseAsStr := ""
for response := range *responseChannel {
// Check if the response is an error
if response.Type == "error" {
*streamChannel <- response.Error.Message
break
logging.Log.Errorf(internalstates.Ctx, "Error in request %v: %v\n", response.InstructionGuid, response.Error.Message)
// send the error message to the stream channel and exit function
*streamChannel <- fmt.Sprintf("$$error$$:$$%v$$", response.Error.Message)
return
}

// append the response to the responseAsStr
Expand All @@ -52,6 +68,23 @@ func transferDatafromResponseToStreamChannel(responseChannel *chan sharedtypes.H
// check for last response
if *(response.IsLast) {

finalMessage := ""
// check for token count
if sendTokenCount {

// get the output token count
outputTokenCount, err := openAiTokenCount(tokenCountModelName, responseAsStr)
if err != nil {
logging.Log.Errorf(internalstates.Ctx, "Error getting token count: %v\n", err)
// send the error message to the stream channel and exit function
*streamChannel <- fmt.Sprintf("$$error$$:$$Error getting token count: %v$$", err)
}
totalOuputTokenCount := previousOutputTokenCount + outputTokenCount

// append the token count message to the final message
finalMessage += fmt.Sprintf("$$input_token_count$$:$$%d$$;$$output_token_count$$:$$%d$$;", previousInputTokenCount, totalOuputTokenCount)
}

// check for code validation
if validateCode {
// Extract the code from the response
Expand All @@ -67,23 +100,26 @@ func transferDatafromResponseToStreamChannel(responseChannel *chan sharedtypes.H
} else {
if valid {
if warnings {
*streamChannel <- "Code has warnings."
finalMessage += "$$code_validation$$:$$warnings$$;"
} else {
*streamChannel <- "Code is valid."
finalMessage += "$$code_validation$$:$$valid$$;"
}
} else {
*streamChannel <- "Code is invalid."
finalMessage += "$$code_validation$$:$$invalid$$;"
}
}
}
}

// exit the loop
break
// send the final message to the stream channel
if finalMessage != "" {
*streamChannel <- finalMessage
}

// exit the function
return
}
}
close(*responseChannel)
close(*streamChannel)
}

// sendChatRequestNoHistory sends a chat request to LLM without history
Expand Down Expand Up @@ -1456,7 +1492,7 @@ func performGeneralRequest(input string, history []sharedtypes.HistoricMessage,
streamChannel := make(chan string, 400)

// Start a goroutine to transfer the data from the response channel to the stream channel.
go transferDatafromResponseToStreamChannel(&responseChannel, &streamChannel, false)
go transferDatafromResponseToStreamChannel(&responseChannel, &streamChannel, false, false, 0, 0, "")

// Return the stream channel.
return "", &streamChannel, nil
Expand Down Expand Up @@ -1706,3 +1742,32 @@ func createPayloadAndSendHttpRequest(url string, requestObject interface{}, resp

return nil, 0
}

// TokenCount takes a model name and a message string, returns the token count.
func openAiTokenCount(modelName string, message string) (int, error) {
// get model from model name
var encoding tokenizer.Encoding
switch modelName {
case "gpt-4-turbo", "gpt-4", "gpt-3.5-turbo":
encoding = tokenizer.Cl100kBase
case "gpt-4o", "gpt-4o-mini":
encoding = tokenizer.O200kBase
default:
return 0, fmt.Errorf("model %s not found", modelName)
}

// Load the tokenizer for the specified model
tokenizer, err := tokenizer.Get(encoding)
if err != nil {
return 0, fmt.Errorf("failed to load tokenizer for model %s: %w", modelName, err)
}

// Tokenize the message
tokens, _, err := tokenizer.Encode(message)
if err != nil {
return 0, fmt.Errorf("failed to tokenize message: %w", err)
}

// Return the number of tokens
return len(tokens), nil
}
14 changes: 10 additions & 4 deletions pkg/grpcserver/grpcserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ func (s *server) StreamFunction(req *allieflowkitgrpc.FunctionInputs, stream all

// listen to channel and send to stream
var counter int32
var previousOutput *allieflowkitgrpc.StreamOutput
for message := range *streamChannel {
// create output
output := &allieflowkitgrpc.StreamOutput{
Expand All @@ -281,11 +282,16 @@ func (s *server) StreamFunction(req *allieflowkitgrpc.FunctionInputs, stream all
}

// send output to stream
err := stream.Send(output)
if err != nil {
return err
if counter > 0 {
err := stream.Send(previousOutput)
if err != nil {
return err
}
}

// save output to previous output
previousOutput = output

// increment counter
counter++
}
Expand All @@ -294,7 +300,7 @@ func (s *server) StreamFunction(req *allieflowkitgrpc.FunctionInputs, stream all
output := &allieflowkitgrpc.StreamOutput{
MessageCounter: counter,
IsLast: true,
Value: "",
Value: previousOutput.Value,
}
err = stream.Send(output)
if err != nil {
Expand Down

0 comments on commit 043ae70

Please sign in to comment.