Skip to content

Commit

Permalink
Merge pull request #1277 from traPtitech/fix/openapi-middleware
Browse files Browse the repository at this point in the history
feat: implement middleware
  • Loading branch information
Eraxyso authored Oct 31, 2024
2 parents 5680dba + b9a1da8 commit 4ec607e
Show file tree
Hide file tree
Showing 3 changed files with 316 additions and 0 deletions.
222 changes: 222 additions & 0 deletions handler/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,28 @@ package handler
import (
"errors"
"fmt"
"net/http"
"strconv"

"github.com/go-playground/validator/v10"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
"github.com/traPtitech/anke-to/model"
)

// Middleware Middlewareの構造体
type Middleware struct {
model.IAdministrator
model.IRespondent
model.IQuestion
model.IQuestionnaire
}

// NewMiddleware Middlewareのコンストラクタ
func NewMiddleware() *Middleware {
return &Middleware{}
}

const (
validatorKey = "validator"
userIDKey = "userID"
Expand All @@ -16,6 +33,13 @@ const (
questionIDKey = "questionID"
)

/*
消せないアンケートの発生を防ぐための管理者
暫定的にハードコーディングで対応
*/
var adminUserIDs = []string{"ryoha", "xxarupakaxx", "kaitoyama", "cp20", "itzmeowww"}

// SetUserIDMiddleware X-Showcase-UserからユーザーIDを取得しセットする
func SetUserIDMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
Expand All @@ -30,6 +54,204 @@ func SetUserIDMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
}
}

// TraPMemberAuthenticate traP部員かの認証
func TraPMemberAuthenticate(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
userID, err := getUserID(c)
if err != nil {
c.Logger().Errorf("failed to get userID: %+v", err)
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get userID: %w", err))
}

// トークンを持たないユーザはアクセスできない
if userID == "-" {
c.Logger().Info("not logged in")
return echo.NewHTTPError(http.StatusUnauthorized, "You are not logged in")
}

return next(c)
}
}

// TrapRateLimitMiddlewareFunc traP IDベースのリクエスト制限
func TrapRateLimitMiddlewareFunc() echo.MiddlewareFunc {
config := middleware.RateLimiterConfig{
Store: middleware.NewRateLimiterMemoryStore(5),
IdentifierExtractor: func(c echo.Context) (string, error) {
userID, err := getUserID(c)
if err != nil {
c.Logger().Errorf("failed to get userID: %+v", err)
return "", echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get userID: %w", err))
}

return userID, nil
},
}

return middleware.RateLimiterWithConfig(config)
}

// QuestionnaireAdministratorAuthenticate アンケートの管理者かどうかの認証
func QuestionnaireAdministratorAuthenticate(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
m := NewMiddleware()

userID, err := getUserID(c)
if err != nil {
c.Logger().Errorf("failed to get userID: %+v", err)
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get userID: %w", err))
}

strQuestionnaireID := c.Param("questionnaireID")
questionnaireID, err := strconv.Atoi(strQuestionnaireID)
if err != nil {
c.Logger().Infof("failed to convert questionnaireID to int: %+v", err)
return echo.NewHTTPError(http.StatusBadRequest, fmt.Errorf("invalid questionnaireID:%s(error: %w)", strQuestionnaireID, err))
}

for _, adminID := range adminUserIDs {
if userID == adminID {
c.Set(questionnaireIDKey, questionnaireID)

return next(c)
}
}
isAdmin, err := m.CheckQuestionnaireAdmin(c.Request().Context(), userID, questionnaireID)
if err != nil {
c.Logger().Errorf("failed to check questionnaire admin: %+v", err)
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to check if you are administrator: %w", err))
}
if !isAdmin {
return c.String(http.StatusForbidden, "You are not a administrator of this questionnaire.")
}

c.Set(questionnaireIDKey, questionnaireID)

return next(c)
}
}

// ResponseReadAuthenticate 回答閲覧権限があるかの認証
func ResponseReadAuthenticate(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
m := NewMiddleware()

userID, err := getUserID(c)
if err != nil {
c.Logger().Errorf("failed to get userID: %+v", err)
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get userID: %w", err))
}

strResponseID := c.Param("responseID")
responseID, err := strconv.Atoi(strResponseID)
if err != nil {
c.Logger().Info("failed to convert responseID to int: %+v", err)
return echo.NewHTTPError(http.StatusBadRequest, fmt.Errorf("invalid responseID:%s(error: %w)", strResponseID, err))
}

// 回答者ならOK
respondent, err := m.GetRespondent(c.Request().Context(), responseID)
if errors.Is(err, model.ErrRecordNotFound) {
c.Logger().Infof("response not found: %+v", err)
return echo.NewHTTPError(http.StatusNotFound, fmt.Errorf("response not found:%d", responseID))
}
if err != nil {
c.Logger().Errorf("failed to check if you are a respondent: %+v", err)
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to check if you are a respondent: %w", err))
}
if respondent == nil {
c.Logger().Error("respondent is nil")
return echo.NewHTTPError(http.StatusInternalServerError)
}
if respondent.UserTraqid == userID {
return next(c)
}

// 回答者以外は一時保存の回答は閲覧できない
if !respondent.SubmittedAt.Valid {
c.Logger().Info("not submitted")

// Note: 一時保存の回答の存在もわかってはいけないので、Respondentが見つからない時と全く同じエラーを返す
return echo.NewHTTPError(http.StatusNotFound, fmt.Errorf("response not found:%d", responseID))
}

// アンケートごとの回答閲覧権限チェック
responseReadPrivilegeInfo, err := m.GetResponseReadPrivilegeInfoByResponseID(c.Request().Context(), userID, responseID)
if errors.Is(err, model.ErrRecordNotFound) {
c.Logger().Infof("response not found: %+v", err)
return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("invalid responseID: %d", responseID))
} else if err != nil {
c.Logger().Errorf("failed to get response read privilege info: %+v", err)
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get response read privilege info: %w", err))
}

haveReadPrivilege, err := checkResponseReadPrivilege(responseReadPrivilegeInfo)
if err != nil {
c.Logger().Errorf("failed to check response read privilege: %+v", err)
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to check response read privilege: %w", err))
}
if !haveReadPrivilege {
return c.String(http.StatusForbidden, "You do not have permission to view this response.")
}

return next(c)
}
}

// RespondentAuthenticate 回答者かどうかの認証
func RespondentAuthenticate(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
m := NewMiddleware()

userID, err := getUserID(c)
if err != nil {
c.Logger().Errorf("failed to get userID: %+v", err)
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to get userID: %w", err))
}

strResponseID := c.Param("responseID")
responseID, err := strconv.Atoi(strResponseID)
if err != nil {
c.Logger().Infof("failed to convert responseID to int: %+v", err)
return echo.NewHTTPError(http.StatusBadRequest, fmt.Errorf("invalid responseID:%s(error: %w)", strResponseID, err))
}

respondent, err := m.GetRespondent(c.Request().Context(), responseID)
if errors.Is(err, model.ErrRecordNotFound) {
c.Logger().Infof("response not found: %+v", err)
return echo.NewHTTPError(http.StatusNotFound, fmt.Errorf("response not found:%d", responseID))
}
if err != nil {
c.Logger().Errorf("failed to check if you are a respondent: %+v", err)
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to check if you are a respondent: %w", err))
}
if respondent == nil {
c.Logger().Error("respondent is nil")
return echo.NewHTTPError(http.StatusInternalServerError)
}
if respondent.UserTraqid != userID {
return c.String(http.StatusForbidden, "You are not a respondent of this response.")
}

c.Set(responseIDKey, responseID)

return next(c)
}
}

func checkResponseReadPrivilege(responseReadPrivilegeInfo *model.ResponseReadPrivilegeInfo) (bool, error) {
switch responseReadPrivilegeInfo.ResSharedTo {
case "administrators":
return responseReadPrivilegeInfo.IsAdministrator, nil
case "respondents":
return responseReadPrivilegeInfo.IsAdministrator || responseReadPrivilegeInfo.IsRespondent, nil
case "public":
return true, nil
}

return false, errors.New("invalid resSharedTo")
}

// getValidator Validatorを設定する
func getValidator(c echo.Context) (*validator.Validate, error) {
rowValidate := c.Get(validatorKey)
Expand Down
14 changes: 14 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,21 @@ func main() {
e.Use(handler.SetUserIDMiddleware)
e.Use(middleware.Logger())
e.Use(middleware.Recover())

mws := NewMiddlewareSwitcher()
mws.AddGroupConfig("", handler.TraPMemberAuthenticate)

mws.AddRouteConfig("/questionnaires", http.MethodGet, handler.TrapRateLimitMiddlewareFunc())
mws.AddRouteConfig("/questionnaires/:questionnaireID", http.MethodPatch, handler.QuestionnaireAdministratorAuthenticate)
mws.AddRouteConfig("/questionnaires/:questionnaireID", http.MethodDelete, handler.QuestionnaireAdministratorAuthenticate)

mws.AddRouteConfig("/responses/:responseID", http.MethodGet, handler.ResponseReadAuthenticate)
mws.AddRouteConfig("/responses/:responseID", http.MethodPatch, handler.RespondentAuthenticate)
mws.AddRouteConfig("/responses/:responseID", http.MethodDelete, handler.RespondentAuthenticate)

openapi.RegisterHandlers(e, handler.Handler{})

e.Use(mws.ApplyMiddlewares)
e.Logger.Fatal(e.Start(port))

// SetRouting(port)
Expand Down
80 changes: 80 additions & 0 deletions middleware.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package main

import (
"strings"

"github.com/labstack/echo/v4"
)

type RouteConfig struct {
path string
method string
middlewares []echo.MiddlewareFunc
isGroup bool
}

type MiddlewareSwitcher struct {
routeConfigs []RouteConfig
}

func NewMiddlewareSwitcher() *MiddlewareSwitcher {
return &MiddlewareSwitcher{
routeConfigs: []RouteConfig{},
}
}

func (m *MiddlewareSwitcher) AddGroupConfig(grouppath string, middlewares ...echo.MiddlewareFunc) {
m.routeConfigs = append(m.routeConfigs, RouteConfig{
path: grouppath,
middlewares: middlewares,
isGroup: true,
})
}

func (m *MiddlewareSwitcher) AddRouteConfig(path string, method string, middlewares ...echo.MiddlewareFunc) {
m.routeConfigs = append(m.routeConfigs, RouteConfig{
path: path,
method: method,
middlewares: middlewares,
isGroup: false,
})
}

func (m *MiddlewareSwitcher) IsWithinGroup(groupPath string, path string) bool {
if !strings.HasPrefix(path, groupPath) {
return false
}
return len(groupPath) == len(path) || path[len(groupPath)] == '/'
}

func (m *MiddlewareSwitcher) FindMiddlewares(path string, method string) []echo.MiddlewareFunc {
var matchedMiddlewares []echo.MiddlewareFunc

for _, config := range m.routeConfigs {
if config.isGroup && m.IsWithinGroup(config.path, path) {
matchedMiddlewares = append(matchedMiddlewares, config.middlewares...)
}
if !config.isGroup && config.path == path && config.method == method {
matchedMiddlewares = append(matchedMiddlewares, config.middlewares...)
}
}

return matchedMiddlewares
}

func (m *MiddlewareSwitcher) ApplyMiddlewares(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
path := c.Path()
method := c.Request().Method

middlewares := m.FindMiddlewares(path, method)

for _, mw := range middlewares {
if err := mw(next)(c); err != nil {
return err
}
}

return next(c)
}
}

0 comments on commit 4ec607e

Please sign in to comment.