diff --git a/internal/db/redis.go b/internal/db/redis.go index f5b210e..5bfa9b9 100644 --- a/internal/db/redis.go +++ b/internal/db/redis.go @@ -2,24 +2,23 @@ package db import ( "context" + "github.com/MirrorChyan/resource-backend/internal/config" "github.com/pkg/errors" "github.com/redis/go-redis/v9" ) -var ( - IRS *redis.Client -) - -func NewRedis(conf *config.Config) { - IRS = redis.NewClient(&redis.Options{ +func NewRedis(conf *config.Config) *redis.Client { + rdb := redis.NewClient(&redis.Options{ Addr: conf.Redis.Addr, DB: conf.Redis.DB, Username: conf.Redis.Username, Password: conf.Redis.Password, }) - _, err := IRS.Ping(context.Background()).Result() + _, err := rdb.Ping(context.Background()).Result() if err != nil { panic(errors.WithMessage(err, "failed to ping redis")) } + + return rdb } diff --git a/internal/handler/version.go b/internal/handler/version.go index 7136289..d23774e 100644 --- a/internal/handler/version.go +++ b/internal/handler/version.go @@ -10,16 +10,13 @@ import ( "os" "path/filepath" "strings" - "time" "github.com/MirrorChyan/resource-backend/internal/config" - "github.com/MirrorChyan/resource-backend/internal/db" "github.com/MirrorChyan/resource-backend/internal/ent" "github.com/MirrorChyan/resource-backend/internal/handler/response" "github.com/MirrorChyan/resource-backend/internal/logic" . "github.com/MirrorChyan/resource-backend/internal/model" "github.com/gofiber/fiber/v2" - "github.com/segmentio/ksuid" "go.uber.org/zap" ) @@ -395,70 +392,39 @@ func (h *VersionHandler) GetLatest(c *fiber.Ctx) error { h.logger.Info("CDK validation success") - var isFull = req.CurrentVersion == "" - - // if current version is not provided, we will download the full version - var current *ent.Version - if !isFull { - getVersionByNameParam := GetVersionByNameParam{ - ResourceID: resID, - Name: req.CurrentVersion, - } - current, err = h.versionLogic.GetByName(ctx, getVersionByNameParam) - if err != nil { - if !ent.IsNotFound(err) { - h.logger.Error("Failed to get current version", - zap.Error(err), - ) - resp := response.UnexpectedError() - return c.Status(fiber.StatusInternalServerError).JSON(resp) - } - isFull = true - } + storeTempDownloadInfoParam := StoreTempDownloadInfoParam{ + ResourceID: resID, + CurrentVersionName: req.CurrentVersion, + LatestVersion: latest, } - var info = TempDownloadInfo{ - ResourceID: resID, - Full: isFull, - TargetVersionID: latest.ID, - } - - if !isFull { - info.TargetVersionFileHashes = latest.FileHashes - info.CurrentVersionID = current.ID - info.CurrentVersionFileHashes = current.FileHashes - } - - rk := ksuid.New().String() - - if buf, err := json.Marshal(info); err != nil { - h.logger.Error("Failed to marshal JSON", + key, err := h.versionLogic.StoreTempDownloadInfo(ctx, storeTempDownloadInfoParam) + if err != nil { + h.logger.Error("Failed to store temp download info", zap.Error(err), ) - return c.Status(fiber.StatusInternalServerError).JSON(response.UnexpectedError()) - } else { - db.IRS.Set(ctx, fmt.Sprintf("RES:%v", rk), string(buf), 20*time.Minute) - - url := strings.Join([]string{h.conf.Extra.DownloadPrefix, rk}, "/") - data.Url = url - return c.Status(fiber.StatusOK).JSON(response.Success(data, "success")) + resp := response.UnexpectedError() + return c.Status(fiber.StatusInternalServerError).JSON(resp) } + url := strings.Join([]string{h.conf.Extra.DownloadPrefix, key}, "/") + data.Url = url + resp := response.Success(data) + return c.Status(fiber.StatusOK).JSON(resp) } func (h *VersionHandler) Download(c *fiber.Ctx) error { key := c.Params("key", "") if key == "" { - return c.Status(fiber.StatusNotFound).JSON(response.BusinessError("missing key")) + resp := response.BusinessError("missing key") + return c.Status(fiber.StatusNotFound).JSON(resp) } ctx := c.UserContext() - val, err := db.IRS.GetDel(ctx, fmt.Sprintf("RES:%v", key)).Result() - - var info TempDownloadInfo - if err != nil || val == "" || json.Unmarshal([]byte(val), &info) != nil { + info, err := h.versionLogic.GetTempDownloadInfo(ctx, key) + if err != nil { h.logger.Warn("invalid key or resource not found", zap.String("key", key), ) diff --git a/internal/logic/version.go b/internal/logic/version.go index 4a0cb7f..a3af86b 100644 --- a/internal/logic/version.go +++ b/internal/logic/version.go @@ -7,6 +7,7 @@ import ( "path/filepath" "strconv" "strings" + "time" "github.com/MirrorChyan/resource-backend/internal/ent" . "github.com/MirrorChyan/resource-backend/internal/model" @@ -16,22 +17,31 @@ import ( "github.com/MirrorChyan/resource-backend/internal/pkg/fileops" "github.com/MirrorChyan/resource-backend/internal/pkg/stg" "github.com/MirrorChyan/resource-backend/internal/repo" + "github.com/segmentio/ksuid" "go.uber.org/zap" ) type VersionLogic struct { - logger *zap.Logger - versionRepo *repo.Version - storageRepo *repo.Storage - storage *stg.Storage + logger *zap.Logger + versionRepo *repo.Version + storageRepo *repo.Storage + tempDownloadInfoRepo *repo.TempDownloadInfo + storage *stg.Storage } -func NewVersionLogic(logger *zap.Logger, versionRepo *repo.Version, storageRepo *repo.Storage, storage *stg.Storage) *VersionLogic { +func NewVersionLogic( + logger *zap.Logger, + versionRepo *repo.Version, + storageRepo *repo.Storage, + tempDownloadInfoRepo *repo.TempDownloadInfo, + storage *stg.Storage, +) *VersionLogic { return &VersionLogic{ - logger: logger, - versionRepo: versionRepo, - storageRepo: storageRepo, - storage: storage, + logger: logger, + versionRepo: versionRepo, + storageRepo: storageRepo, + tempDownloadInfoRepo: tempDownloadInfoRepo, + storage: storage, } } @@ -196,6 +206,71 @@ func (l *VersionLogic) GetByName(ctx context.Context, param GetVersionByNamePara return l.versionRepo.GetVersionByName(ctx, param.ResourceID, param.Name) } +func (l *VersionLogic) StoreTempDownloadInfo(ctx context.Context, param StoreTempDownloadInfoParam) (string, error) { + isFull := param.CurrentVersionName == "" + + // if current version is not provided, we will download the full version + var ( + current *ent.Version + err error + ) + if !isFull { + getVersionByNameParam := GetVersionByNameParam{ + ResourceID: param.ResourceID, + Name: param.CurrentVersionName, + } + current, err = l.GetByName(ctx, getVersionByNameParam) + if err != nil { + if !ent.IsNotFound(err) { + l.logger.Error("Failed to get current version", + zap.Error(err), + ) + return "", err + } + isFull = true + } + } + + var info = &TempDownloadInfo{ + ResourceID: param.ResourceID, + Full: isFull, + TargetVersionID: param.LatestVersion.ID, + } + + if !isFull { + info.TargetVersionFileHashes = param.LatestVersion.FileHashes + info.CurrentVersionID = current.ID + info.CurrentVersionFileHashes = current.FileHashes + } + + key := ksuid.New().String() + rk := fmt.Sprintf("RES:%v", key) + + err = l.tempDownloadInfoRepo.SetTempDownloadInfo(ctx, rk, info, 20*time.Minute) + if err != nil { + l.logger.Error("Failed to set temp download info", + zap.Error(err), + ) + return "", err + } + + return key, nil +} + +func (l *VersionLogic) GetTempDownloadInfo(ctx context.Context, key string) (*TempDownloadInfo, error) { + rk := fmt.Sprintf("RES:%v", key) + + info, err := l.tempDownloadInfoRepo.GetDelTempDownloadInfo(ctx, rk) + if err != nil { + l.logger.Error("Failed to get temp download info", + zap.Error(err), + ) + return nil, err + } + + return info, nil +} + func (l *VersionLogic) GetResourcePath(param GetResourcePathParam) string { return l.storage.ResourcePath(param.ResourceID, param.VersionID) } diff --git a/internal/model/model.go b/internal/model/model.go index c0219b0..323045b 100644 --- a/internal/model/model.go +++ b/internal/model/model.go @@ -1,5 +1,7 @@ package model +import "github.com/MirrorChyan/resource-backend/internal/ent" + type UpdateResourceParam struct { ID string Name string @@ -109,6 +111,12 @@ type BillingCheckinRequest struct { UserAgent string `json:"user_agent"` } +type StoreTempDownloadInfoParam struct { + ResourceID string + CurrentVersionName string + LatestVersion *ent.Version +} + type GetResourcePathParam struct { ResourceID string VersionID int diff --git a/internal/repo/temp_download_info.go b/internal/repo/temp_download_info.go new file mode 100644 index 0000000..0c693eb --- /dev/null +++ b/internal/repo/temp_download_info.go @@ -0,0 +1,44 @@ +package repo + +import ( + "context" + "encoding/json" + "time" + + "github.com/MirrorChyan/resource-backend/internal/model" + "github.com/redis/go-redis/v9" +) + +type TempDownloadInfo struct { + rdb *redis.Client +} + +func NewTempDownloadInfo(rdb *redis.Client) *TempDownloadInfo { + return &TempDownloadInfo{ + rdb: rdb, + } +} + +func (r *TempDownloadInfo) GetDelTempDownloadInfo(ctx context.Context, key string) (*model.TempDownloadInfo, error) { + val, err := r.rdb.GetDel(ctx, key).Result() + if err != nil { + return nil, err + } + + info := &model.TempDownloadInfo{} + err = json.Unmarshal([]byte(val), info) + if err != nil { + return nil, err + } + + return info, nil +} + +func (r *TempDownloadInfo) SetTempDownloadInfo(ctx context.Context, key string, info *model.TempDownloadInfo, expiration time.Duration) error { + buf, err := json.Marshal(info) + if err != nil { + return err + } + + return r.rdb.Set(ctx, key, buf, expiration).Err() +} diff --git a/internal/wire/wire.go b/internal/wire/wire.go index e170b7f..01e7bd6 100644 --- a/internal/wire/wire.go +++ b/internal/wire/wire.go @@ -11,6 +11,7 @@ import ( "github.com/MirrorChyan/resource-backend/internal/pkg/stg" "github.com/MirrorChyan/resource-backend/internal/repo" "github.com/google/wire" + "github.com/redis/go-redis/v9" "go.uber.org/zap" ) @@ -18,6 +19,7 @@ var repoProviderSet = wire.NewSet( repo.NewResource, repo.NewVersion, repo.NewStorage, + repo.NewTempDownloadInfo, ) var logicProviderSet = wire.NewSet( @@ -42,6 +44,6 @@ func provideHandlerSet(resourceHandler *handler.ResourceHandler, versionHandler } } -func NewHandlerSet(conf *config.Config, logger *zap.Logger, db *ent.Client, storage *stg.Storage) *HandlerSet { +func NewHandlerSet(conf *config.Config, logger *zap.Logger, db *ent.Client, rdb *redis.Client, storage *stg.Storage) *HandlerSet { panic(wire.Build(repoProviderSet, logicProviderSet, handlerProviderSet, provideHandlerSet)) } diff --git a/internal/wire/wire_gen.go b/internal/wire/wire_gen.go index bce403c..96a543a 100644 --- a/internal/wire/wire_gen.go +++ b/internal/wire/wire_gen.go @@ -14,18 +14,20 @@ import ( "github.com/MirrorChyan/resource-backend/internal/pkg/stg" "github.com/MirrorChyan/resource-backend/internal/repo" "github.com/google/wire" + "github.com/redis/go-redis/v9" "go.uber.org/zap" ) // Injectors from wire.go: -func NewHandlerSet(conf *config.Config, logger *zap.Logger, db *ent.Client, storage *stg.Storage) *HandlerSet { +func NewHandlerSet(conf *config.Config, logger *zap.Logger, db *ent.Client, rdb *redis.Client, storage *stg.Storage) *HandlerSet { resource := repo.NewResource(db) resourceLogic := logic.NewResourceLogic(logger, resource) resourceHandler := handler.NewResourceHandler(logger, resourceLogic) version := repo.NewVersion(db) repoStorage := repo.NewStorage(db) - versionLogic := logic.NewVersionLogic(logger, version, repoStorage, storage) + tempDownloadInfo := repo.NewTempDownloadInfo(rdb) + versionLogic := logic.NewVersionLogic(logger, version, repoStorage, tempDownloadInfo, storage) versionHandler := handler.NewVersionHandler(conf, logger, resourceLogic, versionLogic) handlerSet := provideHandlerSet(resourceHandler, versionHandler) return handlerSet @@ -33,7 +35,7 @@ func NewHandlerSet(conf *config.Config, logger *zap.Logger, db *ent.Client, stor // wire.go: -var repoProviderSet = wire.NewSet(repo.NewResource, repo.NewVersion, repo.NewStorage) +var repoProviderSet = wire.NewSet(repo.NewResource, repo.NewVersion, repo.NewStorage, repo.NewTempDownloadInfo) var logicProviderSet = wire.NewSet(logic.NewResourceLogic, logic.NewVersionLogic) diff --git a/main.go b/main.go index 6e5b001..c240ad5 100644 --- a/main.go +++ b/main.go @@ -37,8 +37,6 @@ func main() { l := logger.New(conf) zap.ReplaceGlobals(l) - db.NewRedis(conf) - mySQL, err := db.NewMySQL(conf) if err != nil { @@ -60,9 +58,11 @@ func main() { ) } + redis := db.NewRedis(conf) + storage := stg.New(cwd) - handlerSet := wire.NewHandlerSet(conf, l, mySQL, storage) + handlerSet := wire.NewHandlerSet(conf, l, mySQL, redis, storage) app := fiber.New(fiber.Config{ BodyLimit: BodyLimit,