Skip to content

Commit

Permalink
feat: add custom endpoint feature
Browse files Browse the repository at this point in the history
  • Loading branch information
zmh-program committed Dec 29, 2023
1 parent 947cab8 commit 1e93ed8
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 8 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ go get -u github.com/Deeptrain-Community/chatnio-api-go
instance := chatnio.NewInstance("sk-...")
// or load from env
instance := chatnio.NewInstanceFromEnv("CHATNIO_TOKEN")

// set custom api endpoint (default: https://api.chatnio.net)
// instance.SetEndpoint("https://example.com/api")
```

- Chat
Expand Down
14 changes: 11 additions & 3 deletions instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,17 @@ func (i *Instance) IsAuthenticated() bool {
return strings.TrimSpace(i.ApiKey) != ""
}

func (i *Instance) GetChatEndpoint() string {
host := utils.TrimPrefixes(i.GetEndpoint(), "http://", "https://")
return fmt.Sprintf("wss://%s/chat", host)
func (i *Instance) GetChatEndpoint() (host string) {
host = i.GetEndpoint()
if strings.HasPrefix(host, "http://") {
host = fmt.Sprintf("ws://%s/chat", strings.TrimPrefix(host, "http://"))
} else if strings.HasPrefix(host, "https://") {
host = fmt.Sprintf("wss://%s/chat", strings.TrimPrefix(host, "https://"))
} else {
host = fmt.Sprintf("wss://%s/chat", host)
}

return
}

func (i *Instance) GetHeaders() utils.Headers {
Expand Down
8 changes: 4 additions & 4 deletions instance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ func init() {
}

func TestInstance_GetEndpoint(t *testing.T) {
if instance.GetEndpoint() != "https://api.chatnio.net" {
t.Error("endpoint is not https://api.chatnio.net")
if len(instance.GetEndpoint()) == 0 {
t.Error("endpoint is not not correct")
}
}

func TestInstance_GetChatEndpoint(t *testing.T) {
if instance.GetChatEndpoint() != "wss://api.chatnio.net/chat" {
if len(instance.GetChatEndpoint()) == 0 {
t.Error("chat endpoint is not correct")
}
}
Expand All @@ -30,7 +30,7 @@ func TestInstance_GetHeaders(t *testing.T) {
}

func TestInstance_Mix(t *testing.T) {
if instance.Mix("/test") != "https://api.chatnio.net/test" {
if len(instance.Mix("/test")) == 0 {
t.Error("mix is not correct")
}
}
Expand Down
2 changes: 1 addition & 1 deletion quota.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func (i *Instance) GetQuota() (float32, error) {
if err != nil {
return 0., err
} else if !quota.Status {
return 0., fmt.Errorf("quota status is false")
return 0., fmt.Errorf("quota query status is false")
}

return quota.Quota, nil
Expand Down

0 comments on commit 1e93ed8

Please sign in to comment.