-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat!: Refactor config, commands, and version
- Loading branch information
Showing
9 changed files
with
425 additions
and
326 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,3 +19,4 @@ | |
|
||
# Go workspace file | ||
go.work | ||
gaia |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
package api | ||
|
||
import ( | ||
"bytes" | ||
"encoding/json" | ||
"fmt" | ||
"io" | ||
"net/http" | ||
"os" | ||
"runtime" | ||
|
||
"github.com/spf13/viper" | ||
) | ||
|
||
type Message struct { | ||
Content string `json:"content"` | ||
Role string `json:"role"` | ||
} | ||
|
||
type APIRequest struct { | ||
Model string `json:"model"` | ||
Messages []Message `json:"messages"` | ||
Stream bool `json:"stream"` | ||
} | ||
|
||
type APIResponse struct { | ||
Model string `json:"model"` | ||
Response string `json:"response"` | ||
Message *Message `json:"message"` | ||
} | ||
|
||
func processStreamedResponse(body io.Reader) { | ||
decoder := json.NewDecoder(body) | ||
respChan := make(chan string) | ||
eofChan := make(chan struct{}) | ||
errorChan := make(chan error) | ||
|
||
go func() { | ||
var apiResp APIResponse | ||
for { | ||
if err := decoder.Decode(&apiResp); err == io.EOF { | ||
eofChan <- struct{}{} | ||
} else if err != nil { | ||
errorChan <- err | ||
} else { | ||
respChan <- apiResp.Message.Content | ||
} | ||
} | ||
}() | ||
|
||
for { | ||
select { | ||
case resp := <-respChan: | ||
fmt.Print(resp) | ||
case err := <-errorChan: | ||
fmt.Println("Error decoding JSON:", err) | ||
return | ||
case <-eofChan: | ||
fmt.Println() | ||
return | ||
} | ||
} | ||
} | ||
|
||
func ProcessMessage(msg string) error { | ||
systemrole := "default" | ||
|
||
if viper.IsSet("systemrole") { | ||
if viper.IsSet(fmt.Sprintf("roles.%s", viper.GetString("systemrole"))) { | ||
systemrole = viper.GetString("systemrole") | ||
} else { | ||
fmt.Printf("Error: Role '%s' not found in the configuration", viper.GetString("systemrole")) | ||
return nil | ||
} | ||
} | ||
|
||
role := fmt.Sprintf(viper.GetString(fmt.Sprintf("roles.%s", systemrole)), os.Getenv("SHELL"), runtime.GOOS) | ||
|
||
request := APIRequest{ | ||
Model: viper.GetString("model"), | ||
Messages: []Message{ | ||
{ | ||
Role: "system", | ||
Content: role, | ||
}, | ||
{ | ||
Role: "user", | ||
Content: msg, | ||
}, | ||
}, | ||
Stream: true, | ||
} | ||
|
||
requestBody, err := json.Marshal(request) | ||
if err != nil { | ||
fmt.Println("Error during call on API") | ||
return fmt.Errorf("failed to marshal JSON request: %v", err) | ||
} | ||
|
||
url := fmt.Sprintf("http://%s:%d/api/chat", viper.GetString("host"), viper.GetInt("port")) | ||
contentType := "application/json" | ||
|
||
resp, err := http.Post(url, contentType, bytes.NewBuffer(requestBody)) | ||
if err != nil { | ||
return fmt.Errorf("failed to make HTTP request: %v", err) | ||
} | ||
defer resp.Body.Close() | ||
|
||
processStreamedResponse(resp.Body) | ||
return nil | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
package commands | ||
|
||
import ( | ||
"fmt" | ||
"os" | ||
"sort" | ||
"strings" | ||
|
||
"gaia/api" | ||
"gaia/config" | ||
|
||
"github.com/spf13/cobra" | ||
"github.com/spf13/viper" | ||
) | ||
|
||
var ( | ||
version string = "dev" | ||
commitSHA string = "none" | ||
buildDate string = "unknown" | ||
) | ||
|
||
var rootCmd = &cobra.Command{Use: "gaia"} | ||
|
||
var configCmd = &cobra.Command{ | ||
Use: "config", | ||
Short: "Set configuration options", | ||
} | ||
|
||
var listCmd = &cobra.Command{ | ||
Use: "list", | ||
Short: "List configuration settings", | ||
Run: func(cmd *cobra.Command, args []string) { | ||
keys := make([]string, 0, len(viper.AllKeys())) | ||
keys = append(keys, viper.AllKeys()...) | ||
sort.Strings(keys) | ||
for _, key := range keys { | ||
fmt.Printf("%s: %v\n", key, viper.Get(key)) | ||
} | ||
}, | ||
} | ||
|
||
var setCmd = &cobra.Command{ | ||
Use: "set [key] [value]", | ||
Short: "Set configuration setting", | ||
Args: cobra.ExactArgs(2), | ||
Run: func(cmd *cobra.Command, args []string) { | ||
config.SetConfigString(args[0], args[1]) | ||
fmt.Println("Config setting updated", args[0], "to", args[1]) | ||
}, | ||
} | ||
|
||
var getCmd = &cobra.Command{ | ||
Use: "get [key]", | ||
Short: "Get configuration setting", | ||
Args: cobra.ExactArgs(1), | ||
Run: func(cmd *cobra.Command, args []string) { | ||
fmt.Println(viper.GetString(args[0])) | ||
}, | ||
} | ||
|
||
var pathCmd = &cobra.Command{ | ||
Use: "path", | ||
Short: "Get configuration path", | ||
Args: cobra.ExactArgs(0), | ||
Run: func(cmd *cobra.Command, args []string) { | ||
fmt.Println(config.CfgFile) | ||
}, | ||
} | ||
|
||
var askCmd = &cobra.Command{ | ||
Use: "ask [string]", | ||
Short: "Ask to a model", | ||
Args: cobra.MinimumNArgs(1), | ||
Run: func(cmd *cobra.Command, args []string) { | ||
msg := "" | ||
msg += readStdin() | ||
if len(args) > 0 { | ||
msg += " " + args[0] | ||
} | ||
if err := api.ProcessMessage(msg); err != nil { | ||
fmt.Println(err) | ||
} | ||
}, | ||
} | ||
|
||
var versionCmd = &cobra.Command{ | ||
Use: "version", | ||
Short: "Print the version information", | ||
Args: cobra.ExactArgs(0), | ||
Run: func(cmd *cobra.Command, args []string) { | ||
fmt.Printf("Gaia %s, commit %s, built at %s\n", version, commitSHA, buildDate) | ||
}, | ||
} | ||
|
||
func readStdin() string { | ||
var stdinLines string | ||
stat, _ := os.Stdin.Stat() | ||
if (stat.Mode() & os.ModeCharDevice) == 0 { | ||
buf := make([]byte, 4096) | ||
n, _ := os.Stdin.Read(buf) | ||
stdinLines = string(buf[:n]) | ||
} | ||
return strings.TrimSpace(stdinLines) | ||
} | ||
|
||
func Execute() error { | ||
configCmd.AddCommand(listCmd, setCmd, getCmd, pathCmd) | ||
askCmd.Flags().StringP("role", "r", "", "Specify role code (default, describe, code)") | ||
if err := viper.BindPFlag("systemrole", askCmd.Flags().Lookup("role")); err != nil { | ||
fmt.Printf("Error binding flag to Viper: %v\n", err) | ||
return err | ||
} | ||
rootCmd.AddCommand(configCmd, versionCmd, askCmd) | ||
return rootCmd.Execute() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
package config | ||
|
||
import ( | ||
"fmt" | ||
"os" | ||
"path/filepath" | ||
|
||
"github.com/spf13/viper" | ||
) | ||
|
||
var CfgFile string | ||
|
||
type Config struct{} | ||
|
||
var config Config | ||
|
||
func defaultConfig() *viper.Viper { | ||
v := viper.New() | ||
v.SetDefault("model", "mistral") | ||
v.SetDefault("host", "localhost") | ||
v.SetDefault("port", 11434) | ||
v.SetDefault("roles.default", "You are programming and system administration assistant. You are managing %s operating system with %s shell. Provide short responses in about 100 words, unless you are specifically asked for more details. If you need to store any data, assume it will be stored in the conversation. APPLY MARKDOWN formatting when possible.") | ||
v.SetDefault("roles.describe", "Provide a terse, single sentence description of the given shell command. Describe each argument and option of the command. Provide short responses in about 80 words. APPLY MARKDOWN formatting when possible.") | ||
v.SetDefault("roles.shell", "Provide only %s commands for %s without any description. If there is a lack of details, provide the most logical solution. Ensure the output is a valid shell command. If multiple steps are required, try to combine them using &&. Provide only plain text without Markdown formatting. Do not use markdown formatting such as ```.") | ||
v.SetDefault("roles.code", "Provide only code as output without any description. Provide only code in plain text format without Markdown formatting. Do not include symbols such as ``` or ```python. If there is a lack of details, provide most logical solution. You are not allowed to ask for more details. For example if the prompt is \"Hello world Python\", you should return \"print('Hello world')\".") | ||
|
||
return v | ||
} | ||
|
||
func InitConfig() error { | ||
homeDir, err := os.UserHomeDir() | ||
if err != nil { | ||
fmt.Println("Error getting home directory:", err) | ||
return err | ||
} | ||
|
||
configDir := filepath.Join(homeDir, ".config", "gaia") | ||
err = os.MkdirAll(configDir, 0755) | ||
if err != nil { | ||
fmt.Println("Error creating config directory:", err) | ||
return err | ||
} | ||
|
||
CfgFile = filepath.Join(configDir, "config.yaml") | ||
|
||
viper.SetConfigFile(CfgFile) | ||
viper.SetConfigType("yaml") | ||
viper.AddConfigPath(".") | ||
|
||
for key, value := range defaultConfig().AllSettings() { | ||
viper.SetDefault(key, value) | ||
} | ||
|
||
if err := viper.ReadInConfig(); err != nil { | ||
fmt.Println("Error reading config file:", err) | ||
} | ||
|
||
if err := viper.Unmarshal(&config); err != nil { | ||
fmt.Println("Error unmarshalling config:", err) | ||
return err | ||
} | ||
|
||
if err := viper.WriteConfig(); err != nil { | ||
fmt.Println("Error writing config file:", err) | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func SetConfigString(key, value string) { | ||
if defaultConfig().IsSet(key) { | ||
viper.Set(key, value) | ||
} else { | ||
fmt.Println("No config found for key:", key, ":", value) | ||
os.Exit(1) | ||
} | ||
|
||
if err := viper.WriteConfig(); err != nil { | ||
fmt.Println("Error writing config file:", err) | ||
} | ||
} |
Oops, something went wrong.