Skip to content

Commit

Permalink
Implement Search method
Browse files Browse the repository at this point in the history
  • Loading branch information
lalinsky committed Feb 13, 2024
1 parent 28465d5 commit 953d228
Showing 1 changed file with 44 additions and 22 deletions.
66 changes: 44 additions & 22 deletions pkg/fpstore/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,26 @@ func (s *FingerprintStoreService) getFingerprint(ctx context.Context, id uint64)
return fp, nil
}

func (s *FingerprintStoreService) compareFingerprints(ctx context.Context, a, b *pb.Fingerprint) (float64, error) {
if len(a.Hashes) == 0 || len(b.Hashes) == 0 {
return 0, nil
func (s *FingerprintStoreService) compareFingerprints(ctx context.Context, query *pb.Fingerprint, ids []uint64) ([]*pb.MatchingFingerprint, error) {
var results []*pb.MatchingFingerprint
for _, id := range ids {
if ctx.Err() == context.Canceled {
return nil, status.Error(codes.Canceled, "request canceled")
}
if id == 0 {
return nil, status.Error(codes.InvalidArgument, "id is required")
}
fp, err := s.getFingerprint(ctx, id)
if err != nil {
return nil, status.Error(codes.Internal, fmt.Sprintf("failed to get fingerprint %d", id))
}
if fp == nil {
continue
}
score, err := chromaprint.CompareFingerprints(query, fp)
results = append(results, &pb.MatchingFingerprint{Id: id, Similarity: float32(score)})
}
return chromaprint.CompareFingerprints(a, b)
return results, nil
}

// Implement Get method
Expand All @@ -104,23 +119,30 @@ func (s *FingerprintStoreService) Get(ctx context.Context, req *pb.GetFingerprin
}

func (s *FingerprintStoreService) Compare(ctx context.Context, req *pb.CompareFingerprintRequest) (*pb.CompareFingerprintResponse, error) {
var resp pb.CompareFingerprintResponse
for _, id := range req.Ids {
if ctx.Err() == context.Canceled {
return nil, status.Error(codes.Canceled, "request canceled")
}
if id == 0 {
return nil, status.Error(codes.InvalidArgument, "id is required")
}
fp, err := s.getFingerprint(ctx, id)
if err != nil {
return nil, status.Error(codes.Internal, fmt.Sprintf("failed to get fingerprint %d", id))
}
if fp == nil {
continue
}
score, err := s.compareFingerprints(ctx, req.Fingerprint, fp)
resp.Results = append(resp.Results, &pb.MatchingFingerprint{Id: id, Similarity: float32(score)})
if len(req.Fingerprint.Hashes) == 0 {
return nil, status.Error(codes.InvalidArgument, "fingerprint can't be empty")
}
if len(req.Ids) == 0 {
return nil, status.Error(codes.InvalidArgument, "ids can't be empty")
}
results, err := s.compareFingerprints(ctx, req.Fingerprint, req.Ids)
if err != nil {
return nil, err
}
return &pb.CompareFingerprintResponse{Results: results}, nil
}

func (s *FingerprintStoreService) Search(ctx context.Context, req *pb.SearchFingerprintRequest) (*pb.SearchFingerprintResponse, error) {
if len(req.Fingerprint.Hashes) == 0 {
return nil, status.Error(codes.InvalidArgument, "fingerprint can't be empty")
}
candidateIds, err := s.index.Search(ctx, req.Fingerprint, 10)
if err != nil {
return nil, status.Error(codes.Internal, "failed to search index")
}
results, err := s.compareFingerprints(ctx, req.Fingerprint, candidateIds)
if err != nil {
return nil, err
}
return &resp, nil
return &pb.SearchFingerprintResponse{Results: results}, nil
}

0 comments on commit 953d228

Please sign in to comment.