Skip to content

Commit

Permalink
reranker config
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianliechti committed Sep 15, 2024
1 parent e2bd322 commit a35fcc7
Show file tree
Hide file tree
Showing 5 changed files with 150 additions and 0 deletions.
1 change: 1 addition & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type Config struct {

completer map[string]provider.Completer
embedder map[string]provider.Embedder
reranker map[string]provider.Reranker
renderer map[string]provider.Renderer
synthesizer map[string]provider.Synthesizer
transcriber map[string]provider.Transcriber
Expand Down
11 changes: 11 additions & 0 deletions config/config_model.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ const (
ModelTypeAuto ModelType = ""
ModelTypeCompleter ModelType = "completer"
ModelTypeEmbedder ModelType = "embedder"
ModelTypeReranker ModelType = "reranker"
ModelTypeRenderer ModelType = "renderer"
ModelTypeSynthesizer ModelType = "synthesizer"
ModelTypeTranscriber ModelType = "transcriber"
Expand Down Expand Up @@ -114,6 +115,10 @@ func DetectModelType(id string) ModelType {
"minilm",
}

rerankers := []string{
"rerank",
}

renderers := []string{
"dall-e",
"flux-dev",
Expand Down Expand Up @@ -149,6 +154,12 @@ func DetectModelType(id string) ModelType {
}
}

for _, val := range rerankers {
if strings.Contains(strings.ToLower(id), strings.ToLower(val)) {
return ModelTypeReranker
}
}

for _, val := range renderers {
if strings.Contains(strings.ToLower(id), strings.ToLower(val)) {
return ModelTypeRenderer
Expand Down
9 changes: 9 additions & 0 deletions config/config_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,15 @@ func (cfg *Config) registerProviders(f *configFile) error {

cfg.RegisterEmbedder(p.Type, id, embedder)

case ModelTypeReranker:
reranker, err := createReranker(p, context)

if err != nil {
return err
}

cfg.RegisterReranker(p.Type, id, reranker)

case ModelTypeRenderer:
renderer, err := createRenderer(p, context)

Expand Down
70 changes: 70 additions & 0 deletions config/config_reranker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
package config

import (
"errors"
"strings"

"github.com/adrianliechti/llama/pkg/otel"
"github.com/adrianliechti/llama/pkg/provider"
"github.com/adrianliechti/llama/pkg/provider/huggingface"
"github.com/adrianliechti/llama/pkg/provider/jina"
)

func (cfg *Config) RegisterReranker(name, model string, p provider.Reranker) {
cfg.RegisterModel(model)

if cfg.reranker == nil {
cfg.reranker = make(map[string]provider.Reranker)
}

reranker, ok := p.(otel.ObservableReranker)

if !ok {
reranker = otel.NewReranker(name, model, p)
}

cfg.reranker[model] = reranker
}

func (cfg *Config) Reranker(model string) (provider.Reranker, error) {
if cfg.reranker != nil {
if e, ok := cfg.reranker[model]; ok {
return e, nil
}
}

return nil, errors.New("reranker not found: " + model)
}

func createReranker(cfg providerConfig, model modelContext) (provider.Reranker, error) {
switch strings.ToLower(cfg.Type) {
case "huggingface":
return huggingfaceReranker(cfg, model)

case "jina":
return jinaReranker(cfg, model)

default:
return nil, errors.New("invalid reranker type: " + cfg.Type)
}
}

func huggingfaceReranker(cfg providerConfig, model modelContext) (provider.Reranker, error) {
var options []huggingface.Option

if cfg.Token != "" {
options = append(options, huggingface.WithToken(cfg.Token))
}

return huggingface.NewReranker(cfg.URL, options...)
}

func jinaReranker(cfg providerConfig, model modelContext) (provider.Reranker, error) {
var options []jina.Option

if cfg.Token != "" {
options = append(options, jina.WithToken(cfg.Token))
}

return jina.NewReranker(cfg.URL, options...)
}
59 changes: 59 additions & 0 deletions pkg/otel/provider_reranker.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
package otel

import (
"context"
"strings"

"github.com/adrianliechti/llama/pkg/provider"

"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
)

type ObservableReranker interface {
Observable
provider.Reranker
}

type observableReranker struct {
name string
library string

model string
provider string

reranker provider.Reranker
}

func NewReranker(provider, model string, p provider.Reranker) ObservableReranker {
library := strings.ToLower(provider)

return &observableReranker{
reranker: p,

name: strings.TrimSuffix(strings.ToLower(provider), "-reranker") + "-reranker",
library: library,

model: model,
provider: provider,
}
}

func (p *observableReranker) otelSetup() {
}

func (p *observableReranker) Rerank(ctx context.Context, query string, inputs []string) ([]provider.Result, error) {
ctx, span := otel.Tracer(p.library).Start(ctx, p.name)
defer span.End()

result, err := p.reranker.Rerank(ctx, query, inputs)

meterRequest(ctx, p.library, p.provider, "rerank", p.model)

if EnableDebug {
span.SetAttributes(attribute.String("query", query))
span.SetAttributes(attribute.StringSlice("inputs", inputs))
}

return result, err
}

0 comments on commit a35fcc7

Please sign in to comment.