diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml new file mode 100644 index 0000000..7eea50f --- /dev/null +++ b/.github/workflows/test.yaml @@ -0,0 +1,15 @@ +name: Test +on: [push, pull_request] +jobs: + go-test: + runs-on: ubuntu-latest + steps: + - name: Check out source code + uses: actions/checkout@v3 + - name: Setup + uses: actions/setup-go@v3 + with: + go-version-file: "go.mod" + cache: true + - name: Test + run: go test -v ./... \ No newline at end of file diff --git a/README.md b/README.md index cb154c9..e8f287e 100644 --- a/README.md +++ b/README.md @@ -18,6 +18,7 @@ After you download the file, extract it into a folder and open the `env.example` - `TELEGRAM_ID` (Optional): Your Telegram User ID - If you set this, only you will be able to interact with the bot. - To get your ID, message `@userinfobot` on Telegram. + - Multiple IDs can be provided, separated by commas. - `EDIT_WAIT_SECONDS` (Optional): Amount of seconds to wait between edits - This is set to `1` by default, but you can increase if you start getting a lot of `Too Many Requests` errors. - Save the file, and rename it to `.env`. diff --git a/go.mod b/go.mod index 586700e..e6ef859 100644 --- a/go.mod +++ b/go.mod @@ -5,14 +5,15 @@ go 1.19 require ( github.com/go-telegram-bot-api/telegram-bot-api/v5 v5.5.1 github.com/google/uuid v1.3.0 - github.com/joho/godotenv v1.4.0 github.com/launchdarkly/eventsource v1.7.1 github.com/playwright-community/playwright-go v0.2000.1 github.com/spf13/viper v1.14.0 + github.com/stretchr/testify v1.8.1 ) require ( github.com/danwakefield/fnmatch v0.0.0-20160403171240-cbb64ac3d964 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/fsnotify/fsnotify v1.6.0 // indirect github.com/go-stack/stack v1.8.1 // indirect github.com/hashicorp/hcl v1.0.0 // indirect @@ -20,6 +21,7 @@ require ( github.com/mitchellh/mapstructure v1.5.0 // indirect github.com/pelletier/go-toml v1.9.5 // indirect github.com/pelletier/go-toml/v2 v2.0.5 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/spf13/afero v1.9.2 // indirect github.com/spf13/cast v1.5.0 // indirect github.com/spf13/jwalterweatherman v1.1.0 // indirect diff --git a/go.sum b/go.sum index f45b846..895f011 100644 --- a/go.sum +++ b/go.sum @@ -132,8 +132,6 @@ github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= -github.com/joho/godotenv v1.4.0 h1:3l4+N6zfMWnkbPEXKng2o2/MR5mSwTrBih4ZEkkz1lg= -github.com/joho/godotenv v1.4.0/go.mod h1:f4LDr5Voq0i2e/R5DDNOoa2zzDfwtkZa6DnEwAbqwq4= github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1:6v2b51hI/fHJwM22ozAgKL4VKDeJcHhJFhtBdhmNjmU= github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= @@ -176,6 +174,7 @@ github.com/spf13/viper v1.14.0 h1:Rg7d3Lo706X9tHsJMUjdiwMpHB7W8WnSVOssIY+JElU= github.com/spf13/viper v1.14.0/go.mod h1:WT//axPky3FdvXHzGw33dNdXXXfFQqmEalje+egj8As= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= @@ -184,6 +183,7 @@ github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/subosito/gotenv v1.4.1 h1:jyEFiXpy21Wm81FBN71l9VoMMV8H8jG+qIK3GCpY6Qs= github.com/subosito/gotenv v1.4.1/go.mod h1:ayKnFf/c6rvx/2iiLrJUk1e6plDbT3edrFNGqEflhK0= github.com/yuin/goldmark v1.1.25/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= diff --git a/main.go b/main.go index 6503292..ccc1e77 100644 --- a/main.go +++ b/main.go @@ -5,11 +5,9 @@ import ( "log" "os" "os/signal" - "strconv" "syscall" "time" - "github.com/joho/godotenv" "github.com/m1guelpf/chatgpt-telegram/src/chatgpt" "github.com/m1guelpf/chatgpt-telegram/src/config" "github.com/m1guelpf/chatgpt-telegram/src/session" @@ -17,42 +15,34 @@ import ( ) func main() { - config, err := config.Init() + persistentConfig, err := config.LoadOrCreatePersistentConfig() if err != nil { log.Fatalf("Couldn't load config: %v", err) } - if config.OpenAISession == "" { - session, err := session.GetSession() + if persistentConfig.OpenAISession == "" { + token, err := session.GetSession() if err != nil { log.Fatalf("Couldn't get OpenAI session: %v", err) } - err = config.Set("OpenAISession", session) - if err != nil { + if err = persistentConfig.SetSessionToken(token); err != nil { log.Fatalf("Couldn't save OpenAI session: %v", err) } } - chatGPT := chatgpt.Init(config) + chatGPT := chatgpt.Init(persistentConfig) log.Println("Started ChatGPT") - err = godotenv.Load() + envConfig, err := config.LoadEnvConfig(".env") if err != nil { - log.Printf("Couldn't load .env file: %v. Using shell exposed env variables...", err) + log.Fatalf("Couldn't load .env config: %v", err) } - - editInterval := 1 * time.Second - if os.Getenv("EDIT_WAIT_SECONDS") != "" { - editSecond, err := strconv.ParseInt(os.Getenv("EDIT_WAIT_SECONDS"), 10, 64) - if err != nil { - log.Printf("Couldn't convert your edit seconds setting into int: %v", err) - editSecond = 1 - } - editInterval = time.Duration(editSecond) * time.Second + if err := envConfig.ValidateWithDefaults(); err != nil { + log.Fatalf("Invalid .env config: %v", err) } - bot, err := tgbot.New(os.Getenv("TELEGRAM_TOKEN"), editInterval) + bot, err := tgbot.New(envConfig.TelegramToken, time.Duration(envConfig.EditWaitSeconds)) if err != nil { log.Fatalf("Couldn't start Telegram bot: %v", err) } @@ -76,10 +66,11 @@ func main() { updateText = update.Message.Text updateChatID = update.Message.Chat.ID updateMessageID = update.Message.MessageID + updateUserID = update.Message.From.ID ) - userId := strconv.FormatInt(update.Message.Chat.ID, 10) - if os.Getenv("TELEGRAM_ID") != "" && userId != os.Getenv("TELEGRAM_ID") { + if len(envConfig.TelegramID) != 0 && !envConfig.HasTelegramID(updateUserID) { + log.Printf("User %d is not allowed to use this bot", updateUserID) bot.Send(updateChatID, updateMessageID, "You are not authorized to use this bot.") continue } diff --git a/src/chatgpt/chatgpt.go b/src/chatgpt/chatgpt.go index 0b4b88d..20a7ad5 100644 --- a/src/chatgpt/chatgpt.go +++ b/src/chatgpt/chatgpt.go @@ -48,7 +48,7 @@ type ChatResponse struct { Message string } -func Init(config config.Config) *ChatGPT { +func Init(config *config.Config) *ChatGPT { return &ChatGPT{ AccessTokenMap: expirymap.New(), SessionToken: config.OpenAISession, diff --git a/src/config/config.go b/src/config/config.go index 3109651..b74e52b 100644 --- a/src/config/config.go +++ b/src/config/config.go @@ -9,46 +9,46 @@ import ( ) type Config struct { + v *viper.Viper + OpenAISession string } -// init tries to read the config from the file, and creates it if it doesn't exist. -func Init() (Config, error) { +// LoadOrCreatePersistentConfig uses the default config directory for the current OS +// to load or create a config file named "chatgpt.json" +func LoadOrCreatePersistentConfig() (*Config, error) { configPath, err := os.UserConfigDir() if err != nil { - return Config{}, errors.New(fmt.Sprintf("Couldn't get user config dir: %v", err)) + return nil, errors.New(fmt.Sprintf("Couldn't get user config dir: %v", err)) } - viper.SetConfigType("json") - viper.SetConfigName("chatgpt") - viper.AddConfigPath(configPath) + v := viper.New() + v.SetConfigType("json") + v.SetConfigName("chatgpt") + v.AddConfigPath(configPath) - if err := viper.ReadInConfig(); err != nil { + if err := v.ReadInConfig(); err != nil { if _, ok := err.(viper.ConfigFileNotFoundError); ok { - if err := viper.SafeWriteConfig(); err != nil { - return Config{}, errors.New(fmt.Sprintf("Couldn't create config file: %v", err)) + if err := v.SafeWriteConfig(); err != nil { + return nil, errors.New(fmt.Sprintf("Couldn't create config file: %v", err)) } } else { - return Config{}, errors.New(fmt.Sprintf("Couldn't read config file: %v", err)) + return nil, errors.New(fmt.Sprintf("Couldn't read config file: %v", err)) } } var cfg Config - err = viper.Unmarshal(&cfg) + err = v.Unmarshal(&cfg) if err != nil { - return Config{}, errors.New(fmt.Sprintf("Error parsing config: %v", err)) + return nil, errors.New(fmt.Sprintf("Error parsing config: %v", err)) } + cfg.v = v - return cfg, nil + return &cfg, nil } -// key should be part of the Config struct -func (cfg *Config) Set(key string, value interface{}) error { - viper.Set(key, value) - - err := viper.Unmarshal(&cfg) - if err != nil { - return errors.New(fmt.Sprintf("Error parsing config: %v", err)) - } - - return viper.WriteConfig() +func (cfg *Config) SetSessionToken(token string) error { + // key must match the struct field name + cfg.v.Set("OpenAISession", token) + cfg.OpenAISession = token + return cfg.v.WriteConfig() } diff --git a/src/config/config_test.go b/src/config/config_test.go new file mode 100644 index 0000000..c6f5888 --- /dev/null +++ b/src/config/config_test.go @@ -0,0 +1,133 @@ +package config + +import ( + "fmt" + "os" + "testing" + + "github.com/stretchr/testify/require" +) + +func createFile(name string, content string) (remove func(), err error) { + f, err := os.Create(name) + if err != nil { + return nil, err + } + defer f.Close() + + if _, err := f.WriteString(content); err != nil { + return nil, err + } + + return func() { + if err := os.Remove(name); err != nil { + panic(fmt.Sprintf("failed to remove file: %s", err)) + } + }, nil +} + +func setEnvVariables(vals map[string]string) func() { + for k, v := range vals { + os.Setenv(k, v) + } + return func() { + for k := range vals { + os.Unsetenv(k) + } + } +} + +func TestLoadEnvConfig(t *testing.T) { + for label, test := range map[string]struct { + fileContent string + envVars map[string]string + want *EnvConfig + }{ + "all values empty in file and env": { + fileContent: `TELEGRAM_ID= +TELEGRAM_TOKEN= +EDIT_WAIT_SECONDS=`, + want: &EnvConfig{ + TelegramID: []int64{}, + TelegramToken: "", + EditWaitSeconds: 0, + }, + }, + "no file, all values through env": { + envVars: map[string]string{ + "TELEGRAM_ID": "123,456", + "TELEGRAM_TOKEN": "token", + "EDIT_WAIT_SECONDS": "10", + }, + want: &EnvConfig{ + TelegramID: []int64{123, 456}, + TelegramToken: "token", + EditWaitSeconds: 10, + }, + }, + "all values provided in file, single TELEGRAM_ID": { + fileContent: `TELEGRAM_ID=123 +TELEGRAM_TOKEN=abc +EDIT_WAIT_SECONDS=10`, + want: &EnvConfig{ + TelegramID: []int64{123}, + TelegramToken: "abc", + EditWaitSeconds: 10, + }, + }, + "multiple TELEGRAM_IDs provided in file": { + fileContent: `TELEGRAM_ID=123,456 +TELEGRAM_TOKEN=abc +EDIT_WAIT_SECONDS=10`, + envVars: map[string]string{}, + want: &EnvConfig{ + TelegramID: []int64{123, 456}, + TelegramToken: "abc", + EditWaitSeconds: 10, + }, + }, + "env variables should override file values": { + fileContent: `TELEGRAM_ID=123 +TELEGRAM_TOKEN=abc +EDIT_WAIT_SECONDS=10`, + envVars: map[string]string{ + "TELEGRAM_ID": "456", + "TELEGRAM_TOKEN": "def", + "EDIT_WAIT_SECONDS": "20", + }, + want: &EnvConfig{ + TelegramID: []int64{456}, + TelegramToken: "def", + EditWaitSeconds: 20, + }, + }, + "multiple TELEGRAM_IDs provided in env": { + fileContent: `TELEGRAM_ID=123 +TELEGRAM_TOKEN=abc +EDIT_WAIT_SECONDS=10`, + envVars: map[string]string{ + "TELEGRAM_ID": "456,789", + }, + want: &EnvConfig{ + TelegramID: []int64{456, 789}, + TelegramToken: "abc", + EditWaitSeconds: 10, + }, + }, + } { + t.Run(label, func(t *testing.T) { + unset := setEnvVariables(test.envVars) + t.Cleanup(unset) + + if test.fileContent != "" { + remove, err := createFile("test.env", test.fileContent) + require.NoError(t, err) + t.Cleanup(remove) + } + + cfg, err := LoadEnvConfig("test.env") + require.NoError(t, err) + require.Equal(t, test.want, cfg) + }) + } +} diff --git a/src/config/env_config.go b/src/config/env_config.go new file mode 100644 index 0000000..26a7485 --- /dev/null +++ b/src/config/env_config.go @@ -0,0 +1,80 @@ +package config + +import ( + "bytes" + "errors" + "log" + "os" + + "github.com/spf13/viper" +) + +type EnvConfig struct { + TelegramID []int64 `mapstructure:"TELEGRAM_ID"` + TelegramToken string `mapstructure:"TELEGRAM_TOKEN"` + EditWaitSeconds int `mapstructure:"EDIT_WAIT_SECONDS"` +} + +// emptyConfig is used to initialize viper. +// It is required to register config keys with viper when in case no config file is provided. +const emptyConfig = `TELEGRAM_ID= +TELEGRAM_TOKEN= +EDIT_WAIT_SECONDS=` + +func (e *EnvConfig) HasTelegramID(id int64) bool { + for _, v := range e.TelegramID { + if v == id { + return true + } + } + return false +} + +// LoadEnvConfig loads config from .env file, variables from environment take precedence if provided. +// If no .env file is provided, config is loaded from environment variables. +func LoadEnvConfig(path string) (*EnvConfig, error) { + fileExists := fileExists(path) + if !fileExists { + log.Printf("config file %s does not exist, using env variables", path) + } + + v := viper.New() + v.SetConfigType("env") + v.AutomaticEnv() + if err := v.ReadConfig(bytes.NewBufferString(emptyConfig)); err != nil { + return nil, err + } + if fileExists { + v.SetConfigFile(path) + if err := v.ReadInConfig(); err != nil { + return nil, err + } + } + + var cfg EnvConfig + if err := v.Unmarshal(&cfg); err != nil { + return nil, err + } + return &cfg, nil +} + +func fileExists(path string) bool { + if _, err := os.Stat(path); err != nil { + return os.IsExist(err) + } + return true +} + +func (e *EnvConfig) ValidateWithDefaults() error { + if e.TelegramToken == "" { + return errors.New("TELEGRAM_TOKEN is not set") + } + if len(e.TelegramID) == 0 { + log.Printf("TELEGRAM_ID is not set, all users will be able to use the bot") + } + if e.EditWaitSeconds < 0 { + log.Printf("EDIT_WAIT_SECONDS not set, defaulting to 1") + e.EditWaitSeconds = 1 + } + return nil +}