diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index 41d8cd5f4..969bae907 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -1,6 +1,9 @@ name: CI -on: [push] +on: + push: + branches: + - '**' jobs: mod: diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index aa846413b..f5bf17450 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -36,6 +36,8 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 + with: + fetch-depth: 100 - uses: actions/setup-go@v1 with: go-version: 1.14 diff --git a/Makefile b/Makefile index 07b86b40f..b595b81f5 100644 --- a/Makefile +++ b/Makefile @@ -64,7 +64,7 @@ goreleaser-snapshot: .PHONY: update-frontend update-frontend: @mkdir -p ./dev/frontend - @curl -L -Ss https://github.com/traPtitech/traQ_R-UI/releases/latest/download/dist.tar.gz | tar zxv -C ./dev/frontend/ --strip-components=2 + @curl -L -Ss https://github.com/traPtitech/traQ_S-UI/releases/latest/download/dist.tar.gz | tar zxv -C ./dev/frontend/ --strip-components=2 .PHONY: reset-frontend reset-frontend: diff --git a/bot/handlers.go b/bot/handlers.go index 524af8fa0..8660708a2 100644 --- a/bot/handlers.go +++ b/bot/handlers.go @@ -133,11 +133,6 @@ func botJoinedAndLeftHandler(p *Processor, ev string, fields hub.Fields) { p.logger.Error("failed to GetChannel", zap.Error(err), zap.Stringer("id", channelID)) return } - path, err := p.repo.GetChannelPath(channelID) - if err != nil { - p.logger.Error("failed to GetChannelPath", zap.Error(err), zap.Stringer("id", channelID)) - return - } user, err := p.repo.GetUser(ch.CreatorID, false) if err != nil && err != repository.ErrNotFound { p.logger.Error("failed to GetUser", zap.Error(err), zap.Stringer("id", ch.CreatorID)) @@ -146,7 +141,7 @@ func botJoinedAndLeftHandler(p *Processor, ev string, fields hub.Fields) { payload := joinAndLeftPayload{ basePayload: makeBasePayload(), - Channel: makeChannelPayload(ch, path, user), + Channel: makeChannelPayload(ch, p.repo.GetChannelTree().GetChannelPath(channelID), user), } buf, release, err := p.makePayloadJSON(&payload) @@ -193,11 +188,6 @@ func channelCreatedHandler(p *Processor, _ string, fields hub.Fields) { return } - path, err := p.repo.GetChannelPath(ch.ID) - if err != nil { - p.logger.Error("failed to GetChannelPath", zap.Error(err), zap.Stringer("id", ch.ID)) - return - } user, err := p.repo.GetUser(ch.CreatorID, false) if err != nil { p.logger.Error("failed to GetUser", zap.Error(err), zap.Stringer("id", ch.CreatorID)) @@ -206,7 +196,7 @@ func channelCreatedHandler(p *Processor, _ string, fields hub.Fields) { multicast(p, model.BotEventChannelCreated, &channelCreatedPayload{ basePayload: makeBasePayload(), - Channel: makeChannelPayload(ch, path, user), + Channel: makeChannelPayload(ch, p.repo.GetChannelTree().GetChannelPath(ch.ID), user), }, bots) } } @@ -231,12 +221,6 @@ func channelTopicUpdatedHandler(p *Processor, _ string, fields hub.Fields) { return } - path, err := p.repo.GetChannelPath(ch.ID) - if err != nil { - p.logger.Error("failed to GetChannelPath", zap.Error(err), zap.Stringer("id", ch.ID)) - return - } - chCreator, err := p.repo.GetUser(ch.CreatorID, false) if err != nil && err != repository.ErrNotFound { p.logger.Error("failed to GetUser", zap.Error(err), zap.Stringer("id", ch.CreatorID)) @@ -251,7 +235,7 @@ func channelTopicUpdatedHandler(p *Processor, _ string, fields hub.Fields) { multicast(p, model.BotEventChannelTopicChanged, &channelTopicChangedPayload{ basePayload: makeBasePayload(), - Channel: makeChannelPayload(ch, path, chCreator), + Channel: makeChannelPayload(ch, p.repo.GetChannelTree().GetChannelPath(ch.ID), chCreator), Topic: topic, Updater: makeUserPayload(user), }, bots) diff --git a/dev/Caddyfile b/dev/Caddyfile index 6391e69cc..f91d6f9e6 100644 --- a/dev/Caddyfile +++ b/dev/Caddyfile @@ -1,24 +1,19 @@ +{ + admin off +} + :80 -log stdout -errors stderr +log +root * /usr/share/caddy -root /srv -gzip { - level 1 - not /api /img +handle /api/* { + reverse_proxy backend:3000 } -header /service-worker.js Cache-Control "max-age=0" -header /sw.js Cache-Control "max-age=0" - -rewrite { - regexp .* - if {path} not_starts_with /api - to {path} / -} +handle { + file_server + try_files {path} /index.html -proxy /api backend:3000 { - transparent - websocket + header /sw.js Cache-Control "max-age=0" } diff --git a/docker-compose.yml b/docker-compose.yml index 657e54d47..0d172f873 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -19,7 +19,7 @@ services: - mysql frontend: - image: abiosoft/caddy:1.0.3-no-stats + image: caddy:latest restart: always expose: - "80" @@ -28,8 +28,8 @@ services: depends_on: - backend volumes: - - ./dev/Caddyfile:/etc/Caddyfile:ro - - ./dev/frontend:/srv:ro + - ./dev/Caddyfile:/etc/caddy/Caddyfile:ro + - ./dev/frontend:/usr/share/caddy:ro mysql: image: mariadb:10.0.19 diff --git a/model/user_group.go b/model/user_group.go index 6247621a8..1c051567c 100644 --- a/model/user_group.go +++ b/model/user_group.go @@ -41,6 +41,14 @@ func (ug *UserGroup) IsMember(uid uuid.UUID) bool { return false } +func (ug *UserGroup) AdminIDArray() []uuid.UUID { + result := make([]uuid.UUID, len(ug.Admins)) + for i, admin := range ug.Admins { + result[i] = admin.UserID + } + return result +} + // UserGroupMember ユーザーグループメンバー構造体 type UserGroupMember struct { GroupID uuid.UUID `gorm:"type:char(36);not null;primary_key"` diff --git a/notification/handlers.go b/notification/handlers.go index b38df03df..6f057c58d 100644 --- a/notification/handlers.go +++ b/notification/handlers.go @@ -106,11 +106,7 @@ func messageCreatedHandler(ns *Service, ev hub.Message) { fcmPayload.Path = "/users/" + mUser.GetName() fcmPayload.SetBodyWithEllipsis(parsed.OneLine()) } else { - path, err := ns.repo.GetChannelPath(m.ChannelID) - if err != nil { - logger.Error("failed to GetChannelPath", zap.Error(err), zap.Stringer("channelId", m.ChannelID)) - return - } + path := ns.repo.GetChannelTree().GetChannelPath(m.ChannelID) fcmPayload.Title = "#" + path fcmPayload.Path = "/channels/" + path fcmPayload.SetBodyWithEllipsis(mUser.GetResponseDisplayName() + ": " + parsed.OneLine()) diff --git a/repository/channel.go b/repository/channel.go index cc919fb83..a666b7326 100644 --- a/repository/channel.go +++ b/repository/channel.go @@ -134,12 +134,6 @@ type ChannelRepository interface { // 存在しないチャンネルを指定した場合は空配列とnilを返します。 // DBによるエラーを返すことがあります。 GetChildrenChannelIDs(channelID uuid.UUID) ([]uuid.UUID, error) - // GetChannelPath 指定したチャンネルのパス文字列を取得する - // - // 成功した場合、パス文字列とnilを返します。 - // 存在しないチャンネルを指定した場合、ErrNotFoundを返します。 - // DBによるエラーを返すことがあります。 - GetChannelPath(id uuid.UUID) (string, error) // GetPrivateChannelMemberIDs 指定したプライベートチャンネルのメンバーのUUIDを全て取得する // // 成功した場合、UUIDの配列とnilを返します。 @@ -169,6 +163,6 @@ type ChannelRepository interface { // 存在しないチャンネルを指定した場合、ErrNotFoundを返します。 // DBによるエラーを返すことがあります。 GetChannelStats(channelID uuid.UUID) (*ChannelStats, error) - // GetChannelTree チャンネルツリーを取得します + // GetChannelTree 公開チャンネルツリーを取得します GetChannelTree() ChannelTree } diff --git a/repository/channel_impl.go b/repository/channel_impl.go index a1bd67560..15ee22510 100644 --- a/repository/channel_impl.go +++ b/repository/channel_impl.go @@ -418,15 +418,6 @@ func (repo *GormRepository) GetChildrenChannelIDs(channelID uuid.UUID) (children return repo.chTree.GetChildrenIDs(channelID), nil } -// GetChannelPath implements ChannelRepository interface. -func (repo *GormRepository) GetChannelPath(id uuid.UUID) (string, error) { - path := repo.chTree.GetChannelPath(id) - if len(path) > 0 { - return path, nil - } - return "", ErrNotFound -} - // GetPrivateChannelMemberIDs implements ChannelRepository interface. func (repo *GormRepository) GetPrivateChannelMemberIDs(channelID uuid.UUID) (users []uuid.UUID, err error) { users = make([]uuid.UUID, 0) diff --git a/repository/channel_impl_test.go b/repository/channel_impl_test.go index 6b935abc4..c2e0b1126 100644 --- a/repository/channel_impl_test.go +++ b/repository/channel_impl_test.go @@ -112,61 +112,6 @@ func TestRepositoryImpl_GetChannel(t *testing.T) { }) } -func TestRepositoryImpl_GetChannelPath(t *testing.T) { - t.Parallel() - repo, _, _ := setup(t, common) - - ch1 := mustMakeChannelDetail(t, repo, uuid.Nil, random, uuid.Nil) - ch2 := mustMakeChannelDetail(t, repo, uuid.Nil, random, ch1.ID) - ch3 := mustMakeChannelDetail(t, repo, uuid.Nil, random, ch2.ID) - - t.Run("ch1", func(t *testing.T) { - t.Parallel() - assert := assert.New(t) - - path, err := repo.GetChannelPath(ch1.ID) - if assert.NoError(err) { - assert.Equal(ch1.Name, path) - } - }) - - t.Run("ch2", func(t *testing.T) { - t.Parallel() - assert := assert.New(t) - - path, err := repo.GetChannelPath(ch2.ID) - if assert.NoError(err) { - assert.Equal(fmt.Sprintf("%s/%s", ch1.Name, ch2.Name), path) - } - }) - - t.Run("ch3", func(t *testing.T) { - t.Parallel() - assert := assert.New(t) - - path, err := repo.GetChannelPath(ch3.ID) - if assert.NoError(err) { - assert.Equal(fmt.Sprintf("%s/%s/%s", ch1.Name, ch2.Name, ch3.Name), path) - } - }) - - t.Run("NotExists1", func(t *testing.T) { - t.Parallel() - assert := assert.New(t) - - _, err := repo.GetChannelPath(uuid.Nil) - assert.Error(err) - }) - - t.Run("NotExists2", func(t *testing.T) { - t.Parallel() - assert := assert.New(t) - - _, err := repo.GetChannelPath(uuid.Must(uuid.NewV4())) - assert.Error(err) - }) -} - func TestRepositoryImpl_ChangeChannelName(t *testing.T) { t.Parallel() repo, _, _, parent := setupWithChannel(t, common) @@ -234,38 +179,6 @@ func TestRepositoryImpl_ChangeChannelParent(t *testing.T) { }) } -func TestRepositoryImpl_GetChildrenChannelIDs(t *testing.T) { - t.Parallel() - repo, _, _, c1 := setupWithChannel(t, common) - - c2 := mustMakeChannelDetail(t, repo, uuid.Nil, random, c1.ID) - c3 := mustMakeChannelDetail(t, repo, uuid.Nil, random, c2.ID) - c4 := mustMakeChannelDetail(t, repo, uuid.Nil, random, c2.ID) - - cases := []struct { - name string - ch uuid.UUID - expect []uuid.UUID - }{ - {"c1", c1.ID, []uuid.UUID{c2.ID}}, - {"c2", c2.ID, []uuid.UUID{c3.ID, c4.ID}}, - {"c3", c3.ID, []uuid.UUID{}}, - {"c4", c4.ID, []uuid.UUID{}}, - } - - for _, v := range cases { - v := v - t.Run(v.name, func(t *testing.T) { - t.Parallel() - - ids, err := repo.GetChildrenChannelIDs(v.ch) - if assert.NoError(t, err) { - assert.ElementsMatch(t, ids, v.expect) - } - }) - } -} - func TestGormRepository_ChangeChannelSubscription(t *testing.T) { t.Parallel() repo, _, _ := setup(t, common) diff --git a/repository/channel_tree.go b/repository/channel_tree.go index 953727344..248a6995c 100644 --- a/repository/channel_tree.go +++ b/repository/channel_tree.go @@ -11,14 +11,23 @@ import ( "sync" ) +// ChannelTree 公開チャンネルのチャンネル階層木 type ChannelTree interface { + // GetChildrenIDs 子チャンネルのIDの配列を取得する GetChildrenIDs(id uuid.UUID) []uuid.UUID + // GetDescendantIDs 子孫チャンネルのIDの配列を取得する GetDescendantIDs(id uuid.UUID) []uuid.UUID + // GetAscendantIDs 祖先チャンネルのIDの配列を取得する GetAscendantIDs(id uuid.UUID) []uuid.UUID + // GetChannelDepth 指定したチャンネル木の深さを取得する GetChannelDepth(id uuid.UUID) int + // IsChildPresent 指定したnameのチャンネルが指定したチャンネルの子に存在するか IsChildPresent(name string, parent uuid.UUID) bool + // GetChannelPath 指定したチャンネルのパスを取得する GetChannelPath(id uuid.UUID) string + // IsChannelPresent 指定したIDのチャンネルが存在するかどうかを取得する IsChannelPresent(id uuid.UUID) bool + // GetChannelIDFromPath チャンネルパスからチャンネルIDを取得する GetChannelIDFromPath(path string) uuid.UUID json.Marshaler } diff --git a/router/consts/keys.go b/router/consts/keys.go index 91fcc0345..caaf27053 100644 --- a/router/consts/keys.go +++ b/router/consts/keys.go @@ -15,4 +15,5 @@ const ( KeyParamChannel = "paramChannel" KeyParamFile = "paramFile" KeyParamClipFolder = "paramClipFolder" + KeyRepo = "_repo" ) diff --git a/router/extension/context.go b/router/extension/context.go index 4a6035f62..54a089d37 100644 --- a/router/extension/context.go +++ b/router/extension/context.go @@ -5,7 +5,10 @@ import ( "github.com/gofrs/uuid" jsoniter "github.com/json-iterator/go" "github.com/labstack/echo/v4" + "github.com/traPtitech/traQ/repository" + "github.com/traPtitech/traQ/router/consts" "github.com/traPtitech/traQ/router/extension/herror" + "github.com/traPtitech/traQ/router/utils" ) // CtxKey context.Context用のキータイプ @@ -41,9 +44,12 @@ func json(c echo.Context, code int, i interface{}, cfg jsoniter.API) error { } // Wrap カスタムコンテキストラッパー -func Wrap() echo.MiddlewareFunc { +func Wrap(repo repository.Repository) echo.MiddlewareFunc { return func(n echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { return n(&Context{Context: c}) } + return func(c echo.Context) error { + c.Set(consts.KeyRepo, repo) + return n(&Context{Context: c}) + } } } @@ -57,7 +63,7 @@ func BindAndValidate(c echo.Context, i interface{}) error { if err := c.Bind(i); err != nil { return err } - if err := vd.Validate(i); err != nil { + if err := vd.ValidateWithContext(utils.NewRequestValidateContext(c), i); err != nil { if e, ok := err.(vd.InternalError); ok { return herror.InternalServerError(e.InternalError()) } diff --git a/router/middlewares/access_control.go b/router/middlewares/access_control.go index dc42ba354..e761bb4c0 100644 --- a/router/middlewares/access_control.go +++ b/router/middlewares/access_control.go @@ -9,7 +9,6 @@ import ( "github.com/traPtitech/traQ/rbac/role" "github.com/traPtitech/traQ/repository" "github.com/traPtitech/traQ/router/consts" - "github.com/traPtitech/traQ/router/extension" "github.com/traPtitech/traQ/router/extension/herror" "net/http" ) @@ -113,11 +112,16 @@ func CheckWebhookAccessPerm(rbac rbac.RBAC, repo repository.Repository) echo.Mid func CheckFileAccessPerm(rbac rbac.RBAC, repo repository.Repository) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { + file := c.Get(consts.KeyParamFile).(model.FileMeta) userID := c.Get(consts.KeyUser).(model.UserInfo).GetID() - fileID := extension.GetRequestParamAsUUID(c, consts.ParamFileID) + + if t := file.GetFileType(); t == model.FileTypeIcon || t == model.FileTypeStamp { + // スタンプ・アイコン画像の場合はスキップ + return next(c) + } // アクセス権確認 - if ok, err := repo.IsFileAccessible(fileID, userID); err != nil { + if ok, err := repo.IsFileAccessible(file.GetID(), userID); err != nil { switch err { case repository.ErrNilID, repository.ErrNotFound: return herror.NotFound() diff --git a/router/middlewares/access_logging.go b/router/middlewares/access_logging.go index fffc1e35c..2c92bc723 100644 --- a/router/middlewares/access_logging.go +++ b/router/middlewares/access_logging.go @@ -6,7 +6,6 @@ import ( "github.com/traPtitech/traQ/router/extension" "go.uber.org/zap" "strconv" - "strings" "time" ) @@ -30,10 +29,6 @@ func AccessLogging(logger *zap.Logger, dev bool) echo.MiddlewareFunc { } return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { - if strings.HasPrefix(c.Path(), "/api/1.0/heartbeat") { - return next(c) - } - start := time.Now() if err := next(c); err != nil { c.Error(err) diff --git a/router/oauth2/oauth2_test.go b/router/oauth2/oauth2_test.go index 9cb7f1b53..a4906d6b2 100644 --- a/router/oauth2/oauth2_test.go +++ b/router/oauth2/oauth2_test.go @@ -83,7 +83,7 @@ func TestMain(m *testing.M) { e.HideBanner = true e.HidePort = true e.HTTPErrorHandler = extension.ErrorHandler(zap.NewNop()) - e.Use(extension.Wrap()) + e.Use(extension.Wrap(repo)) r, err := rbac.New(repo) if err != nil { diff --git a/router/setup.go b/router/setup.go index fe637fb37..bd743d9ed 100644 --- a/router/setup.go +++ b/router/setup.go @@ -31,7 +31,7 @@ func Setup(config *Config) *echo.Echo { if config.Gzipped { e.Use(middlewares.Gzip()) } - e.Use(extension.Wrap()) + e.Use(extension.Wrap(config.Repository)) e.Use(middlewares.RequestCounter()) e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ ExposeHeaders: []string{consts.HeaderVersion, consts.HeaderCacheFile, consts.HeaderFileMetaType, consts.HeaderMore, echo.HeaderXRequestID}, diff --git a/router/utils/validator.go b/router/utils/validator.go new file mode 100644 index 000000000..8df889aac --- /dev/null +++ b/router/utils/validator.go @@ -0,0 +1,135 @@ +package utils + +import ( + "context" + "errors" + vd "github.com/go-ozzo/ozzo-validation/v4" + "github.com/gofrs/uuid" + "github.com/labstack/echo/v4" + "github.com/traPtitech/traQ/model" + "github.com/traPtitech/traQ/repository" + "github.com/traPtitech/traQ/router/consts" +) + +type ctxKey int + +const ( + repoCtxKey ctxKey = iota +) + +func NewRequestValidateContext(c echo.Context) context.Context { + return context.WithValue(context.Background(), repoCtxKey, c.Get(consts.KeyRepo)) +} + +// IsPublicChannelID 公開チャンネルのUUIDである +var IsPublicChannelID = vd.WithContext(func(ctx context.Context, value interface{}) error { + const errMessage = "invalid channel id" + + repo, ok := ctx.Value(repoCtxKey).(repository.Repository) + if !ok { + return vd.NewInternalError(errors.New("this context didn't have repository")) + } + + switch v := value.(type) { + case nil: + return nil + case uuid.UUID: + if !repo.GetChannelTree().IsChannelPresent(v) { + return errors.New(errMessage) + } + case uuid.NullUUID: + if v.Valid && !repo.GetChannelTree().IsChannelPresent(v.UUID) { + return errors.New(errMessage) + } + case string: + if !repo.GetChannelTree().IsChannelPresent(uuid.FromStringOrNil(v)) { + return errors.New(errMessage) + } + case []byte: + if !repo.GetChannelTree().IsChannelPresent(uuid.FromBytesOrNil(v)) { + return errors.New(errMessage) + } + default: + return errors.New(errMessage) + } + return nil +}) + +// IsActiveHumanUserID アカウントが有効な一般ユーザーのUUIDである +var IsActiveHumanUserID = vd.WithContext(func(ctx context.Context, value interface{}) error { + const errMessage = "invalid user id" + + repo, ok := ctx.Value(repoCtxKey).(repository.Repository) + if !ok { + return vd.NewInternalError(errors.New("this context didn't have repository")) + } + + var ( + u model.UserInfo + err error + ) + switch v := value.(type) { + case nil: + return nil + case uuid.UUID: + u, err = repo.GetUser(v, false) + case uuid.NullUUID: + if !v.Valid { + return nil + } + u, err = repo.GetUser(v.UUID, false) + case string: + u, err = repo.GetUser(uuid.FromStringOrNil(v), false) + case []byte: + u, err = repo.GetUser(uuid.FromBytesOrNil(v), false) + default: + return errors.New(errMessage) + } + if err != nil { + switch err { + case repository.ErrNotFound: + return errors.New(errMessage) + default: + return vd.NewInternalError(err) + } + } + + if !u.IsActive() || u.IsBot() { + return errors.New(errMessage) + } + + return nil +}) + +// IsUserID ユーザーのUUIDである +var IsUserID = vd.WithContext(func(ctx context.Context, value interface{}) error { + const errMessage = "invalid user id" + + repo, ok := ctx.Value(repoCtxKey).(repository.Repository) + if !ok { + return vd.NewInternalError(errors.New("this context didn't have repository")) + } + + var err error + switch v := value.(type) { + case nil: + return nil + case uuid.UUID: + ok, err = repo.UserExists(v) + case uuid.NullUUID: + if !v.Valid { + return nil + } + ok, err = repo.UserExists(v.UUID) + case string: + ok, err = repo.UserExists(uuid.FromStringOrNil(v)) + case []byte: + ok, err = repo.UserExists(uuid.FromBytesOrNil(v)) + default: + return errors.New(errMessage) + } + if err != nil { + return vd.NewInternalError(err) + } + return nil +}) diff --git a/router/v1/channels.go b/router/v1/channels.go index 598e34eb3..f6ea21447 100644 --- a/router/v1/channels.go +++ b/router/v1/channels.go @@ -253,8 +253,7 @@ func (h *Handlers) PutTopic(c echo.Context) error { ch := getChannelFromContext(c) if ch.IsArchived() { - path, _ := h.Repo.GetChannelPath(ch.ID) - return herror.BadRequest(fmt.Sprintf("channel #%s has been archived", path)) + return herror.BadRequest(fmt.Sprintf("channel has been archived")) } var req struct { diff --git a/router/v1/messages.go b/router/v1/messages.go index 7c926d532..9dedf1516 100644 --- a/router/v1/messages.go +++ b/router/v1/messages.go @@ -44,8 +44,7 @@ func (h *Handlers) PutMessageByID(c echo.Context) error { return herror.InternalServerError(err) } if ch.IsArchived() { - path, _ := h.Repo.GetChannelPath(ch.ID) - return herror.BadRequest(fmt.Sprintf("channel #%s has been archived", path)) + return herror.BadRequest(fmt.Sprintf("channel has been archived")) } var req PutMessageByIDRequest @@ -102,8 +101,7 @@ func (h *Handlers) DeleteMessageByID(c echo.Context) error { return herror.InternalServerError(err) } if ch.IsArchived() { - path, _ := h.Repo.GetChannelPath(ch.ID) - return herror.BadRequest(fmt.Sprintf("channel #%s has been archived", path)) + return herror.BadRequest(fmt.Sprintf("channel has been archived")) } if err := h.Repo.DeleteMessage(m.ID); err != nil { @@ -143,8 +141,7 @@ func (h *Handlers) PostMessage(c echo.Context) error { ch := getChannelFromContext(c) if ch.IsArchived() { - path, _ := h.Repo.GetChannelPath(ch.ID) - return herror.BadRequest(fmt.Sprintf("channel #%s has been archived", path)) + return herror.BadRequest(fmt.Sprintf("channel has been archived")) } var req PostMessageRequest diff --git a/router/v1/router_repository_test.go b/router/v1/router_repository_test.go index 628afe595..e8ea15639 100644 --- a/router/v1/router_repository_test.go +++ b/router/v1/router_repository_test.go @@ -1209,23 +1209,6 @@ func (repo *TestRepository) IsChannelAccessibleToUser(userID, channelID uuid.UUI return repo.IsUserPrivateChannelMember(channelID, userID) } -func (repo *TestRepository) GetParentChannel(channelID uuid.UUID) (*model.Channel, error) { - repo.ChannelsLock.RLock() - defer repo.ChannelsLock.RUnlock() - ch, ok := repo.Channels[channelID] - if !ok { - return nil, repository.ErrNotFound - } - if ch.ParentID == uuid.Nil { - return nil, nil - } - pCh, ok := repo.Channels[ch.ParentID] - if !ok { - return nil, repository.ErrNotFound - } - return &pCh, nil -} - func (repo *TestRepository) GetChildrenChannelIDs(channelID uuid.UUID) ([]uuid.UUID, error) { result := make([]uuid.UUID, 0) repo.ChannelsLock.RLock() @@ -1238,47 +1221,6 @@ func (repo *TestRepository) GetChildrenChannelIDs(channelID uuid.UUID) ([]uuid.U return result, nil } -func (repo *TestRepository) GetDescendantChannelIDs(channelID uuid.UUID) ([]uuid.UUID, error) { - var descendants []uuid.UUID - children, err := repo.GetChildrenChannelIDs(channelID) - if err != nil { - return nil, err - } - descendants = append(descendants, children...) - for _, v := range children { - sub, err := repo.GetDescendantChannelIDs(v) - if err != nil { - return nil, err - } - descendants = append(descendants, sub...) - } - return descendants, nil -} - -func (repo *TestRepository) GetAscendantChannelIDs(channelID uuid.UUID) ([]uuid.UUID, error) { - var ascendants []uuid.UUID - parent, err := repo.GetParentChannel(channelID) - if err != nil { - if err == repository.ErrNotFound { - return nil, nil - } - return nil, err - } else if parent == nil { - return []uuid.UUID{}, nil - } - ascendants = append(ascendants, parent.ID) - sub, err := repo.GetAscendantChannelIDs(parent.ID) - if err != nil { - return nil, err - } - ascendants = append(ascendants, sub...) - return ascendants, nil -} - -func (repo *TestRepository) GetChannelPath(id uuid.UUID) (string, error) { - panic("implement me") -} - func (repo *TestRepository) getChannelDepthWithoutLock(id uuid.UUID) int { children := make([]uuid.UUID, 0) for cid, ch := range repo.Channels { diff --git a/router/v1/router_test.go b/router/v1/router_test.go index c51a78877..6dcbc27a2 100644 --- a/router/v1/router_test.go +++ b/router/v1/router_test.go @@ -61,13 +61,14 @@ func TestMain(m *testing.M) { s4, } for _, key := range repos { + repo := NewTestRepository() + e := echo.New() e.HideBanner = true e.HidePort = true e.HTTPErrorHandler = extension.ErrorHandler(zap.NewNop()) - e.Use(extension.Wrap()) + e.Use(extension.Wrap(repo)) - repo := NewTestRepository() r, err := rbac.New(repo) if err != nil { panic(err) diff --git a/router/v1/webhooks.go b/router/v1/webhooks.go index c308aa561..2ad3e3fbd 100644 --- a/router/v1/webhooks.go +++ b/router/v1/webhooks.go @@ -192,8 +192,7 @@ func (h *Handlers) PostWebhook(c echo.Context) error { return herror.BadRequest("invalid channel") } if ch.IsArchived() { - path, _ := h.Repo.GetChannelPath(ch.ID) - return herror.BadRequest(fmt.Sprintf("channel #%s has been archived", path)) + return herror.BadRequest(fmt.Sprintf("channel has been archived")) } if c.QueryParam("embed") == "1" { diff --git a/router/v3/bots.go b/router/v3/bots.go index 7ca7ca346..8d20e45c6 100644 --- a/router/v3/bots.go +++ b/router/v3/bots.go @@ -1,6 +1,7 @@ package v3 import ( + "context" vd "github.com/go-ozzo/ozzo-validation/v4" "github.com/go-ozzo/ozzo-validation/v4/is" "github.com/gofrs/uuid" @@ -119,12 +120,12 @@ type PatchBotRequest struct { SubscribeEvents model.BotEvents `json:"subscribeEvents"` } -func (r PatchBotRequest) Validate() error { - return vd.ValidateStruct(&r, +func (r PatchBotRequest) ValidateWithContext(ctx context.Context) error { + return vd.ValidateStructWithContext(ctx, &r, vd.Field(&r.DisplayName, vd.RuneLength(1, 32)), vd.Field(&r.Description, vd.RuneLength(0, 1000)), vd.Field(&r.Endpoint, is.URL, validator.NotInternalURL), - vd.Field(&r.DeveloperID, validator.NotNilUUID), + vd.Field(&r.DeveloperID, validator.NotNilUUID, utils.IsActiveHumanUserID), vd.Field(&r.SubscribeEvents), ) } @@ -293,9 +294,9 @@ type PostBotActionJoinRequest struct { ChannelID uuid.UUID `json:"channelId"` } -func (r PostBotActionJoinRequest) Validate() error { - return vd.ValidateStruct(&r, - vd.Field(&r.ChannelID, vd.Required, validator.NotNilUUID), +func (r PostBotActionJoinRequest) ValidateWithContext(ctx context.Context) error { + return vd.ValidateStructWithContext(ctx, &r, + vd.Field(&r.ChannelID, vd.Required, validator.NotNilUUID, utils.IsPublicChannelID), // 公開チャンネルのみ許可 ) } @@ -308,20 +309,8 @@ func (h *Handlers) LetBotJoinChannel(c echo.Context) error { b := getParamBot(c) - // チャンネル検証 - ch, err := h.Repo.GetChannel(req.ChannelID) - if err != nil { - if err == repository.ErrNotFound { - return herror.BadRequest("invalid channel") - } - return herror.InternalServerError(err) - } - if !ch.IsPublic { - return herror.BadRequest("invalid channel") // 公開チャンネルのみ許可 - } - // 参加 - if err := h.Repo.AddBotToChannel(b.ID, ch.ID); err != nil { + if err := h.Repo.AddBotToChannel(b.ID, req.ChannelID); err != nil { return herror.InternalServerError(err) } diff --git a/router/v3/channels.go b/router/v3/channels.go index 19f62c2ed..23d11ae7e 100644 --- a/router/v3/channels.go +++ b/router/v3/channels.go @@ -78,13 +78,7 @@ func (h *Handlers) CreateChannels(c echo.Context) error { // GetChannel GET /channels/:channelID func (h *Handlers) GetChannel(c echo.Context) error { ch := getParamChannel(c) - - childrenID, err := h.Repo.GetChildrenChannelIDs(ch.ID) - if err != nil { - return herror.InternalServerError(err) - } - - return c.JSON(http.StatusOK, formatChannel(ch, childrenID)) + return c.JSON(http.StatusOK, formatChannel(ch, h.Repo.GetChannelTree().GetChildrenIDs(ch.ID))) } // PatchChannelRequest PATCH /channels/:channelID リクエストボディ @@ -172,8 +166,7 @@ func (h *Handlers) EditChannelTopic(c echo.Context) error { ch := getParamChannel(c) if ch.IsArchived() { - path, _ := h.Repo.GetChannelPath(ch.ID) - return herror.BadRequest(fmt.Sprintf("channel #%s has been archived", path)) + return herror.BadRequest(fmt.Sprintf("channel #%s has been archived", h.Repo.GetChannelTree().GetChannelPath(ch.ID))) } var req PutChannelTopicRequest diff --git a/router/v3/files.go b/router/v3/files.go index eaf0ea9ab..c38a7ace4 100644 --- a/router/v3/files.go +++ b/router/v3/files.go @@ -104,8 +104,7 @@ func (h *Handlers) PostFile(c echo.Context) error { } if ch.IsArchived() { - path, _ := h.Repo.GetChannelPath(ch.ID) - return herror.BadRequest(fmt.Sprintf("channel #%s has been archived", path)) + return herror.BadRequest(fmt.Sprintf("channel #%s has been archived", h.Repo.GetChannelTree().GetChannelPath(ch.ID))) } args := repository.SaveFileArgs{ diff --git a/router/v3/messages.go b/router/v3/messages.go index 203aa285a..d41beb44d 100644 --- a/router/v3/messages.go +++ b/router/v3/messages.go @@ -64,8 +64,7 @@ func (h *Handlers) EditMessage(c echo.Context) error { return herror.InternalServerError(err) } if ch.IsArchived() { - path, _ := h.Repo.GetChannelPath(ch.ID) - return herror.BadRequest(fmt.Sprintf("channel #%s has been archived", path)) + return herror.BadRequest(fmt.Sprintf("channel #%s has been archived", h.Repo.GetChannelTree().GetChannelPath(ch.ID))) } var req PostMessageRequest @@ -142,8 +141,7 @@ func (h *Handlers) DeleteMessage(c echo.Context) error { return herror.InternalServerError(err) } if ch.IsArchived() { - path, _ := h.Repo.GetChannelPath(ch.ID) - return herror.BadRequest(fmt.Sprintf("channel #%s has been archived", path)) + return herror.BadRequest(fmt.Sprintf("channel #%s has been archived", h.Repo.GetChannelTree().GetChannelPath(ch.ID))) } if err := h.Repo.DeleteMessage(m.ID); err != nil { @@ -273,8 +271,7 @@ func (h *Handlers) PostMessage(c echo.Context) error { ch := getParamChannel(c) if ch.IsArchived() { - path, _ := h.Repo.GetChannelPath(ch.ID) - return herror.BadRequest(fmt.Sprintf("channel #%s has been archived", path)) + return herror.BadRequest(fmt.Sprintf("channel #%s has been archived", h.Repo.GetChannelTree().GetChannelPath(ch.ID))) } var req PostMessageRequest diff --git a/router/v3/router_test.go b/router/v3/router_test.go index 6ac070510..ec16de099 100644 --- a/router/v3/router_test.go +++ b/router/v3/router_test.go @@ -82,7 +82,7 @@ func TestMain(m *testing.M) { e.HideBanner = true e.HidePort = true e.HTTPErrorHandler = extension.ErrorHandler(zap.NewNop()) - e.Use(extension.Wrap()) + e.Use(extension.Wrap(repo)) r, err := rbac.New(repo) if err != nil { diff --git a/router/v3/stamp_palettes.go b/router/v3/stamp_palettes.go index 7b1ce0522..c1f3eaac3 100644 --- a/router/v3/stamp_palettes.go +++ b/router/v3/stamp_palettes.go @@ -115,7 +115,7 @@ func (h *Handlers) GetStampPalette(c echo.Context) error { return c.JSON(http.StatusOK, getParamStampPalette(c)) } -// DeleteStampPalette DELETE /stamps/:stampID +// DeleteStampPalette DELETE /stamp-palette/:paletteID func (h *Handlers) DeleteStampPalette(c echo.Context) error { user := getRequestUser(c) stampPalette := getParamStampPalette(c) diff --git a/router/v3/stamps.go b/router/v3/stamps.go index c0d634f8b..1389d72b6 100644 --- a/router/v3/stamps.go +++ b/router/v3/stamps.go @@ -1,6 +1,7 @@ package v3 import ( + "context" vd "github.com/go-ozzo/ozzo-validation/v4" "github.com/gofrs/uuid" "github.com/labstack/echo/v4" @@ -61,15 +62,16 @@ func (h *Handlers) GetStamp(c echo.Context) error { return c.JSON(http.StatusOK, getParamStamp(c)) } -// PatchStampRequest PATCH /users/me リクエストボディ +// PatchStampRequest PATCH /stamps/:stampID リクエストボディ type PatchStampRequest struct { - Name null.String `json:"name"` - CreatorID uuid.UUID `json:"creatorId"` + Name null.String `json:"name"` + CreatorID uuid.NullUUID `json:"creatorId"` } -func (r PatchStampRequest) Validate() error { - return vd.ValidateStruct(&r, +func (r PatchStampRequest) ValidateWithContext(ctx context.Context) error { + return vd.ValidateStructWithContext(ctx, &r, vd.Field(&r.Name, validator.StampNameRule...), + vd.Field(&r.CreatorID, validator.NotNilUUID, utils.IsActiveHumanUserID), ) } @@ -89,20 +91,8 @@ func (h *Handlers) EditStamp(c echo.Context) error { } args := repository.UpdateStampArgs{ - Name: req.Name, - } - - // 作成者変更 - if req.CreatorID != uuid.Nil { - ok, err := h.Repo.UserExists(req.CreatorID) - if err != nil { - return herror.InternalServerError(err) - } - if !ok { - return herror.BadRequest("invalid creatorId") - } - - args.CreatorID = uuid.NullUUID{Valid: true, UUID: req.CreatorID} + Name: req.Name, + CreatorID: req.CreatorID, } // 更新 diff --git a/router/v3/star.go b/router/v3/star.go index 0e4feb185..9f7cca08a 100644 --- a/router/v3/star.go +++ b/router/v3/star.go @@ -1,11 +1,13 @@ package v3 import ( + "context" vd "github.com/go-ozzo/ozzo-validation/v4" "github.com/gofrs/uuid" "github.com/labstack/echo/v4" "github.com/traPtitech/traQ/router/consts" "github.com/traPtitech/traQ/router/extension/herror" + "github.com/traPtitech/traQ/router/utils" "github.com/traPtitech/traQ/utils/validator" "net/http" ) @@ -27,9 +29,9 @@ type PostStarRequest struct { ChannelID uuid.UUID `json:"channelId"` } -func (r PostStarRequest) Validate() error { - return vd.ValidateStruct(&r, - vd.Field(&r.ChannelID, vd.Required, validator.NotNilUUID), +func (r PostStarRequest) ValidateWithContext(ctx context.Context) error { + return vd.ValidateStructWithContext(ctx, &r, + vd.Field(&r.ChannelID, vd.Required, validator.NotNilUUID, utils.IsPublicChannelID), ) } @@ -40,14 +42,7 @@ func (h *Handlers) PostStar(c echo.Context) error { return err } - userID := getRequestUserID(c) - if ok, err := h.Repo.IsChannelAccessibleToUser(userID, req.ChannelID); err != nil { - return herror.InternalServerError(err) - } else if !ok { - return herror.BadRequest("bad channelID") - } - - if err := h.Repo.AddStar(userID, req.ChannelID); err != nil { + if err := h.Repo.AddStar(getRequestUserID(c), req.ChannelID); err != nil { return herror.InternalServerError(err) } diff --git a/router/v3/user_groups.go b/router/v3/user_groups.go index 7e886f74e..8af0e6d1e 100644 --- a/router/v3/user_groups.go +++ b/router/v3/user_groups.go @@ -1,6 +1,7 @@ package v3 import ( + "context" vd "github.com/go-ozzo/ozzo-validation/v4" "github.com/gofrs/uuid" "github.com/labstack/echo/v4" @@ -8,6 +9,7 @@ import ( "github.com/traPtitech/traQ/repository" "github.com/traPtitech/traQ/router/consts" "github.com/traPtitech/traQ/router/extension/herror" + "github.com/traPtitech/traQ/router/utils" "github.com/traPtitech/traQ/utils/validator" "gopkg.in/guregu/null.v3" "net/http" @@ -141,9 +143,9 @@ type PostUserGroupMemberRequest struct { Role string `json:"role"` } -func (r PostUserGroupMemberRequest) Validate() error { - return vd.ValidateStruct(&r, - vd.Field(&r.ID, vd.Required, validator.NotNilUUID), +func (r PostUserGroupMemberRequest) ValidateWithContext(ctx context.Context) error { + return vd.ValidateStructWithContext(ctx, &r, + vd.Field(&r.ID, vd.Required, validator.NotNilUUID, utils.IsUserID), vd.Field(&r.Role, vd.RuneLength(0, 100)), ) } @@ -157,13 +159,6 @@ func (h *Handlers) AddUserGroupMember(c echo.Context) error { return err } - // ユーザーが存在するか - if ok, err := h.Repo.UserExists(req.ID); err != nil { - return herror.InternalServerError(err) - } else if !ok { - return herror.BadRequest("this user doesn't exist") - } - if err := h.Repo.AddUserToGroup(req.ID, g.ID, req.Role); err != nil { return herror.InternalServerError(err) } @@ -218,12 +213,7 @@ func (h *Handlers) RemoveUserGroupMember(c echo.Context) error { // GetUserGroupAdmins GET /groups/:groupID/admins func (h *Handlers) GetUserGroupAdmins(c echo.Context) error { - g := getParamGroup(c) - result := make([]uuid.UUID, 0) - for _, admin := range g.Admins { - result = append(result, admin.UserID) - } - return c.JSON(http.StatusOK, result) + return c.JSON(http.StatusOK, getParamGroup(c).AdminIDArray()) } // PostUserGroupAdminRequest POST /groups/:groupID/admins リクエストボディ @@ -231,9 +221,9 @@ type PostUserGroupAdminRequest struct { ID uuid.UUID `json:"id"` } -func (r PostUserGroupAdminRequest) Validate() error { - return vd.ValidateStruct(&r, - vd.Field(&r.ID, vd.Required, validator.NotNilUUID), +func (r PostUserGroupAdminRequest) ValidateWithContext(ctx context.Context) error { + return vd.ValidateStructWithContext(ctx, &r, + vd.Field(&r.ID, vd.Required, validator.NotNilUUID, utils.IsUserID), ) } @@ -246,13 +236,6 @@ func (h *Handlers) AddUserGroupAdmin(c echo.Context) error { return err } - // ユーザーが存在するか - if ok, err := h.Repo.UserExists(req.ID); err != nil { - return herror.InternalServerError(err) - } else if !ok { - return herror.BadRequest("this user doesn't exist") - } - if err := h.Repo.AddUserToGroupAdmin(req.ID, g.ID); err != nil { return herror.InternalServerError(err) } diff --git a/router/v3/users.go b/router/v3/users.go index 4d1cf55a1..b7aaf158f 100644 --- a/router/v3/users.go +++ b/router/v3/users.go @@ -1,6 +1,7 @@ package v3 import ( + "context" "github.com/dgrijalva/jwt-go" vd "github.com/go-ozzo/ozzo-validation/v4" "github.com/gofrs/uuid" @@ -83,8 +84,8 @@ type PatchMeRequest struct { HomeChannel uuid.NullUUID `json:"homeChannel"` } -func (r PatchMeRequest) Validate() error { - return vd.ValidateStruct(&r, +func (r PatchMeRequest) ValidateWithContext(ctx context.Context) error { + return vd.ValidateStructWithContext(ctx, &r, vd.Field(&r.DisplayName, vd.RuneLength(0, 64)), vd.Field(&r.TwitterID, validator.TwitterIDRule...), vd.Field(&r.Bio, vd.RuneLength(0, 1000)), @@ -102,10 +103,8 @@ func (h *Handlers) EditMe(c echo.Context) error { if req.HomeChannel.Valid { if req.HomeChannel.UUID != uuid.Nil { - // チャンネルアクセス権確認 - if ok, err := h.Repo.IsChannelAccessibleToUser(userID, req.HomeChannel.UUID); err != nil { - return herror.InternalServerError(err) - } else if !ok { + // チャンネル存在確認 + if !h.Repo.GetChannelTree().IsChannelPresent(req.HomeChannel.UUID) { return herror.BadRequest("invalid homeChannel") } } diff --git a/router/v3/webhooks.go b/router/v3/webhooks.go index b042904d4..f9f0a8781 100644 --- a/router/v3/webhooks.go +++ b/router/v3/webhooks.go @@ -1,6 +1,7 @@ package v3 import ( + "context" "crypto/subtle" "encoding/hex" "fmt" @@ -15,6 +16,7 @@ import ( "github.com/traPtitech/traQ/router/utils" "github.com/traPtitech/traQ/utils/hmac" "github.com/traPtitech/traQ/utils/message" + "github.com/traPtitech/traQ/utils/validator" "gopkg.in/guregu/null.v3" "io/ioutil" "net/http" @@ -67,10 +69,11 @@ type PostWebhooksRequest struct { Secret string `json:"secret"` } -func (r PostWebhooksRequest) Validate() error { - return vd.ValidateStruct(&r, +func (r PostWebhooksRequest) ValidateWithContext(ctx context.Context) error { + return vd.ValidateStructWithContext(ctx, &r, vd.Field(&r.Name, vd.Required, vd.RuneLength(1, 32)), vd.Field(&r.Description, vd.Required, vd.RuneLength(1, 1000)), + vd.Field(&r.ChannelID, vd.Required, validator.NotNilUUID, utils.IsPublicChannelID), vd.Field(&r.Secret, vd.RuneLength(0, 50)), ) } @@ -112,11 +115,13 @@ type PatchWebhookRequest struct { OwnerID uuid.NullUUID `json:"ownerId"` } -func (r PatchWebhookRequest) Validate() error { - return vd.ValidateStruct(&r, +func (r PatchWebhookRequest) ValidateWithContext(ctx context.Context) error { + return vd.ValidateStructWithContext(ctx, &r, vd.Field(&r.Name, vd.RuneLength(1, 32)), vd.Field(&r.Description, vd.RuneLength(1, 1000)), + vd.Field(&r.ChannelID, validator.NotNilUUID, utils.IsPublicChannelID), vd.Field(&r.Secret, vd.RuneLength(0, 50)), + vd.Field(&r.OwnerID, validator.NotNilUUID, utils.IsActiveHumanUserID), ) } @@ -202,8 +207,7 @@ func (h *Handlers) PostWebhook(c echo.Context) error { return herror.BadRequest("invalid channel") } if ch.IsArchived() { - path, _ := h.Repo.GetChannelPath(ch.ID) - return herror.BadRequest(fmt.Sprintf("channel #%s has been archived", path)) + return herror.BadRequest(fmt.Sprintf("channel #%s has been archived", h.Repo.GetChannelTree().GetChannelPath(ch.ID))) } // 埋め込み変換