Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improved the search to be concurrent #62

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 53 additions & 7 deletions source/multi.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@

package source

import "context"
import (
"context"
"sync"
)

type MultiSource struct {
sources []Sourcer
Expand All @@ -24,16 +27,59 @@ func NewMultiSource(sources ...Sourcer) *MultiSource {
return &MultiSource{sources}
}

// Search concurrently queries all sources and returns the combined results.
func (s *MultiSource) Search(ctx context.Context, collectionName string, subjectDigests, attestations []string) ([]CollectionEnvelope, error) {
results := make([]CollectionEnvelope, 0)
for _, source := range s.sources {
res, err := source.Search(ctx, collectionName, subjectDigests, attestations)
if err != nil {
return results, err
results := []CollectionEnvelope{}
errors := []error{}

errs := make(chan error) // Channel for collecting errors from each source
resChan := make(chan []CollectionEnvelope) // Channel for collecting results from each source

errdone := make(chan bool) // Signal channel indicating when error collection is done
readerDone := make(chan bool) // Signal channel indicating when result collection is done

// Goroutine for collecting results from the result channel
go func() {
for item := range resChan {
results = append(results, item...)
}
readerDone <- true
}()

results = append(results, res...)
// Goroutine for collecting errors from the error channel
go func() {
for err := range errs {
errors = append(errors, err)
}
errdone <- true
}()

var wg sync.WaitGroup // WaitGroup for waiting on all source queries to finish
for _, source := range s.sources {
source := source
wg.Add(1)
// Goroutine for querying a source and collecting the results or error
go func(src Sourcer) {
defer wg.Done()
res, err := src.Search(ctx, collectionName, subjectDigests, attestations)
if err != nil {
errs <- err
} else {
resChan <- res
}
}(source)
}
wg.Wait() // Wait for all source queries to finish
close(resChan) // Close the result channel
close(errs) // Close the error channel

<-errdone // Wait for error collection to finish
<-readerDone // Wait for result collection to finish

// If any errors occurred, return the first error and discard the results
if len(errors) > 0 {
return nil, errors[0]
}
// Return the combined results from all sources
return results, nil
}
Loading