From 1e93ed87aeb819ed040910db262120c36bc6e64c Mon Sep 17 00:00:00 2001 From: Zhang Minghan Date: Fri, 29 Dec 2023 15:35:49 +0800 Subject: [PATCH] feat: add custom endpoint feature --- README.md | 3 +++ instance.go | 14 +++++++++++--- instance_test.go | 8 ++++---- quota.go | 2 +- 4 files changed, 19 insertions(+), 8 deletions(-) diff --git a/README.md b/README.md index 657e6c6..903f8a9 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,9 @@ go get -u github.com/Deeptrain-Community/chatnio-api-go instance := chatnio.NewInstance("sk-...") // or load from env instance := chatnio.NewInstanceFromEnv("CHATNIO_TOKEN") + +// set custom api endpoint (default: https://api.chatnio.net) +// instance.SetEndpoint("https://example.com/api") ``` - Chat diff --git a/instance.go b/instance.go index 52f6e09..298be5c 100644 --- a/instance.go +++ b/instance.go @@ -45,9 +45,17 @@ func (i *Instance) IsAuthenticated() bool { return strings.TrimSpace(i.ApiKey) != "" } -func (i *Instance) GetChatEndpoint() string { - host := utils.TrimPrefixes(i.GetEndpoint(), "http://", "https://") - return fmt.Sprintf("wss://%s/chat", host) +func (i *Instance) GetChatEndpoint() (host string) { + host = i.GetEndpoint() + if strings.HasPrefix(host, "http://") { + host = fmt.Sprintf("ws://%s/chat", strings.TrimPrefix(host, "http://")) + } else if strings.HasPrefix(host, "https://") { + host = fmt.Sprintf("wss://%s/chat", strings.TrimPrefix(host, "https://")) + } else { + host = fmt.Sprintf("wss://%s/chat", host) + } + + return } func (i *Instance) GetHeaders() utils.Headers { diff --git a/instance_test.go b/instance_test.go index ef92de8..719b707 100644 --- a/instance_test.go +++ b/instance_test.go @@ -11,13 +11,13 @@ func init() { } func TestInstance_GetEndpoint(t *testing.T) { - if instance.GetEndpoint() != "https://api.chatnio.net" { - t.Error("endpoint is not https://api.chatnio.net") + if len(instance.GetEndpoint()) == 0 { + t.Error("endpoint is not not correct") } } func TestInstance_GetChatEndpoint(t *testing.T) { - if instance.GetChatEndpoint() != "wss://api.chatnio.net/chat" { + if len(instance.GetChatEndpoint()) == 0 { t.Error("chat endpoint is not correct") } } @@ -30,7 +30,7 @@ func TestInstance_GetHeaders(t *testing.T) { } func TestInstance_Mix(t *testing.T) { - if instance.Mix("/test") != "https://api.chatnio.net/test" { + if len(instance.Mix("/test")) == 0 { t.Error("mix is not correct") } } diff --git a/quota.go b/quota.go index 71fbb04..6cd3c02 100644 --- a/quota.go +++ b/quota.go @@ -40,7 +40,7 @@ func (i *Instance) GetQuota() (float32, error) { if err != nil { return 0., err } else if !quota.Status { - return 0., fmt.Errorf("quota status is false") + return 0., fmt.Errorf("quota query status is false") } return quota.Quota, nil