Skip to content

Commit

Permalink
fix(instillmodel): fix bug
Browse files Browse the repository at this point in the history
  • Loading branch information
chuang8511 committed Oct 16, 2024
1 parent 65b1177 commit 7493ab5
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pkg/component/ai/instill/v0/config/tasks.json
Original file line number Diff line number Diff line change
Expand Up @@ -593,7 +593,7 @@
"type": "object",
"properties": {
"model": {
"description": "The model to be used for generating embeddings.",
"description": "The model to be used for generating embeddings. It should be `namespace/model-name/version`. i.e. `abrc/yolov7-stomata/v0.1.0`. You can see the version from the Versions tab of Model page.",
"instillShortDescription": "The model to be used.",
"instillAcceptFormats": [
"string"
Expand Down
39 changes: 36 additions & 3 deletions pkg/component/ai/instill/v0/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/structpb"

"github.com/instill-ai/pipeline-backend/pkg/component/ai"
"github.com/instill-ai/pipeline-backend/pkg/component/base"
"github.com/instill-ai/pipeline-backend/pkg/component/internal/util"

Expand Down Expand Up @@ -101,12 +102,46 @@ func (e *execution) Execute(ctx context.Context, jobs []*base.Job) error {
defer gRPCCientConn.Close()
}

var result []*structpb.Struct

// We will refactor the component soon to align the data structure with Instill Model.
// For now, we will use the task field to determine the task.
if e.Task == "TASK_EMBEDDING" {
var inputStruct ai.EmbeddingInput
err := base.ConvertFromStructpb(inputs[0], &inputStruct)
if err != nil {
return fmt.Errorf("convert to defined struct: %w", err)
}

model := inputStruct.Data.Model
modelNameSplits := strings.Split(model, "/")

nsID := modelNameSplits[0]
modelID := modelNameSplits[1]
version := modelNameSplits[2]

result, err = e.executeEmbedding(gRPCClient, nsID, modelID, version, inputs)

if err != nil {
return fmt.Errorf("execute embedding: %w", err)
}

for idx, job := range jobs {
err = job.Output.Write(ctx, result[idx])
if err != nil {
job.Error.Error(ctx, err)
continue
}
}
return nil
}

modelNameSplits := strings.Split(inputs[0].GetFields()["model-name"].GetStringValue(), "/")

nsID := modelNameSplits[0]
modelID := modelNameSplits[1]
version := modelNameSplits[2]
var result []*structpb.Struct

switch e.Task {
case "TASK_CLASSIFICATION":
result, err = e.executeVisionTask(gRPCClient, nsID, modelID, version, inputs)
Expand All @@ -126,8 +161,6 @@ func (e *execution) Execute(ctx context.Context, jobs []*base.Job) error {
result, err = e.executeTextGeneration(gRPCClient, nsID, modelID, version, inputs)
case "TASK_TEXT_GENERATION_CHAT", "TASK_VISUAL_QUESTION_ANSWERING", "TASK_CHAT":
result, err = e.executeTextGenerationChat(gRPCClient, nsID, modelID, version, inputs)
case "TASK_EMBEDDING":
result, err = e.executeEmbedding(gRPCClient, nsID, modelID, version, inputs)
default:
return fmt.Errorf("unsupported task: %s", e.Task)
}
Expand Down

0 comments on commit 7493ab5

Please sign in to comment.