diff --git a/cmd/api/api.go b/cmd/api/api.go
index 8f21656..99efaad 100644
--- a/cmd/api/api.go
+++ b/cmd/api/api.go
@@ -1,6 +1,7 @@
package api
import (
+ "context"
"fmt"
"github.com/celestiaorg/knuu/internal/api/v1"
@@ -66,16 +67,6 @@ func NewAPICmd() *cobra.Command {
}
func runAPIServer(cmd *cobra.Command, args []string) error {
- port, err := cmd.Flags().GetInt(flagPort)
- if err != nil {
- return fmt.Errorf("failed to get port: %v", err)
- }
-
- logLevel, err := cmd.Flags().GetString(flagLogLevel)
- if err != nil {
- return fmt.Errorf("failed to get log level: %v", err)
- }
-
dbOpts, err := getDBOptions(cmd.Flags())
if err != nil {
return fmt.Errorf("failed to get database options: %v", err)
@@ -86,29 +77,16 @@ func runAPIServer(cmd *cobra.Command, args []string) error {
return fmt.Errorf("failed to connect to database: %v", err)
}
- secretKey, err := cmd.Flags().GetString(flagSecretKey)
+ apiOpts, err := getAPIOptions(cmd.Flags())
if err != nil {
- return fmt.Errorf("failed to get secret key: %v", err)
+ return fmt.Errorf("failed to get API options: %v", err)
}
- adminUser, err := cmd.Flags().GetString(flagAdminUser)
+ apiServer, err := api.New(context.Background(), db, apiOpts)
if err != nil {
- return fmt.Errorf("failed to get admin user: %v", err)
+ return fmt.Errorf("failed to create API server: %v", err)
}
- adminPass, err := cmd.Flags().GetString(flagAdminPass)
- if err != nil {
- return fmt.Errorf("failed to get admin password: %v", err)
- }
-
- apiServer := api.New(db, api.Options{
- Port: port,
- LogMode: logLevel,
- SecretKey: secretKey,
- AdminUser: adminUser,
- AdminPass: adminPass,
- })
-
return apiServer.Start()
}
@@ -146,3 +124,38 @@ func getDBOptions(flags *pflag.FlagSet) (database.Options, error) {
Port: dbPort,
}, nil
}
+
+func getAPIOptions(flags *pflag.FlagSet) (api.Options, error) {
+ port, err := flags.GetInt(flagPort)
+ if err != nil {
+ return api.Options{}, fmt.Errorf("failed to get port: %v", err)
+ }
+
+ logLevel, err := flags.GetString(flagLogLevel)
+ if err != nil {
+ return api.Options{}, fmt.Errorf("failed to get log level: %v", err)
+ }
+
+ secretKey, err := flags.GetString(flagSecretKey)
+ if err != nil {
+ return api.Options{}, fmt.Errorf("failed to get secret key: %v", err)
+ }
+
+ adminUser, err := flags.GetString(flagAdminUser)
+ if err != nil {
+ return api.Options{}, fmt.Errorf("failed to get admin user: %v", err)
+ }
+
+ adminPass, err := flags.GetString(flagAdminPass)
+ if err != nil {
+ return api.Options{}, fmt.Errorf("failed to get admin password: %v", err)
+ }
+
+ return api.Options{
+ Port: port,
+ LogMode: logLevel,
+ SecretKey: secretKey,
+ AdminUser: adminUser,
+ AdminPass: adminPass,
+ }, nil
+}
diff --git a/go.mod b/go.mod
index b0c168e..a9011b1 100644
--- a/go.mod
+++ b/go.mod
@@ -5,12 +5,14 @@ go 1.22.5
require (
github.com/celestiaorg/bittwister v0.0.0-20231213180407-65cdbaf5b8c7
github.com/gin-gonic/gin v1.10.0
+ github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/golang-jwt/jwt/v4 v4.5.1
github.com/google/uuid v1.6.0
github.com/minio/minio-go/v7 v7.0.74
github.com/rs/cors v1.11.1
github.com/sirupsen/logrus v1.9.3
github.com/spf13/cobra v1.8.0
+ github.com/spf13/pflag v1.0.5
github.com/stretchr/testify v1.9.0
golang.org/x/crypto v0.29.0
gopkg.in/yaml.v2 v2.4.0
@@ -78,7 +80,6 @@ require (
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect
github.com/rs/xid v1.5.0 // indirect
- github.com/spf13/pflag v1.0.5 // indirect
github.com/stretchr/objx v0.5.2 // indirect
github.com/twitchyliquid64/golang-asm v0.15.1 // indirect
github.com/ugorji/go/codec v1.2.12 // indirect
diff --git a/go.sum b/go.sum
index 93ece83..ea9aba2 100644
--- a/go.sum
+++ b/go.sum
@@ -65,6 +65,8 @@ github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEe
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572/go.mod h1:9Pwr4B2jHnOSGXyyzV8ROjYa2ojvAY6HCGYYfMoC3Ls=
github.com/goccy/go-json v0.10.3 h1:KZ5WoDbxAIgm2HNbYckL0se1fHD6rz5j4ywS6ebzDqA=
github.com/goccy/go-json v0.10.3/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
+github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY=
+github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I=
github.com/golang-jwt/jwt/v4 v4.5.1 h1:JdqV9zKUdtaa9gdPlywC3aeoEsR681PlKC+4F5gQgeo=
github.com/golang-jwt/jwt/v4 v4.5.1/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
diff --git a/internal/api/v1/api.go b/internal/api/v1/api.go
index 8d7c82a..f4f4442 100644
--- a/internal/api/v1/api.go
+++ b/internal/api/v1/api.go
@@ -1,12 +1,14 @@
package api
import (
+ "context"
"fmt"
"net/http"
"github.com/celestiaorg/knuu/internal/api/v1/handlers"
"github.com/celestiaorg/knuu/internal/api/v1/middleware"
"github.com/celestiaorg/knuu/internal/api/v1/services"
+ "github.com/celestiaorg/knuu/internal/database/models"
"github.com/celestiaorg/knuu/internal/database/repos"
"github.com/gin-gonic/gin"
@@ -29,21 +31,44 @@ type Options struct {
LogMode string // gin.DebugMode, gin.ReleaseMode(default), gin.TestMode
OriginAllowed string
SecretKey string
+
+ AdminUser string // default admin username
+ AdminPass string // default admin password
}
-func New(db *gorm.DB, opts Options) *API {
+func New(ctx context.Context, db *gorm.DB, opts Options) (*API, error) {
opts = setDefaults(opts)
gin.SetMode(opts.LogMode)
rt := gin.Default()
+ auth := middleware.NewAuth(opts.SecretKey)
+ uh, err := getUserHandler(ctx, opts, db, auth)
+ if err != nil {
+ return nil, err
+ }
+
public := rt.Group("/")
{
- uh := handlers.NewUserHandler(services.NewUserService(opts.SecretKey, repos.NewUserRepository(db)))
- public.POST(pathsUserRegister, uh.Register)
public.POST(pathsUserLogin, uh.Login)
}
- protected := rt.Group("/", middleware.AuthMiddleware())
+
+ protected := rt.Group("/", auth.AuthMiddleware())
+ {
+ protected.POST(pathsUserRegister, auth.RequireRole(models.RoleAdmin), uh.Register)
+
+ th, err := getTestHandler(ctx, db)
+ if err != nil {
+ return nil, err
+ }
+
+ protected.POST(pathsTests, th.CreateTest)
+ // protected.GET(pathsTestDetails, th.GetTestDetails)
+ // protected.GET(pathsTestInstances, th.GetInstances)
+ protected.GET(pathsTestInstanceDetails, th.GetInstance)
+ protected.POST(pathsTestInstanceDetails, th.CreateInstance) // Need to do something about updating an instance
+ // protected.POST(pathsTestInstanceExecute, th.ExecuteInstance)
+ }
_ = protected
@@ -59,7 +84,7 @@ func New(db *gorm.DB, opts Options) *API {
public.GET("/", a.IndexPage)
}
- return a
+ return a, nil
}
func (a *API) Start() error {
@@ -103,3 +128,21 @@ func handleOrigin(router *gin.Engine, originAllowed string) http.Handler {
AllowedMethods: methodsOk,
}).Handler(router)
}
+
+func getUserHandler(ctx context.Context, opts Options, db *gorm.DB, auth *middleware.Auth) (*handlers.UserHandler, error) {
+ us, err := services.NewUserService(ctx, opts.AdminUser, opts.AdminPass, repos.NewUserRepository(db))
+ if err != nil {
+ return nil, err
+ }
+
+ return handlers.NewUserHandler(us, auth), nil
+}
+
+func getTestHandler(ctx context.Context, db *gorm.DB) (*handlers.TestHandler, error) {
+ ts, err := services.NewTestService(ctx, repos.NewTestRepository(db))
+ if err != nil {
+ return nil, err
+ }
+
+ return handlers.NewTestHandler(ts), nil
+}
diff --git a/internal/api/v1/handlers/instance.go b/internal/api/v1/handlers/instance.go
new file mode 100644
index 0000000..27ca523
--- /dev/null
+++ b/internal/api/v1/handlers/instance.go
@@ -0,0 +1,44 @@
+package handlers
+
+import (
+ "net/http"
+
+ "github.com/celestiaorg/knuu/internal/api/v1/services"
+ "github.com/gin-gonic/gin"
+)
+
+func (h *TestHandler) CreateInstance(c *gin.Context) {
+ user, err := getUserFromContext(c)
+ if err != nil {
+ c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
+ return
+ }
+
+ var input services.Instance
+ if err := c.ShouldBindJSON(&input); err != nil {
+ c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid input"})
+ return
+ }
+
+ err = h.testService.CreateInstance(c.Request.Context(), user.ID, &input)
+ if err != nil {
+ c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+ return
+ }
+ c.JSON(http.StatusCreated, gin.H{"message": "Instance created successfully"})
+}
+
+func (h *TestHandler) GetInstance(c *gin.Context) {
+ user, err := getUserFromContext(c)
+ if err != nil {
+ c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
+ return
+ }
+
+ instance, err := h.testService.GetInstance(c.Request.Context(), user.ID, c.Param("scope"), c.Param("instance_name"))
+ if err != nil {
+ c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+ return
+ }
+ c.JSON(http.StatusOK, instance)
+}
diff --git a/internal/api/v1/handlers/test.go b/internal/api/v1/handlers/test.go
new file mode 100644
index 0000000..2037849
--- /dev/null
+++ b/internal/api/v1/handlers/test.go
@@ -0,0 +1,38 @@
+package handlers
+
+import (
+ "net/http"
+
+ "github.com/celestiaorg/knuu/internal/api/v1/services"
+ "github.com/celestiaorg/knuu/internal/database/models"
+ "github.com/gin-gonic/gin"
+)
+
+type TestHandler struct {
+ testService *services.TestService
+}
+
+func NewTestHandler(ts *services.TestService) *TestHandler {
+ return &TestHandler{testService: ts}
+}
+
+func (h *TestHandler) CreateTest(c *gin.Context) {
+ user, err := getUserFromContext(c)
+ if err != nil {
+ c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
+ return
+ }
+
+ var input models.Test
+ if err := c.ShouldBindJSON(&input); err != nil {
+ c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid input"})
+ return
+ }
+
+ input.UserID = user.ID
+ if err := h.testService.Create(c.Request.Context(), &input); err != nil {
+ c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
+ return
+ }
+ c.JSON(http.StatusCreated, gin.H{"message": "Test created successfully"})
+}
diff --git a/internal/api/v1/handlers/test_handler.go b/internal/api/v1/handlers/test_handler.go
deleted file mode 100644
index 5ac8282..0000000
--- a/internal/api/v1/handlers/test_handler.go
+++ /dev/null
@@ -1 +0,0 @@
-package handlers
diff --git a/internal/api/v1/handlers/token.go b/internal/api/v1/handlers/token.go
new file mode 100644
index 0000000..ee3dd57
--- /dev/null
+++ b/internal/api/v1/handlers/token.go
@@ -0,0 +1,5 @@
+package handlers
+
+// Users can request for new tokens with some permissions
+// A user can have multiple tokens with different permissions
+// A token can be revoked by the user or by the admin
diff --git a/internal/api/v1/handlers/token_handler.go b/internal/api/v1/handlers/token_handler.go
deleted file mode 100644
index 5ac8282..0000000
--- a/internal/api/v1/handlers/token_handler.go
+++ /dev/null
@@ -1 +0,0 @@
-package handlers
diff --git a/internal/api/v1/handlers/user.go b/internal/api/v1/handlers/user.go
index 35ea342..227ac96 100644
--- a/internal/api/v1/handlers/user.go
+++ b/internal/api/v1/handlers/user.go
@@ -3,6 +3,7 @@ package handlers
import (
"net/http"
+ "github.com/celestiaorg/knuu/internal/api/v1/middleware"
"github.com/celestiaorg/knuu/internal/api/v1/services"
"github.com/celestiaorg/knuu/internal/database/models"
@@ -11,10 +12,14 @@ import (
type UserHandler struct {
userService services.UserService
+ auth *middleware.Auth
}
-func NewUserHandler(userService services.UserService) *UserHandler {
- return &UserHandler{userService: userService}
+func NewUserHandler(userService services.UserService, auth *middleware.Auth) *UserHandler {
+ return &UserHandler{
+ userService: userService,
+ auth: auth,
+ }
}
func (h *UserHandler) Register(c *gin.Context) {
@@ -24,12 +29,12 @@ func (h *UserHandler) Register(c *gin.Context) {
return
}
- user, err := h.userService.Register(&input)
+ _, err := h.userService.Register(c.Request.Context(), &input)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
- c.JSON(http.StatusCreated, user)
+ c.JSON(http.StatusCreated, gin.H{"message": "User registered successfully"})
}
func (h *UserHandler) Login(c *gin.Context) {
@@ -42,10 +47,16 @@ func (h *UserHandler) Login(c *gin.Context) {
return
}
- token, err := h.userService.Authenticate(input.Username, input.Password)
+ user, err := h.userService.Authenticate(c.Request.Context(), input.Username, input.Password)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
return
}
+
+ token, err := h.auth.GenerateToken(user)
+ if err != nil {
+ c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
+ return
+ }
c.JSON(http.StatusOK, gin.H{"token": token})
}
diff --git a/internal/api/v1/handlers/utils.go b/internal/api/v1/handlers/utils.go
new file mode 100644
index 0000000..0c8db9e
--- /dev/null
+++ b/internal/api/v1/handlers/utils.go
@@ -0,0 +1,21 @@
+package handlers
+
+import (
+ "errors"
+
+ "github.com/celestiaorg/knuu/internal/api/v1/middleware"
+ "github.com/celestiaorg/knuu/internal/database/models"
+ "github.com/gin-gonic/gin"
+)
+
+func getUserFromContext(c *gin.Context) (*models.User, error) {
+ user, ok := c.Get(middleware.UserContextKey)
+ if !ok {
+ return nil, errors.New("user not found in context")
+ }
+ authUser, ok := user.(*models.User)
+ if !ok {
+ return nil, errors.New("invalid user data in context")
+ }
+ return authUser, nil
+}
diff --git a/internal/api/v1/index.go b/internal/api/v1/index.go
index 6c8fad9..c42fd37 100644
--- a/internal/api/v1/index.go
+++ b/internal/api/v1/index.go
@@ -46,7 +46,7 @@ func (a *API) IndexPage(c *gin.Context) {
for _, a := range allAPIs {
href := strings.TrimPrefix(a.Path, "/") // it fixes the links if the service is running under a path
- html += fmt.Sprintf(`%s
`, href, a.Path)
+ html += fmt.Sprintf(`%s [ %s ]
`, href, a.Path, a.Method)
}
html += buildInfo
diff --git a/internal/api/v1/middleware/auth.go b/internal/api/v1/middleware/auth.go
index 65d69b8..0f286cb 100644
--- a/internal/api/v1/middleware/auth.go
+++ b/internal/api/v1/middleware/auth.go
@@ -1,24 +1,127 @@
package middleware
import (
+ "errors"
"net/http"
+ "time"
+ "github.com/celestiaorg/knuu/internal/database/models"
"github.com/gin-gonic/gin"
+ "github.com/golang-jwt/jwt"
)
-func AuthMiddleware() gin.HandlerFunc {
+const (
+ UserTokenDuration = 24 * time.Hour
+ UserContextKey = "user"
+
+ authTokenPrefix = "Bearer "
+ userTokenClaimsUserID = "user_id"
+ userTokenClaimsUsername = "username"
+ userTokenClaimsRole = "role"
+ userTokenClaimsExp = "exp"
+)
+
+type Auth struct {
+ secretKey string
+}
+
+func NewAuth(secretKey string) *Auth {
+ return &Auth{
+ secretKey: secretKey,
+ }
+}
+
+func (a *Auth) AuthMiddleware() gin.HandlerFunc {
return func(c *gin.Context) {
- token := c.GetHeader("Authorization")
- if token == "" || !isValidToken(token) {
+ token := a.getAuthToken(c)
+ if token == "" || !a.isValidToken(token) {
c.JSON(http.StatusUnauthorized, gin.H{"error": "Unauthorized"})
c.Abort()
return
}
+
+ user, err := a.getUserFromToken(token)
+ if err != nil {
+ c.JSON(http.StatusUnauthorized, gin.H{"error": "Invalid token err: " + err.Error()})
+ c.Abort()
+ return
+ }
+ c.Set(UserContextKey, user)
c.Next()
}
}
-func isValidToken(token string) bool {
- // Implement token validation logic
- return token == "valid-token"
+func (a *Auth) RequireRole(requiredRole models.UserRole) gin.HandlerFunc {
+ return func(c *gin.Context) {
+ user, err := a.getUserFromToken(a.getAuthToken(c))
+ if err != nil {
+ c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
+ c.Abort()
+ return
+ }
+ if user.Role != requiredRole {
+ c.JSON(http.StatusForbidden, gin.H{"error": "Insufficient permissions"})
+ c.Abort()
+ return
+ }
+ c.Next()
+ }
+}
+
+func (a *Auth) GenerateToken(user *models.User) (string, error) {
+ claims := jwt.MapClaims{
+ userTokenClaimsUserID: user.ID,
+ userTokenClaimsUsername: user.Username,
+ userTokenClaimsRole: user.Role,
+ userTokenClaimsExp: time.Now().Add(UserTokenDuration).Unix(),
+ }
+
+ token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
+ return token.SignedString([]byte(a.secretKey))
+}
+
+func (a *Auth) getUserFromToken(token string) (*models.User, error) {
+ claims := jwt.MapClaims{}
+ _, err := jwt.ParseWithClaims(token, &claims, func(token *jwt.Token) (interface{}, error) {
+ return []byte(a.secretKey), nil
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ userID, ok := claims[userTokenClaimsUserID].(float64)
+ if !ok {
+ return nil, errors.New("invalid user ID")
+ }
+ username, ok := claims[userTokenClaimsUsername].(string)
+ if !ok {
+ return nil, errors.New("invalid username")
+ }
+ role, ok := claims[userTokenClaimsRole].(float64)
+ if !ok {
+ return nil, errors.New("invalid role")
+ }
+
+ return &models.User{ID: uint(userID), Username: username, Role: models.UserRole(role)}, nil
+}
+
+func (a *Auth) isValidToken(token string) bool {
+ claims := &jwt.MapClaims{}
+ parsedToken, err := jwt.ParseWithClaims(token, claims, func(token *jwt.Token) (interface{}, error) {
+ return []byte(a.secretKey), nil
+ })
+
+ if err != nil {
+ return false
+ }
+
+ return parsedToken.Valid
+}
+
+func (a *Auth) getAuthToken(c *gin.Context) string {
+ token := c.GetHeader("Authorization")
+ if len(token) > len(authTokenPrefix) && token[:len(authTokenPrefix)] == authTokenPrefix {
+ token = token[len(authTokenPrefix):]
+ }
+ return token
}
diff --git a/internal/api/v1/middleware/roles.go b/internal/api/v1/middleware/roles.go
deleted file mode 100644
index 670bd55..0000000
--- a/internal/api/v1/middleware/roles.go
+++ /dev/null
@@ -1,19 +0,0 @@
-package middleware
-
-import (
- "net/http"
-
- "github.com/gin-gonic/gin"
-)
-
-func RequireRole(requiredRole string) gin.HandlerFunc {
- return func(c *gin.Context) {
- role := c.GetString("role") // Set during JWT parsing
- if role != requiredRole {
- c.JSON(http.StatusForbidden, gin.H{"error": "Insufficient permissions"})
- c.Abort()
- return
- }
- c.Next()
- }
-}
diff --git a/internal/api/v1/paths.go b/internal/api/v1/paths.go
index a18e452..e92c29c 100644
--- a/internal/api/v1/paths.go
+++ b/internal/api/v1/paths.go
@@ -6,4 +6,10 @@ const (
pathsUser = pathsPrefix + "/user"
pathsUserRegister = pathsUser + "/register"
pathsUserLogin = pathsUser + "/login"
+
+ pathsTests = pathsPrefix + "/tests"
+ pathsTestDetails = pathsTests + "/{scope}"
+ pathsTestInstances = pathsTestDetails + "/instances"
+ pathsTestInstanceDetails = pathsTestInstances + "/{instance_id}"
+ pathsTestInstanceExecute = pathsTestInstanceDetails + "/execute"
)
diff --git a/internal/api/v1/services/auth_service.go b/internal/api/v1/services/auth_service.go
deleted file mode 100644
index 5e568ea..0000000
--- a/internal/api/v1/services/auth_service.go
+++ /dev/null
@@ -1 +0,0 @@
-package services
diff --git a/internal/api/v1/services/errors.go b/internal/api/v1/services/errors.go
new file mode 100644
index 0000000..fb52022
--- /dev/null
+++ b/internal/api/v1/services/errors.go
@@ -0,0 +1,16 @@
+package services
+
+import "github.com/celestiaorg/knuu/pkg/errors"
+
+type Error = errors.Error
+
+var (
+ ErrUsernameAlreadyTaken = errors.New("UsernameAlreadyTaken", "username already taken")
+ ErrUserNotFound = errors.New("UserNotFound", "user not found")
+ ErrCreatingAdminUser = errors.New("CreatingAdminUser", "error creating admin user")
+ ErrUserIDRequired = errors.New("UserIDRequired", "user ID is required")
+ ErrTestAlreadyExists = errors.New("TestAlreadyExists", "test already exists")
+ ErrTestNotFound = errors.New("TestNotFound", "test not found")
+ ErrInvalidCredentials = errors.New("InvalidCredentials", "invalid credentials")
+ ErrScopeRequired = errors.New("ScopeRequired", "scope is required")
+)
diff --git a/internal/api/v1/services/instance.go b/internal/api/v1/services/instance.go
new file mode 100644
index 0000000..0e6cd80
--- /dev/null
+++ b/internal/api/v1/services/instance.go
@@ -0,0 +1,105 @@
+package services
+
+import (
+ "context"
+ "fmt"
+
+ "github.com/celestiaorg/knuu/pkg/builder"
+)
+
+type Instance struct {
+ Name string `json:"name" binding:"required"`
+ Scope string `json:"scope" binding:"required"`
+ Image string `json:"image"`
+ GitContext builder.GitContext `json:"git_context"`
+ BuildArgs []string `json:"build_args"`
+ StartCommand []string `json:"start_command"`
+ Args []string `json:"args"`
+ Status string `json:"status"` // Readonly
+ StartNow bool `json:"start_now"`
+ Env map[string]string `json:"env"`
+ TCPPorts []int `json:"tcp_ports"`
+ UDPPorts []int `json:"udp_ports"`
+ Hostname string `json:"hostname"` // Readonly
+
+ // Volumes []k8s.Volume `json:"volumes"`
+}
+
+func (s *TestService) CreateInstance(ctx context.Context, userID uint, instance *Instance) error {
+ if userID == 0 {
+ return ErrUserIDRequired
+ }
+
+ kn, err := s.Knuu(userID, instance.Scope)
+ if err != nil {
+ return err
+ }
+
+ ins, err := kn.NewInstance(instance.Name)
+ if err != nil {
+ return err
+ }
+
+ buildArgs := []builder.ArgInterface{}
+ for _, arg := range instance.BuildArgs {
+ buildArgs = append(buildArgs, &builder.BuildArg{Value: arg})
+ }
+
+ if instance.Image != "" {
+ if err := ins.Build().SetImage(ctx, instance.Image, buildArgs...); err != nil {
+ return err
+ }
+ }
+
+ if len(instance.StartCommand) > 0 {
+ if err := ins.Build().SetStartCommand(instance.StartCommand...); err != nil {
+ return err
+ }
+ }
+
+ if len(instance.Args) > 0 {
+ if err := ins.Build().SetArgs(instance.Args...); err != nil {
+ return err
+ }
+ }
+
+ for k, v := range instance.Env {
+ if err := ins.Build().SetEnvironmentVariable(k, v); err != nil {
+ return err
+ }
+ }
+
+ if instance.GitContext.Repo != "" {
+ if err := ins.Build().SetGitRepo(ctx, instance.GitContext, buildArgs...); err != nil {
+ return err
+ }
+ }
+
+ for _, port := range instance.TCPPorts {
+ if err := ins.Network().AddPortTCP(port); err != nil {
+ return err
+ }
+ }
+
+ for _, port := range instance.UDPPorts {
+ if err := ins.Network().AddPortUDP(port); err != nil {
+ return err
+ }
+ }
+
+ if instance.StartNow {
+ return ins.Execution().StartAsync(ctx)
+ }
+ return nil
+}
+
+func (s *TestService) GetInstance(ctx context.Context, userID uint, scope, instanceName string) (*Instance, error) {
+ kn, err := s.Knuu(userID, scope)
+ if err != nil {
+ return nil, err
+ }
+
+ _ = kn
+
+ return nil, fmt.Errorf("not implemented")
+}
diff --git a/internal/api/v1/services/resource_service.go b/internal/api/v1/services/resource_service.go
deleted file mode 100644
index 5e568ea..0000000
--- a/internal/api/v1/services/resource_service.go
+++ /dev/null
@@ -1 +0,0 @@
-package services
diff --git a/internal/api/v1/services/test.go b/internal/api/v1/services/test.go
new file mode 100644
index 0000000..b678136
--- /dev/null
+++ b/internal/api/v1/services/test.go
@@ -0,0 +1,161 @@
+package services
+
+import (
+ "context"
+ "fmt"
+ "sync"
+ "time"
+
+ "github.com/celestiaorg/knuu/internal/database/models"
+ "github.com/celestiaorg/knuu/internal/database/repos"
+ "github.com/celestiaorg/knuu/pkg/k8s"
+ "github.com/celestiaorg/knuu/pkg/knuu"
+ "github.com/celestiaorg/knuu/pkg/minio"
+ "github.com/sirupsen/logrus"
+)
+
+type TestService struct {
+ repo *repos.TestRepository
+ knuuList map[uint]map[string]*knuu.Knuu // key is the user ID, second key is the scope
+ knuuListMu sync.RWMutex
+}
+
+func NewTestService(ctx context.Context, repo *repos.TestRepository) (*TestService, error) {
+ s := &TestService{
+ repo: repo,
+ knuuList: make(map[uint]map[string]*knuu.Knuu),
+ }
+
+ if err := s.loadKnuuFromDB(ctx); err != nil {
+ return nil, err
+ }
+
+ return s, nil
+}
+
+func (s *TestService) Create(ctx context.Context, test *models.Test) error {
+ if test.UserID == 0 {
+ return ErrUserIDRequired
+ }
+
+ if err := s.prepareKnuu(ctx, test); err != nil {
+ return err
+ }
+
+ return s.repo.Create(ctx, test)
+}
+
+func (s *TestService) Knuu(userID uint, scope string) (*knuu.Knuu, error) {
+ s.knuuListMu.RLock()
+ defer s.knuuListMu.RUnlock()
+
+ kn, ok := s.knuuList[userID][scope]
+ if !ok {
+ return nil, ErrTestNotFound
+ }
+
+ return kn, nil
+}
+
+func (s *TestService) Delete(ctx context.Context, userID uint, scope string) error {
+ s.knuuListMu.Lock()
+ defer s.knuuListMu.Unlock()
+
+ kn, ok := s.knuuList[userID][scope]
+ if !ok {
+ return nil
+ }
+
+ if err := kn.CleanUp(ctx); err != nil {
+ return err
+ }
+
+ delete(s.knuuList[userID], scope)
+ if len(s.knuuList[userID]) == 0 {
+ delete(s.knuuList, userID)
+ }
+
+ return s.repo.Delete(ctx, scope)
+}
+
+func (s *TestService) Details(ctx context.Context, userID uint, scope string) (*models.Test, error) {
+ return s.repo.Get(ctx, userID, scope)
+}
+
+func (s *TestService) List(ctx context.Context, userID uint, limit int, offset int) ([]models.Test, error) {
+ return s.repo.List(ctx, userID, limit, offset)
+}
+
+func (s *TestService) Count(ctx context.Context, userID uint) (int64, error) {
+ return s.repo.Count(ctx, userID)
+}
+
+func (s *TestService) Update(test *models.Test) error {
+ // Update the knuu object if needed e.g. deadline(timeout),...
+ // return s.repo.Update(test)
+ return fmt.Errorf("not implemented")
+}
+
+func (s *TestService) loadKnuuFromDB(ctx context.Context) error {
+ tests, err := s.repo.ListAllAlive(ctx)
+ if err != nil {
+ return err
+ }
+
+ for _, test := range tests {
+ err := s.prepareKnuu(ctx, &test)
+ if err != nil && err != ErrTestAlreadyExists {
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (s *TestService) prepareKnuu(ctx context.Context, test *models.Test) error {
+ s.knuuListMu.Lock()
+ defer s.knuuListMu.Unlock()
+
+ if _, ok := s.knuuList[test.UserID]; !ok {
+ s.knuuList[test.UserID] = make(map[string]*knuu.Knuu)
+ }
+
+ if test.Scope == "" {
+ return ErrScopeRequired
+ }
+
+ _, ok := s.knuuList[test.UserID][test.Scope]
+ if ok {
+ return ErrTestAlreadyExists
+ }
+
+ var (
+ logger = logrus.New()
+ minioClient *minio.Minio
+ )
+
+ k8sClient, err := k8s.NewClient(ctx, test.Scope, logger)
+ if err != nil {
+ return err
+ }
+
+ if test.MinioEnabled {
+ minioClient, err = minio.New(ctx, k8sClient, logger)
+ if err != nil {
+ return err
+ }
+ }
+
+ kn, err := knuu.New(ctx, knuu.Options{
+ ProxyEnabled: test.ProxyEnabled,
+ K8sClient: k8sClient,
+ MinioClient: minioClient,
+ Timeout: time.Until(test.Deadline), // TODO: replace it with deadline when the deadline PR is merged
+ })
+ if err != nil {
+ return err
+ }
+ s.knuuList[test.UserID][test.Scope] = kn
+
+ return nil
+}
diff --git a/internal/api/v1/services/user.go b/internal/api/v1/services/user.go
index ce2fbd1..d6eefe6 100644
--- a/internal/api/v1/services/user.go
+++ b/internal/api/v1/services/user.go
@@ -1,76 +1,76 @@
package services
import (
- "errors"
- "time"
+ "context"
+ "fmt"
"github.com/celestiaorg/knuu/internal/database/models"
"github.com/celestiaorg/knuu/internal/database/repos"
- "github.com/golang-jwt/jwt/v4"
"golang.org/x/crypto/bcrypt"
)
-const (
- UserTokenDuration = 1 * time.Hour
-)
-
type UserService interface {
- Register(user *models.User) (*models.User, error)
- Authenticate(username, password string) (string, error)
+ Register(ctx context.Context, user *models.User) (*models.User, error)
+ Authenticate(ctx context.Context, username, password string) (*models.User, error)
}
type userServiceImpl struct {
- secretKey string
- userRepo repos.UserRepository
+ repo repos.UserRepository
}
var _ UserService = &userServiceImpl{}
-// TODO: need to add the admin user for the first time
-func NewUserService(secretKey string, userRepo repos.UserRepository) UserService {
- return &userServiceImpl{
- secretKey: secretKey,
- userRepo: userRepo,
+// This function is used to create the admin user and the user service.
+// It is called when the API is initialized.
+func NewUserService(ctx context.Context, adminUser, adminPass string, userRepo repos.UserRepository) (UserService, error) {
+ us := &userServiceImpl{
+ repo: userRepo,
+ }
+
+ _, err := us.Register(ctx,
+ &models.User{
+ Username: adminUser,
+ Password: adminPass,
+ Role: models.RoleAdmin,
+ })
+ if err != nil && err != ErrUsernameAlreadyTaken {
+ return nil, ErrCreatingAdminUser.Wrap(err)
}
+
+ return us, nil
}
-func (s *userServiceImpl) Register(user *models.User) (*models.User, error) {
- if _, err := s.userRepo.FindUserByUsername(user.Username); err == nil {
- return nil, errors.New("username already taken")
+func (s *userServiceImpl) Register(ctx context.Context, user *models.User) (*models.User, error) {
+ if _, err := s.repo.FindUserByUsername(ctx, user.Username); err == nil {
+ return nil, ErrUsernameAlreadyTaken
}
+ fmt.Printf("user: %#v\n", user)
+
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(user.Password), bcrypt.DefaultCost)
if err != nil {
return nil, err
}
user.Password = string(hashedPassword)
- if err := s.userRepo.CreateUser(user); err != nil {
+ if err := s.repo.CreateUser(ctx, user); err != nil {
return nil, err
}
return user, nil
}
-func (s *userServiceImpl) Authenticate(username, password string) (string, error) {
- user, err := s.userRepo.FindUserByUsername(username)
+func (s *userServiceImpl) Authenticate(ctx context.Context, username, password string) (*models.User, error) {
+ user, err := s.repo.FindUserByUsername(ctx, username)
if err != nil {
- return "", err
+ return nil, err
}
+
+ fmt.Printf("user.Password: `%s`\n", user.Password)
+ fmt.Printf("password: `%s`\n", password)
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(password)); err != nil {
- return "", errors.New("invalid credentials")
+ return nil, ErrInvalidCredentials.Wrap(err)
}
- // Generate JWT token
- token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
- "user_id": user.ID,
- "username": user.Username,
- "role": user.Role,
- "exp": time.Now().Add(UserTokenDuration).Unix(),
- })
- tokenString, err := token.SignedString([]byte(s.secretKey))
- if err != nil {
- return "", err
- }
- return tokenString, nil
+ return user, nil
}
diff --git a/internal/database/db.go b/internal/database/db.go
index fdb3e06..90450e1 100644
--- a/internal/database/db.go
+++ b/internal/database/db.go
@@ -9,25 +9,31 @@ import (
)
const (
- DefaultHost = "localhost"
- DefaultUser = "postgres"
- DefaultPassword = "postgres"
- DefaultDBName = "postgres"
- DefaultPort = 5432
+ DefaultHost = "localhost"
+ DefaultUser = "postgres"
+ DefaultPassword = "postgres"
+ DefaultDBName = "postgres"
+ DefaultPort = 5432
+ DefaultSSLEnabled = false
)
type Options struct {
- Host string
- User string
- Password string
- DBName string
- Port int
+ Host string
+ User string
+ Password string
+ DBName string
+ Port int
+ SSLEnabled *bool
}
func New(opts Options) (*gorm.DB, error) {
opts = setDefaults(opts)
- dsn := fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%d sslmode=disable",
- opts.Host, opts.User, opts.Password, opts.DBName, opts.Port)
+ sslMode := "disable"
+ if opts.SSLEnabled != nil && *opts.SSLEnabled {
+ sslMode = "enable"
+ }
+ dsn := fmt.Sprintf("host=%s user=%s password=%s dbname=%s port=%d sslmode=%s",
+ opts.Host, opts.User, opts.Password, opts.DBName, opts.Port, sslMode)
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
if err != nil {
return nil, err
@@ -54,6 +60,10 @@ func setDefaults(opts Options) Options {
if opts.Port == 0 {
opts.Port = DefaultPort
}
+ if opts.SSLEnabled == nil {
+ sslMode := DefaultSSLEnabled
+ opts.SSLEnabled = &sslMode
+ }
return opts
}
@@ -62,5 +72,6 @@ func migrate(db *gorm.DB) error {
&models.User{},
&models.Token{},
&models.Permission{},
+ &models.Test{},
)
}
diff --git a/internal/database/models/test.go b/internal/database/models/test.go
index 2640e7f..a7fe1e7 100644
--- a/internal/database/models/test.go
+++ b/internal/database/models/test.go
@@ -1 +1,22 @@
package models
+
+import (
+ "time"
+)
+
+const (
+ TestFinishedField = "finished"
+ TestCreatedAtField = "created_at"
+)
+
+type Test struct {
+ Scope string `json:"scope" gorm:"primaryKey"`
+ UserID uint `json:"-" gorm:"index"` // the owner of the test
+ Title string `json:"title" gorm:""`
+ MinioEnabled bool `json:"minio_enabled" gorm:""`
+ ProxyEnabled bool `json:"proxy_enabled" gorm:""`
+ Deadline time.Time `json:"deadline" gorm:"index"`
+ CreatedAt time.Time `json:"created_at" gorm:"index"`
+ UpdatedAt time.Time `json:"updated_at"`
+ Finished bool `json:"finished" gorm:"index"`
+}
diff --git a/internal/database/models/user.go b/internal/database/models/user.go
index 4b467d3..27e304b 100644
--- a/internal/database/models/user.go
+++ b/internal/database/models/user.go
@@ -5,14 +5,14 @@ import "time"
type UserRole int
const (
- RoleAdmin UserRole = iota + 1
- RoleUser
+ RoleUser UserRole = iota
+ RoleAdmin
)
type User struct {
ID uint `json:"-" gorm:"primaryKey"`
Username string `json:"username" gorm:"unique;not null"`
- Password string `json:"-" gorm:"not null"`
+ Password string `json:"password" gorm:"not null"`
Role UserRole `json:"role" gorm:"not null"`
CreatedAt time.Time `json:"created_at"`
}
diff --git a/internal/database/repos/test.go b/internal/database/repos/test.go
index 2b0d3a9..0ce02fe 100644
--- a/internal/database/repos/test.go
+++ b/internal/database/repos/test.go
@@ -1 +1,59 @@
package repos
+
+import (
+ "context"
+
+ "github.com/celestiaorg/knuu/internal/database/models"
+ "gorm.io/gorm"
+)
+
+type TestRepository struct {
+ db *gorm.DB
+}
+
+func NewTestRepository(db *gorm.DB) *TestRepository {
+ return &TestRepository{
+ db: db,
+ }
+}
+
+func (r *TestRepository) Create(ctx context.Context, test *models.Test) error {
+ return r.db.WithContext(ctx).Create(test).Error
+}
+
+func (r *TestRepository) Get(ctx context.Context, userID uint, scope string) (*models.Test, error) {
+ var test models.Test
+ err := r.db.WithContext(ctx).Where(&models.Test{UserID: userID, Scope: scope}).First(&test).Error
+ return &test, err
+}
+
+func (r *TestRepository) Delete(ctx context.Context, scope string) error {
+ return r.db.WithContext(ctx).Delete(&models.Test{Scope: scope}).Error
+}
+
+func (r *TestRepository) Update(ctx context.Context, test *models.Test) error {
+ return r.db.WithContext(ctx).Model(&models.Test{}).Where(&models.Test{Scope: test.Scope}).Updates(test).Error
+}
+
+func (r *TestRepository) List(ctx context.Context, userID uint, limit int, offset int) ([]models.Test, error) {
+ var tests []models.Test
+ err := r.db.WithContext(ctx).
+ Where(&models.Test{UserID: userID}).
+ Limit(limit).Offset(offset).
+ Order(models.TestFinishedField + " ASC").
+ Order(models.TestCreatedAtField + " DESC").
+ Find(&tests).Error
+ return tests, err
+}
+
+func (r *TestRepository) Count(ctx context.Context, userID uint) (int64, error) {
+ var count int64
+ err := r.db.WithContext(ctx).Model(&models.Test{}).Where(&models.Test{UserID: userID}).Count(&count).Error
+ return count, err
+}
+
+func (r *TestRepository) ListAllAlive(ctx context.Context) ([]models.Test, error) {
+ var tests []models.Test
+ err := r.db.WithContext(ctx).Where(&models.Test{Finished: false}).Find(&tests).Error
+ return tests, err
+}
diff --git a/internal/database/repos/user.go b/internal/database/repos/user.go
index cc4164c..30514c8 100644
--- a/internal/database/repos/user.go
+++ b/internal/database/repos/user.go
@@ -1,17 +1,19 @@
package repos
import (
+ "context"
+
"github.com/celestiaorg/knuu/internal/database/models"
"gorm.io/gorm"
)
type UserRepository interface {
- CreateUser(user *models.User) error
- FindUserByUsername(username string) (*models.User, error)
- FindUserByID(id uint) (*models.User, error)
- UpdatePassword(id uint, password string) error
- DeleteUserById(id uint) error
+ CreateUser(ctx context.Context, user *models.User) error
+ FindUserByUsername(ctx context.Context, username string) (*models.User, error)
+ FindUserByID(ctx context.Context, id uint) (*models.User, error)
+ UpdatePassword(ctx context.Context, id uint, password string) error
+ DeleteUserById(ctx context.Context, id uint) error
}
type userRepositoryImpl struct {
@@ -22,30 +24,30 @@ func NewUserRepository(db *gorm.DB) UserRepository {
return &userRepositoryImpl{db: db}
}
-func (r *userRepositoryImpl) CreateUser(user *models.User) error {
- return r.db.Create(user).Error
+func (r *userRepositoryImpl) CreateUser(ctx context.Context, user *models.User) error {
+ return r.db.WithContext(ctx).Create(user).Error
}
-func (r *userRepositoryImpl) FindUserByUsername(username string) (*models.User, error) {
+func (r *userRepositoryImpl) FindUserByUsername(ctx context.Context, username string) (*models.User, error) {
var user models.User
- err := r.db.Where(&models.User{Username: username}).First(&user).Error
+ err := r.db.WithContext(ctx).Where(&models.User{Username: username}).First(&user).Error
return &user, err
}
-func (r *userRepositoryImpl) FindUserByID(id uint) (*models.User, error) {
+func (r *userRepositoryImpl) FindUserByID(ctx context.Context, id uint) (*models.User, error) {
var user models.User
- err := r.db.Where(&models.User{ID: id}).First(&user).Error
+ err := r.db.WithContext(ctx).Where(&models.User{ID: id}).First(&user).Error
return &user, err
}
-func (r *userRepositoryImpl) UpdatePassword(id uint, password string) error {
+func (r *userRepositoryImpl) UpdatePassword(ctx context.Context, id uint, password string) error {
updatedUser := &models.User{
Password: password,
}
- return r.db.Model(&models.User{}).
+ return r.db.WithContext(ctx).Model(&models.User{}).
Where(&models.User{ID: id}).Updates(updatedUser).Error
}
-func (r *userRepositoryImpl) DeleteUserById(id uint) error {
- return r.db.Delete(&models.User{ID: id}).Error
+func (r *userRepositoryImpl) DeleteUserById(ctx context.Context, id uint) error {
+ return r.db.WithContext(ctx).Delete(&models.User{ID: id}).Error
}
diff --git a/pkg/builder/git.go b/pkg/builder/git.go
index a6418e7..3b313dd 100644
--- a/pkg/builder/git.go
+++ b/pkg/builder/git.go
@@ -12,11 +12,11 @@ const (
)
type GitContext struct {
- Repo string
- Branch string
- Commit string
- Username string
- Password string
+ Repo string `json:"repo"`
+ Branch string `json:"branch"`
+ Commit string `json:"commit"`
+ Username string `json:"username"`
+ Password string `json:"password"`
}
// This build context follows Kaniko build context pattern
diff --git a/pkg/k8s/pod.go b/pkg/k8s/pod.go
index 035923a..d6f61ce 100644
--- a/pkg/k8s/pod.go
+++ b/pkg/k8s/pod.go
@@ -66,9 +66,9 @@ type PodConfig struct {
}
type Volume struct {
- Path string
- Size resource.Quantity
- Owner int64
+ Path string `json:"path"`
+ Size resource.Quantity `json:"size"`
+ Owner int64 `json:"owner"`
}
type File struct {