From d5c570d3bb5ce039b4d05a09d1f5195c89d1d601 Mon Sep 17 00:00:00 2001 From: zhenghaoz Date: Sun, 9 Feb 2025 17:57:36 +0800 Subject: [PATCH] dashboard: support chat with LLM (#936) --- .github/workflows/build_test.yml | 6 +- common/mock/openai.go | 45 +- common/mock/openai_test.go | 38 +- config/config.go | 11 +- config/config.toml | 14 + config/config_test.go | 5 + go.mod | 16 +- go.sum | 24 +- logics/item_to_item.go | 2 +- master/master.go | 9 + master/rest.go | 52 +++ master/rest_test.go | 770 +++++++++++++++---------------- 12 files changed, 561 insertions(+), 431 deletions(-) diff --git a/.github/workflows/build_test.yml b/.github/workflows/build_test.yml index e64b52f3d..ba8ea03f0 100644 --- a/.github/workflows/build_test.yml +++ b/.github/workflows/build_test.yml @@ -287,10 +287,10 @@ jobs: name: lint runs-on: ubuntu-latest steps: - - uses: actions/setup-go@v4 - with: - go-version: 1.23.x - uses: actions/checkout@v2 + - uses: actions/setup-go@v5 + with: + go-version-file: ./go.mod - name: golangci-lint uses: golangci/golangci-lint-action@v2 with: diff --git a/common/mock/openai.go b/common/mock/openai.go index 5cf0cd87c..d347f5b33 100644 --- a/common/mock/openai.go +++ b/common/mock/openai.go @@ -15,6 +15,8 @@ package mock import ( + "bytes" + "encoding/json" "fmt" "github.com/emicklei/go-restful/v3" "github.com/sashabaranov/go-openai" @@ -28,16 +30,15 @@ type OpenAIServer struct { authToken string ready chan struct{} - mockChatCompletion string - mockEmbeddings []float32 + mockEmbeddings []float32 } func NewOpenAIServer() *OpenAIServer { s := &OpenAIServer{} ws := new(restful.WebService) ws.Path("/v1"). - Consumes(restful.MIME_XML, restful.MIME_JSON). - Produces(restful.MIME_JSON, restful.MIME_XML) + Consumes(restful.MIME_JSON). + Produces(restful.MIME_JSON, "text/event-stream") ws.Route(ws.POST("chat/completions"). Reads(openai.ChatCompletionRequest{}). Writes(openai.ChatCompletionResponse{}). @@ -80,10 +81,6 @@ func (s *OpenAIServer) Close() error { return s.httpServer.Close() } -func (s *OpenAIServer) ChatCompletion(mock string) { - s.mockChatCompletion = mock -} - func (s *OpenAIServer) Embeddings(embeddings []float32) { s.mockEmbeddings = embeddings } @@ -95,13 +92,31 @@ func (s *OpenAIServer) chatCompletion(req *restful.Request, resp *restful.Respon _ = resp.WriteError(http.StatusBadRequest, err) return } - _ = resp.WriteEntity(openai.ChatCompletionResponse{ - Choices: []openai.ChatCompletionChoice{{ - Message: openai.ChatCompletionMessage{ - Content: s.mockChatCompletion, - }, - }}, - }) + if r.Stream { + content := r.Messages[0].Content + for i := 0; i < len(content); i += 8 { + buf := bytes.NewBuffer(nil) + buf.WriteString("data: ") + encoder := json.NewEncoder(buf) + _ = encoder.Encode(openai.ChatCompletionStreamResponse{ + Choices: []openai.ChatCompletionStreamChoice{{ + Delta: openai.ChatCompletionStreamChoiceDelta{ + Content: content[i:min(i+8, len(content))], + }, + }}, + }) + buf.WriteString("\n") + _, _ = resp.Write(buf.Bytes()) + } + } else { + _ = resp.WriteEntity(openai.ChatCompletionResponse{ + Choices: []openai.ChatCompletionChoice{{ + Message: openai.ChatCompletionMessage{ + Content: r.Messages[0].Content, + }, + }}, + }) + } } func (s *OpenAIServer) embeddings(req *restful.Request, resp *restful.Response) { diff --git a/common/mock/openai_test.go b/common/mock/openai_test.go index 52205bbaa..450d2c6f6 100644 --- a/common/mock/openai_test.go +++ b/common/mock/openai_test.go @@ -16,8 +16,11 @@ package mock import ( "context" + "github.com/juju/errors" "github.com/sashabaranov/go-openai" "github.com/stretchr/testify/suite" + "io" + "strings" "testing" ) @@ -45,7 +48,6 @@ func (suite *OpenAITestSuite) TearDownSuite() { } func (suite *OpenAITestSuite) TestChatCompletion() { - suite.server.ChatCompletion("World") resp, err := suite.client.CreateChatCompletion( context.Background(), openai.ChatCompletionRequest{ @@ -59,7 +61,39 @@ func (suite *OpenAITestSuite) TestChatCompletion() { }, ) suite.NoError(err) - suite.Equal("World", resp.Choices[0].Message.Content) + suite.Equal("Hello", resp.Choices[0].Message.Content) +} + +func (suite *OpenAITestSuite) TestChatCompletionStream() { + content := "In my younger and more vulnerable years my father gave me some advice that I've been turning over in" + + " my mind ever since. Whenever you feel like criticizing anyone, he told me, just remember that all the " + + "people in this world haven't had the advantages that you've had." + stream, err := suite.client.CreateChatCompletionStream( + context.Background(), + openai.ChatCompletionRequest{ + Model: "qwen2.5", + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: content, + }, + }, + Stream: true, + }, + ) + suite.NoError(err) + defer stream.Close() + var buffer strings.Builder + for { + var resp openai.ChatCompletionStreamResponse + resp, err = stream.Recv() + if errors.Is(err, io.EOF) { + suite.Equal(content, buffer.String()) + return + } + suite.NoError(err) + buffer.WriteString(resp.Choices[0].Delta.Content) + } } func (suite *OpenAITestSuite) TestEmbeddings() { diff --git a/config/config.go b/config/config.go index d5db32424..76f718c20 100644 --- a/config/config.go +++ b/config/config.go @@ -63,6 +63,7 @@ type Config struct { Tracing TracingConfig `mapstructure:"tracing"` Experimental ExperimentalConfig `mapstructure:"experimental"` OIDC OIDCConfig `mapstructure:"oidc"` + OpenAI OpenAIConfig `mapstructure:"openai"` } // DatabaseConfig is the configuration for the database. @@ -149,8 +150,9 @@ type NeighborsConfig struct { type ItemToItemConfig struct { Name string `mapstructure:"name" json:"name"` - Type string `mapstructure:"type" json:"type" validate:"oneof=embedding tags users"` + Type string `mapstructure:"type" json:"type" validate:"oneof=embedding tags users llm"` Column string `mapstructure:"column" json:"column" validate:"item_expr"` + Prompt string `mapstructure:"prompt" json:"prompt"` } func (config *ItemToItemConfig) Hash() string { @@ -214,6 +216,13 @@ type OIDCConfig struct { RedirectURL string `mapstructure:"redirect_url" validate:"omitempty,endswith=/callback/oauth2"` } +type OpenAIConfig struct { + BaseURL string `mapstructure:"base_url"` + AuthToken string `mapstructure:"auth_token"` + ChatCompletionModel string `mapstructure:"chat_completion_model"` + EmbeddingsModel string `mapstructure:"embeddings_model"` +} + func GetDefaultConfig() *Config { return &Config{ Database: DatabaseConfig{ diff --git a/config/config.toml b/config/config.toml index 637865b7b..7bbb1a1b3 100644 --- a/config/config.toml +++ b/config/config.toml @@ -309,3 +309,17 @@ client_secret = "" # Gorse dashboard URL and "/callback/oauth2". For example, if the Gorse dashboard URL is # http://localhost:8088, the redirect URL should be: http://localhost:8088/callback/oauth2 redirect_url = "" + +[openai] + +# Base URL of OpenAI API. +base_url = "http://localhost:11434/v1" + +# API key of OpenAI API. +auth_token = "ollama" + +# Name of chat completion model. +chat_completion_model = "qwen2.5" + +# Name of embeddings model. +embeddings_model = "mxbai-embed-large" \ No newline at end of file diff --git a/config/config_test.go b/config/config_test.go index 414646f21..ea1eb9599 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -157,6 +157,11 @@ func TestUnmarshal(t *testing.T) { assert.Equal(t, "client_id", config.OIDC.ClientID) assert.Equal(t, "client_secret", config.OIDC.ClientSecret) assert.Equal(t, "http://localhost:8088/callback/oauth2", config.OIDC.RedirectURL) + // [openai] + assert.Equal(t, "http://localhost:11434/v1", config.OpenAI.BaseURL) + assert.Equal(t, "ollama", config.OpenAI.AuthToken) + assert.Equal(t, "qwen2.5", config.OpenAI.ChatCompletionModel) + assert.Equal(t, "mxbai-embed-large", config.OpenAI.EmbeddingsModel) }) } } diff --git a/go.mod b/go.mod index 82c04ff06..5d232fe0c 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,6 @@ module github.com/zhenghaoz/gorse -go 1.23.5 - -toolchain go1.23.6 +go 1.23.6 require ( github.com/XSAM/otelsql v0.35.0 @@ -24,7 +22,7 @@ require ( github.com/go-viper/mapstructure/v2 v2.2.1 github.com/google/uuid v1.6.0 github.com/gorilla/securecookie v1.1.1 - github.com/gorse-io/dashboard v0.0.0-20250206135652-01a4864452d9 + github.com/gorse-io/dashboard v0.0.0-20250209091713-a70341e78d48 github.com/gorse-io/gorse-go v0.5.0-alpha.1 github.com/haxii/go-swagger-ui v0.0.0-20210203093335-a63a6bbde946 github.com/jaswdr/faker v1.16.0 @@ -37,6 +35,7 @@ require ( github.com/madflojo/testcerts v1.3.0 github.com/mailru/go-clickhouse/v2 v2.0.1-0.20221121001540-b259988ad8e5 github.com/matttproud/golang_protobuf_extensions v1.0.1 + github.com/nikolalohinski/gonja/v2 v2.3.3 github.com/orcaman/concurrent-map v1.0.0 github.com/pkg/errors v0.9.1 github.com/prometheus/client_golang v1.13.0 @@ -66,7 +65,7 @@ require ( go.opentelemetry.io/otel/trace v1.31.0 go.uber.org/atomic v1.10.0 go.uber.org/zap v1.24.0 - golang.org/x/exp v0.0.0-20230905200255-921286631fa9 + golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 golang.org/x/oauth2 v0.22.0 google.golang.org/grpc v1.67.1 google.golang.org/grpc/security/advancedtls v1.0.0 @@ -96,7 +95,7 @@ require ( github.com/chewxy/hm v1.0.0 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect - github.com/dustin/go-humanize v1.0.0 // indirect + github.com/dustin/go-humanize v1.0.1 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect github.com/fsnotify/fsnotify v1.7.0 // indirect github.com/gabriel-vasile/mimetype v1.4.3 // indirect @@ -111,6 +110,7 @@ require ( github.com/golang/protobuf v1.5.3 // indirect github.com/golang/snappy v0.0.4 // indirect github.com/google/flatbuffers v2.0.6+incompatible // indirect + github.com/google/pprof v0.0.0-20240827171923-fa2c70bbbfe5 // indirect github.com/grpc-ecosystem/grpc-gateway/v2 v2.22.0 // indirect github.com/hashicorp/go-version v1.6.0 // indirect github.com/hashicorp/hcl v1.0.0 // indirect @@ -167,13 +167,13 @@ require ( go.uber.org/multierr v1.10.0 // indirect go4.org/unsafe/assume-no-moving-gc v0.0.0-20230525183740-e7c30c78aeb2 // indirect golang.org/x/crypto v0.31.0 // indirect - golang.org/x/mod v0.17.0 // indirect + golang.org/x/mod v0.20.0 // indirect golang.org/x/net v0.33.0 // indirect golang.org/x/sync v0.10.0 // indirect golang.org/x/sys v0.28.0 // indirect golang.org/x/term v0.27.0 // indirect golang.org/x/text v0.21.0 // indirect - golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d // indirect + golang.org/x/tools v0.24.0 // indirect golang.org/x/xerrors v0.0.0-20220907171357-04be3eba64a2 // indirect gonum.org/v1/gonum v0.11.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20241007155032-5fefd90f89a9 // indirect diff --git a/go.sum b/go.sum index f500dd45f..0c136e0ed 100644 --- a/go.sum +++ b/go.sum @@ -124,8 +124,9 @@ github.com/deckarep/golang-set/v2 v2.3.1 h1:vjmkvJt/IV27WXPyYQpAh4bRyWJc5Y435D17 github.com/deckarep/golang-set/v2 v2.3.1/go.mod h1:VAky9rY/yGXJOLEDv3OMci+7wtDpOF4IN+y82NBOac4= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= -github.com/dustin/go-humanize v1.0.0 h1:VSnTsYCnlFHaM2/igO1h6X3HA71jcobQuxemgkq4zYo= github.com/dustin/go-humanize v1.0.0/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk= +github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= +github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= github.com/emicklei/go-restful-openapi/v2 v2.9.0 h1:djsWqjhI0EVYfkLCCX6jZxUkLmYUq2q9tt09ZbixfyE= github.com/emicklei/go-restful-openapi/v2 v2.9.0/go.mod h1:VKNgZyYviM1hnyrjD9RDzP2RuE94xTXxV+u6MGN4v4k= github.com/emicklei/go-restful/v3 v3.7.3/go.mod h1:6n3XBCmQQb25CM2LCACGz8ukIrRry+4bhvbpWn3mrbc= @@ -286,8 +287,8 @@ github.com/google/pprof v0.0.0-20200212024743-f11f1df84d12/go.mod h1:ZgVRPoUq/hf github.com/google/pprof v0.0.0-20200229191704-1ebb73c60ed3/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= github.com/google/pprof v0.0.0-20200430221834-fc25d7d30c6d/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= github.com/google/pprof v0.0.0-20200708004538-1a94d8640e99/go.mod h1:ZgVRPoUq/hfqzAqh7sHMqb3I9Rq5C59dIz2SbBwJ4eM= -github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26 h1:Xim43kblpZXfIBQsbuBVKCudVG457BR2GZFIz3uw3hQ= -github.com/google/pprof v0.0.0-20221118152302-e6195bd50e26/go.mod h1:dDKJzRmX4S37WGHujM7tX//fmj1uioxKzKxz3lo4HJo= +github.com/google/pprof v0.0.0-20240827171923-fa2c70bbbfe5 h1:5iH8iuqE5apketRbSFBy+X1V0o+l+8NF1avt4HWl7cA= +github.com/google/pprof v0.0.0-20240827171923-fa2c70bbbfe5/go.mod h1:vavhavw2zAxS5dIdcRluK6cSGGPlZynqzFM8NdvU144= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -303,8 +304,8 @@ github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyC github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= github.com/gorse-io/clickhouse v0.3.3-0.20220715124633-688011a495bb h1:z/oOWE+Vy0PLcwIulZmIug4FtmvE3dJ1YOGprLeHwwY= github.com/gorse-io/clickhouse v0.3.3-0.20220715124633-688011a495bb/go.mod h1:iILWzbul8U+gsf4kqbheF2QzBmdvVp63mloGGK8emDI= -github.com/gorse-io/dashboard v0.0.0-20250206135652-01a4864452d9 h1:Eh7AzLEERcKDk/CczD+eUHKhNg/X9Bob+47JIcU7/3M= -github.com/gorse-io/dashboard v0.0.0-20250206135652-01a4864452d9/go.mod h1:lv2bu311bjIJeRfY+6hiIaw20M6fLxT4ma9Ye+bpwGY= +github.com/gorse-io/dashboard v0.0.0-20250209091713-a70341e78d48 h1:kfCK07ae/+NvxlcPqh0SpaXxkDlceqSmamsX7t/E4+w= +github.com/gorse-io/dashboard v0.0.0-20250209091713-a70341e78d48/go.mod h1:lv2bu311bjIJeRfY+6hiIaw20M6fLxT4ma9Ye+bpwGY= github.com/gorse-io/gorgonia v0.0.0-20230817132253-6dd1dbf95849 h1:Hwywr6NxzYeZYn35KwOsw7j8ZiMT60TBzpbn1MbEido= github.com/gorse-io/gorgonia v0.0.0-20230817132253-6dd1dbf95849/go.mod h1:TtVGAt7ENNmgBnC0JA68CAjIDCEtcqaRHvnkAWJ/Fu0= github.com/gorse-io/gorse-go v0.5.0-alpha.1 h1:QBWKGAbSKNAWnieXVIdQiE0lLGvKXfFFAFPOQEkPW/E= @@ -487,6 +488,7 @@ github.com/montanaflynn/stats v0.7.1/go.mod h1:etXPPgVO6n31NxCd9KQUMvCM+ve0ruNzt github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U= github.com/niemeyer/pretty v0.0.0-20200227124842-a10e7caefd8e/go.mod h1:zD1mROLANZcx1PVRCS0qkT7pwLkGfwJo4zjcN/Tysno= +github.com/nikolalohinski/gonja/v2 v2.3.3/go.mod h1:8KC3RlefxnOaY5P4rH5erdwV0/owS83U615cSnDLYFs= github.com/openzipkin/zipkin-go v0.4.1 h1:kNd/ST2yLLWhaWrkgchya40TJabe8Hioj9udfPcEO5A= github.com/openzipkin/zipkin-go v0.4.1/go.mod h1:qY0VqDSN1pOBN94dBc6w2GJlWLiovAyg7Qt6/I9HecM= github.com/orcaman/concurrent-map v1.0.0 h1:I/2A2XPCb4IuQWcQhBhSwGfiuybl/J0ev9HDbW65HOY= @@ -737,8 +739,8 @@ golang.org/x/exp v0.0.0-20191227195350-da58074b4299/go.mod h1:2RIsYlXP63K8oxa1u0 golang.org/x/exp v0.0.0-20200119233911-0405dc783f0a/go.mod h1:2RIsYlXP63K8oxa1u096TMicItID8zy7Y6sNkU49FU4= golang.org/x/exp v0.0.0-20200207192155-f17229e696bd/go.mod h1:J/WKrq2StrnmMY6+EHIKF9dgMWnmCNThgcyBT1FY9mM= golang.org/x/exp v0.0.0-20200224162631-6cc2880d07d6/go.mod h1:3jZMyOhIsHpP37uCMkUooju7aAi5cS1Q23tOzKc+0MU= -golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= -golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= +golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56 h1:2dVuKD2vS7b0QIHQbpyTISPd0LeHDbnYEryqj5Q1ug8= +golang.org/x/exp v0.0.0-20240719175910-8a7402abbf56/go.mod h1:M4RDyNAINzryxdtnbRXRL/OHtkFuWGRjvuhBJpk2IlY= golang.org/x/image v0.0.0-20180708004352-c73c2afc3b81/go.mod h1:ux5Hcp/YLpHSI86hEcLt0YII63i6oz57MZXIpbrjZUs= golang.org/x/image v0.0.0-20190227222117-0694c2d4d067/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js= golang.org/x/image v0.0.0-20190802002840-cff245a6509b/go.mod h1:FeLwcggjj3mMvU+oOTbSwawSJRM1uh48EjtB4UJZlP0= @@ -774,8 +776,8 @@ golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= golang.org/x/mod v0.5.1/go.mod h1:5OXOZSfqPIIbmVBIIKWRFfZjPR0E5r58TLhUjH0a2Ro= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= -golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= -golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/mod v0.20.0 h1:utOm6MM3R3dnawAiJgn0y+xvuYRsm1RKM/4giyfDgV0= +golang.org/x/mod v0.20.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= @@ -986,8 +988,8 @@ golang.org/x/tools v0.1.4/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= golang.org/x/tools v0.1.9/go.mod h1:nABZi5QlRsZVlzPpHl034qft6wpY4eDcsTt5AaioBiU= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= -golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= +golang.org/x/tools v0.24.0 h1:J1shsA93PJUEVaUSaay7UXAyE8aimq3GW0pjlolpa24= +golang.org/x/tools v0.24.0/go.mod h1:YhNqVBIfWHdzvTLs0d8LCuMhkKUgSUKldakyV7W/WDQ= golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/logics/item_to_item.go b/logics/item_to_item.go index 2e602b834..0801a25ae 100644 --- a/logics/item_to_item.go +++ b/logics/item_to_item.go @@ -105,7 +105,7 @@ type embeddingItemToItem struct { dimension int } -func newEmbeddingItemToItem(cfg config.ItemToItemConfig, n int, timestamp time.Time) (ItemToItem, error) { +func newEmbeddingItemToItem(cfg config.ItemToItemConfig, n int, timestamp time.Time) (*embeddingItemToItem, error) { // Compile column expression columnFunc, err := expr.Compile(cfg.Column, expr.Env(map[string]any{ "item": data.Item{}, diff --git a/master/master.go b/master/master.go index e5a5223f2..40ce6a32c 100644 --- a/master/master.go +++ b/master/master.go @@ -29,6 +29,7 @@ import ( "github.com/emicklei/go-restful/v3" "github.com/jellydator/ttlcache/v3" "github.com/juju/errors" + "github.com/sashabaranov/go-openai" "github.com/zhenghaoz/gorse/base" "github.com/zhenghaoz/gorse/base/encoding" "github.com/zhenghaoz/gorse/base/log" @@ -71,6 +72,7 @@ type Master struct { jobsScheduler *task.JobsScheduler cacheFile string managedMode bool + openAIClient *openai.Client // cluster meta cache metaStore meta.Database @@ -116,6 +118,7 @@ type Master struct { // NewMaster creates a master node. func NewMaster(cfg *config.Config, cacheFile string, managedMode bool) *Master { rand.Seed(time.Now().UnixNano()) + // setup trace provider tp, err := cfg.Tracing.NewTracerProvider() if err != nil { @@ -124,12 +127,18 @@ func NewMaster(cfg *config.Config, cacheFile string, managedMode bool) *Master { otel.SetTracerProvider(tp) otel.SetErrorHandler(log.GetErrorHandler()) otel.SetTextMapPropagator(propagation.NewCompositeTextMapPropagator(propagation.TraceContext{}, propagation.Baggage{})) + + // setup OpenAI client + clientConfig := openai.DefaultConfig(cfg.OpenAI.AuthToken) + clientConfig.BaseURL = cfg.OpenAI.BaseURL + m := &Master{ // create task monitor cacheFile: cacheFile, managedMode: managedMode, jobsScheduler: task.NewJobsScheduler(cfg.Master.NumJobs), tracer: progress.NewTracer("master"), + openAIClient: openai.NewClientWithConfig(clientConfig), // default ranking model rankingModelName: "bpr", rankingModelSearcher: ranking.NewModelSearcher( diff --git a/master/rest.go b/master/rest.go index 54bb6212c..7bd1299fd 100644 --- a/master/rest.go +++ b/master/rest.go @@ -37,6 +37,7 @@ import ( "github.com/juju/errors" "github.com/rakyll/statik/fs" "github.com/samber/lo" + "github.com/sashabaranov/go-openai" "github.com/zhenghaoz/gorse/base" "github.com/zhenghaoz/gorse/base/log" "github.com/zhenghaoz/gorse/base/progress" @@ -221,6 +222,7 @@ func (m *Master) StartHttpServer() { container.Handle("/api/bulk/feedback", http.HandlerFunc(m.importExportFeedback)) container.Handle("/api/dump", http.HandlerFunc(m.dump)) container.Handle("/api/restore", http.HandlerFunc(m.restore)) + container.Handle("/api/chat", http.HandlerFunc(m.chat)) if m.workerScheduleHandler == nil { container.Handle("/api/admin/schedule", http.HandlerFunc(m.scheduleAPIHandler)) } else { @@ -1629,3 +1631,53 @@ func (m *Master) handleOAuth2Callback(w http.ResponseWriter, r *http.Request) { zap.String("email", claims.Email)) } } + +func (m *Master) chat(response http.ResponseWriter, request *http.Request) { + if !m.checkAdmin(request) { + writeError(response, http.StatusUnauthorized, "unauthorized") + return + } + content, err := io.ReadAll(request.Body) + if err != nil { + writeError(response, http.StatusInternalServerError, err.Error()) + return + } + stream, err := m.openAIClient.CreateChatCompletionStream( + request.Context(), + openai.ChatCompletionRequest{ + Model: m.Config.OpenAI.ChatCompletionModel, + Messages: []openai.ChatCompletionMessage{ + { + Role: openai.ChatMessageRoleUser, + Content: string(content), + }, + }, + Stream: true, + }, + ) + if err != nil { + writeError(response, http.StatusInternalServerError, err.Error()) + return + } + // read response + defer stream.Close() + for { + var resp openai.ChatCompletionStreamResponse + resp, err = stream.Recv() + if errors.Is(err, io.EOF) { + return + } + if err != nil { + writeError(response, http.StatusInternalServerError, err.Error()) + return + } + if _, err = response.Write([]byte(resp.Choices[0].Delta.Content)); err != nil { + log.Logger().Error("failed to write response", zap.Error(err)) + return + } + // flush response + if f, ok := response.(http.Flusher); ok { + f.Flush() + } + } +} diff --git a/master/rest_test.go b/master/rest_test.go index 757dc4d2b..d940031f4 100644 --- a/master/rest_test.go +++ b/master/rest_test.go @@ -31,8 +31,11 @@ import ( "github.com/go-viper/mapstructure/v2" "github.com/juju/errors" "github.com/samber/lo" + "github.com/sashabaranov/go-openai" "github.com/steinfletcher/apitest" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" + "github.com/zhenghaoz/gorse/common/mock" "github.com/zhenghaoz/gorse/config" "github.com/zhenghaoz/gorse/model/click" "github.com/zhenghaoz/gorse/model/ranking" @@ -49,59 +52,6 @@ const ( mockMasterPassword = "pass" ) -type mockServer struct { - handler *restful.Container - Master -} - -func newMockServer(t *testing.T) (*mockServer, string) { - s := new(mockServer) - // open database - var err error - s.Settings = config.NewSettings() - s.metaStore, err = meta.Open(fmt.Sprintf("sqlite://%s/meta.db", t.TempDir()), s.Config.Master.MetaTimeout) - assert.NoError(t, err) - s.DataClient, err = data.Open(fmt.Sprintf("sqlite://%s/data.db", t.TempDir()), "") - assert.NoError(t, err) - s.CacheClient, err = cache.Open(fmt.Sprintf("sqlite://%s/cache.db", t.TempDir()), "") - assert.NoError(t, err) - // init database - err = s.metaStore.Init() - assert.NoError(t, err) - err = s.DataClient.Init() - assert.NoError(t, err) - err = s.CacheClient.Init() - assert.NoError(t, err) - // create server - s.Config = config.GetDefaultConfig() - s.Config.Master.DashboardUserName = mockMasterUsername - s.Config.Master.DashboardPassword = mockMasterPassword - s.WebService = new(restful.WebService) - s.CreateWebService() - s.RestServer.CreateWebService() - // create handler - s.handler = restful.NewContainer() - s.handler.Add(s.WebService) - // login - req, err := http.NewRequest("POST", "/login", - strings.NewReader(fmt.Sprintf("user_name=%s&password=%s", mockMasterUsername, mockMasterPassword))) - assert.NoError(t, err) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - resp := httptest.NewRecorder() - s.login(resp, req) - assert.Equal(t, http.StatusFound, resp.Code) - return s, resp.Header().Get("Set-Cookie") -} - -func (s *mockServer) Close(t *testing.T) { - err := s.metaStore.Close() - assert.NoError(t, err) - err = s.DataClient.Close() - assert.NoError(t, err) - err = s.CacheClient.Close() - assert.NoError(t, err) -} - func marshal(t *testing.T, v interface{}) string { s, err := json.Marshal(v) assert.NoError(t, err) @@ -125,9 +75,73 @@ func convertToMapStructure(t *testing.T, v interface{}) map[string]interface{} { return m } -func TestMaster_ExportUsers(t *testing.T) { - s, cookie := newMockServer(t) - defer s.Close(t) +type MasterAPITestSuite struct { + suite.Suite + Master + handler *restful.Container + openAIServer *mock.OpenAIServer + cookie string +} + +func (suite *MasterAPITestSuite) SetupTest() { + // open database + var err error + suite.Settings = config.NewSettings() + suite.metaStore, err = meta.Open(fmt.Sprintf("sqlite://%s/meta.db", suite.T().TempDir()), suite.Config.Master.MetaTimeout) + suite.NoError(err) + suite.DataClient, err = data.Open(fmt.Sprintf("sqlite://%s/data.db", suite.T().TempDir()), "") + suite.NoError(err) + suite.CacheClient, err = cache.Open(fmt.Sprintf("sqlite://%s/cache.db", suite.T().TempDir()), "") + suite.NoError(err) + // init database + err = suite.metaStore.Init() + suite.NoError(err) + err = suite.DataClient.Init() + suite.NoError(err) + err = suite.CacheClient.Init() + suite.NoError(err) + // create server + suite.Config = config.GetDefaultConfig() + suite.Config.Master.DashboardUserName = mockMasterUsername + suite.Config.Master.DashboardPassword = mockMasterPassword + suite.WebService = new(restful.WebService) + suite.CreateWebService() + suite.RestServer.CreateWebService() + // create handler + suite.handler = restful.NewContainer() + suite.handler.Add(suite.WebService) + // creat mock AI server + suite.openAIServer = mock.NewOpenAIServer() + go func() { + _ = suite.openAIServer.Start() + }() + suite.openAIServer.Ready() + clientConfig := openai.DefaultConfig(suite.openAIServer.AuthToken()) + clientConfig.BaseURL = suite.openAIServer.BaseURL() + suite.openAIClient = openai.NewClientWithConfig(clientConfig) + // login + req, err := http.NewRequest("POST", "/login", + strings.NewReader(fmt.Sprintf("user_name=%s&password=%s", mockMasterUsername, mockMasterPassword))) + suite.NoError(err) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + resp := httptest.NewRecorder() + suite.login(resp, req) + suite.Equal(http.StatusFound, resp.Code) + suite.cookie = resp.Header().Get("Set-Cookie") +} + +func (suite *MasterAPITestSuite) TearDownTest() { + err := suite.metaStore.Close() + suite.NoError(err) + err = suite.DataClient.Close() + suite.NoError(err) + err = suite.CacheClient.Close() + suite.NoError(err) + err = suite.openAIServer.Close() + suite.NoError(err) +} + +func (suite *MasterAPITestSuite) TestExportUsers() { ctx := context.Background() // insert users users := []data.User{ @@ -135,22 +149,20 @@ func TestMaster_ExportUsers(t *testing.T) { {UserId: "2", Labels: map[string]any{"gender": "male", "job": "lawyer"}}, {UserId: "3", Labels: map[string]any{"gender": "female", "job": "teacher"}}, } - err := s.DataClient.BatchInsertUsers(ctx, users) - assert.NoError(t, err) + err := suite.DataClient.BatchInsertUsers(ctx, users) + suite.NoError(err) // send request req := httptest.NewRequest("GET", "https://example.com/", nil) - req.Header.Set("Cookie", cookie) + req.Header.Set("Cookie", suite.cookie) w := httptest.NewRecorder() - s.importExportUsers(w, req) - assert.Equal(t, http.StatusOK, w.Result().StatusCode) - assert.Equal(t, "application/jsonl", w.Header().Get("Content-Type")) - assert.Equal(t, "attachment;filename=users.jsonl", w.Header().Get("Content-Disposition")) - assert.Equal(t, marshalJSONLines(t, users), w.Body.String()) + suite.importExportUsers(w, req) + suite.Equal(http.StatusOK, w.Result().StatusCode) + suite.Equal("application/jsonl", w.Header().Get("Content-Type")) + suite.Equal("attachment;filename=users.jsonl", w.Header().Get("Content-Disposition")) + suite.Equal(marshalJSONLines(suite.T(), users), w.Body.String()) } -func TestMaster_ExportItems(t *testing.T) { - s, cookie := newMockServer(t) - defer s.Close(t) +func (suite *MasterAPITestSuite) TestExportItems() { ctx := context.Background() // insert items items := []data.Item{ @@ -179,23 +191,20 @@ func TestMaster_ExportItems(t *testing.T) { Comment: "\"three\"", }, } - err := s.DataClient.BatchInsertItems(ctx, items) - assert.NoError(t, err) + err := suite.DataClient.BatchInsertItems(ctx, items) + suite.NoError(err) // send request req := httptest.NewRequest("GET", "https://example.com/", nil) - req.Header.Set("Cookie", cookie) + req.Header.Set("Cookie", suite.cookie) w := httptest.NewRecorder() - s.importExportItems(w, req) - assert.Equal(t, http.StatusOK, w.Result().StatusCode) - assert.Equal(t, "application/jsonl", w.Header().Get("Content-Type")) - assert.Equal(t, "attachment;filename=items.jsonl", w.Header().Get("Content-Disposition")) - assert.Equal(t, marshalJSONLines(t, items), w.Body.String()) + suite.importExportItems(w, req) + suite.Equal(http.StatusOK, w.Result().StatusCode) + suite.Equal("application/jsonl", w.Header().Get("Content-Type")) + suite.Equal("attachment;filename=items.jsonl", w.Header().Get("Content-Disposition")) + suite.Equal(marshalJSONLines(suite.T(), items), w.Body.String()) } -func TestMaster_ExportFeedback(t *testing.T) { - s, cookie := newMockServer(t) - defer s.Close(t) - +func (suite *MasterAPITestSuite) TestExportFeedback() { ctx := context.Background() // insert feedback feedbacks := []data.Feedback{ @@ -203,77 +212,73 @@ func TestMaster_ExportFeedback(t *testing.T) { {FeedbackKey: data.FeedbackKey{FeedbackType: "read", UserId: "2", ItemId: "6"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "share", UserId: "1", ItemId: "4"}}, } - err := s.DataClient.BatchInsertFeedback(ctx, feedbacks, true, true, true) - assert.NoError(t, err) + err := suite.DataClient.BatchInsertFeedback(ctx, feedbacks, true, true, true) + suite.NoError(err) // send request req := httptest.NewRequest("GET", "https://example.com/", nil) - req.Header.Set("Cookie", cookie) + req.Header.Set("Cookie", suite.cookie) w := httptest.NewRecorder() - s.importExportFeedback(w, req) - assert.Equal(t, http.StatusOK, w.Result().StatusCode) - assert.Equal(t, "application/jsonl", w.Header().Get("Content-Type")) - assert.Equal(t, "attachment;filename=feedback.jsonl", w.Header().Get("Content-Disposition")) - assert.Equal(t, marshalJSONLines(t, feedbacks), w.Body.String()) + suite.importExportFeedback(w, req) + suite.Equal(http.StatusOK, w.Result().StatusCode) + suite.Equal("application/jsonl", w.Header().Get("Content-Type")) + suite.Equal("attachment;filename=feedback.jsonl", w.Header().Get("Content-Disposition")) + suite.Equal(marshalJSONLines(suite.T(), feedbacks), w.Body.String()) } -func TestMaster_ImportUsers(t *testing.T) { - s, cookie := newMockServer(t) - defer s.Close(t) +func (suite *MasterAPITestSuite) TestImportUsers() { ctx := context.Background() // send request buf := bytes.NewBuffer(nil) writer := multipart.NewWriter(buf) file, err := writer.CreateFormFile("file", "users.jsonl") - assert.NoError(t, err) + suite.NoError(err) _, err = file.Write([]byte(`{"UserId":"1","Labels":{"性别":"男","职业":"工程师"}} {"UserId":"2","Labels":{"性别":"男","职业":"律师"}} {"UserId":"3","Labels":{"性别":"女","职业":"教师"}}`)) - assert.NoError(t, err) + suite.NoError(err) err = writer.Close() - assert.NoError(t, err) + suite.NoError(err) req := httptest.NewRequest("POST", "https://example.com/", buf) - req.Header.Set("Cookie", cookie) + req.Header.Set("Cookie", suite.cookie) req.Header.Set("Content-Type", writer.FormDataContentType()) w := httptest.NewRecorder() - s.importExportUsers(w, req) + suite.importExportUsers(w, req) // check - assert.Equal(t, http.StatusOK, w.Result().StatusCode) - assert.JSONEq(t, marshal(t, server.Success{RowAffected: 3}), w.Body.String()) - _, items, err := s.DataClient.GetUsers(ctx, "", 100) - assert.NoError(t, err) - assert.Equal(t, []data.User{ + suite.Equal(http.StatusOK, w.Result().StatusCode) + suite.JSONEq(marshal(suite.T(), server.Success{RowAffected: 3}), w.Body.String()) + _, items, err := suite.DataClient.GetUsers(ctx, "", 100) + suite.NoError(err) + suite.Equal([]data.User{ {UserId: "1", Labels: map[string]any{"性别": "男", "职业": "工程师"}}, {UserId: "2", Labels: map[string]any{"性别": "男", "职业": "律师"}}, {UserId: "3", Labels: map[string]any{"性别": "女", "职业": "教师"}}, }, items) } -func TestMaster_ImportItems(t *testing.T) { - s, cookie := newMockServer(t) - defer s.Close(t) +func (suite *MasterAPITestSuite) TestImportItems() { ctx := context.Background() // send request buf := bytes.NewBuffer(nil) writer := multipart.NewWriter(buf) file, err := writer.CreateFormFile("file", "items.jsonl") - assert.NoError(t, err) + suite.NoError(err) _, err = file.Write([]byte(`{"ItemId":"1","IsHidden":false,"Categories":["x"],"Timestamp":"2020-01-01 01:01:01.000000001 +0000 UTC","Labels":{"类型":["喜剧","科幻"]},"Comment":"one"} {"ItemId":"2","IsHidden":false,"Categories":["x","y"],"Timestamp":"2021-01-01 01:01:01.000000001 +0000 UTC","Labels":{"类型":["卡通","科幻"]},"Comment":"two"} {"ItemId":"3","IsHidden":true,"Timestamp":"2022-01-01 01:01:01.000000001 +0000 UTC","Comment":"three"}`)) - assert.NoError(t, err) + suite.NoError(err) err = writer.Close() - assert.NoError(t, err) + suite.NoError(err) req := httptest.NewRequest("POST", "https://example.com/", buf) - req.Header.Set("Cookie", cookie) + req.Header.Set("Cookie", suite.cookie) req.Header.Set("Content-Type", writer.FormDataContentType()) w := httptest.NewRecorder() - s.importExportItems(w, req) + suite.importExportItems(w, req) // check - assert.Equal(t, http.StatusOK, w.Result().StatusCode) - assert.JSONEq(t, marshal(t, server.Success{RowAffected: 3}), w.Body.String()) - _, items, err := s.DataClient.GetItems(ctx, "", 100, nil) - assert.NoError(t, err) - assert.Equal(t, []data.Item{ + suite.Equal(http.StatusOK, w.Result().StatusCode) + suite.JSONEq(marshal(suite.T(), server.Success{RowAffected: 3}), w.Body.String()) + _, items, err := suite.DataClient.GetItems(ctx, "", 100, nil) + suite.NoError(err) + suite.Equal([]data.Item{ { ItemId: "1", IsHidden: false, @@ -300,41 +305,37 @@ func TestMaster_ImportItems(t *testing.T) { }, items) } -func TestMaster_ImportFeedback(t *testing.T) { - s, cookie := newMockServer(t) - defer s.Close(t) +func (suite *MasterAPITestSuite) TestImportFeedback() { // send request ctx := context.Background() buf := bytes.NewBuffer(nil) writer := multipart.NewWriter(buf) file, err := writer.CreateFormFile("file", "feedback.jsonl") - assert.NoError(t, err) + suite.NoError(err) _, err = file.Write([]byte(`{"FeedbackType":"click","UserId":"0","ItemId":"2","Timestamp":"0001-01-01 00:00:00 +0000 UTC"} {"FeedbackType":"read","UserId":"2","ItemId":"6","Timestamp":"0001-01-01 00:00:00 +0000 UTC"} {"FeedbackType":"share","UserId":"1","ItemId":"4","Timestamp":"0001-01-01 00:00:00 +0000 UTC"}`)) - assert.NoError(t, err) + suite.NoError(err) err = writer.Close() - assert.NoError(t, err) + suite.NoError(err) req := httptest.NewRequest("POST", "https://example.com/", buf) - req.Header.Set("Cookie", cookie) + req.Header.Set("Cookie", suite.cookie) req.Header.Set("Content-Type", writer.FormDataContentType()) w := httptest.NewRecorder() - s.importExportFeedback(w, req) + suite.importExportFeedback(w, req) // check - assert.Equal(t, http.StatusOK, w.Result().StatusCode) - assert.JSONEq(t, marshal(t, server.Success{RowAffected: 3}), w.Body.String()) - _, feedback, err := s.DataClient.GetFeedback(ctx, "", 100, nil, lo.ToPtr(time.Now())) - assert.NoError(t, err) - assert.Equal(t, []data.Feedback{ + suite.Equal(http.StatusOK, w.Result().StatusCode) + suite.JSONEq(marshal(suite.T(), server.Success{RowAffected: 3}), w.Body.String()) + _, feedback, err := suite.DataClient.GetFeedback(ctx, "", 100, nil, lo.ToPtr(time.Now())) + suite.NoError(err) + suite.Equal([]data.Feedback{ {FeedbackKey: data.FeedbackKey{FeedbackType: "click", UserId: "0", ItemId: "2"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "read", UserId: "2", ItemId: "6"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "share", UserId: "1", ItemId: "4"}}, }, feedback) } -func TestMaster_GetCluster(t *testing.T) { - s, cookie := newMockServer(t) - defer s.Close(t) +func (suite *MasterAPITestSuite) TestGetCluster() { // add nodes serverNode := &meta.Node{ UUID: "alan turnin", @@ -350,45 +351,42 @@ func TestMaster_GetCluster(t *testing.T) { Version: "worker_version", UpdateTime: time.Now().UTC(), } - err := s.metaStore.UpdateNode(serverNode) - assert.NoError(t, err) - err = s.metaStore.UpdateNode(workerNode) - assert.NoError(t, err) + err := suite.metaStore.UpdateNode(serverNode) + suite.NoError(err) + err = suite.metaStore.UpdateNode(workerNode) + suite.NoError(err) // get nodes apitest.New(). - Handler(s.handler). + Handler(suite.handler). Get("/api/dashboard/cluster"). - Header("Cookie", cookie). - Expect(t). + Header("Cookie", suite.cookie). + Expect(suite.T()). Status(http.StatusOK). - Body(marshal(t, []*meta.Node{serverNode, workerNode})). + Body(marshal(suite.T(), []*meta.Node{serverNode, workerNode})). End() } -func TestMaster_GetStats(t *testing.T) { - s, cookie := newMockServer(t) - defer s.Close(t) - +func (suite *MasterAPITestSuite) TestGetStats() { ctx := context.Background() // set stats - s.rankingScore = ranking.Score{Precision: 0.1} - s.clickScore = click.Score{Precision: 0.2} - err := s.CacheClient.Set(ctx, cache.Integer(cache.Key(cache.GlobalMeta, cache.NumUsers), 123)) - assert.NoError(t, err) - err = s.CacheClient.Set(ctx, cache.Integer(cache.Key(cache.GlobalMeta, cache.NumItems), 234)) - assert.NoError(t, err) - err = s.CacheClient.Set(ctx, cache.Integer(cache.Key(cache.GlobalMeta, cache.NumValidPosFeedbacks), 345)) - assert.NoError(t, err) - err = s.CacheClient.Set(ctx, cache.Integer(cache.Key(cache.GlobalMeta, cache.NumValidNegFeedbacks), 456)) - assert.NoError(t, err) + suite.rankingScore = ranking.Score{Precision: 0.1} + suite.clickScore = click.Score{Precision: 0.2} + err := suite.CacheClient.Set(ctx, cache.Integer(cache.Key(cache.GlobalMeta, cache.NumUsers), 123)) + suite.NoError(err) + err = suite.CacheClient.Set(ctx, cache.Integer(cache.Key(cache.GlobalMeta, cache.NumItems), 234)) + suite.NoError(err) + err = suite.CacheClient.Set(ctx, cache.Integer(cache.Key(cache.GlobalMeta, cache.NumValidPosFeedbacks), 345)) + suite.NoError(err) + err = suite.CacheClient.Set(ctx, cache.Integer(cache.Key(cache.GlobalMeta, cache.NumValidNegFeedbacks), 456)) + suite.NoError(err) // get stats apitest.New(). - Handler(s.handler). + Handler(suite.handler). Get("/api/dashboard/stats"). - Header("Cookie", cookie). - Expect(t). + Header("Cookie", suite.cookie). + Expect(suite.T()). Status(http.StatusOK). - Body(marshal(t, Status{ + Body(marshal(suite.T(), Status{ NumUsers: 123, NumItems: 234, NumValidPosFeedback: 345, @@ -400,16 +398,13 @@ func TestMaster_GetStats(t *testing.T) { End() } -func TestMaster_GetRates(t *testing.T) { - s, cookie := newMockServer(t) - defer s.Close(t) - +func (suite *MasterAPITestSuite) TestGetRates() { ctx := context.Background() // write rates - s.Config.Recommend.DataSource.PositiveFeedbackTypes = []string{"a", "b"} + suite.Config.Recommend.DataSource.PositiveFeedbackTypes = []string{"a", "b"} // This first measurement should be overwritten. baseTimestamp := time.Now() - err := s.CacheClient.AddTimeSeriesPoints(ctx, []cache.TimeSeriesPoint{ + err := suite.CacheClient.AddTimeSeriesPoints(ctx, []cache.TimeSeriesPoint{ {Name: cache.Key(PositiveFeedbackRate, "a"), Value: 100.0, Timestamp: baseTimestamp.Add(-2 * 24 * time.Hour)}, {Name: cache.Key(PositiveFeedbackRate, "a"), Value: 2.0, Timestamp: baseTimestamp.Add(-2 * 24 * time.Hour)}, {Name: cache.Key(PositiveFeedbackRate, "a"), Value: 2.0, Timestamp: baseTimestamp.Add(-1 * 24 * time.Hour)}, @@ -418,16 +413,16 @@ func TestMaster_GetRates(t *testing.T) { {Name: cache.Key(PositiveFeedbackRate, "b"), Value: 20.0, Timestamp: baseTimestamp.Add(-1 * 24 * time.Hour)}, {Name: cache.Key(PositiveFeedbackRate, "b"), Value: 30.0, Timestamp: baseTimestamp.Add(-0 * 24 * time.Hour)}, }) - assert.NoError(t, err) + suite.NoError(err) // get rates apitest.New(). - Handler(s.handler). + Handler(suite.handler). Get("/api/dashboard/rates"). - Header("Cookie", cookie). - Expect(t). + Header("Cookie", suite.cookie). + Expect(suite.T()). Status(http.StatusOK). - Body(marshal(t, map[string][]cache.TimeSeriesPoint{ + Body(marshal(suite.T(), map[string][]cache.TimeSeriesPoint{ "a": { {Name: cache.Key(PositiveFeedbackRate, "a"), Value: 2.0, Timestamp: baseTimestamp.Add(-2 * 24 * time.Hour)}, {Name: cache.Key(PositiveFeedbackRate, "a"), Value: 2.0, Timestamp: baseTimestamp.Add(-1 * 24 * time.Hour)}, @@ -442,27 +437,23 @@ func TestMaster_GetRates(t *testing.T) { End() } -func TestMaster_GetCategories(t *testing.T) { - s, cookie := newMockServer(t) - defer s.Close(t) +func (suite *MasterAPITestSuite) TestGetCategories() { ctx := context.Background() // insert categories - err := s.CacheClient.SetSet(ctx, cache.ItemCategories, "a", "b", "c") - assert.NoError(t, err) + err := suite.CacheClient.SetSet(ctx, cache.ItemCategories, "a", "b", "c") + suite.NoError(err) // get categories apitest.New(). - Handler(s.handler). + Handler(suite.handler). Get("/api/dashboard/categories"). - Header("Cookie", cookie). - Expect(t). + Header("Cookie", suite.cookie). + Expect(suite.T()). Status(http.StatusOK). - Body(marshal(t, []string{"a", "b", "c"})). + Body(marshal(suite.T(), []string{"a", "b", "c"})). End() } -func TestMaster_GetUsers(t *testing.T) { - s, cookie := newMockServer(t) - defer s.Close(t) +func (suite *MasterAPITestSuite) TestGetUsers() { ctx := context.Background() // add users users := []User{ @@ -471,39 +462,37 @@ func TestMaster_GetUsers(t *testing.T) { {data.User{UserId: "2"}, time.Date(2002, 1, 1, 1, 1, 1, 1, time.UTC), time.Date(2022, 1, 1, 1, 1, 1, 1, time.UTC)}, } for _, user := range users { - err := s.DataClient.BatchInsertUsers(ctx, []data.User{user.User}) - assert.NoError(t, err) - err = s.CacheClient.Set(ctx, cache.Time(cache.Key(cache.LastModifyUserTime, user.UserId), user.LastActiveTime)) - assert.NoError(t, err) - err = s.CacheClient.Set(ctx, cache.Time(cache.Key(cache.LastUpdateUserRecommendTime, user.UserId), user.LastUpdateTime)) - assert.NoError(t, err) + err := suite.DataClient.BatchInsertUsers(ctx, []data.User{user.User}) + suite.NoError(err) + err = suite.CacheClient.Set(ctx, cache.Time(cache.Key(cache.LastModifyUserTime, user.UserId), user.LastActiveTime)) + suite.NoError(err) + err = suite.CacheClient.Set(ctx, cache.Time(cache.Key(cache.LastUpdateUserRecommendTime, user.UserId), user.LastUpdateTime)) + suite.NoError(err) } // get users apitest.New(). - Handler(s.handler). + Handler(suite.handler). Get("/api/dashboard/users"). - Header("Cookie", cookie). - Expect(t). + Header("Cookie", suite.cookie). + Expect(suite.T()). Status(http.StatusOK). - Body(marshal(t, UserIterator{ + Body(marshal(suite.T(), UserIterator{ Cursor: "", Users: users, })). End() // get a user apitest.New(). - Handler(s.handler). + Handler(suite.handler). Get("/api/dashboard/user/1"). - Header("Cookie", cookie). - Expect(t). + Header("Cookie", suite.cookie). + Expect(suite.T()). Status(http.StatusOK). - Body(marshal(t, users[1])). + Body(marshal(suite.T(), users[1])). End() } -func TestServer_SearchDocumentsOfItems(t *testing.T) { - s, cookie := newMockServer(t) - defer s.Close(t) +func (suite *MasterAPITestSuite) TestSearchDocumentsOfItems() { type ListOperator struct { Name string Collection string @@ -521,7 +510,7 @@ func TestServer_SearchDocumentsOfItems(t *testing.T) { {"PopularItemsCategory", cache.NonPersonalized, cache.Popular, "*", "/api/dashboard/non-personalized/popular/"}, } for i, operator := range operators { - t.Run(operator.Name, func(t *testing.T) { + suite.T().Run(operator.Name, func(t *testing.T) { // Put scores scores := []cache.Score{ {Id: strconv.Itoa(i) + "0", Score: 100, Categories: []string{operator.Category}}, @@ -530,27 +519,27 @@ func TestServer_SearchDocumentsOfItems(t *testing.T) { {Id: strconv.Itoa(i) + "3", Score: 97, Categories: []string{operator.Category}}, {Id: strconv.Itoa(i) + "4", Score: 96, Categories: []string{operator.Category}}, } - err := s.CacheClient.AddScores(ctx, operator.Collection, operator.Subset, scores) - assert.NoError(t, err) + err := suite.CacheClient.AddScores(ctx, operator.Collection, operator.Subset, scores) + suite.NoError(err) items := make([]ScoredItem, 0) for _, score := range scores { items = append(items, ScoredItem{Item: data.Item{ItemId: score.Id}, Score: score.Score}) - err = s.DataClient.BatchInsertItems(ctx, []data.Item{{ItemId: score.Id}}) - assert.NoError(t, err) + err = suite.DataClient.BatchInsertItems(ctx, []data.Item{{ItemId: score.Id}}) + suite.NoError(err) } // hide item apitest.New(). - Handler(s.handler). + Handler(suite.handler). Patch("/api/item/"+strconv.Itoa(i)+"3"). - Header("Cookie", cookie). + Header("Cookie", suite.cookie). JSON(data.ItemPatch{IsHidden: proto.Bool(true)}). Expect(t). Status(http.StatusOK). End() apitest.New(). - Handler(s.handler). + Handler(suite.handler). Get(operator.Get). - Header("Cookie", cookie). + Header("Cookie", suite.cookie). Query("category", operator.Category). Expect(t). Status(http.StatusOK). @@ -560,9 +549,7 @@ func TestServer_SearchDocumentsOfItems(t *testing.T) { } } -func TestServer_SearchDocumentsOfUsers(t *testing.T) { - s, cookie := newMockServer(t) - defer s.Close(t) +func (suite *MasterAPITestSuite) TestSearchDocumentsOfUsers() { type ListOperator struct { Prefix string Label string @@ -573,7 +560,7 @@ func TestServer_SearchDocumentsOfUsers(t *testing.T) { {cache.UserToUser, cache.Key(cache.Neighbors, "0"), "/api/dashboard/user-to-user/neighbors/0/"}, } for _, operator := range operators { - t.Logf("test RESTful API: %v", operator.Get) + suite.T().Logf("test RESTful API: %v", operator.Get) // Put scores scores := []cache.Score{ {Id: "0", Score: 100, Categories: []string{""}}, @@ -582,28 +569,26 @@ func TestServer_SearchDocumentsOfUsers(t *testing.T) { {Id: "3", Score: 97, Categories: []string{""}}, {Id: "4", Score: 96, Categories: []string{""}}, } - err := s.CacheClient.AddScores(ctx, operator.Prefix, operator.Label, scores) - assert.NoError(t, err) + err := suite.CacheClient.AddScores(ctx, operator.Prefix, operator.Label, scores) + suite.NoError(err) users := make([]ScoreUser, 0) for _, score := range scores { users = append(users, ScoreUser{User: data.User{UserId: score.Id}, Score: score.Score}) - err = s.DataClient.BatchInsertUsers(ctx, []data.User{{UserId: score.Id}}) - assert.NoError(t, err) + err = suite.DataClient.BatchInsertUsers(ctx, []data.User{{UserId: score.Id}}) + suite.NoError(err) } apitest.New(). - Handler(s.handler). + Handler(suite.handler). Get(operator.Get). - Header("Cookie", cookie). - Expect(t). + Header("Cookie", suite.cookie). + Expect(suite.T()). Status(http.StatusOK). - Body(marshal(t, users)). + Body(marshal(suite.T(), users)). End() } } -func TestServer_Feedback(t *testing.T) { - s, cookie := newMockServer(t) - defer s.Close(t) +func (suite *MasterAPITestSuite) TestFeedback() { ctx := context.Background() // insert feedback feedback := []Feedback{ @@ -614,25 +599,23 @@ func TestServer_Feedback(t *testing.T) { {FeedbackType: "click", UserId: "0", Item: data.Item{ItemId: "8"}}, } for _, v := range feedback { - err := s.DataClient.BatchInsertFeedback(ctx, []data.Feedback{{ + err := suite.DataClient.BatchInsertFeedback(ctx, []data.Feedback{{ FeedbackKey: data.FeedbackKey{FeedbackType: v.FeedbackType, UserId: v.UserId, ItemId: v.Item.ItemId}, }}, true, true, true) - assert.NoError(t, err) + suite.NoError(err) } // get feedback apitest.New(). - Handler(s.handler). + Handler(suite.handler). Get("/api/dashboard/user/0/feedback/click"). - Header("Cookie", cookie). - Expect(t). + Header("Cookie", suite.cookie). + Expect(suite.T()). Status(http.StatusOK). - Body(marshal(t, feedback)). + Body(marshal(suite.T(), feedback)). End() } -func TestServer_GetRecommends(t *testing.T) { - s, cookie := newMockServer(t) - defer s.Close(t) +func (suite *MasterAPITestSuite) TestGetRecommends() { // inset recommendation itemIds := []cache.Score{ {Id: "1", Score: 99, Categories: []string{""}}, @@ -645,150 +628,142 @@ func TestServer_GetRecommends(t *testing.T) { {Id: "8", Score: 92, Categories: []string{""}}, } ctx := context.Background() - err := s.CacheClient.AddScores(ctx, cache.OfflineRecommend, "0", itemIds) - assert.NoError(t, err) + err := suite.CacheClient.AddScores(ctx, cache.OfflineRecommend, "0", itemIds) + suite.NoError(err) // insert feedback feedback := []data.Feedback{ {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "0", ItemId: "2"}}, {FeedbackKey: data.FeedbackKey{FeedbackType: "a", UserId: "0", ItemId: "4"}}, } - err = s.DataClient.BatchInsertFeedback(ctx, feedback, true, true, true) - assert.NoError(t, err) + err = suite.DataClient.BatchInsertFeedback(ctx, feedback, true, true, true) + suite.NoError(err) // insert items for _, item := range itemIds { - err = s.DataClient.BatchInsertItems(ctx, []data.Item{{ItemId: item.Id}}) - assert.NoError(t, err) + err = suite.DataClient.BatchInsertItems(ctx, []data.Item{{ItemId: item.Id}}) + suite.NoError(err) } apitest.New(). - Handler(s.handler). + Handler(suite.handler). Get("/api/dashboard/recommend/0/offline"). - Header("Cookie", cookie). - Expect(t). + Header("Cookie", suite.cookie). + Expect(suite.T()). Status(http.StatusOK). - Body(marshal(t, []data.Item{ + Body(marshal(suite.T(), []data.Item{ {ItemId: "1"}, {ItemId: "3"}, {ItemId: "5"}, {ItemId: "6"}, {ItemId: "7"}, {ItemId: "8"}, })). End() - s.Config.Recommend.Online.FallbackRecommend = []string{"collaborative", "item_based", "user_based", "latest", "popular"} + suite.Config.Recommend.Online.FallbackRecommend = []string{"collaborative", "item_based", "user_based", "latest", "popular"} apitest.New(). - Handler(s.handler). + Handler(suite.handler). Get("/api/dashboard/recommend/0/_"). - Header("Cookie", cookie). - Expect(t). + Header("Cookie", suite.cookie). + Expect(suite.T()). Status(http.StatusOK). - Body(marshal(t, []data.Item{ + Body(marshal(suite.T(), []data.Item{ {ItemId: "1"}, {ItemId: "3"}, {ItemId: "5"}, {ItemId: "6"}, {ItemId: "7"}, {ItemId: "8"}, })). End() } -func TestMaster_Purge(t *testing.T) { - s, cookie := newMockServer(t) - defer s.Close(t) - +func (suite *MasterAPITestSuite) TestPurge() { ctx := context.Background() // insert data - err := s.CacheClient.Set(ctx, cache.String("key", "value")) - assert.NoError(t, err) - ret, err := s.CacheClient.Get(ctx, "key").String() - assert.NoError(t, err) - assert.Equal(t, "value", ret) - - err = s.CacheClient.AddSet(ctx, "set", "a", "b", "c") - assert.NoError(t, err) - set, err := s.CacheClient.GetSet(ctx, "set") - assert.NoError(t, err) - assert.ElementsMatch(t, []string{"a", "b", "c"}, set) - - err = s.CacheClient.AddScores(ctx, "sorted", "", []cache.Score{ + err := suite.CacheClient.Set(ctx, cache.String("key", "value")) + suite.NoError(err) + ret, err := suite.CacheClient.Get(ctx, "key").String() + suite.NoError(err) + suite.Equal("value", ret) + + err = suite.CacheClient.AddSet(ctx, "set", "a", "b", "c") + suite.NoError(err) + set, err := suite.CacheClient.GetSet(ctx, "set") + suite.NoError(err) + suite.ElementsMatch([]string{"a", "b", "c"}, set) + + err = suite.CacheClient.AddScores(ctx, "sorted", "", []cache.Score{ {Id: "a", Score: 1, Categories: []string{""}}, {Id: "b", Score: 2, Categories: []string{""}}, {Id: "c", Score: 3, Categories: []string{""}}}) - assert.NoError(t, err) - z, err := s.CacheClient.SearchScores(ctx, "sorted", "", []string{""}, 0, -1) - assert.NoError(t, err) - assert.ElementsMatch(t, []cache.Score{ + suite.NoError(err) + z, err := suite.CacheClient.SearchScores(ctx, "sorted", "", []string{""}, 0, -1) + suite.NoError(err) + suite.ElementsMatch([]cache.Score{ {Id: "a", Score: 1, Categories: []string{""}}, {Id: "b", Score: 2, Categories: []string{""}}, {Id: "c", Score: 3, Categories: []string{""}}}, z) - err = s.DataClient.BatchInsertFeedback(ctx, lo.Map(lo.Range(100), func(t int, i int) data.Feedback { + err = suite.DataClient.BatchInsertFeedback(ctx, lo.Map(lo.Range(100), func(t int, i int) data.Feedback { return data.Feedback{FeedbackKey: data.FeedbackKey{ FeedbackType: "click", UserId: strconv.Itoa(t), ItemId: strconv.Itoa(t), }} }), true, true, true) - assert.NoError(t, err) - _, users, err := s.DataClient.GetUsers(ctx, "", 100) - assert.NoError(t, err) - assert.Equal(t, 100, len(users)) - _, items, err := s.DataClient.GetItems(ctx, "", 100, nil) - assert.NoError(t, err) - assert.Equal(t, 100, len(items)) - _, feedbacks, err := s.DataClient.GetFeedback(ctx, "", 100, nil, lo.ToPtr(time.Now())) - assert.NoError(t, err) - assert.Equal(t, 100, len(feedbacks)) + suite.NoError(err) + _, users, err := suite.DataClient.GetUsers(ctx, "", 100) + suite.NoError(err) + suite.Equal(100, len(users)) + _, items, err := suite.DataClient.GetItems(ctx, "", 100, nil) + suite.NoError(err) + suite.Equal(100, len(items)) + _, feedbacks, err := suite.DataClient.GetFeedback(ctx, "", 100, nil, lo.ToPtr(time.Now())) + suite.NoError(err) + suite.Equal(100, len(feedbacks)) // purge data req := httptest.NewRequest("POST", "https://example.com/", strings.NewReader("check_list=delete_users,delete_items,delete_feedback,delete_cache")) - req.Header.Set("Cookie", cookie) + req.Header.Set("Cookie", suite.cookie) req.Header.Set("Content-Type", "application/x-www-form-urlencoded") w := httptest.NewRecorder() - s.purge(w, req) - assert.Equal(t, http.StatusOK, w.Code) - - _, err = s.CacheClient.Get(ctx, "key").String() - assert.ErrorIs(t, err, errors.NotFound) - set, err = s.CacheClient.GetSet(ctx, "set") - assert.NoError(t, err) - assert.Empty(t, set) - z, err = s.CacheClient.SearchScores(ctx, "sorted", "", []string{""}, 0, -1) - assert.NoError(t, err) - assert.Empty(t, z) - - _, users, err = s.DataClient.GetUsers(ctx, "", 100) - assert.NoError(t, err) - assert.Empty(t, users) - _, items, err = s.DataClient.GetItems(ctx, "", 100, nil) - assert.NoError(t, err) - assert.Empty(t, items) - _, feedbacks, err = s.DataClient.GetFeedback(ctx, "", 100, nil, lo.ToPtr(time.Now())) - assert.NoError(t, err) - assert.Empty(t, feedbacks) + suite.purge(w, req) + suite.Equal(http.StatusOK, w.Code) + + _, err = suite.CacheClient.Get(ctx, "key").String() + suite.ErrorIs(err, errors.NotFound) + set, err = suite.CacheClient.GetSet(ctx, "set") + suite.NoError(err) + suite.Empty(set) + z, err = suite.CacheClient.SearchScores(ctx, "sorted", "", []string{""}, 0, -1) + suite.NoError(err) + suite.Empty(z) + + _, users, err = suite.DataClient.GetUsers(ctx, "", 100) + suite.NoError(err) + suite.Empty(users) + _, items, err = suite.DataClient.GetItems(ctx, "", 100, nil) + suite.NoError(err) + suite.Empty(items) + _, feedbacks, err = suite.DataClient.GetFeedback(ctx, "", 100, nil, lo.ToPtr(time.Now())) + suite.NoError(err) + suite.Empty(feedbacks) } -func TestMaster_GetConfig(t *testing.T) { - s, cookie := newMockServer(t) - defer s.Close(t) - +func (suite *MasterAPITestSuite) TestGetConfig() { apitest.New(). - Handler(s.handler). + Handler(suite.handler). Get("/api/dashboard/config"). - Header("Cookie", cookie). - Expect(t). + Header("Cookie", suite.cookie). + Expect(suite.T()). Status(http.StatusOK). - Body(marshal(t, formatConfig(convertToMapStructure(t, s.Config)))). + Body(marshal(suite.T(), formatConfig(convertToMapStructure(suite.T(), suite.Config)))). End() - s.Config.Master.DashboardRedacted = true - redactedConfig := formatConfig(convertToMapStructure(t, s.Config)) + suite.Config.Master.DashboardRedacted = true + redactedConfig := formatConfig(convertToMapStructure(suite.T(), suite.Config)) delete(redactedConfig, "database") apitest.New(). - Handler(s.handler). + Handler(suite.handler). Get("/api/dashboard/config"). - Header("Cookie", cookie). - Expect(t). + Header("Cookie", suite.cookie). + Expect(suite.T()). Status(http.StatusOK). - Body(marshal(t, redactedConfig)). + Body(marshal(suite.T(), redactedConfig)). End() } -func TestDumpAndRestore(t *testing.T) { - s, cookie := newMockServer(t) - defer s.Close(t) +func (suite *MasterAPITestSuite) TestDumpAndRestore() { ctx := context.Background() // insert users users := make([]data.User, batchSize+1) @@ -798,8 +773,8 @@ func TestDumpAndRestore(t *testing.T) { Labels: map[string]any{"a": fmt.Sprintf("%d", 2*i+1), "b": fmt.Sprintf("%d", 2*i+2)}, } } - err := s.DataClient.BatchInsertUsers(ctx, users) - assert.NoError(t, err) + err := suite.DataClient.BatchInsertUsers(ctx, users) + suite.NoError(err) // insert items items := make([]data.Item, batchSize+1) for i := range items { @@ -808,8 +783,8 @@ func TestDumpAndRestore(t *testing.T) { Labels: map[string]any{"a": fmt.Sprintf("%d", 2*i+1), "b": fmt.Sprintf("%d", 2*i+2)}, } } - err = s.DataClient.BatchInsertItems(ctx, items) - assert.NoError(t, err) + err = suite.DataClient.BatchInsertItems(ctx, items) + suite.NoError(err) // insert feedback feedback := make([]data.Feedback, batchSize+1) for i := range feedback { @@ -821,47 +796,45 @@ func TestDumpAndRestore(t *testing.T) { }, } } - err = s.DataClient.BatchInsertFeedback(ctx, feedback, true, true, true) - assert.NoError(t, err) + err = suite.DataClient.BatchInsertFeedback(ctx, feedback, true, true, true) + suite.NoError(err) // dump data req := httptest.NewRequest("GET", "https://example.com/", nil) - req.Header.Set("Cookie", cookie) + req.Header.Set("Cookie", suite.cookie) w := httptest.NewRecorder() - s.dump(w, req) - assert.Equal(t, http.StatusOK, w.Code) + suite.dump(w, req) + suite.Equal(http.StatusOK, w.Code) // restore data - err = s.DataClient.Purge() - assert.NoError(t, err) + err = suite.DataClient.Purge() + suite.NoError(err) req = httptest.NewRequest("POST", "https://example.com/", bytes.NewReader(w.Body.Bytes())) - req.Header.Set("Cookie", cookie) + req.Header.Set("Cookie", suite.cookie) req.Header.Set("Content-Type", "application/octet-stream") w = httptest.NewRecorder() - s.restore(w, req) - assert.Equal(t, http.StatusOK, w.Code) + suite.restore(w, req) + suite.Equal(http.StatusOK, w.Code) // check data - _, returnUsers, err := s.DataClient.GetUsers(ctx, "", len(users)) - assert.NoError(t, err) - if assert.Equal(t, len(users), len(returnUsers)) { - assert.Equal(t, users, returnUsers) + _, returnUsers, err := suite.DataClient.GetUsers(ctx, "", len(users)) + suite.NoError(err) + if suite.Equal(len(users), len(returnUsers)) { + suite.Equal(users, returnUsers) } - _, returnItems, err := s.DataClient.GetItems(ctx, "", len(items), nil) - assert.NoError(t, err) - if assert.Equal(t, len(items), len(returnItems)) { - assert.Equal(t, items, returnItems) + _, returnItems, err := suite.DataClient.GetItems(ctx, "", len(items), nil) + suite.NoError(err) + if suite.Equal(len(items), len(returnItems)) { + suite.Equal(items, returnItems) } - _, returnFeedback, err := s.DataClient.GetFeedback(ctx, "", len(feedback), nil, lo.ToPtr(time.Now())) - assert.NoError(t, err) - if assert.Equal(t, len(feedback), len(returnFeedback)) { - assert.Equal(t, feedback, returnFeedback) + _, returnFeedback, err := suite.DataClient.GetFeedback(ctx, "", len(feedback), nil, lo.ToPtr(time.Now())) + suite.NoError(err) + if suite.Equal(len(feedback), len(returnFeedback)) { + suite.Equal(feedback, returnFeedback) } } -func TestExportAndImport(t *testing.T) { - s, cookie := newMockServer(t) - defer s.Close(t) +func (suite *MasterAPITestSuite) TestExportAndImport() { ctx := context.Background() // insert users users := make([]data.User, batchSize+1) @@ -871,8 +844,8 @@ func TestExportAndImport(t *testing.T) { Labels: map[string]any{"a": fmt.Sprintf("%d", 2*i+1), "b": fmt.Sprintf("%d", 2*i+2)}, } } - err := s.DataClient.BatchInsertUsers(ctx, users) - assert.NoError(t, err) + err := suite.DataClient.BatchInsertUsers(ctx, users) + suite.NoError(err) // insert items items := make([]data.Item, batchSize+1) for i := range items { @@ -881,8 +854,8 @@ func TestExportAndImport(t *testing.T) { Labels: map[string]any{"a": fmt.Sprintf("%d", 2*i+1), "b": fmt.Sprintf("%d", 2*i+2)}, } } - err = s.DataClient.BatchInsertItems(ctx, items) - assert.NoError(t, err) + err = suite.DataClient.BatchInsertItems(ctx, items) + suite.NoError(err) // insert feedback feedback := make([]data.Feedback, batchSize+1) for i := range feedback { @@ -894,93 +867,110 @@ func TestExportAndImport(t *testing.T) { }, } } - err = s.DataClient.BatchInsertFeedback(ctx, feedback, true, true, true) - assert.NoError(t, err) + err = suite.DataClient.BatchInsertFeedback(ctx, feedback, true, true, true) + suite.NoError(err) // export users req := httptest.NewRequest("GET", "https://example.com/", nil) - req.Header.Set("Cookie", cookie) + req.Header.Set("Cookie", suite.cookie) w := httptest.NewRecorder() - s.importExportUsers(w, req) - assert.Equal(t, http.StatusOK, w.Code) + suite.importExportUsers(w, req) + suite.Equal(http.StatusOK, w.Code) usersData := w.Body.Bytes() // export items req = httptest.NewRequest("GET", "https://example.com/", nil) - req.Header.Set("Cookie", cookie) + req.Header.Set("Cookie", suite.cookie) w = httptest.NewRecorder() - s.importExportItems(w, req) - assert.Equal(t, http.StatusOK, w.Code) + suite.importExportItems(w, req) + suite.Equal(http.StatusOK, w.Code) itemsData := w.Body.Bytes() // export feedback req = httptest.NewRequest("GET", "https://example.com/", nil) - req.Header.Set("Cookie", cookie) + req.Header.Set("Cookie", suite.cookie) w = httptest.NewRecorder() - s.importExportFeedback(w, req) - assert.Equal(t, http.StatusOK, w.Code) + suite.importExportFeedback(w, req) + suite.Equal(http.StatusOK, w.Code) feedbackData := w.Body.Bytes() - err = s.DataClient.Purge() - assert.NoError(t, err) + err = suite.DataClient.Purge() + suite.NoError(err) // import users buf := bytes.NewBuffer(nil) writer := multipart.NewWriter(buf) file, err := writer.CreateFormFile("file", "users.jsonl") - assert.NoError(t, err) + suite.NoError(err) _, err = file.Write(usersData) - assert.NoError(t, err) + suite.NoError(err) err = writer.Close() - assert.NoError(t, err) + suite.NoError(err) req = httptest.NewRequest("POST", "https://example.com/", buf) - req.Header.Set("Cookie", cookie) + req.Header.Set("Cookie", suite.cookie) req.Header.Set("Content-Type", writer.FormDataContentType()) w = httptest.NewRecorder() - s.importExportUsers(w, req) - assert.Equal(t, http.StatusOK, w.Code) + suite.importExportUsers(w, req) + suite.Equal(http.StatusOK, w.Code) // import items buf = bytes.NewBuffer(nil) writer = multipart.NewWriter(buf) file, err = writer.CreateFormFile("file", "items.jsonl") - assert.NoError(t, err) + suite.NoError(err) _, err = file.Write(itemsData) - assert.NoError(t, err) + suite.NoError(err) err = writer.Close() - assert.NoError(t, err) + suite.NoError(err) req = httptest.NewRequest("POST", "https://example.com/", buf) - req.Header.Set("Cookie", cookie) + req.Header.Set("Cookie", suite.cookie) req.Header.Set("Content-Type", writer.FormDataContentType()) w = httptest.NewRecorder() - s.importExportItems(w, req) - assert.Equal(t, http.StatusOK, w.Code) + suite.importExportItems(w, req) + suite.Equal(http.StatusOK, w.Code) // import feedback buf = bytes.NewBuffer(nil) writer = multipart.NewWriter(buf) file, err = writer.CreateFormFile("file", "feedback.jsonl") - assert.NoError(t, err) + suite.NoError(err) _, err = file.Write(feedbackData) - assert.NoError(t, err) + suite.NoError(err) err = writer.Close() - assert.NoError(t, err) + suite.NoError(err) req = httptest.NewRequest("POST", "https://example.com/", buf) - req.Header.Set("Cookie", cookie) + req.Header.Set("Cookie", suite.cookie) req.Header.Set("Content-Type", writer.FormDataContentType()) w = httptest.NewRecorder() - s.importExportFeedback(w, req) - assert.Equal(t, http.StatusOK, w.Code) + suite.importExportFeedback(w, req) + suite.Equal(http.StatusOK, w.Code) // check data - _, returnUsers, err := s.DataClient.GetUsers(ctx, "", len(users)) - assert.NoError(t, err) - if assert.Equal(t, len(users), len(returnUsers)) { - assert.Equal(t, users, returnUsers) + _, returnUsers, err := suite.DataClient.GetUsers(ctx, "", len(users)) + suite.NoError(err) + if suite.Equal(len(users), len(returnUsers)) { + suite.Equal(users, returnUsers) } - _, returnItems, err := s.DataClient.GetItems(ctx, "", len(items), nil) - assert.NoError(t, err) - if assert.Equal(t, len(items), len(returnItems)) { - assert.Equal(t, items, returnItems) + _, returnItems, err := suite.DataClient.GetItems(ctx, "", len(items), nil) + suite.NoError(err) + if suite.Equal(len(items), len(returnItems)) { + suite.Equal(items, returnItems) } - _, returnFeedback, err := s.DataClient.GetFeedback(ctx, "", len(feedback), nil, lo.ToPtr(time.Now())) - assert.NoError(t, err) - if assert.Equal(t, len(feedback), len(returnFeedback)) { - assert.Equal(t, feedback, returnFeedback) + _, returnFeedback, err := suite.DataClient.GetFeedback(ctx, "", len(feedback), nil, lo.ToPtr(time.Now())) + suite.NoError(err) + if suite.Equal(len(feedback), len(returnFeedback)) { + suite.Equal(feedback, returnFeedback) } } + +func (suite *MasterAPITestSuite) TestChat() { + content := "In my younger and more vulnerable years my father gave me some advice that I've been turning over in" + + " my mind ever since. \"Whenever you feel like criticizing any one,\" he told me, \" just remember that all " + + "the people in this world haven't had the advantages that you've had.\"" + buf := strings.NewReader(content) + req := httptest.NewRequest("POST", "https://example.com/", buf) + req.Header.Set("Cookie", suite.cookie) + w := httptest.NewRecorder() + suite.chat(w, req) + suite.Equal(http.StatusOK, w.Code, w.Body.String()) + suite.Equal(content, w.Body.String()) +} + +func TestMasterAPI(t *testing.T) { + suite.Run(t, new(MasterAPITestSuite)) +}