Skip to content

Commit

Permalink
feat(image) : add image generation
Browse files Browse the repository at this point in the history
  • Loading branch information
LordPax committed Oct 14, 2024
1 parent 2eeaf2c commit 107d3a9
Show file tree
Hide file tree
Showing 9 changed files with 333 additions and 2 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Changelog

## [Unreleased]

### Added

* Add sdk openai for image generation command
* Add image generation command

## [0.5.0]

### Added
Expand Down
98 changes: 98 additions & 0 deletions commands/image.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
package commands

import (
"strconv"

"github.com/LordPax/aicli/lang"
"github.com/LordPax/aicli/sdk"
cli "github.com/urfave/cli/v2"
)

func ImageCommand() (*cli.Command, error) {
l := lang.GetLocalize()

if err := sdk.InitSdkImage(""); err != nil {
return nil, err
}

return &cli.Command{
Name: "image",
Usage: l.Get("image-usage"),
ArgsUsage: "[image|-]",
Aliases: []string{"i"},
Action: imageAction,
Flags: imageFlags(),
}, nil
}

func imageFlags() []cli.Flag {
l := lang.GetLocalize()
imageSdk := sdk.GetSdkImage()

return []cli.Flag{
&cli.StringFlag{
Name: "sdk",
Aliases: []string{"S"},
Usage: l.Get("sdk-usage"),
DefaultText: imageSdk.GetName(),
Category: "global",
Action: func(c *cli.Context, value string) error {
if err := sdk.InitSdkImage(value); err != nil {
return err
}
return nil
},
},
&cli.StringFlag{
Name: "model",
Aliases: []string{"m"},
Usage: l.Get("sdk-model-usage"),
DefaultText: imageSdk.GetModel(),
Category: "image",
Action: func(c *cli.Context, value string) error {
imageSdk := sdk.GetSdkImage()
imageSdk.SetModel(value)
return nil
},
},
&cli.StringFlag{
Name: "size",
Aliases: []string{"s"},
Usage: l.Get("image-size-usage"),
DefaultText: imageSdk.GetSize(),
Category: "image",
Action: func(c *cli.Context, value string) error {
imageSdk := sdk.GetSdkImage()
imageSdk.SetSize(value)
return nil
},
},
&cli.IntFlag{
Name: "image-nb",
Aliases: []string{"n"},
Usage: l.Get("image-nb-usage"),
DefaultText: strconv.Itoa(imageSdk.GetImageNb()),
Category: "image",
Action: func(c *cli.Context, value int) error {
imageSdk := sdk.GetSdkImage()
imageSdk.SetImageNb(value)
return nil
},
},
&cli.StringFlag{
Name: "output",
Aliases: []string{"o"},
Usage: l.Get("image-output-usage"),
Category: "image",
Action: func(c *cli.Context, value string) error {
imageSdk := sdk.GetSdkImage()
imageSdk.SetOutput(value)
return nil
},
},
}
}

func imageAction(c *cli.Context) error {
return nil
}
4 changes: 4 additions & 0 deletions lang/en.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,8 @@ var EN_STRINGS = LangString{
"translate-source-usage": "Source language",
"translate-target-usage": "Target language",
"translate-target-required": "Target language is required",
"image-usage": "Generate an image from a prompt",
"image-size-usage": "Set the size of the image",
"image-nb-usage": "Set the number of images",
"image-output-usage": "Set the name of the output file",
}
4 changes: 4 additions & 0 deletions lang/fr.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,8 @@ var FR_STRINGS = LangString{
"translate-source-usage": "Langue source",
"translate-target-usage": "Langue cible",
"translate-target-required": "La langue cible est requise",
"image-usage": "Générer une image à partir d'un prompt",
"image-size-usage": "Définir la taille de l'image",
"image-nb-usage": "Définir le nombre d'images",
"image-output-usage": "Définir le nom du fichier de sortie",
}
9 changes: 8 additions & 1 deletion main.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,17 @@ func main() {
os.Exit(1)
}

imageCmd, err := commands.ImageCommand()
if err != nil {
log.PrintfErr("%v\n", err)
os.Exit(1)
}

app.Commands = []*cli.Command{
textCmd,
translateCmd,
// TODO : add command for image, audio and translate
imageCmd,
// TODO : add command for audio
}

if err := app.Run(os.Args); err != nil {
Expand Down
117 changes: 117 additions & 0 deletions sdk/image.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
package sdk

import (
"errors"
"fmt"

"github.com/LordPax/aicli/config"
"github.com/LordPax/aicli/lang"
)

var sdkImageInstance IImageService

type IImageService interface {
ISdk
ISdkImage
}

type ISdkImage interface {
SendRequest(prompt string) (OpenaiImageResponse, error)
SetModel(model string)
GetModel() string
SetSize(size string)
GetSize() string
SetImageNb(imageNb int)
GetImageNb() int
SetOutput(output string)
GetOutput() string
}

type SdkImage struct {
Model string
Size string
ImageNb int
Output string
}

func InitSdkImage(sdk string) error {
var err error

l := lang.GetLocalize()
sdkType, apiKey, err := getConfigImage(sdk)
if err != nil {
return err
}

switch sdkType {
case "openai":
sdkImageInstance, err = NewOpenaiImage(apiKey)
default:
return fmt.Errorf(l.Get("unknown-sdk"), sdk)
}

if err != nil {
return err
}

return nil
}

func getConfigImage(sdkType string) (string, string, error) {
l := lang.GetLocalize()
configImage := config.CONFIG_INI.Section("image")

if sdkType == "" {
sdkType = configImage.Key("type").String()
if sdkType == "" {
return "", "", errors.New(l.Get("type-required"))
}
}

apiKey := configImage.Key("apiKey").String()
if apiKey == "" {
apiKey = configImage.Key(sdkType + "-apiKey").String()
}

return sdkType, apiKey, nil
}

func GetSdkImage() IImageService {
return sdkImageInstance
}

func SetSdkImage(s IImageService) {
sdkImageInstance = s
}

func (s *SdkImage) SetModel(model string) {
s.Model = model
}

func (s *SdkImage) GetModel() string {
return s.Model
}

func (s *SdkImage) SetSize(size string) {
s.Size = size
}

func (s *SdkImage) GetSize() string {
return s.Size
}

func (s *SdkImage) SetImageNb(imageNb int) {
s.ImageNb = imageNb
}

func (s *SdkImage) GetImageNb() int {
return s.ImageNb
}

func (s *SdkImage) SetOutput(output string) {
s.Output = output
}

func (s *SdkImage) GetOutput() string {
return s.Output
}
94 changes: 94 additions & 0 deletions sdk/openai-image.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
package sdk

import (
"encoding/json"
"errors"
"io"
"net/http"

"github.com/LordPax/aicli/utils"
)

type OpenaiImageBody struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N int `json:"n"`
Size string `json:"size"`
ResponseFormat string `json:"response_format"`
}

type OpenaiImageResponse struct {
Images []struct {
Url string `json:"url"`
B64Json string `json:"b64_json"`
} `json:"images"`
}

type OpenaiImage struct {
Sdk
SdkImage
}

func NewOpenaiImage(apiKey string) (*OpenaiImage, error) {
return &OpenaiImage{
Sdk: Sdk{
Name: "openai",
ApiUrl: "https://api.openai.com/v1/images/generations",
ApiKey: apiKey,
Inerte: false,
},
SdkImage: SdkImage{
Model: "dall-e-3",
Size: "1024x1024",
ImageNb: 1,
},
}, nil
}

func (o *OpenaiImage) SendRequest(prompt string) (OpenaiImageResponse, error) {
var openaiResponse OpenaiImageResponse
format := "url"

if o.GetOutput() != "" {
format = "b4_json"
}

jsonBody, err := json.Marshal(OpenaiImageBody{
Model: o.GetModel(),
Prompt: prompt,
N: o.GetImageNb(),
Size: o.GetSize(),
ResponseFormat: format,
})
if err != nil {
return OpenaiImageResponse{}, err
}

resp, err := utils.PostRequest(o.ApiUrl, jsonBody, map[string]string{
"Content-Type": "application/json",
"Authorization": "Bearer " + o.ApiKey,
})
if err != nil {
return OpenaiImageResponse{}, err
}
defer resp.Body.Close()

respBody, err := io.ReadAll(resp.Body)
if err != nil {
return OpenaiImageResponse{}, err
}

if resp.StatusCode != http.StatusOK {
var errorMsg ErrorMsg
if err := json.Unmarshal(respBody, &errorMsg); err != nil {
return OpenaiImageResponse{}, err
}
return OpenaiImageResponse{}, errors.New(errorMsg.Error.Message)
}

if err := json.Unmarshal(respBody, &openaiResponse); err != nil {
return OpenaiImageResponse{}, err
}

return openaiResponse, nil
}
File renamed without changes.
2 changes: 1 addition & 1 deletion sdk/translate.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func getConfigTranslate(sdkType string) (string, string, error) {
if apiKey == "" {
apiKey = configTranslate.Key(sdkType + "-apiKey").String()
if apiKey == "" {
return "", "", fmt.Errorf(l.Get("api-key-required"), sdkType)
return "", "", errors.New(l.Get("api-key-required"))
}
}

Expand Down

0 comments on commit 107d3a9

Please sign in to comment.