diff --git a/CHANGELOG.md b/CHANGELOG.md index bf43720..00fe37e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [Unreleased] + +### Added + +* Add sdk openai for image generation command +* Add image generation command + ## [0.5.0] ### Added diff --git a/commands/image.go b/commands/image.go new file mode 100644 index 0000000..360245b --- /dev/null +++ b/commands/image.go @@ -0,0 +1,98 @@ +package commands + +import ( + "strconv" + + "github.com/LordPax/aicli/lang" + "github.com/LordPax/aicli/sdk" + cli "github.com/urfave/cli/v2" +) + +func ImageCommand() (*cli.Command, error) { + l := lang.GetLocalize() + + if err := sdk.InitSdkImage(""); err != nil { + return nil, err + } + + return &cli.Command{ + Name: "image", + Usage: l.Get("image-usage"), + ArgsUsage: "[image|-]", + Aliases: []string{"i"}, + Action: imageAction, + Flags: imageFlags(), + }, nil +} + +func imageFlags() []cli.Flag { + l := lang.GetLocalize() + imageSdk := sdk.GetSdkImage() + + return []cli.Flag{ + &cli.StringFlag{ + Name: "sdk", + Aliases: []string{"S"}, + Usage: l.Get("sdk-usage"), + DefaultText: imageSdk.GetName(), + Category: "global", + Action: func(c *cli.Context, value string) error { + if err := sdk.InitSdkImage(value); err != nil { + return err + } + return nil + }, + }, + &cli.StringFlag{ + Name: "model", + Aliases: []string{"m"}, + Usage: l.Get("sdk-model-usage"), + DefaultText: imageSdk.GetModel(), + Category: "image", + Action: func(c *cli.Context, value string) error { + imageSdk := sdk.GetSdkImage() + imageSdk.SetModel(value) + return nil + }, + }, + &cli.StringFlag{ + Name: "size", + Aliases: []string{"s"}, + Usage: l.Get("image-size-usage"), + DefaultText: imageSdk.GetSize(), + Category: "image", + Action: func(c *cli.Context, value string) error { + imageSdk := sdk.GetSdkImage() + imageSdk.SetSize(value) + return nil + }, + }, + &cli.IntFlag{ + Name: "image-nb", + Aliases: []string{"n"}, + Usage: l.Get("image-nb-usage"), + DefaultText: strconv.Itoa(imageSdk.GetImageNb()), + Category: "image", + Action: func(c *cli.Context, value int) error { + imageSdk := sdk.GetSdkImage() + imageSdk.SetImageNb(value) + return nil + }, + }, + &cli.StringFlag{ + Name: "output", + Aliases: []string{"o"}, + Usage: l.Get("image-output-usage"), + Category: "image", + Action: func(c *cli.Context, value string) error { + imageSdk := sdk.GetSdkImage() + imageSdk.SetOutput(value) + return nil + }, + }, + } +} + +func imageAction(c *cli.Context) error { + return nil +} diff --git a/lang/en.go b/lang/en.go index e776101..db44be9 100644 --- a/lang/en.go +++ b/lang/en.go @@ -33,4 +33,8 @@ var EN_STRINGS = LangString{ "translate-source-usage": "Source language", "translate-target-usage": "Target language", "translate-target-required": "Target language is required", + "image-usage": "Generate an image from a prompt", + "image-size-usage": "Set the size of the image", + "image-nb-usage": "Set the number of images", + "image-output-usage": "Set the name of the output file", } diff --git a/lang/fr.go b/lang/fr.go index da86872..050f355 100644 --- a/lang/fr.go +++ b/lang/fr.go @@ -33,4 +33,8 @@ var FR_STRINGS = LangString{ "translate-source-usage": "Langue source", "translate-target-usage": "Langue cible", "translate-target-required": "La langue cible est requise", + "image-usage": "Générer une image à partir d'un prompt", + "image-size-usage": "Définir la taille de l'image", + "image-nb-usage": "Définir le nombre d'images", + "image-output-usage": "Définir le nom du fichier de sortie", } diff --git a/main.go b/main.go index 436b39b..3ab33f3 100644 --- a/main.go +++ b/main.go @@ -56,10 +56,17 @@ func main() { os.Exit(1) } + imageCmd, err := commands.ImageCommand() + if err != nil { + log.PrintfErr("%v\n", err) + os.Exit(1) + } + app.Commands = []*cli.Command{ textCmd, translateCmd, - // TODO : add command for image, audio and translate + imageCmd, + // TODO : add command for audio } if err := app.Run(os.Args); err != nil { diff --git a/sdk/image.go b/sdk/image.go new file mode 100644 index 0000000..cd470af --- /dev/null +++ b/sdk/image.go @@ -0,0 +1,117 @@ +package sdk + +import ( + "errors" + "fmt" + + "github.com/LordPax/aicli/config" + "github.com/LordPax/aicli/lang" +) + +var sdkImageInstance IImageService + +type IImageService interface { + ISdk + ISdkImage +} + +type ISdkImage interface { + SendRequest(prompt string) (OpenaiImageResponse, error) + SetModel(model string) + GetModel() string + SetSize(size string) + GetSize() string + SetImageNb(imageNb int) + GetImageNb() int + SetOutput(output string) + GetOutput() string +} + +type SdkImage struct { + Model string + Size string + ImageNb int + Output string +} + +func InitSdkImage(sdk string) error { + var err error + + l := lang.GetLocalize() + sdkType, apiKey, err := getConfigImage(sdk) + if err != nil { + return err + } + + switch sdkType { + case "openai": + sdkImageInstance, err = NewOpenaiImage(apiKey) + default: + return fmt.Errorf(l.Get("unknown-sdk"), sdk) + } + + if err != nil { + return err + } + + return nil +} + +func getConfigImage(sdkType string) (string, string, error) { + l := lang.GetLocalize() + configImage := config.CONFIG_INI.Section("image") + + if sdkType == "" { + sdkType = configImage.Key("type").String() + if sdkType == "" { + return "", "", errors.New(l.Get("type-required")) + } + } + + apiKey := configImage.Key("apiKey").String() + if apiKey == "" { + apiKey = configImage.Key(sdkType + "-apiKey").String() + } + + return sdkType, apiKey, nil +} + +func GetSdkImage() IImageService { + return sdkImageInstance +} + +func SetSdkImage(s IImageService) { + sdkImageInstance = s +} + +func (s *SdkImage) SetModel(model string) { + s.Model = model +} + +func (s *SdkImage) GetModel() string { + return s.Model +} + +func (s *SdkImage) SetSize(size string) { + s.Size = size +} + +func (s *SdkImage) GetSize() string { + return s.Size +} + +func (s *SdkImage) SetImageNb(imageNb int) { + s.ImageNb = imageNb +} + +func (s *SdkImage) GetImageNb() int { + return s.ImageNb +} + +func (s *SdkImage) SetOutput(output string) { + s.Output = output +} + +func (s *SdkImage) GetOutput() string { + return s.Output +} diff --git a/sdk/openai-image.go b/sdk/openai-image.go new file mode 100644 index 0000000..cc96a34 --- /dev/null +++ b/sdk/openai-image.go @@ -0,0 +1,94 @@ +package sdk + +import ( + "encoding/json" + "errors" + "io" + "net/http" + + "github.com/LordPax/aicli/utils" +) + +type OpenaiImageBody struct { + Model string `json:"model"` + Prompt string `json:"prompt"` + N int `json:"n"` + Size string `json:"size"` + ResponseFormat string `json:"response_format"` +} + +type OpenaiImageResponse struct { + Images []struct { + Url string `json:"url"` + B64Json string `json:"b64_json"` + } `json:"images"` +} + +type OpenaiImage struct { + Sdk + SdkImage +} + +func NewOpenaiImage(apiKey string) (*OpenaiImage, error) { + return &OpenaiImage{ + Sdk: Sdk{ + Name: "openai", + ApiUrl: "https://api.openai.com/v1/images/generations", + ApiKey: apiKey, + Inerte: false, + }, + SdkImage: SdkImage{ + Model: "dall-e-3", + Size: "1024x1024", + ImageNb: 1, + }, + }, nil +} + +func (o *OpenaiImage) SendRequest(prompt string) (OpenaiImageResponse, error) { + var openaiResponse OpenaiImageResponse + format := "url" + + if o.GetOutput() != "" { + format = "b4_json" + } + + jsonBody, err := json.Marshal(OpenaiImageBody{ + Model: o.GetModel(), + Prompt: prompt, + N: o.GetImageNb(), + Size: o.GetSize(), + ResponseFormat: format, + }) + if err != nil { + return OpenaiImageResponse{}, err + } + + resp, err := utils.PostRequest(o.ApiUrl, jsonBody, map[string]string{ + "Content-Type": "application/json", + "Authorization": "Bearer " + o.ApiKey, + }) + if err != nil { + return OpenaiImageResponse{}, err + } + defer resp.Body.Close() + + respBody, err := io.ReadAll(resp.Body) + if err != nil { + return OpenaiImageResponse{}, err + } + + if resp.StatusCode != http.StatusOK { + var errorMsg ErrorMsg + if err := json.Unmarshal(respBody, &errorMsg); err != nil { + return OpenaiImageResponse{}, err + } + return OpenaiImageResponse{}, errors.New(errorMsg.Error.Message) + } + + if err := json.Unmarshal(respBody, &openaiResponse); err != nil { + return OpenaiImageResponse{}, err + } + + return openaiResponse, nil +} diff --git a/sdk/openai.go b/sdk/openai-text.go similarity index 100% rename from sdk/openai.go rename to sdk/openai-text.go diff --git a/sdk/translate.go b/sdk/translate.go index b1118a9..d86269f 100644 --- a/sdk/translate.go +++ b/sdk/translate.go @@ -66,7 +66,7 @@ func getConfigTranslate(sdkType string) (string, string, error) { if apiKey == "" { apiKey = configTranslate.Key(sdkType + "-apiKey").String() if apiKey == "" { - return "", "", fmt.Errorf(l.Get("api-key-required"), sdkType) + return "", "", errors.New(l.Get("api-key-required")) } }