diff --git a/pkg/component/ai/instill/v0/config/tasks.json b/pkg/component/ai/instill/v0/config/tasks.json index 964591ae3..1ed30ed16 100644 --- a/pkg/component/ai/instill/v0/config/tasks.json +++ b/pkg/component/ai/instill/v0/config/tasks.json @@ -593,7 +593,7 @@ "type": "object", "properties": { "model": { - "description": "The model to be used for generating embeddings.", + "description": "The model to be used for generating embeddings. It should be `namespace/model-name/version`. i.e. `abrc/yolov7-stomata/v0.1.0`. You can see the version from the Versions tab of Model page.", "instillShortDescription": "The model to be used.", "instillAcceptFormats": [ "string" diff --git a/pkg/component/ai/instill/v0/main.go b/pkg/component/ai/instill/v0/main.go index dd32775d7..1408fe778 100644 --- a/pkg/component/ai/instill/v0/main.go +++ b/pkg/component/ai/instill/v0/main.go @@ -14,6 +14,7 @@ import ( "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/structpb" + "github.com/instill-ai/pipeline-backend/pkg/component/ai" "github.com/instill-ai/pipeline-backend/pkg/component/base" "github.com/instill-ai/pipeline-backend/pkg/component/internal/util" @@ -101,12 +102,46 @@ func (e *execution) Execute(ctx context.Context, jobs []*base.Job) error { defer gRPCCientConn.Close() } + var result []*structpb.Struct + + // We will refactor the component soon to align the data structure with Instill Model. + // For now, we will use the task field to determine the task. + if e.Task == "TASK_EMBEDDING" { + var inputStruct ai.EmbeddingInput + err := base.ConvertFromStructpb(inputs[0], &inputStruct) + if err != nil { + return fmt.Errorf("convert to defined struct: %w", err) + } + + model := inputStruct.Data.Model + modelNameSplits := strings.Split(model, "/") + + nsID := modelNameSplits[0] + modelID := modelNameSplits[1] + version := modelNameSplits[2] + + result, err = e.executeEmbedding(gRPCClient, nsID, modelID, version, inputs) + + if err != nil { + return fmt.Errorf("execute embedding: %w", err) + } + + for idx, job := range jobs { + err = job.Output.Write(ctx, result[idx]) + if err != nil { + job.Error.Error(ctx, err) + continue + } + } + return nil + } + modelNameSplits := strings.Split(inputs[0].GetFields()["model-name"].GetStringValue(), "/") nsID := modelNameSplits[0] modelID := modelNameSplits[1] version := modelNameSplits[2] - var result []*structpb.Struct + switch e.Task { case "TASK_CLASSIFICATION": result, err = e.executeVisionTask(gRPCClient, nsID, modelID, version, inputs) @@ -126,8 +161,6 @@ func (e *execution) Execute(ctx context.Context, jobs []*base.Job) error { result, err = e.executeTextGeneration(gRPCClient, nsID, modelID, version, inputs) case "TASK_TEXT_GENERATION_CHAT", "TASK_VISUAL_QUESTION_ANSWERING", "TASK_CHAT": result, err = e.executeTextGenerationChat(gRPCClient, nsID, modelID, version, inputs) - case "TASK_EMBEDDING": - result, err = e.executeEmbedding(gRPCClient, nsID, modelID, version, inputs) default: return fmt.Errorf("unsupported task: %s", e.Task) }