Skip to content

Commit

Permalink
add retriever tool
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianliechti committed Sep 21, 2024
1 parent d6e39be commit 7a87e03
Show file tree
Hide file tree
Showing 13 changed files with 146 additions and 43 deletions.
31 changes: 24 additions & 7 deletions config/config_tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@ import (
"strings"

"github.com/adrianliechti/llama/pkg/extractor"
"github.com/adrianliechti/llama/pkg/index"
"github.com/adrianliechti/llama/pkg/provider"
"github.com/adrianliechti/llama/pkg/tool"
"github.com/adrianliechti/llama/pkg/tool/bing"
"github.com/adrianliechti/llama/pkg/tool/crawler"
"github.com/adrianliechti/llama/pkg/tool/custom"
"github.com/adrianliechti/llama/pkg/tool/draw"
"github.com/adrianliechti/llama/pkg/tool/duckduckgo"
"github.com/adrianliechti/llama/pkg/tool/retriever"
"github.com/adrianliechti/llama/pkg/tool/searxng"
"github.com/adrianliechti/llama/pkg/tool/tavily"

Expand Down Expand Up @@ -43,9 +45,13 @@ type toolConfig struct {
Token string `yaml:"token"`

Model string `yaml:"model"`

Index string `yaml:"index"`
Extractor string `yaml:"extractor"`
}

type toolContext struct {
Index index.Provider
Renderer provider.Renderer
Extractor extractor.Provider
}
Expand All @@ -56,13 +62,15 @@ func (cfg *Config) registerTools(f *configFile) error {

context := toolContext{}

if t.Model != "" {
if r, err := cfg.Renderer(t.Model); err == nil {
context.Renderer = r
}
if i, err := cfg.Index(t.Index); err == nil {
context.Index = i
}

if r, err := cfg.Renderer(t.Model); err == nil {
context.Renderer = r
}

if e, err := cfg.Extractor(""); err == nil {
if e, err := cfg.Extractor(t.Extractor); err == nil {
context.Extractor = e
}

Expand Down Expand Up @@ -96,12 +104,15 @@ func createTool(cfg toolConfig, context toolContext) (tool.Tool, error) {
case "duckduckgo":
return duckduckgoTool(cfg, context)

case "tavily":
return tavilyTool(cfg, context)
case "retriever":
return retrieverTool(cfg, context)

case "searxng":
return searxngTool(cfg, context)

case "tavily":
return tavilyTool(cfg, context)

case "custom":
return customTool(cfg, context)

Expand Down Expand Up @@ -138,6 +149,12 @@ func duckduckgoTool(cfg toolConfig, context toolContext) (tool.Tool, error) {
return duckduckgo.New(options...)
}

func retrieverTool(cfg toolConfig, context toolContext) (tool.Tool, error) {
var options []retriever.Option

return retriever.New(context.Index, options...)
}

func searxngTool(cfg toolConfig, context toolContext) (tool.Tool, error) {
var options []searxng.Option

Expand Down
7 changes: 4 additions & 3 deletions pkg/tool/bing/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,10 @@ func (t *Tool) Execute(ctx context.Context, parameters map[string]any) (any, err

for _, p := range data.WebPages.Value {
result := Result{
Title: p.Name,
Content: p.Snippet,
Location: p.URL,
URL: p.URL,

Title: p.Name,
Content: p.Snippet,
}

results = append(results, result)
Expand Down
7 changes: 4 additions & 3 deletions pkg/tool/bing/models.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package bing

type Result struct {
Title string
Content string
Location string
URL string `json:"url"`

Title string `json:"title"`
Content string `json:"content"`
}

type SearchResponse struct {
Expand Down
10 changes: 0 additions & 10 deletions pkg/tool/crawler/config.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,3 @@
package crawler

import (
"net/http"
)

type Option func(*Tool)

func WithClient(client *http.Client) Option {
return func(t *Tool) {
t.client = client
}
}
5 changes: 0 additions & 5 deletions pkg/tool/crawler/client.go → pkg/tool/crawler/tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package crawler
import (
"context"
"errors"
"net/http"

"github.com/adrianliechti/llama/pkg/extractor"
"github.com/adrianliechti/llama/pkg/tool"
Expand All @@ -12,15 +11,11 @@ import (
var _ tool.Tool = &Tool{}

type Tool struct {
client *http.Client

extractor extractor.Provider
}

func New(extractor extractor.Provider, options ...Option) (*Tool, error) {
t := &Tool{
client: http.DefaultClient,

extractor: extractor,
}

Expand Down
7 changes: 4 additions & 3 deletions pkg/tool/duckduckgo/models.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package duckduckgo

type Result struct {
Title string
Content string
Location string
URL string `json:"url"`

Title string `json:"title"`
Content string `json:"content"`
}
3 changes: 3 additions & 0 deletions pkg/tool/retriever/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
package retriever

type Option func(*Tool)
7 changes: 7 additions & 0 deletions pkg/tool/retriever/models.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package retriever

type Result struct {
Title string `json:"title,omitempty"`
Content string `json:"content,omitempty"`
Location string `json:"location,omitempty"`
}
84 changes: 84 additions & 0 deletions pkg/tool/retriever/tool.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package retriever

import (
"context"
"errors"

"github.com/adrianliechti/llama/pkg/index"
"github.com/adrianliechti/llama/pkg/tool"
)

var _ tool.Tool = &Tool{}

type Tool struct {
index index.Provider
}

func New(index index.Provider, options ...Option) (*Tool, error) {
t := &Tool{
index: index,
}

for _, option := range options {
option(t)
}

if t.index == nil {
return nil, errors.New("missing index provider")
}

return t, nil
}

func (t *Tool) Name() string {
return "retriever"
}

func (t *Tool) Description() string {
return "Query the knowledge base to find relevant documents to answer questions"
}

func (*Tool) Parameters() any {
return map[string]any{
"type": "object",

"properties": map[string]any{
"query": map[string]any{
"type": "string",
"description": "The natural language query input. The query input should be clear and standalone",
},
},

"required": []string{"query"},
}
}

func (t *Tool) Execute(ctx context.Context, parameters map[string]any) (any, error) {
query, ok := parameters["query"].(string)

if !ok {
return nil, errors.New("missing query parameter")
}

options := &index.QueryOptions{}

data, err := t.index.Query(ctx, query, options)

if err != nil {
return nil, err
}

results := []Result{}

for _, r := range data {
result := Result{
Title: r.Title,
Content: r.Content,
Location: r.Location,
}

results = append(results, result)
}

return results, nil
}
7 changes: 4 additions & 3 deletions pkg/tool/searxng/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,10 @@ func (t *Tool) Execute(ctx context.Context, parameters map[string]any) (any, err

for _, r := range data.Results {
result := Result{
Title: r.Title,
Content: r.Content,
Location: r.URL,
URL: r.URL,

Title: r.Title,
Content: r.Content,
}

results = append(results, result)
Expand Down
7 changes: 4 additions & 3 deletions pkg/tool/searxng/models.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package searxng

type Result struct {
Title string
Content string
Location string
URL string `json:"url"`

Title string `json:"title"`
Content string `json:"content"`
}

type SearchResult struct {
Expand Down
7 changes: 4 additions & 3 deletions pkg/tool/tavily/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,10 @@ func (t *Tool) Execute(ctx context.Context, parameters map[string]any) (any, err

for _, r := range data.Results {
result := Result{
Title: r.Title,
Content: r.Content,
Location: r.URL,
URL: r.URL,

Title: r.Title,
Content: r.Content,
}

results = append(results, result)
Expand Down
7 changes: 4 additions & 3 deletions pkg/tool/tavily/models.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
package tavily

type Result struct {
Title string
Content string
Location string
URL string `json:"url"`

Title string `json:"title"`
Content string `json:"content"`
}

type SearchResult struct {
Expand Down

0 comments on commit 7a87e03

Please sign in to comment.