diff --git a/source/multi.go b/source/multi.go index d3611af4..500fc587 100644 --- a/source/multi.go +++ b/source/multi.go @@ -14,7 +14,10 @@ package source -import "context" +import ( + "context" + "sync" +) type MultiSource struct { sources []Sourcer @@ -24,16 +27,59 @@ func NewMultiSource(sources ...Sourcer) *MultiSource { return &MultiSource{sources} } +// Search concurrently queries all sources and returns the combined results. func (s *MultiSource) Search(ctx context.Context, collectionName string, subjectDigests, attestations []string) ([]CollectionEnvelope, error) { - results := make([]CollectionEnvelope, 0) - for _, source := range s.sources { - res, err := source.Search(ctx, collectionName, subjectDigests, attestations) - if err != nil { - return results, err + results := []CollectionEnvelope{} + errors := []error{} + + errs := make(chan error) // Channel for collecting errors from each source + resChan := make(chan []CollectionEnvelope) // Channel for collecting results from each source + + errdone := make(chan bool) // Signal channel indicating when error collection is done + readerDone := make(chan bool) // Signal channel indicating when result collection is done + + // Goroutine for collecting results from the result channel + go func() { + for item := range resChan { + results = append(results, item...) } + readerDone <- true + }() - results = append(results, res...) + // Goroutine for collecting errors from the error channel + go func() { + for err := range errs { + errors = append(errors, err) + } + errdone <- true + }() + + var wg sync.WaitGroup // WaitGroup for waiting on all source queries to finish + for _, source := range s.sources { + source := source + wg.Add(1) + // Goroutine for querying a source and collecting the results or error + go func(src Sourcer) { + defer wg.Done() + res, err := src.Search(ctx, collectionName, subjectDigests, attestations) + if err != nil { + errs <- err + } else { + resChan <- res + } + }(source) } + wg.Wait() // Wait for all source queries to finish + close(resChan) // Close the result channel + close(errs) // Close the error channel + <-errdone // Wait for error collection to finish + <-readerDone // Wait for result collection to finish + + // If any errors occurred, return the first error and discard the results + if len(errors) > 0 { + return nil, errors[0] + } + // Return the combined results from all sources return results, nil }