Skip to content

Commit

Permalink
milvus: Honor WithEmbedder option in AddDocuments
Browse files Browse the repository at this point in the history
  • Loading branch information
HomayoonAlimohammadi committed Feb 13, 2025
1 parent d3e43b6 commit 477a26d
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 1 deletion.
16 changes: 15 additions & 1 deletion vectorstores/milvus/milvus.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,10 @@ func (s *Store) load(ctx context.Context) error {
// AddDocuments adds the text and metadata from the documents to the Milvus collection associated with 'Store'.
// and returns the ids of the added documents.
func (s Store) AddDocuments(ctx context.Context, docs []schema.Document,
_ ...vectorstores.Option,
opts ...vectorstores.Option,
) ([]string, error) {
s.applyOptions(opts)

texts := make([]string, 0, len(docs))
for _, doc := range docs {
texts = append(texts, doc.PageContent)
Expand Down Expand Up @@ -347,3 +349,15 @@ func (s Store) getFilters(opts vectorstores.Options) (string, error) {
}
return "", nil
}

// applyOptions applies the options to the store.
func (s *Store) applyOptions(opts []vectorstores.Option) {
var options vectorstores.Options
for _, o := range opts {
o(&options)
}

if options.Embedder != nil {
s.embedder = options.Embedder
}
}
33 changes: 33 additions & 0 deletions vectorstores/milvus/milvus_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package milvus

import (
"context"
"errors"
"log"
"os"
"strings"
Expand Down Expand Up @@ -102,6 +103,18 @@ func TestMilvusConnection(t *testing.T) {
_, err = storer.AddDocuments(context.Background(), data)
require.NoError(t, err)

unusedData := []schema.Document{
{PageContent: "MockCity", Metadata: map[string]any{"population": 100, "area": 100}},
}

embedderM := &mockEmbedder{
embedDocumentsErr: errEmbeddingDocuments,
embedQueryErr: errEmbeddingQuery,
}

_, err = storer.AddDocuments(context.Background(), unusedData, vectorstores.WithEmbedder(embedderM))
require.ErrorIs(t, err, errEmbeddingDocuments)

// search docs with filter
filterRes, err := storer.SimilaritySearch(context.Background(),
"Tokyo", 10,
Expand All @@ -117,3 +130,23 @@ func TestMilvusConnection(t *testing.T) {
require.NoError(t, err)
require.Len(t, japanRes, 1)
}

type mockEmbedder struct {
embedDocumentsReturn [][]float32
embedDocumentsErr error
embedQueryReturn []float32
embedQueryErr error
}

func (m *mockEmbedder) EmbedDocuments(_ context.Context, _ []string) ([][]float32, error) {
return m.embedDocumentsReturn, m.embedDocumentsErr
}

func (m *mockEmbedder) EmbedQuery(_ context.Context, _ string) ([]float32, error) {
return m.embedQueryReturn, m.embedQueryErr
}

var (
errEmbeddingDocuments = errors.New("error embedding documents")
errEmbeddingQuery = errors.New("error embedding query")
)

0 comments on commit 477a26d

Please sign in to comment.