From 7a87e0330ece132a77b5c4d8f845bab273ebe381 Mon Sep 17 00:00:00 2001 From: Adrian Liechti Date: Sat, 21 Sep 2024 23:20:30 +0200 Subject: [PATCH] add retriever tool --- config/config_tool.go | 31 ++++++--- pkg/tool/bing/client.go | 7 ++- pkg/tool/bing/models.go | 7 ++- pkg/tool/crawler/config.go | 10 --- pkg/tool/crawler/{client.go => tool.go} | 5 -- pkg/tool/duckduckgo/models.go | 7 ++- pkg/tool/retriever/config.go | 3 + pkg/tool/retriever/models.go | 7 +++ pkg/tool/retriever/tool.go | 84 +++++++++++++++++++++++++ pkg/tool/searxng/client.go | 7 ++- pkg/tool/searxng/models.go | 7 ++- pkg/tool/tavily/client.go | 7 ++- pkg/tool/tavily/models.go | 7 ++- 13 files changed, 146 insertions(+), 43 deletions(-) rename pkg/tool/crawler/{client.go => tool.go} (95%) create mode 100644 pkg/tool/retriever/config.go create mode 100644 pkg/tool/retriever/models.go create mode 100644 pkg/tool/retriever/tool.go diff --git a/config/config_tool.go b/config/config_tool.go index bceec3ee..d6c307b2 100644 --- a/config/config_tool.go +++ b/config/config_tool.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/adrianliechti/llama/pkg/extractor" + "github.com/adrianliechti/llama/pkg/index" "github.com/adrianliechti/llama/pkg/provider" "github.com/adrianliechti/llama/pkg/tool" "github.com/adrianliechti/llama/pkg/tool/bing" @@ -12,6 +13,7 @@ import ( "github.com/adrianliechti/llama/pkg/tool/custom" "github.com/adrianliechti/llama/pkg/tool/draw" "github.com/adrianliechti/llama/pkg/tool/duckduckgo" + "github.com/adrianliechti/llama/pkg/tool/retriever" "github.com/adrianliechti/llama/pkg/tool/searxng" "github.com/adrianliechti/llama/pkg/tool/tavily" @@ -43,9 +45,13 @@ type toolConfig struct { Token string `yaml:"token"` Model string `yaml:"model"` + + Index string `yaml:"index"` + Extractor string `yaml:"extractor"` } type toolContext struct { + Index index.Provider Renderer provider.Renderer Extractor extractor.Provider } @@ -56,13 +62,15 @@ func (cfg *Config) registerTools(f *configFile) error { context := toolContext{} - if t.Model != "" { - if r, err := cfg.Renderer(t.Model); err == nil { - context.Renderer = r - } + if i, err := cfg.Index(t.Index); err == nil { + context.Index = i + } + + if r, err := cfg.Renderer(t.Model); err == nil { + context.Renderer = r } - if e, err := cfg.Extractor(""); err == nil { + if e, err := cfg.Extractor(t.Extractor); err == nil { context.Extractor = e } @@ -96,12 +104,15 @@ func createTool(cfg toolConfig, context toolContext) (tool.Tool, error) { case "duckduckgo": return duckduckgoTool(cfg, context) - case "tavily": - return tavilyTool(cfg, context) + case "retriever": + return retrieverTool(cfg, context) case "searxng": return searxngTool(cfg, context) + case "tavily": + return tavilyTool(cfg, context) + case "custom": return customTool(cfg, context) @@ -138,6 +149,12 @@ func duckduckgoTool(cfg toolConfig, context toolContext) (tool.Tool, error) { return duckduckgo.New(options...) } +func retrieverTool(cfg toolConfig, context toolContext) (tool.Tool, error) { + var options []retriever.Option + + return retriever.New(context.Index, options...) +} + func searxngTool(cfg toolConfig, context toolContext) (tool.Tool, error) { var options []searxng.Option diff --git a/pkg/tool/bing/client.go b/pkg/tool/bing/client.go index 83eb5891..95112874 100644 --- a/pkg/tool/bing/client.go +++ b/pkg/tool/bing/client.go @@ -94,9 +94,10 @@ func (t *Tool) Execute(ctx context.Context, parameters map[string]any) (any, err for _, p := range data.WebPages.Value { result := Result{ - Title: p.Name, - Content: p.Snippet, - Location: p.URL, + URL: p.URL, + + Title: p.Name, + Content: p.Snippet, } results = append(results, result) diff --git a/pkg/tool/bing/models.go b/pkg/tool/bing/models.go index 8d2a8358..110e2abe 100644 --- a/pkg/tool/bing/models.go +++ b/pkg/tool/bing/models.go @@ -1,9 +1,10 @@ package bing type Result struct { - Title string - Content string - Location string + URL string `json:"url"` + + Title string `json:"title"` + Content string `json:"content"` } type SearchResponse struct { diff --git a/pkg/tool/crawler/config.go b/pkg/tool/crawler/config.go index 7bee0e06..c62b06ab 100644 --- a/pkg/tool/crawler/config.go +++ b/pkg/tool/crawler/config.go @@ -1,13 +1,3 @@ package crawler -import ( - "net/http" -) - type Option func(*Tool) - -func WithClient(client *http.Client) Option { - return func(t *Tool) { - t.client = client - } -} diff --git a/pkg/tool/crawler/client.go b/pkg/tool/crawler/tool.go similarity index 95% rename from pkg/tool/crawler/client.go rename to pkg/tool/crawler/tool.go index c46cb19e..310d9ecc 100644 --- a/pkg/tool/crawler/client.go +++ b/pkg/tool/crawler/tool.go @@ -3,7 +3,6 @@ package crawler import ( "context" "errors" - "net/http" "github.com/adrianliechti/llama/pkg/extractor" "github.com/adrianliechti/llama/pkg/tool" @@ -12,15 +11,11 @@ import ( var _ tool.Tool = &Tool{} type Tool struct { - client *http.Client - extractor extractor.Provider } func New(extractor extractor.Provider, options ...Option) (*Tool, error) { t := &Tool{ - client: http.DefaultClient, - extractor: extractor, } diff --git a/pkg/tool/duckduckgo/models.go b/pkg/tool/duckduckgo/models.go index cfd0fceb..193780b4 100644 --- a/pkg/tool/duckduckgo/models.go +++ b/pkg/tool/duckduckgo/models.go @@ -1,7 +1,8 @@ package duckduckgo type Result struct { - Title string - Content string - Location string + URL string `json:"url"` + + Title string `json:"title"` + Content string `json:"content"` } diff --git a/pkg/tool/retriever/config.go b/pkg/tool/retriever/config.go new file mode 100644 index 00000000..57fd2677 --- /dev/null +++ b/pkg/tool/retriever/config.go @@ -0,0 +1,3 @@ +package retriever + +type Option func(*Tool) diff --git a/pkg/tool/retriever/models.go b/pkg/tool/retriever/models.go new file mode 100644 index 00000000..56d302a4 --- /dev/null +++ b/pkg/tool/retriever/models.go @@ -0,0 +1,7 @@ +package retriever + +type Result struct { + Title string `json:"title,omitempty"` + Content string `json:"content,omitempty"` + Location string `json:"location,omitempty"` +} diff --git a/pkg/tool/retriever/tool.go b/pkg/tool/retriever/tool.go new file mode 100644 index 00000000..98e71c80 --- /dev/null +++ b/pkg/tool/retriever/tool.go @@ -0,0 +1,84 @@ +package retriever + +import ( + "context" + "errors" + + "github.com/adrianliechti/llama/pkg/index" + "github.com/adrianliechti/llama/pkg/tool" +) + +var _ tool.Tool = &Tool{} + +type Tool struct { + index index.Provider +} + +func New(index index.Provider, options ...Option) (*Tool, error) { + t := &Tool{ + index: index, + } + + for _, option := range options { + option(t) + } + + if t.index == nil { + return nil, errors.New("missing index provider") + } + + return t, nil +} + +func (t *Tool) Name() string { + return "retriever" +} + +func (t *Tool) Description() string { + return "Query the knowledge base to find relevant documents to answer questions" +} + +func (*Tool) Parameters() any { + return map[string]any{ + "type": "object", + + "properties": map[string]any{ + "query": map[string]any{ + "type": "string", + "description": "The natural language query input. The query input should be clear and standalone", + }, + }, + + "required": []string{"query"}, + } +} + +func (t *Tool) Execute(ctx context.Context, parameters map[string]any) (any, error) { + query, ok := parameters["query"].(string) + + if !ok { + return nil, errors.New("missing query parameter") + } + + options := &index.QueryOptions{} + + data, err := t.index.Query(ctx, query, options) + + if err != nil { + return nil, err + } + + results := []Result{} + + for _, r := range data { + result := Result{ + Title: r.Title, + Content: r.Content, + Location: r.Location, + } + + results = append(results, result) + } + + return results, nil +} diff --git a/pkg/tool/searxng/client.go b/pkg/tool/searxng/client.go index e8dad473..7bb07517 100644 --- a/pkg/tool/searxng/client.go +++ b/pkg/tool/searxng/client.go @@ -96,9 +96,10 @@ func (t *Tool) Execute(ctx context.Context, parameters map[string]any) (any, err for _, r := range data.Results { result := Result{ - Title: r.Title, - Content: r.Content, - Location: r.URL, + URL: r.URL, + + Title: r.Title, + Content: r.Content, } results = append(results, result) diff --git a/pkg/tool/searxng/models.go b/pkg/tool/searxng/models.go index 950051e6..d45fd277 100644 --- a/pkg/tool/searxng/models.go +++ b/pkg/tool/searxng/models.go @@ -1,9 +1,10 @@ package searxng type Result struct { - Title string - Content string - Location string + URL string `json:"url"` + + Title string `json:"title"` + Content string `json:"content"` } type SearchResult struct { diff --git a/pkg/tool/tavily/client.go b/pkg/tool/tavily/client.go index 264624d2..bb2e6f8d 100644 --- a/pkg/tool/tavily/client.go +++ b/pkg/tool/tavily/client.go @@ -99,9 +99,10 @@ func (t *Tool) Execute(ctx context.Context, parameters map[string]any) (any, err for _, r := range data.Results { result := Result{ - Title: r.Title, - Content: r.Content, - Location: r.URL, + URL: r.URL, + + Title: r.Title, + Content: r.Content, } results = append(results, result) diff --git a/pkg/tool/tavily/models.go b/pkg/tool/tavily/models.go index d53322d0..41c52076 100644 --- a/pkg/tool/tavily/models.go +++ b/pkg/tool/tavily/models.go @@ -1,9 +1,10 @@ package tavily type Result struct { - Title string - Content string - Location string + URL string `json:"url"` + + Title string `json:"title"` + Content string `json:"content"` } type SearchResult struct {