Skip to content

Commit

Permalink
add jina embedder
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianliechti committed Sep 15, 2024
1 parent 610a0b8 commit d9d9170
Show file tree
Hide file tree
Showing 5 changed files with 172 additions and 0 deletions.
14 changes: 14 additions & 0 deletions config/config_embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/adrianliechti/llama/pkg/provider/azure"
"github.com/adrianliechti/llama/pkg/provider/cohere"
"github.com/adrianliechti/llama/pkg/provider/huggingface"
"github.com/adrianliechti/llama/pkg/provider/jina"
"github.com/adrianliechti/llama/pkg/provider/llama"
"github.com/adrianliechti/llama/pkg/provider/ollama"
"github.com/adrianliechti/llama/pkg/provider/openai"
Expand Down Expand Up @@ -54,6 +55,9 @@ func createEmbedder(cfg providerConfig, model modelContext) (provider.Embedder,
case "huggingface":
return huggingfaceEmbedder(cfg, model)

case "jina":
return jinaEmbedder(cfg, model)

case "llama":
return llamaEmbedder(cfg, model)

Expand Down Expand Up @@ -98,6 +102,16 @@ func huggingfaceEmbedder(cfg providerConfig, model modelContext) (provider.Embed
return huggingface.NewEmbedder(cfg.URL, options...)
}

func jinaEmbedder(cfg providerConfig, model modelContext) (provider.Embedder, error) {
var options []jina.Option

if cfg.Token != "" {
options = append(options, jina.WithToken(cfg.Token))
}

return jina.NewEmbedder(cfg.URL, options...)
}

func llamaEmbedder(cfg providerConfig, model modelContext) (provider.Embedder, error) {
var options []llama.Option

Expand Down
1 change: 1 addition & 0 deletions config/config_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ func DetectModelType(id string) ModelType {

embedders := []string{
"bge",
"clip",
"embed",
"gte",
"minilm",
Expand Down
28 changes: 28 additions & 0 deletions pkg/provider/jina/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
package jina

import (
"net/http"
)

type Config struct {
url string

token string
model string

client *http.Client
}

type Option func(*Config)

func WithClient(client *http.Client) Option {
return func(c *Config) {
c.client = client
}
}

func WithToken(token string) Option {
return func(c *Config) {
c.token = token
}
}
100 changes: 100 additions & 0 deletions pkg/provider/jina/embedder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
package jina

import (
"context"
"encoding/json"
"errors"
"net/http"
"net/url"
"strings"

"github.com/adrianliechti/llama/pkg/provider"
)

var _ provider.Embedder = (*Embedder)(nil)

type Embedder struct {
*Config
}

func NewEmbedder(url string, options ...Option) (*Embedder, error) {
if url == "" {
url = "https://api.jina.ai"
}

url = strings.TrimRight(url, "/")
url = strings.TrimSuffix(url, "/v1")

cfg := &Config{
client: http.DefaultClient,

url: url,

model: "jina-clip-v1",
}

for _, option := range options {
option(cfg)
}

return &Embedder{
Config: cfg,
}, nil
}

func (e *Embedder) Embed(ctx context.Context, content string) (*provider.Embedding, error) {
body := map[string]any{
"input": []string{
strings.TrimSpace(content),
},
}

u, _ := url.JoinPath(e.url, "/v1/embeddings")

req, _ := http.NewRequestWithContext(ctx, "POST", u, jsonReader(body))
req.Header.Set("Content-Type", "application/json")

if e.token != "" {
req.Header.Set("Authorization", "Bearer "+e.token)
}

resp, err := e.client.Do(req)

if err != nil {
return nil, err
}

defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return nil, convertError(resp)
}

var result EmbeddingList

if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
return nil, err
}

if len(result.Data) == 0 {
return nil, errors.New("no embeddings found")
}

return &provider.Embedding{
Data: result.Data[0].Embedding,
}, nil
}

type EmbeddingList struct {
Object string `json:"object"`

Model string `json:"model"`
Data []Embedding `json:"data"`
}

type Embedding struct {
Object string `json:"object"`

Index int `json:"index"`
Embedding []float32 `json:"embedding"`
}
29 changes: 29 additions & 0 deletions pkg/provider/jina/util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package jina

import (
"bytes"
"encoding/json"
"errors"
"io"
"net/http"
)

func convertError(resp *http.Response) error {
data, _ := io.ReadAll(resp.Body)

if len(data) == 0 {
return errors.New(http.StatusText(resp.StatusCode))
}

return errors.New(string(data))
}

func jsonReader(v any) io.Reader {
b := new(bytes.Buffer)

enc := json.NewEncoder(b)
enc.SetEscapeHTML(false)

enc.Encode(v)
return b
}

0 comments on commit d9d9170

Please sign in to comment.