diff --git a/.gitignore b/.gitignore index 3b735ec..4bfef32 100644 --- a/.gitignore +++ b/.gitignore @@ -19,3 +19,4 @@ # Go workspace file go.work +gaia diff --git a/README.md b/README.md index 77cb9a1..0519de6 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,13 @@ # gaia +## Requirements + +```sh +brew install ollama +ollama serve +ollama pull openhermes2.5-mistral +``` + ## Installation ```sh @@ -13,24 +21,22 @@ brew install vonglasow/tap/gaia ```sh $ gaia -h -gaia is a CLI tool - Usage: - app [options] [message] [flags] + gaia [command] + +Available Commands: + ask Ask to a model + completion Generate the autocompletion script for the specified shell + config Set configuration options + help Help about any command + version Print the version information Flags: - -c, --code string message for code option - -t, --create-config create config file if it doesn't exist - -d, --description string message for description option - -h, --help help for app - -s, --shell string message for shell option - -g, --show-config display current config - -v, --verbose verbose output - -V, --version version + -h, --help help for gaia ``` ```sh -$ cat CVE-2021-4034.py | gaia -d "Analyze and explain this code" +$ cat CVE-2021-4034.py | gaia ask "Analyze and explain this code" This code is a Python script that exploits the CVE-2021-4034 vulnerability in Python. It was originally written by Joe Ammond, who used it as an experiment to see if he could get it to work in Python while also playing around with ctypes. The code starts by importing necessary libraries and defining variables. The `base64` library is imported to decode the payload, while the `os` library is needed for certain file operations. The `sys` library is used to handle system-level interactions, and the `ctypes` library is used to call the `execve()` function directly. @@ -41,4 +47,3 @@ An environment list is set to configure the call to `execve()`. The code also fi The code ends with calling the `execve()` function using the C library found earlier, passing in NULL arguments as required by `execve()`. ``` - diff --git a/api/api.go b/api/api.go new file mode 100644 index 0000000..55a2752 --- /dev/null +++ b/api/api.go @@ -0,0 +1,111 @@ +package api + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "os" + "runtime" + + "github.com/spf13/viper" +) + +type Message struct { + Content string `json:"content"` + Role string `json:"role"` +} + +type APIRequest struct { + Model string `json:"model"` + Messages []Message `json:"messages"` + Stream bool `json:"stream"` +} + +type APIResponse struct { + Model string `json:"model"` + Response string `json:"response"` + Message *Message `json:"message"` +} + +func processStreamedResponse(body io.Reader) { + decoder := json.NewDecoder(body) + respChan := make(chan string) + eofChan := make(chan struct{}) + errorChan := make(chan error) + + go func() { + var apiResp APIResponse + for { + if err := decoder.Decode(&apiResp); err == io.EOF { + eofChan <- struct{}{} + } else if err != nil { + errorChan <- err + } else { + respChan <- apiResp.Message.Content + } + } + }() + + for { + select { + case resp := <-respChan: + fmt.Print(resp) + case err := <-errorChan: + fmt.Println("Error decoding JSON:", err) + return + case <-eofChan: + fmt.Println() + return + } + } +} + +func ProcessMessage(msg string) error { + systemrole := "default" + + if viper.IsSet("systemrole") { + if viper.IsSet(fmt.Sprintf("roles.%s", viper.GetString("systemrole"))) { + systemrole = viper.GetString("systemrole") + } else { + fmt.Printf("Error: Role '%s' not found in the configuration", viper.GetString("systemrole")) + return nil + } + } + + role := fmt.Sprintf(viper.GetString(fmt.Sprintf("roles.%s", systemrole)), os.Getenv("SHELL"), runtime.GOOS) + + request := APIRequest{ + Model: viper.GetString("model"), + Messages: []Message{ + { + Role: "system", + Content: role, + }, + { + Role: "user", + Content: msg, + }, + }, + Stream: true, + } + + requestBody, err := json.Marshal(request) + if err != nil { + fmt.Println("Error during call on API") + return fmt.Errorf("failed to marshal JSON request: %v", err) + } + + url := fmt.Sprintf("http://%s:%d/api/chat", viper.GetString("host"), viper.GetInt("port")) + contentType := "application/json" + + resp, err := http.Post(url, contentType, bytes.NewBuffer(requestBody)) + if err != nil { + return fmt.Errorf("failed to make HTTP request: %v", err) + } + defer resp.Body.Close() + + processStreamedResponse(resp.Body) + return nil +} diff --git a/commands/commands.go b/commands/commands.go new file mode 100644 index 0000000..46d0db1 --- /dev/null +++ b/commands/commands.go @@ -0,0 +1,115 @@ +package commands + +import ( + "fmt" + "os" + "sort" + "strings" + + "gaia/api" + "gaia/config" + + "github.com/spf13/cobra" + "github.com/spf13/viper" +) + +var ( + version string = "dev" + commitSHA string = "none" + buildDate string = "unknown" +) + +var rootCmd = &cobra.Command{Use: "gaia"} + +var configCmd = &cobra.Command{ + Use: "config", + Short: "Set configuration options", +} + +var listCmd = &cobra.Command{ + Use: "list", + Short: "List configuration settings", + Run: func(cmd *cobra.Command, args []string) { + keys := make([]string, 0, len(viper.AllKeys())) + keys = append(keys, viper.AllKeys()...) + sort.Strings(keys) + for _, key := range keys { + fmt.Printf("%s: %v\n", key, viper.Get(key)) + } + }, +} + +var setCmd = &cobra.Command{ + Use: "set [key] [value]", + Short: "Set configuration setting", + Args: cobra.ExactArgs(2), + Run: func(cmd *cobra.Command, args []string) { + config.SetConfigString(args[0], args[1]) + fmt.Println("Config setting updated", args[0], "to", args[1]) + }, +} + +var getCmd = &cobra.Command{ + Use: "get [key]", + Short: "Get configuration setting", + Args: cobra.ExactArgs(1), + Run: func(cmd *cobra.Command, args []string) { + fmt.Println(viper.GetString(args[0])) + }, +} + +var pathCmd = &cobra.Command{ + Use: "path", + Short: "Get configuration path", + Args: cobra.ExactArgs(0), + Run: func(cmd *cobra.Command, args []string) { + fmt.Println(config.CfgFile) + }, +} + +var askCmd = &cobra.Command{ + Use: "ask [string]", + Short: "Ask to a model", + Args: cobra.MinimumNArgs(1), + Run: func(cmd *cobra.Command, args []string) { + msg := "" + msg += readStdin() + if len(args) > 0 { + msg += " " + args[0] + } + if err := api.ProcessMessage(msg); err != nil { + fmt.Println(err) + } + }, +} + +var versionCmd = &cobra.Command{ + Use: "version", + Short: "Print the version information", + Args: cobra.ExactArgs(0), + Run: func(cmd *cobra.Command, args []string) { + fmt.Printf("Gaia %s, commit %s, built at %s\n", version, commitSHA, buildDate) + }, +} + +func readStdin() string { + var stdinLines string + stat, _ := os.Stdin.Stat() + if (stat.Mode() & os.ModeCharDevice) == 0 { + buf := make([]byte, 4096) + n, _ := os.Stdin.Read(buf) + stdinLines = string(buf[:n]) + } + return strings.TrimSpace(stdinLines) +} + +func Execute() error { + configCmd.AddCommand(listCmd, setCmd, getCmd, pathCmd) + askCmd.Flags().StringP("role", "r", "", "Specify role code (default, describe, code)") + if err := viper.BindPFlag("systemrole", askCmd.Flags().Lookup("role")); err != nil { + fmt.Printf("Error binding flag to Viper: %v\n", err) + return err + } + rootCmd.AddCommand(configCmd, versionCmd, askCmd) + return rootCmd.Execute() +} diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..fdfefd5 --- /dev/null +++ b/config/config.go @@ -0,0 +1,81 @@ +package config + +import ( + "fmt" + "os" + "path/filepath" + + "github.com/spf13/viper" +) + +var CfgFile string + +type Config struct{} + +var config Config + +func defaultConfig() *viper.Viper { + v := viper.New() + v.SetDefault("model", "mistral") + v.SetDefault("host", "localhost") + v.SetDefault("port", 11434) + v.SetDefault("roles.default", "You are programming and system administration assistant. You are managing %s operating system with %s shell. Provide short responses in about 100 words, unless you are specifically asked for more details. If you need to store any data, assume it will be stored in the conversation. APPLY MARKDOWN formatting when possible.") + v.SetDefault("roles.describe", "Provide a terse, single sentence description of the given shell command. Describe each argument and option of the command. Provide short responses in about 80 words. APPLY MARKDOWN formatting when possible.") + v.SetDefault("roles.shell", "Provide only %s commands for %s without any description. If there is a lack of details, provide the most logical solution. Ensure the output is a valid shell command. If multiple steps are required, try to combine them using &&. Provide only plain text without Markdown formatting. Do not use markdown formatting such as ```.") + v.SetDefault("roles.code", "Provide only code as output without any description. Provide only code in plain text format without Markdown formatting. Do not include symbols such as ``` or ```python. If there is a lack of details, provide most logical solution. You are not allowed to ask for more details. For example if the prompt is \"Hello world Python\", you should return \"print('Hello world')\".") + + return v +} + +func InitConfig() error { + homeDir, err := os.UserHomeDir() + if err != nil { + fmt.Println("Error getting home directory:", err) + return err + } + + configDir := filepath.Join(homeDir, ".config", "gaia") + err = os.MkdirAll(configDir, 0755) + if err != nil { + fmt.Println("Error creating config directory:", err) + return err + } + + CfgFile = filepath.Join(configDir, "config.yaml") + + viper.SetConfigFile(CfgFile) + viper.SetConfigType("yaml") + viper.AddConfigPath(".") + + for key, value := range defaultConfig().AllSettings() { + viper.SetDefault(key, value) + } + + if err := viper.ReadInConfig(); err != nil { + fmt.Println("Error reading config file:", err) + } + + if err := viper.Unmarshal(&config); err != nil { + fmt.Println("Error unmarshalling config:", err) + return err + } + + if err := viper.WriteConfig(); err != nil { + fmt.Println("Error writing config file:", err) + } + + return nil +} + +func SetConfigString(key, value string) { + if defaultConfig().IsSet(key) { + viper.Set(key, value) + } else { + fmt.Println("No config found for key:", key, ":", value) + os.Exit(1) + } + + if err := viper.WriteConfig(); err != nil { + fmt.Println("Error writing config file:", err) + } +} diff --git a/go.mod b/go.mod index 48bb0f0..653d5e2 100644 --- a/go.mod +++ b/go.mod @@ -1,13 +1,31 @@ -module main.go +module gaia -go 1.22 +go 1.22.2 require ( github.com/spf13/cobra v1.8.0 - gopkg.in/yaml.v2 v2.4.0 + github.com/spf13/viper v1.18.2 ) require ( + github.com/fsnotify/fsnotify v1.7.0 // indirect + github.com/hashicorp/hcl v1.0.0 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/magiconair/properties v1.8.7 // indirect + github.com/mitchellh/mapstructure v1.5.0 // indirect + github.com/pelletier/go-toml/v2 v2.1.0 // indirect + github.com/sagikazarmark/locafero v0.4.0 // indirect + github.com/sagikazarmark/slog-shim v0.1.0 // indirect + github.com/sourcegraph/conc v0.3.0 // indirect + github.com/spf13/afero v1.11.0 // indirect + github.com/spf13/cast v1.6.0 // indirect github.com/spf13/pflag v1.0.5 // indirect + github.com/subosito/gotenv v1.6.0 // indirect + go.uber.org/atomic v1.9.0 // indirect + go.uber.org/multierr v1.9.0 // indirect + golang.org/x/exp v0.0.0-20230905200255-921286631fa9 // indirect + golang.org/x/sys v0.15.0 // indirect + golang.org/x/text v0.14.0 // indirect + gopkg.in/ini.v1 v1.67.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index d90dbd8..b6a7dcc 100644 --- a/go.sum +++ b/go.sum @@ -1,13 +1,75 @@ github.com/cpuguy83/go-md2man/v2 v2.0.3/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= +github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= +github.com/frankban/quicktest v1.14.6/go.mod h1:4ptaffx2x8+WTWXmUCuVU6aPUX1/Mz7zb5vbUoiM6w0= +github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= +github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= +github.com/hashicorp/hcl v1.0.0 h1:0Anlzjpi4vEasTeNFn2mLJgTSwt0+6sfsiTG8qcWGx4= +github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE= +github.com/kr/pretty v0.3.1/go.mod h1:hoEshYVHaxMs3cyo3Yncou5ZscifuDolrwPKZanG3xk= +github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/magiconair/properties v1.8.7 h1:IeQXZAiQcpL9mgcAe1Nu6cX9LLw6ExEHKjN0VQdvPDY= +github.com/magiconair/properties v1.8.7/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0= +github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= +github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/pelletier/go-toml/v2 v2.1.0 h1:FnwAJ4oYMvbT/34k9zzHuZNrhlz48GB3/s6at6/MHO4= +github.com/pelletier/go-toml/v2 v2.1.0/go.mod h1:tJU2Z3ZkXwnxa4DPO899bsyIoywizdUvyaeZurnPPDc= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 h1:Jamvg5psRIccs7FGNTlIRMkT8wgtp5eCXdBlqhYGL6U= +github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= +github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/sagikazarmark/locafero v0.4.0 h1:HApY1R9zGo4DBgr7dqsTH/JJxLTTsOt7u6keLGt6kNQ= +github.com/sagikazarmark/locafero v0.4.0/go.mod h1:Pe1W6UlPYUk/+wc/6KFhbORCfqzgYEpgQ3O5fPuL3H4= +github.com/sagikazarmark/slog-shim v0.1.0 h1:diDBnUNK9N/354PgrxMywXnAwEr1QZcOr6gto+ugjYE= +github.com/sagikazarmark/slog-shim v0.1.0/go.mod h1:SrcSrq8aKtyuqEI1uvTDTK1arOWRIczQRv+GVI1AkeQ= +github.com/sourcegraph/conc v0.3.0 h1:OQTbbt6P72L20UqAkXXuLOj79LfEanQ+YQFNpLA9ySo= +github.com/sourcegraph/conc v0.3.0/go.mod h1:Sdozi7LEKbFPqYX2/J+iBAM6HpqSLTASQIKqDmF7Mt0= +github.com/spf13/afero v1.11.0 h1:WJQKhtpdm3v2IzqG8VMqrr6Rf3UYpEF239Jy9wNepM8= +github.com/spf13/afero v1.11.0/go.mod h1:GH9Y3pIexgf1MTIWtNGyogA5MwRIDXGUr+hbWNoBjkY= +github.com/spf13/cast v1.6.0 h1:GEiTHELF+vaR5dhz3VqZfFSzZjYbgeKDpBxQVS4GYJ0= +github.com/spf13/cast v1.6.0/go.mod h1:ancEpBxwJDODSW/UG4rDrAqiKolqNNh2DX3mk86cAdo= github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0= github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyhcho= github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= -gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +github.com/spf13/viper v1.18.2 h1:LUXCnvUvSM6FXAsj6nnfc8Q2tp1dIgUfY9Kc8GsSOiQ= +github.com/spf13/viper v1.18.2/go.mod h1:EKmWIqdnk5lOcmR72yw6hS+8OPYcwD0jteitLMVB+yk= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/subosito/gotenv v1.6.0 h1:9NlTDc1FTs4qu0DDq7AEtTPNw6SVm7uBMsUCUjABIf8= +github.com/subosito/gotenv v1.6.0/go.mod h1:Dk4QP5c2W3ibzajGcXpNraDfq2IrhjMIvMSWPKKo0FU= +go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE= +go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc= +go.uber.org/multierr v1.9.0 h1:7fIwc/ZtS0q++VgcfqFDxSBZVv/Xo49/SYnDFupUwlI= +go.uber.org/multierr v1.9.0/go.mod h1:X2jQV1h+kxSjClGpnseKVIxpmcjrj7MNnI0bnlfKTVQ= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9 h1:GoHiUyI/Tp2nVkLI2mCxVkOjsbSXD66ic0XW0js0R9g= +golang.org/x/exp v0.0.0-20230905200255-921286631fa9/go.mod h1:S2oDrQGGwySpoQPVqRShND87VCbxmc6bL1Yd2oYrm6k= +golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc= +golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= -gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/ini.v1 v1.67.0 h1:Dgnx+6+nfE+IfzjUEISNeydPJh9AXNNsWbGP9KzCsOA= +gopkg.in/ini.v1 v1.67.0/go.mod h1:pNLf8WUiyNEtQjuu5G5vTm06TEv9tsIgeAvK8hOrP4k= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/main.go b/main.go index 6de223d..2d29f01 100644 --- a/main.go +++ b/main.go @@ -1,305 +1,24 @@ package main import ( - "bytes" - "encoding/json" "fmt" - "io" - "net/http" + "gaia/commands" + "gaia/config" "os" - "os/user" - "path/filepath" - "runtime" - "strings" - - "github.com/spf13/cobra" - "gopkg.in/yaml.v2" -) - -const ( - roleCode = "Provide only code as output without any description." + - "Provide only code in plain text format without Markdown formatting. Do not" + - "include symbols such as ``` or ```python. If there is a lack of details," + - "provide most logical solution. You are not allowed to ask for more details." + - "For example if the prompt is \"Hello world Python\", you should return" + - "\"print('Hello world')\"." - - roleShell = "Provide only %s commands for %s without any description." + - "If there is a lack of details, provide most logical solution." + - "Ensure the output is a valid shell command." + - "If multiple steps required try to combine them together using &&." + - "Provide only plain text without Markdown formatting." + - "Do not provide markdown formatting such as ```." - - roleDescribeShell = "Provide a terse, single sentence description of the given shell command." + - "Describe each argument and option of the command." + - "Provide short responses in about 80 words." + - "APPLY MARKDOWN formatting when possible." - - roleDefault = "You are programming and system administration assistant." + - "You are managing %s operating system with %s shell." + - "Provide short responses in about 100 words, unless you are specifically asked for more details." + - "If you need to store any data, assume it will be stored in the conversation." + - "APPLY MARKDOWN formatting when possible." - - ollamaChatURL = "/api/chat" ) -var ( - shellMsg string - codeMsg string - descMsg string - verbose bool - showConfig bool - createConfig bool - versionFlag bool - version string = "dev" - commitSHA string = "none" - buildDate string = "unknown" -) - -type Config struct { - OLLAMA_BASE_URL string `yaml:"OLLAMA_BASE_URL"` - OLLAMA_MODEL string `yaml:"OLLAMA_MODEL"` -} - -type Message struct { - Content string `json:"content"` - Role string `json:"role"` -} - -type APIRequest struct { - Model string `json:"model"` - Messages []Message `json:"messages"` - Stream bool `json:"stream"` -} - -type APIResponse struct { - Model string `json:"model"` - Response string `json:"response"` - Message *Message `json:"message"` -} - -func loadConfig(createIfNotExists bool) (Config, error) { - var config Config - - // Set default values - config.OLLAMA_MODEL = "openhermes2.5-mistral" - config.OLLAMA_BASE_URL = "http://localhost:11434" - - usr, err := user.Current() - if err != nil { - return config, fmt.Errorf("failed to get user: %v", err) - } - - configFilePath := filepath.Join(usr.HomeDir, ".config", "gaia", "config.yaml") - - //yamlFile, err := ioutil.ReadFile(configFilePath) - yamlFile, err := os.ReadFile(configFilePath) - if err != nil { - if createIfNotExists { - // Config file doesn't exist, create it with default values - if err := createDefaultConfig(configFilePath, config); err != nil { - return config, fmt.Errorf("failed to create config file: %v", err) - } - } else { - // Config file doesn't exist, return default values - return config, nil - } - } - - // Load config from file - err = yaml.Unmarshal(yamlFile, &config) - if err != nil { - return config, fmt.Errorf("failed to unmarshal YAML: %v", err) - } - - return config, nil -} - -func createDefaultConfig(filePath string, config Config) error { - // Create config directory if it doesn't exist - err := os.MkdirAll(filepath.Dir(filePath), 0755) - if err != nil { - return err - } - - // Marshal default config to YAML - defaultConfig, err := yaml.Marshal(config) - if err != nil { - return err - } - - // Write default config to file - //err = ioutil.WriteFile(filePath, defaultConfig, 0644) - err = os.WriteFile(filePath, defaultConfig, 0644) - if err != nil { - return err - } - - return nil -} +var () func main() { - var rootCmd = &cobra.Command{ - Use: "app [options] [message]", - Short: "gaia is a CLI tool", - Run: func(cmd *cobra.Command, args []string) { - if versionFlag { - fmt.Printf("Gaia %s, commit %s, built at %s\n", version, commitSHA, buildDate) - os.Exit(0) - } - - config, err := loadConfig(createConfig) - if err != nil { - fmt.Println(err) - os.Exit(1) - } - - if showConfig || createConfig { - displayConfig() - os.Exit(0) - } - - msg := "" - stat, _ := os.Stdin.Stat() - if (stat.Mode() & os.ModeCharDevice) == 0 { - buf := make([]byte, 4096) - n, _ := os.Stdin.Read(buf) - msg = string(buf[:n]) - } - - var role string - message := strings.TrimSpace(msg) - - if shellMsg != "" { - message += " " + shellMsg - role = fmt.Sprintf(roleShell, os.Getenv("SHELL"), runtime.GOOS) - } else if codeMsg != "" { - message += " " + codeMsg - role = roleCode - } else if descMsg != "" { - message += " " + descMsg - role = roleDescribeShell - } else { - message += strings.Join(args, " ") - role = fmt.Sprintf(roleDefault, runtime.GOOS, os.Getenv("SHELL")) - } - - if message == "" || strings.TrimSpace(message) == "" { - if err := cmd.Usage(); err != nil { - fmt.Printf("error displaying usage: %v\n", err) - } - return - } - - if verbose { - fmt.Println(message) - fmt.Println(config) - fmt.Println(role) - } - - processMessage(message, config, role) - }, - } - - rootCmd.Flags().StringVarP(&shellMsg, "shell", "s", "", "message for shell option") - rootCmd.Flags().StringVarP(&codeMsg, "code", "c", "", "message for code option") - rootCmd.Flags().StringVarP(&descMsg, "description", "d", "", "message for description option") - rootCmd.Flags().BoolVarP(&verbose, "verbose", "v", false, "verbose output") - rootCmd.Flags().BoolVarP(&showConfig, "show-config", "g", false, "display current config") - rootCmd.Flags().BoolVarP(&createConfig, "create-config", "t", false, "create config file if it doesn't exist") - rootCmd.Flags().BoolVarP(&versionFlag, "version", "V", false, "version") - - if err := rootCmd.Execute(); err != nil { + err := config.InitConfig() + if err != nil { fmt.Println(err) os.Exit(1) } -} -func processMessage(msg string, config Config, role string) { - err := callAPI(config, "POST", "application/json", msg, role) + err = commands.Execute() if err != nil { fmt.Println(err) os.Exit(1) } } - -func callAPI(config Config, method, contentType string, body string, role string) error { - request := APIRequest{ - Model: config.OLLAMA_MODEL, - Messages: []Message{ - { - Role: "system", - Content: role, - }, - { - Role: "user", - Content: body, - }, - }, - Stream: true, - } - - requestBody, err := json.Marshal(request) - if err != nil { - fmt.Println("Error during call on API") - return fmt.Errorf("failed to marshal JSON request: %v", err) - } - - resp, err := http.Post(config.OLLAMA_BASE_URL+ollamaChatURL, contentType, bytes.NewBuffer(requestBody)) - if err != nil { - return fmt.Errorf("failed to make HTTP request: %v", err) - } - defer resp.Body.Close() - - processStreamedResponse(resp.Body) - return nil -} - -func displayConfig() { - config, err := loadConfig(false) - if err != nil { - fmt.Println("Error loading config:", err) - return - } - - fmt.Println("Current Config:") - fmt.Println("OLLAMA_MODEL:", config.OLLAMA_MODEL) - fmt.Println("OLLAMA_BASE_URL:", config.OLLAMA_BASE_URL) -} - -func processStreamedResponse(body io.Reader) { - decoder := json.NewDecoder(body) - - fmt.Println() - - respChan := make(chan string) - eofChan := make(chan struct{}) - errorChan := make(chan error) - go func() { - var apiResp APIResponse - for { - if err := decoder.Decode(&apiResp); err == io.EOF { - eofChan <- struct{}{} - } else if err != nil { - errorChan <- err - } else { - respChan <- apiResp.Message.Content - } - } - }() - - for { - select { - case resp := <-respChan: - fmt.Print(resp) - case err := <-errorChan: - fmt.Println("Error decoding JSON:", err) - return - case <-eofChan: - fmt.Println() - return - } - } -} diff --git a/main_test.go b/main_test.go index 3ba9ef1..57a1676 100644 --- a/main_test.go +++ b/main_test.go @@ -1,26 +1,13 @@ package main -import "testing" +import ( + "gaia/config" + "testing" +) -func TestConvert(t *testing.T) { - config, err := loadConfig(false) +func TestInitConfig(t *testing.T) { + err := config.InitConfig() if err != nil { - t.Logf("error loading config: %v", err) - t.Fail() - } - - expectedConfig := Config{ - OLLAMA_MODEL: "openhermes2.5-mistral", - OLLAMA_BASE_URL: "http://localhost:11434", - } - - if config.OLLAMA_MODEL != expectedConfig.OLLAMA_MODEL { - t.Logf("incorrect OLLAMA_MODEL, got: %s, want: %s", config.OLLAMA_MODEL, expectedConfig.OLLAMA_MODEL) - t.Fail() - } - - if config.OLLAMA_BASE_URL != expectedConfig.OLLAMA_BASE_URL { - t.Logf("incorrect OLLAMA_BASE_URL, got: %s, want: %s", config.OLLAMA_BASE_URL, expectedConfig.OLLAMA_BASE_URL) - t.Fail() + t.Fatalf("Error initializing config: %v", err) } }