Skip to content

Commit

Permalink
fix(artifact): fix retry file process (#97)
Browse files Browse the repository at this point in the history
Because

1. when all processing workers are busy, the dispatch worker continues
to send tasks to the queue. These tasks are not held by any workers, so
the dispatcher may keep sending the same tasks to the queue.

2. when chunk text is empty, it still insert into db.

This commit

1. **Add a check at the beginning of each process step** to verify the
file’s processing state in the database. If the state does not matches
the current process state, discard the file process task.

2. **Set buffer of worker queue to zero** to prevent out of date file
status being queued

3. **Skip the chunk's text if it text is empty**
  • Loading branch information
Yougigun authored Sep 19, 2024
1 parent 013b8df commit 97ff707
Show file tree
Hide file tree
Showing 4 changed files with 107 additions and 29 deletions.
8 changes: 8 additions & 0 deletions pkg/repository/chunk.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ import (
"time"

"github.com/gofrs/uuid"
"github.com/instill-ai/artifact-backend/pkg/logger"
"go.uber.org/zap"
"gorm.io/gorm"
"gorm.io/gorm/clause"
)
Expand Down Expand Up @@ -100,6 +102,7 @@ func (r *Repository) TextChunkTableName() string {
// a certain source table and sourceUID, then batch inserts the new chunks
// within a transaction.
func (r *Repository) DeleteAndCreateChunks(ctx context.Context, sourceTable string, sourceUID uuid.UUID, chunks []*TextChunk, externalServiceCall func(chunkUIDs []string) (map[string]any, error)) ([]*TextChunk, error) {
logger, _ := logger.GetZapLogger(ctx)
// Start a transaction
err := r.db.Transaction(func(tx *gorm.DB) error {
// Delete existing chunks
Expand All @@ -108,8 +111,13 @@ func (r *Repository) DeleteAndCreateChunks(ctx context.Context, sourceTable stri
return err
}

if len(chunks) == 0 {
logger.Warn("no chunks to create")
return nil
}
// Batch insert new chunks
if err := tx.WithContext(ctx).Create(&chunks).Error; err != nil {
logger.Error("error creating chunks: ", zap.Error(err))
return err
}

Expand Down
4 changes: 4 additions & 0 deletions pkg/repository/embedding.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@ func (r *Repository) UpsertEmbeddings(
logger, _ := logger.GetZapLogger(ctx)
// Start a transaction
err := r.db.WithContext(ctx).Transaction(func(tx *gorm.DB) error {
if len(embeddings) == 0 {
logger.Warn("no embeddings to upsert")
return nil
}
// Upsert the embeddings
if err := tx.Clauses(clause.OnConflict{
Columns: []clause.Column{{Name: EmbeddingColumn.SourceTable}, {Name: EmbeddingColumn.SourceUID}}, // Unique column that triggers the upsert
Expand Down
9 changes: 8 additions & 1 deletion pkg/service/pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,14 @@ func (s *Service) SplitMarkdownPipe(ctx context.Context, caller uuid.UUID, reque
if err != nil {
return nil, err
}
return result, nil
// remove the empty chunk
var filteredResult []Chunk
for _, chunk := range result {
if chunk.Text != "" {
filteredResult = append(filteredResult, chunk)
}
}
return filteredResult, nil
}

// GetChunksFromResponse converts the pipeline response into a slice of Chunk.
Expand Down
115 changes: 87 additions & 28 deletions pkg/worker/worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package worker
import (
"context"
"encoding/base64"
"errors"
"fmt"
"runtime/debug"
"sync"
Expand All @@ -25,6 +26,8 @@ const extensionHelperPeriod = 5 * time.Second
const workerLifetime = 45 * time.Second
const workerPrefix = "worker-processing-file-"

var ErrFileStatusNotMatch = errors.New("file status not match")

type fileToEmbWorkerPool struct {
numberOfWorkers int
svc *service.Service
Expand All @@ -39,10 +42,11 @@ func NewFileToEmbWorkerPool(ctx context.Context, svc *service.Service, nums int)
return &fileToEmbWorkerPool{
numberOfWorkers: nums,
svc: svc,
channel: make(chan repository.KnowledgeBaseFile, 100),
wg: sync.WaitGroup{},
ctx: ctx,
cancel: cancel,
// channel is un-buffered because we dont want the out of date file to be processed
channel: make(chan repository.KnowledgeBaseFile),
wg: sync.WaitGroup{},
ctx: ctx,
cancel: cancel,
}
}

Expand Down Expand Up @@ -80,11 +84,11 @@ func (wp *fileToEmbWorkerPool) startDispatcher() {
// Periodically check for incomplete files
incompleteFiles := wp.svc.Repository.GetNeedProcessFiles(wp.ctx)
// Check if any of the incomplete files have active workers
fileUID := make([]string, len(incompleteFiles))
fileUIDs := make([]string, len(incompleteFiles))
for i, file := range incompleteFiles {
fileUID[i] = file.UID.String()
fileUIDs[i] = file.UID.String()
}
nonExistentKeys := wp.checkRegisteredFilesWorker(wp.ctx, fileUID)
nonExistentKeys := wp.checkRegisteredFilesWorker(wp.ctx, fileUIDs)

// Dispatch the files that do not have active workers
incompleteAndNonRegisteredFiles := make([]repository.KnowledgeBaseFile, 0)
Expand All @@ -94,15 +98,23 @@ func (wp *fileToEmbWorkerPool) startDispatcher() {
}
}

dispatchLoop:
for _, file := range incompleteAndNonRegisteredFiles {
select {
case <-wp.ctx.Done():
fmt.Println("Dispatcher received termination signal while dispatching")
return
case wp.channel <- file:
fmt.Printf("Dispatcher dispatched file. fileUID: %s\n", file.UID.String())
default:
select {
case wp.channel <- file:
logger.Info("Dispatcher dispatched file.", zap.String("fileUID", file.UID.String()))
default:
logger.Debug("channel is full, skip dispatching remaining files.", zap.String("fileUID", file.UID.String()))
break dispatchLoop
}
}
}

}
}
}
Expand Down Expand Up @@ -144,18 +156,33 @@ func (wp *fileToEmbWorkerPool) startWorker(ctx context.Context, workerID int) {

// register file process worker in redis and extend the lifetime
ok, stopRegisterFunc := wp.registerFileWorker(ctx, file.UID.String(), extensionHelperPeriod, workerLifetime)

if !ok {
if stopRegisterFunc != nil {
stopRegisterFunc()
}
continue
}
// check if the file is already processed
// Because the file is from the dispatcher, the file status is guaranteed to be incomplete
// but when the worker wakes up and tries to process the file, the file status might have been updated by other workers.
// So we need to check the file status again to ensure the file is still same as when the worker wakes up
err := wp.checkFileStatus(ctx, file)
if err != nil {
logger.Warn("File status not match. skip processing", zap.String("file uid", file.UID.String()), zap.Error(err))
if stopRegisterFunc != nil {
stopRegisterFunc()
}
continue
}
// start file processing tracing
fmt.Printf("Worker %d processing file: %s\n", workerID, file.UID.String())

// process
t0 := time.Now()
err := wp.processFile(ctx, file)
err = wp.processFile(ctx, file)

if err != nil {
fmt.Printf("Error processing file: %s, error: %v\n", file.UID.String(), err)
logger.Error("Error processing file", zap.String("file uid", file.UID.String()), zap.Error(err))
err = wp.svc.Repository.UpdateKbFileExtraMetaData(ctx, file.UID, err.Error(), "", "", "", nil, nil, nil, nil)
if err != nil {
fmt.Printf("Error marshaling extra metadata: %v\n", err)
Expand Down Expand Up @@ -197,15 +224,19 @@ type stopRegisterWorkerFunc func()
// It returns a boolean indicating success and a stopRegisterWorkerFunc that can be used to cancel the worker's lifetime extension and remove the worker key from Redis.
// period: duration between lifetime extensions
// workerLifetime: total duration the worker key should be kept in Redis
func (wp *fileToEmbWorkerPool) registerFileWorker(ctx context.Context, fileUID string, period time.Duration, workerLifetime time.Duration) (bool, stopRegisterWorkerFunc) {
func (wp *fileToEmbWorkerPool) registerFileWorker(ctx context.Context, fileUID string, period time.Duration, workerLifetime time.Duration) (ok bool, stopRegisterWorker stopRegisterWorkerFunc) {
logger, _ := logger.GetZapLogger(ctx)
stopRegisterWorker = func() {
logger.Warn("stopRegisterWorkerFunc is not implemented yet")
}
ok, err := wp.svc.RedisClient.SetNX(ctx, getWorkerKey(fileUID), "1", workerLifetime).Result()
if err != nil {
fmt.Printf("Error when setting worker key in redis. Error: %v\n", err)
return false, nil
logger.Error("Error when setting worker key in redis", zap.Error(err))
return
}
if !ok {
fmt.Printf("File is already being processed in redis. fileUID: %s\n", fileUID)
return false, nil
logger.Warn("Key exists in redis, file is already being processed by worker", zap.String("fileUID", fileUID))
return
}
ctx, lifetimeHelperCancel := context.WithCancel(ctx)

Expand All @@ -217,14 +248,14 @@ func (wp *fileToEmbWorkerPool) registerFileWorker(ctx context.Context, fileUID s
select {
case <-ctx.Done():
// Context is done, exit the worker
fmt.Printf("Finish %v's lifetime extend helper received termination signal\n", getWorkerKey(fileUID))
logger.Debug("Finish worker lifetime extend helper received termination signal", zap.String("worker", getWorkerKey(fileUID)))
return
case <-ticker.C:
// extend the lifetime of the worker
fmt.Printf("Extending %v's lifetime: %v \n", getWorkerKey(fileUID), workerLifetime)
logger.Debug("Extending worker lifetime", zap.String("worker", getWorkerKey(fileUID)), zap.Duration("lifetime", workerLifetime))
err := wp.svc.RedisClient.Expire(ctx, getWorkerKey(fileUID), workerLifetime).Err()
if err != nil {
fmt.Printf("Error when extending worker lifetime in redis. Error: %v, worker: %v\n", err, getWorkerKey(fileUID))
logger.Error("Error when extending worker lifetime in redis", zap.Error(err), zap.String("worker", getWorkerKey(fileUID)))
return
}
}
Expand All @@ -233,7 +264,7 @@ func (wp *fileToEmbWorkerPool) registerFileWorker(ctx context.Context, fileUID s
go lifetimeExtHelper(ctx)

// stopRegisterWorker function will cancel the lifetimeExtHelper and remove the worker key in redis
stopRegisterWorker := func() {
stopRegisterWorker = func() {
lifetimeHelperCancel()
wp.svc.RedisClient.Del(ctx, getWorkerKey(fileUID))
}
Expand All @@ -243,6 +274,7 @@ func (wp *fileToEmbWorkerPool) registerFileWorker(ctx context.Context, fileUID s

// checkFileWorker checks if any of the provided fileUIDs have active workers
func (wp *fileToEmbWorkerPool) checkRegisteredFilesWorker(ctx context.Context, fileUIDs []string) map[string]struct{} {
logger, _ := logger.GetZapLogger(ctx)
pipe := wp.svc.RedisClient.Pipeline()

// Create a map to hold the results
Expand All @@ -257,7 +289,7 @@ func (wp *fileToEmbWorkerPool) checkRegisteredFilesWorker(ctx context.Context, f
// Execute the pipeline
_, err := pipe.Exec(ctx)
if err != nil {
fmt.Println("Error executing pipeline:", err)
logger.Error("Error executing redis pipeline", zap.Error(err))
return nil
}

Expand All @@ -266,7 +298,7 @@ func (wp *fileToEmbWorkerPool) checkRegisteredFilesWorker(ctx context.Context, f
for fileUID, result := range results {
exists, err := result.Result()
if err != nil {
fmt.Printf("Error getting result for %s: %v\n", fileUID, err)
logger.Error("Error getting result for %s", zap.String("fileUID", fileUID), zap.Error(err))
return nil
}
if exists == 0 {
Expand All @@ -278,13 +310,20 @@ func (wp *fileToEmbWorkerPool) checkRegisteredFilesWorker(ctx context.Context, f

// processFile handles the processing of a file through various stages using a state machine.
func (wp *fileToEmbWorkerPool) processFile(ctx context.Context, file repository.KnowledgeBaseFile) error {
// logger, _ := logger.GetZapLogger(ctx)
logger, _ := logger.GetZapLogger(ctx)
var status artifactpb.FileProcessStatus
if statusInt, ok := artifactpb.FileProcessStatus_value[file.ProcessStatus]; !ok {
return fmt.Errorf("invalid process status: %v", file.ProcessStatus)
} else {
status = artifactpb.FileProcessStatus(statusInt)
}

// check if the file is already processed
err := wp.checkFileStatus(ctx, file)
if err != nil {
return err
}

for {
switch status {
case artifactpb.FileProcessStatus_FILE_PROCESS_STATUS_WAITING:
Expand All @@ -305,7 +344,7 @@ func (wp *fileToEmbWorkerPool) processFile(ctx context.Context, file repository.
convertingTime := int64(time.Since(t0).Seconds())
err = wp.svc.Repository.UpdateKbFileExtraMetaData(ctx, file.UID, "", "", "", "", nil, &convertingTime, nil, nil)
if err != nil {
fmt.Printf("Error updating file extra metadata: %v\n", err)
logger.Error("Error updating file extra metadata", zap.Error(err))
}

case artifactpb.FileProcessStatus_FILE_PROCESS_STATUS_CHUNKING:
Expand All @@ -319,7 +358,7 @@ func (wp *fileToEmbWorkerPool) processFile(ctx context.Context, file repository.
chunkingTime := int64(time.Since(t0).Seconds())
err = wp.svc.Repository.UpdateKbFileExtraMetaData(ctx, file.UID, "", "", "", "", nil, nil, &chunkingTime, nil)
if err != nil {
fmt.Printf("Error updating file extra metadata: %v\n", err)
logger.Error("Error updating file extra metadata", zap.Error(err))
}

case artifactpb.FileProcessStatus_FILE_PROCESS_STATUS_EMBEDDING:
Expand All @@ -333,7 +372,7 @@ func (wp *fileToEmbWorkerPool) processFile(ctx context.Context, file repository.
embeddingTime := int64(time.Since(t0).Seconds())
err = wp.svc.Repository.UpdateKbFileExtraMetaData(ctx, file.UID, "", "", "", "", nil, nil, nil, &embeddingTime)
if err != nil {
fmt.Printf("Error updating file extra metadata: %v\n", err)
logger.Error("Error updating file extra metadata", zap.Error(err))
}

case artifactpb.FileProcessStatus_FILE_PROCESS_STATUS_COMPLETED:
Expand Down Expand Up @@ -401,6 +440,7 @@ func (wp *fileToEmbWorkerPool) processWaitingFile(ctx context.Context, file repo
// If the file is not a PDF, it returns an error.
func (wp *fileToEmbWorkerPool) processConvertingFile(ctx context.Context, file repository.KnowledgeBaseFile) (updatedFile *repository.KnowledgeBaseFile, nextStatus artifactpb.FileProcessStatus, err error) {
logger, _ := logger.GetZapLogger(ctx)

fileInMinIOPath := file.Destination
data, err := wp.svc.MinIO.GetFile(ctx, fileInMinIOPath)
if err != nil {
Expand Down Expand Up @@ -571,6 +611,7 @@ func (wp *fileToEmbWorkerPool) processChunkingFile(ctx context.Context, file rep
logger.Error("Failed to get chunks from original file.", zap.String("File uid", file.UID.String()))
return nil, artifactpb.FileProcessStatus_FILE_PROCESS_STATUS_UNSPECIFIED, err
}

// Save the chunks into object storage(minIO) and metadata into database
err = wp.saveChunks(ctx, file.KnowledgeBaseUID.String(), file.UID, wp.svc.Repository.KnowledgeBaseFileTableName(), file.UID, chunks)
if err != nil {
Expand Down Expand Up @@ -720,7 +761,7 @@ func (wp *fileToEmbWorkerPool) saveChunks(ctx context.Context, kbUID string, kbF
logger, _ := logger.GetZapLogger(ctx)
textChunks := make([]*repository.TextChunk, len(chunks))

// turn kbuid to uuid no must parse
// turn kbUid to uuid no must parse
kbUIDuuid, err := uuid.FromString(kbUID)
if err != nil {
logger.Error("Failed to parse kbUID to uuid.", zap.String("KbUID", kbUID))
Expand Down Expand Up @@ -805,6 +846,24 @@ func (wp *fileToEmbWorkerPool) saveEmbeddings(ctx context.Context, kbUID string,
return nil
}

// checkFileStatus checks if the file status from argument is the same as the file in database
func (wp *fileToEmbWorkerPool) checkFileStatus(ctx context.Context, file repository.KnowledgeBaseFile) error {
dbFiles, err := wp.svc.Repository.GetKnowledgeBaseFilesByFileUIDs(ctx, []uuid.UUID{file.UID})
if err != nil {
return err
}
if len(dbFiles) == 0 {
return fmt.Errorf("file uid not found in database. file uid: %s", file.UID)
}
// if the file's status from argument is not the same as the file in database, skip the processing
// because the file in argument is not the latest file in database. Instead, it is from the queue.
if dbFiles[0].ProcessStatus != file.ProcessStatus {
err := fmt.Errorf("%w - file uid: %s, database file status: %v, file status in argument: %v", ErrFileStatusNotMatch, file.UID, dbFiles[0].ProcessStatus, file.ProcessStatus)
return err
}
return nil
}

func getWorkerKey(fileUID string) string {
return workerPrefix + fileUID
}
Expand Down

0 comments on commit 97ff707

Please sign in to comment.