diff --git a/cmd/app/approve.go b/cmd/app/approve.go index 1426b7b6..3d7a034d 100644 --- a/cmd/app/approve.go +++ b/cmd/app/approve.go @@ -17,14 +17,7 @@ type mergeRequestApproverService struct { } /* approveHandler approves a merge request. */ -func (a mergeRequestApproverService) handler(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - if r.Method != http.MethodPost { - w.Header().Set("Access-Control-Allow-Methods", http.MethodPost) - handleError(w, InvalidRequestError{}, "Expected POST", http.StatusMethodNotAllowed) - return - } - +func (a mergeRequestApproverService) ServeHTTP(w http.ResponseWriter, r *http.Request) { _, res, err := a.client.ApproveMergeRequest(a.projectInfo.ProjectId, a.projectInfo.MergeId, nil, nil) if err != nil { @@ -33,15 +26,12 @@ func (a mergeRequestApproverService) handler(w http.ResponseWriter, r *http.Requ } if res.StatusCode >= 300 { - handleError(w, GenericError{endpoint: "/mr/approve"}, "Could not approve merge request", res.StatusCode) + handleError(w, GenericError{r.URL.Path}, "Could not approve merge request", res.StatusCode) return } w.WriteHeader(http.StatusOK) - response := SuccessResponse{ - Message: "Approved MR", - Status: http.StatusOK, - } + response := SuccessResponse{Message: "Approved MR"} err = json.NewEncoder(w).Encode(response) if err != nil { diff --git a/cmd/app/approve_test.go b/cmd/app/approve_test.go index 5baf1c8a..b450169c 100644 --- a/cmd/app/approve_test.go +++ b/cmd/app/approve_test.go @@ -23,33 +23,36 @@ func TestApproveHandler(t *testing.T) { t.Run("Approves merge request", func(t *testing.T) { request := makeRequest(t, http.MethodPost, "/mr/approve", nil) client := fakeApproverClient{} - svc := mergeRequestApproverService{testProjectData, client} + svc := middleware( + mergeRequestApproverService{testProjectData, client}, + withMr(testProjectData, fakeMergeRequestLister{}), + withMethodCheck(http.MethodPost), + ) data := getSuccessData(t, svc, request) assert(t, data.Message, "Approved MR") - assert(t, data.Status, http.StatusOK) - }) - - t.Run("Disallows non-POST method", func(t *testing.T) { - request := makeRequest(t, http.MethodGet, "/mr/approve", nil) - client := fakeApproverClient{} - svc := mergeRequestApproverService{testProjectData, client} - data := getFailData(t, svc, request) - checkBadMethod(t, data, http.MethodPost) }) t.Run("Handles errors from Gitlab client", func(t *testing.T) { request := makeRequest(t, http.MethodPost, "/mr/approve", nil) client := fakeApproverClient{testBase{errFromGitlab: true}} - svc := mergeRequestApproverService{testProjectData, client} - data := getFailData(t, svc, request) + svc := middleware( + mergeRequestApproverService{testProjectData, client}, + withMr(testProjectData, fakeMergeRequestLister{}), + withMethodCheck(http.MethodPost), + ) + data, _ := getFailData(t, svc, request) checkErrorFromGitlab(t, data, "Could not approve merge request") }) t.Run("Handles non-200s from Gitlab client", func(t *testing.T) { request := makeRequest(t, http.MethodPost, "/mr/approve", nil) client := fakeApproverClient{testBase{status: http.StatusSeeOther}} - svc := mergeRequestApproverService{testProjectData, client} - data := getFailData(t, svc, request) + svc := middleware( + mergeRequestApproverService{testProjectData, client}, + withMr(testProjectData, fakeMergeRequestLister{}), + withMethodCheck(http.MethodPost), + ) + data, _ := getFailData(t, svc, request) checkNon200(t, data, "Could not approve merge request", "/mr/approve") }) } diff --git a/cmd/app/assignee.go b/cmd/app/assignee.go index cf419e19..b7be1670 100644 --- a/cmd/app/assignee.go +++ b/cmd/app/assignee.go @@ -2,14 +2,14 @@ package app import ( "encoding/json" - "io" + "errors" "net/http" "github.com/xanzy/go-gitlab" ) type AssigneeUpdateRequest struct { - Ids []int `json:"ids"` + Ids []int `json:"ids" validate:"required"` } type AssigneeUpdateResponse struct { @@ -17,37 +17,18 @@ type AssigneeUpdateResponse struct { Assignees []*gitlab.BasicUser `json:"assignees"` } -type AssigneesRequestResponse struct { - SuccessResponse - Assignees []int `json:"assignees"` -} - type assigneesService struct { data client MergeRequestUpdater } /* assigneesHandler adds or removes assignees from a merge request. */ -func (a assigneesService) handler(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - if r.Method != http.MethodPut { - w.Header().Set("Access-Control-Allow-Methods", http.MethodPut) - handleError(w, InvalidRequestError{}, "Expected PUT", http.StatusMethodNotAllowed) - return - } +func (a assigneesService) ServeHTTP(w http.ResponseWriter, r *http.Request) { - body, err := io.ReadAll(r.Body) - if err != nil { - handleError(w, err, "Could not read request body", http.StatusBadRequest) - return - } + assigneeUpdateRequest, ok := r.Context().Value(payload("payload")).(*AssigneeUpdateRequest) - defer r.Body.Close() - var assigneeUpdateRequest AssigneeUpdateRequest - err = json.Unmarshal(body, &assigneeUpdateRequest) - - if err != nil { - handleError(w, err, "Could not read JSON from request", http.StatusBadRequest) + if !ok { + handleError(w, errors.New("Could not get payload from context"), "Bad payload", http.StatusInternalServerError) return } @@ -61,17 +42,14 @@ func (a assigneesService) handler(w http.ResponseWriter, r *http.Request) { } if res.StatusCode >= 300 { - handleError(w, GenericError{endpoint: "/mr/assignee"}, "Could not modify merge request assignees", res.StatusCode) + handleError(w, GenericError{r.URL.Path}, "Could not modify merge request assignees", res.StatusCode) return } w.WriteHeader(http.StatusOK) response := AssigneeUpdateResponse{ - SuccessResponse: SuccessResponse{ - Message: "Assignees updated", - Status: http.StatusOK, - }, - Assignees: mr.Assignees, + SuccessResponse: SuccessResponse{Message: "Assignees updated"}, + Assignees: mr.Assignees, } err = json.NewEncoder(w).Encode(response) diff --git a/cmd/app/assignee_test.go b/cmd/app/assignee_test.go index 24e7e60c..8cdde646 100644 --- a/cmd/app/assignee_test.go +++ b/cmd/app/assignee_test.go @@ -24,34 +24,39 @@ func TestAssigneeHandler(t *testing.T) { t.Run("Updates assignees", func(t *testing.T) { request := makeRequest(t, http.MethodPut, "/mr/assignee", updatePayload) - client := fakeAssigneeClient{} - svc := assigneesService{testProjectData, client} + svc := middleware( + assigneesService{testProjectData, fakeAssigneeClient{}}, + withMr(testProjectData, fakeMergeRequestLister{}), + withPayloadValidation(methodToPayload{http.MethodPut: &AssigneeUpdateRequest{}}), + withMethodCheck(http.MethodPut), + ) data := getSuccessData(t, svc, request) assert(t, data.Message, "Assignees updated") - assert(t, data.Status, http.StatusOK) - }) - - t.Run("Disallows non-PUT method", func(t *testing.T) { - request := makeRequest(t, http.MethodGet, "/mr/assignee", nil) - client := fakeAssigneeClient{} - svc := assigneesService{testProjectData, client} - data := getFailData(t, svc, request) - checkBadMethod(t, data, http.MethodPut) }) t.Run("Handles errors from Gitlab client", func(t *testing.T) { - request := makeRequest(t, http.MethodPut, "/mr/approve", updatePayload) + request := makeRequest(t, http.MethodPut, "/mr/assignee", updatePayload) client := fakeAssigneeClient{testBase{errFromGitlab: true}} - svc := assigneesService{testProjectData, client} - data := getFailData(t, svc, request) + svc := middleware( + assigneesService{testProjectData, client}, + withMr(testProjectData, fakeMergeRequestLister{}), + withPayloadValidation(methodToPayload{http.MethodPut: &AssigneeUpdateRequest{}}), + withMethodCheck(http.MethodPut), + ) + data, _ := getFailData(t, svc, request) checkErrorFromGitlab(t, data, "Could not modify merge request assignees") }) t.Run("Handles non-200s from Gitlab client", func(t *testing.T) { - request := makeRequest(t, http.MethodPut, "/mr/approve", updatePayload) + request := makeRequest(t, http.MethodPut, "/mr/assignee", updatePayload) client := fakeAssigneeClient{testBase{status: http.StatusSeeOther}} - svc := assigneesService{testProjectData, client} - data := getFailData(t, svc, request) + svc := middleware( + assigneesService{testProjectData, client}, + withMr(testProjectData, fakeMergeRequestLister{}), + withPayloadValidation(methodToPayload{http.MethodPut: &AssigneeUpdateRequest{}}), + withMethodCheck(http.MethodPut), + ) + data, _ := getFailData(t, svc, request) checkNon200(t, data, "Could not modify merge request assignees", "/mr/assignee") }) } diff --git a/cmd/app/attachment.go b/cmd/app/attachment.go index 10ee33dd..b34ea23d 100644 --- a/cmd/app/attachment.go +++ b/cmd/app/attachment.go @@ -16,8 +16,8 @@ type FileReader interface { } type AttachmentRequest struct { - FilePath string `json:"file_path"` - FileName string `json:"file_name"` + FilePath string `json:"file_path" validate:"required"` + FileName string `json:"file_name" validate:"required"` } type AttachmentResponse struct { @@ -58,55 +58,31 @@ type attachmentService struct { } /* attachmentHandler uploads an attachment (file, image, etc) to Gitlab and returns metadata about the upload. */ -func (a attachmentService) handler(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - if r.Method != http.MethodPost { - w.Header().Set("Access-Control-Allow-Methods", http.MethodPost) - handleError(w, InvalidRequestError{}, "Expected POST", http.StatusMethodNotAllowed) - return - } - - var attachmentRequest AttachmentRequest - - body, err := io.ReadAll(r.Body) - if err != nil { - handleError(w, err, "Could not read request body", http.StatusBadRequest) - return - } - - defer r.Body.Close() - - err = json.Unmarshal(body, &attachmentRequest) - if err != nil { - handleError(w, err, "Could not unmarshal JSON", http.StatusBadRequest) - return - } +func (a attachmentService) ServeHTTP(w http.ResponseWriter, r *http.Request) { + payload := r.Context().Value(payload("payload")).(*AttachmentRequest) - file, err := a.fileReader.ReadFile(attachmentRequest.FilePath) + file, err := a.fileReader.ReadFile(payload.FilePath) if err != nil || file == nil { - handleError(w, err, fmt.Sprintf("Could not read %s file", attachmentRequest.FileName), http.StatusInternalServerError) + handleError(w, err, fmt.Sprintf("Could not read %s file", payload.FileName), http.StatusInternalServerError) return } - projectFile, res, err := a.client.UploadFile(a.projectInfo.ProjectId, file, attachmentRequest.FileName) + projectFile, res, err := a.client.UploadFile(a.projectInfo.ProjectId, file, payload.FileName) if err != nil { - handleError(w, err, fmt.Sprintf("Could not upload %s to Gitlab", attachmentRequest.FileName), http.StatusInternalServerError) + handleError(w, err, fmt.Sprintf("Could not upload %s to Gitlab", payload.FileName), http.StatusInternalServerError) return } if res.StatusCode >= 300 { - handleError(w, GenericError{endpoint: "/attachment"}, fmt.Sprintf("Could not upload %s to Gitlab", attachmentRequest.FileName), res.StatusCode) + handleError(w, GenericError{r.URL.Path}, fmt.Sprintf("Could not upload %s to Gitlab", payload.FileName), res.StatusCode) return } response := AttachmentResponse{ - SuccessResponse: SuccessResponse{ - Status: http.StatusOK, - Message: "File uploaded successfully", - }, - Markdown: projectFile.Markdown, - Alt: projectFile.Alt, - Url: projectFile.URL, + SuccessResponse: SuccessResponse{Message: "File uploaded successfully"}, + Markdown: projectFile.Markdown, + Alt: projectFile.Alt, + Url: projectFile.URL, } err = json.NewEncoder(w).Encode(response) diff --git a/cmd/app/attachment_test.go b/cmd/app/attachment_test.go index 8fefaf0b..38a1a0d5 100644 --- a/cmd/app/attachment_test.go +++ b/cmd/app/attachment_test.go @@ -36,29 +36,34 @@ func TestAttachmentHandler(t *testing.T) { t.Run("Returns 200-status response after upload", func(t *testing.T) { request := makeRequest(t, http.MethodPost, "/attachment", attachmentTestRequestData) - svc := attachmentService{testProjectData, fakeFileReader{}, fakeFileUploaderClient{}} + svc := middleware( + attachmentService{testProjectData, fakeFileReader{}, fakeFileUploaderClient{}}, + withPayloadValidation(methodToPayload{http.MethodPost: &AttachmentRequest{}}), + withMethodCheck(http.MethodPost), + ) data := getSuccessData(t, svc, request) - assert(t, data.Status, http.StatusOK) assert(t, data.Message, "File uploaded successfully") }) - t.Run("Disallows non-POST method", func(t *testing.T) { - request := makeRequest(t, http.MethodGet, "/attachment", nil) - svc := attachmentService{testProjectData, fakeFileReader{}, fakeFileUploaderClient{}} - data := getFailData(t, svc, request) - checkBadMethod(t, data, http.MethodPost) - }) t.Run("Handles errors from Gitlab client", func(t *testing.T) { request := makeRequest(t, http.MethodPost, "/attachment", attachmentTestRequestData) - svc := attachmentService{testProjectData, fakeFileReader{}, fakeFileUploaderClient{testBase{errFromGitlab: true}}} - data := getFailData(t, svc, request) + svc := middleware( + attachmentService{testProjectData, fakeFileReader{}, fakeFileUploaderClient{testBase{errFromGitlab: true}}}, + withPayloadValidation(methodToPayload{http.MethodPost: &AttachmentRequest{}}), + withMethodCheck(http.MethodPost), + ) + data, _ := getFailData(t, svc, request) checkErrorFromGitlab(t, data, "Could not upload some_file_name to Gitlab") }) t.Run("Handles non-200s from Gitlab client", func(t *testing.T) { request := makeRequest(t, http.MethodPost, "/attachment", attachmentTestRequestData) - svc := attachmentService{testProjectData, fakeFileReader{}, fakeFileUploaderClient{testBase{status: http.StatusSeeOther}}} - data := getFailData(t, svc, request) + svc := middleware( + attachmentService{testProjectData, fakeFileReader{}, fakeFileUploaderClient{testBase{status: http.StatusSeeOther}}}, + withPayloadValidation(methodToPayload{http.MethodPost: &AttachmentRequest{}}), + withMethodCheck(http.MethodPost), + ) + data, _ := getFailData(t, svc, request) checkNon200(t, data, "Could not upload some_file_name to Gitlab", "/attachment") }) } diff --git a/cmd/app/client.go b/cmd/app/client.go index c9e8ddb2..756371f7 100644 --- a/cmd/app/client.go +++ b/cmd/app/client.go @@ -5,10 +5,7 @@ import ( "encoding/json" "errors" "fmt" - "log" "net/http" - "net/http/httputil" - "os" "github.com/harrisoncramer/gitlab.nvim/cmd/app/git" "github.com/hashicorp/go-retryablehttp" @@ -48,12 +45,19 @@ func NewClient() (error, *Client) { gitlab.WithBaseURL(apiCustUrl), } - if pluginOptions.Debug.Request { - gitlabOptions = append(gitlabOptions, gitlab.WithRequestLogHook(requestLogger)) + if pluginOptions.Debug.GitlabRequest { + gitlabOptions = append(gitlabOptions, gitlab.WithRequestLogHook( + func(l retryablehttp.Logger, r *http.Request, i int) { + logRequest("REQUEST TO GITLAB", r) + }, + )) } - if pluginOptions.Debug.Response { - gitlabOptions = append(gitlabOptions, gitlab.WithResponseLogHook(responseLogger)) + if pluginOptions.Debug.GitlabResponse { + gitlabOptions = append(gitlabOptions, gitlab.WithResponseLogHook(func(l retryablehttp.Logger, response *http.Response) { + logResponse("RESPONSE FROM GITLAB", response) + }, + )) } tr := &http.Transport{ @@ -106,7 +110,6 @@ func InitProjectSettings(c *Client, gitInfo git.GitData) (error, *ProjectInfo) { return nil, &ProjectInfo{ ProjectId: projectId, } - } /* handleError is a utililty handler that returns errors to the client along with their statuses and messages */ @@ -115,7 +118,6 @@ func handleError(w http.ResponseWriter, err error, message string, status int) { response := ErrorResponse{ Message: message, Details: err.Error(), - Status: status, } err = json.NewEncoder(w).Encode(response) @@ -123,53 +125,3 @@ func handleError(w http.ResponseWriter, err error, message string, status int) { handleError(w, err, "Could not encode error response", http.StatusInternalServerError) } } - -var requestLogger retryablehttp.RequestLogHook = func(l retryablehttp.Logger, r *http.Request, i int) { - file := openLogFile() - defer file.Close() - - token := r.Header.Get("Private-Token") - r.Header.Set("Private-Token", "REDACTED") - res, err := httputil.DumpRequest(r, true) - if err != nil { - log.Fatalf("Error dumping request: %v", err) - os.Exit(1) - } - r.Header.Set("Private-Token", token) - - _, err = file.Write([]byte("\n-- REQUEST --\n")) //nolint:all - _, err = file.Write(res) //nolint:all - _, err = file.Write([]byte("\n")) //nolint:all -} - -var responseLogger retryablehttp.ResponseLogHook = func(l retryablehttp.Logger, response *http.Response) { - file := openLogFile() - defer file.Close() - - res, err := httputil.DumpResponse(response, true) - if err != nil { - log.Fatalf("Error dumping response: %v", err) - os.Exit(1) - } - - _, err = file.Write([]byte("\n-- RESPONSE --\n")) //nolint:all - _, err = file.Write(res) //nolint:all - _, err = file.Write([]byte("\n")) //nolint:all -} - -func openLogFile() *os.File { - file, err := os.OpenFile(pluginOptions.LogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) - if err != nil { - if os.IsNotExist(err) { - log.Printf("Log file %s does not exist", pluginOptions.LogPath) - } else if os.IsPermission(err) { - log.Printf("Permission denied for log file %s", pluginOptions.LogPath) - } else { - log.Printf("Error opening log file %s: %v", pluginOptions.LogPath, err) - } - - os.Exit(1) - } - - return file -} diff --git a/cmd/app/comment.go b/cmd/app/comment.go index 6ccc14f3..ebc4a412 100644 --- a/cmd/app/comment.go +++ b/cmd/app/comment.go @@ -2,45 +2,17 @@ package app import ( "encoding/json" - "fmt" - "io" "net/http" "github.com/xanzy/go-gitlab" ) -type PostCommentRequest struct { - Comment string `json:"comment"` - PositionData -} - -type DeleteCommentRequest struct { - NoteId int `json:"note_id"` - DiscussionId string `json:"discussion_id"` -} - -type EditCommentRequest struct { - Comment string `json:"comment"` - NoteId int `json:"note_id"` - DiscussionId string `json:"discussion_id"` - Resolved bool `json:"resolved"` -} - type CommentResponse struct { SuccessResponse Comment *gitlab.Note `json:"note"` Discussion *gitlab.Discussion `json:"discussion"` } -/* CommentWithPosition is a comment with an (optional) position data value embedded in it. The position data will be non-nil for range-based comments. */ -type CommentWithPosition struct { - PositionData PositionData -} - -func (comment CommentWithPosition) GetPositionData() PositionData { - return comment.PositionData -} - type CommentManager interface { CreateMergeRequestDiscussion(pid interface{}, mergeRequest int, opt *gitlab.CreateMergeRequestDiscussionOptions, options ...gitlab.RequestOptionFunc) (*gitlab.Discussion, *gitlab.Response, error) UpdateMergeRequestDiscussionNote(pid interface{}, mergeRequest int, discussion string, note int, opt *gitlab.UpdateMergeRequestDiscussionNoteOptions, options ...gitlab.RequestOptionFunc) (*gitlab.Note, *gitlab.Response, error) @@ -53,7 +25,7 @@ type commentService struct { } /* commentHandler creates, edits, and deletes discussions (comments, multi-line comments) */ -func (a commentService) handler(w http.ResponseWriter, r *http.Request) { +func (a commentService) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") switch r.Method { case http.MethodPost: @@ -62,30 +34,19 @@ func (a commentService) handler(w http.ResponseWriter, r *http.Request) { a.editComment(w, r) case http.MethodDelete: a.deleteComment(w, r) - default: - w.Header().Set("Access-Control-Allow-Methods", fmt.Sprintf("%s, %s, %s", http.MethodDelete, http.MethodPost, http.MethodPatch)) - handleError(w, InvalidRequestError{}, "Expected DELETE, POST or PATCH", http.StatusMethodNotAllowed) } } +type DeleteCommentRequest struct { + NoteId int `json:"note_id" validate:"required"` + DiscussionId string `json:"discussion_id" validate:"required"` +} + /* deleteComment deletes a note, multiline comment, or comment, which are all considered discussion notes. */ func (a commentService) deleteComment(w http.ResponseWriter, r *http.Request) { - body, err := io.ReadAll(r.Body) - if err != nil { - handleError(w, err, "Could not read request body", http.StatusBadRequest) - return - } - - defer r.Body.Close() - - var deleteCommentRequest DeleteCommentRequest - err = json.Unmarshal(body, &deleteCommentRequest) - if err != nil { - handleError(w, err, "Could not read JSON from request", http.StatusBadRequest) - return - } + payload := r.Context().Value(payload("payload")).(*DeleteCommentRequest) - res, err := a.client.DeleteMergeRequestDiscussionNote(a.projectInfo.ProjectId, a.projectInfo.MergeId, deleteCommentRequest.DiscussionId, deleteCommentRequest.NoteId) + res, err := a.client.DeleteMergeRequestDiscussionNote(a.projectInfo.ProjectId, a.projectInfo.MergeId, payload.DiscussionId, payload.NoteId) if err != nil { handleError(w, err, "Could not delete comment", http.StatusInternalServerError) @@ -93,15 +54,12 @@ func (a commentService) deleteComment(w http.ResponseWriter, r *http.Request) { } if res.StatusCode >= 300 { - handleError(w, GenericError{endpoint: "/mr/comment"}, "Could not delete comment", res.StatusCode) + handleError(w, GenericError{r.URL.Path}, "Could not delete comment", res.StatusCode) return } w.WriteHeader(http.StatusOK) - response := SuccessResponse{ - Message: "Comment deleted successfully", - Status: http.StatusOK, - } + response := SuccessResponse{Message: "Comment deleted successfully"} err = json.NewEncoder(w).Encode(response) if err != nil { @@ -109,32 +67,33 @@ func (a commentService) deleteComment(w http.ResponseWriter, r *http.Request) { } } -/* postComment creates a note, multiline comment, or comment. */ -func (a commentService) postComment(w http.ResponseWriter, r *http.Request) { - body, err := io.ReadAll(r.Body) - if err != nil { - handleError(w, err, "Could not read request body", http.StatusBadRequest) - return - } +type PostCommentRequest struct { + Comment string `json:"comment" validate:"required"` + PositionData +} - defer r.Body.Close() +/* CommentWithPosition is a comment with an (optional) position data value embedded in it. The position data will be non-nil for range-based comments. */ +type CommentWithPosition struct { + PositionData PositionData +} - var postCommentRequest PostCommentRequest - err = json.Unmarshal(body, &postCommentRequest) - if err != nil { - handleError(w, err, "Could not unmarshal data from request body", http.StatusBadRequest) - return - } +func (comment CommentWithPosition) GetPositionData() PositionData { + return comment.PositionData +} + +/* postComment creates a note, multiline comment, or comment. */ +func (a commentService) postComment(w http.ResponseWriter, r *http.Request) { + payload := r.Context().Value(payload("payload")).(*PostCommentRequest) opt := gitlab.CreateMergeRequestDiscussionOptions{ - Body: &postCommentRequest.Comment, + Body: &payload.Comment, } /* If we are leaving a comment on a line, leave position. Otherwise, we are leaving a note (unlinked comment) */ - if postCommentRequest.FileName != "" { - commentWithPositionData := CommentWithPosition{postCommentRequest.PositionData} + if payload.FileName != "" { + commentWithPositionData := CommentWithPosition{payload.PositionData} opt.Position = buildCommentPosition(commentWithPositionData) } @@ -146,18 +105,15 @@ func (a commentService) postComment(w http.ResponseWriter, r *http.Request) { } if res.StatusCode >= 300 { - handleError(w, GenericError{endpoint: "/mr/comment"}, "Could not create discussion", res.StatusCode) + handleError(w, GenericError{r.URL.Path}, "Could not create discussion", res.StatusCode) return } w.WriteHeader(http.StatusOK) response := CommentResponse{ - SuccessResponse: SuccessResponse{ - Message: "Comment created successfully", - Status: http.StatusOK, - }, - Comment: discussion.Notes[0], - Discussion: discussion, + SuccessResponse: SuccessResponse{Message: "Comment created successfully"}, + Comment: discussion.Notes[0], + Discussion: discussion, } err = json.NewEncoder(w).Encode(response) @@ -166,28 +122,23 @@ func (a commentService) postComment(w http.ResponseWriter, r *http.Request) { } } +type EditCommentRequest struct { + Comment string `json:"comment" validate:"required"` + NoteId int `json:"note_id" validate:"required"` + DiscussionId string `json:"discussion_id" validate:"required"` + Resolved bool `json:"resolved"` +} + /* editComment changes the text of a comment or changes it's resolved status. */ func (a commentService) editComment(w http.ResponseWriter, r *http.Request) { - body, err := io.ReadAll(r.Body) - - if err != nil { - handleError(w, err, "Could not read request body", http.StatusBadRequest) - return - } - defer r.Body.Close() + payload := r.Context().Value(payload("payload")).(*EditCommentRequest) - var editCommentRequest EditCommentRequest - err = json.Unmarshal(body, &editCommentRequest) - if err != nil { - handleError(w, err, "Could not unmarshal data from request body", http.StatusBadRequest) - return + options := gitlab.UpdateMergeRequestDiscussionNoteOptions{ + Body: gitlab.Ptr(payload.Comment), } - options := gitlab.UpdateMergeRequestDiscussionNoteOptions{} - options.Body = gitlab.Ptr(editCommentRequest.Comment) - - note, res, err := a.client.UpdateMergeRequestDiscussionNote(a.projectInfo.ProjectId, a.projectInfo.MergeId, editCommentRequest.DiscussionId, editCommentRequest.NoteId, &options) + note, res, err := a.client.UpdateMergeRequestDiscussionNote(a.projectInfo.ProjectId, a.projectInfo.MergeId, payload.DiscussionId, payload.NoteId, &options) if err != nil { handleError(w, err, "Could not update comment", http.StatusInternalServerError) @@ -195,17 +146,14 @@ func (a commentService) editComment(w http.ResponseWriter, r *http.Request) { } if res.StatusCode >= 300 { - handleError(w, GenericError{endpoint: "/mr/comment"}, "Could not update comment", res.StatusCode) + handleError(w, GenericError{r.URL.Path}, "Could not update comment", res.StatusCode) return } w.WriteHeader(http.StatusOK) response := CommentResponse{ - SuccessResponse: SuccessResponse{ - Message: "Comment updated successfully", - Status: http.StatusOK, - }, - Comment: note, + SuccessResponse: SuccessResponse{Message: "Comment updated successfully"}, + Comment: note, } err = json.NewEncoder(w).Encode(response) diff --git a/cmd/app/comment_test.go b/cmd/app/comment_test.go index 03ed716e..f86902f4 100644 --- a/cmd/app/comment_test.go +++ b/cmd/app/comment_test.go @@ -40,10 +40,18 @@ func TestPostComment(t *testing.T) { var testCommentCreationData = PostCommentRequest{Comment: "Some comment"} t.Run("Creates a new note (unlinked comment)", func(t *testing.T) { request := makeRequest(t, http.MethodPost, "/mr/comment", testCommentCreationData) - svc := commentService{testProjectData, fakeCommentClient{}} + svc := middleware( + commentService{testProjectData, fakeCommentClient{}}, + withMr(testProjectData, fakeMergeRequestLister{}), + withPayloadValidation(methodToPayload{ + http.MethodPost: &PostCommentRequest{}, + http.MethodDelete: &DeleteCommentRequest{}, + http.MethodPatch: &EditCommentRequest{}, + }), + withMethodCheck(http.MethodPost, http.MethodDelete, http.MethodPatch), + ) data := getSuccessData(t, svc, request) assert(t, data.Message, "Comment created successfully") - assert(t, data.Status, http.StatusOK) }) t.Run("Creates a new comment", func(t *testing.T) { @@ -54,23 +62,49 @@ func TestPostComment(t *testing.T) { }, } request := makeRequest(t, http.MethodPost, "/mr/comment", testCommentCreationData) - svc := commentService{testProjectData, fakeCommentClient{}} + svc := middleware( + commentService{testProjectData, fakeCommentClient{}}, + withMr(testProjectData, fakeMergeRequestLister{}), + withPayloadValidation(methodToPayload{ + http.MethodPost: &PostCommentRequest{}, + http.MethodDelete: &DeleteCommentRequest{}, + http.MethodPatch: &EditCommentRequest{}, + }), + withMethodCheck(http.MethodPost, http.MethodDelete, http.MethodPatch), + ) data := getSuccessData(t, svc, request) assert(t, data.Message, "Comment created successfully") - assert(t, data.Status, http.StatusOK) }) t.Run("Handles errors from Gitlab client", func(t *testing.T) { request := makeRequest(t, http.MethodPost, "/mr/comment", testCommentCreationData) - svc := commentService{testProjectData, fakeCommentClient{testBase{errFromGitlab: true}}} - data := getFailData(t, svc, request) + svc := middleware( + commentService{testProjectData, fakeCommentClient{testBase{errFromGitlab: true}}}, + withMr(testProjectData, fakeMergeRequestLister{}), + withPayloadValidation(methodToPayload{ + http.MethodPost: &PostCommentRequest{}, + http.MethodDelete: &DeleteCommentRequest{}, + http.MethodPatch: &EditCommentRequest{}, + }), + withMethodCheck(http.MethodPost, http.MethodDelete, http.MethodPatch), + ) + data, _ := getFailData(t, svc, request) checkErrorFromGitlab(t, data, "Could not create discussion") }) t.Run("Handles non-200s from Gitlab client", func(t *testing.T) { request := makeRequest(t, http.MethodPost, "/mr/comment", testCommentCreationData) - svc := commentService{testProjectData, fakeCommentClient{testBase{status: http.StatusSeeOther}}} - data := getFailData(t, svc, request) + svc := middleware( + commentService{testProjectData, fakeCommentClient{testBase{status: http.StatusSeeOther}}}, + withMr(testProjectData, fakeMergeRequestLister{}), + withPayloadValidation(methodToPayload{ + http.MethodPost: &PostCommentRequest{}, + http.MethodDelete: &DeleteCommentRequest{}, + http.MethodPatch: &EditCommentRequest{}, + }), + withMethodCheck(http.MethodPost, http.MethodDelete, http.MethodPatch), + ) + data, _ := getFailData(t, svc, request) checkNon200(t, data, "Could not create discussion", "/mr/comment") }) } @@ -79,23 +113,18 @@ func TestDeleteComment(t *testing.T) { var testCommentDeletionData = DeleteCommentRequest{NoteId: 3, DiscussionId: "abc123"} t.Run("Deletes a comment", func(t *testing.T) { request := makeRequest(t, http.MethodDelete, "/mr/comment", testCommentDeletionData) - svc := commentService{testProjectData, fakeCommentClient{}} + svc := middleware( + commentService{testProjectData, fakeCommentClient{}}, + withMr(testProjectData, fakeMergeRequestLister{}), + withPayloadValidation(methodToPayload{ + http.MethodPost: &PostCommentRequest{}, + http.MethodDelete: &DeleteCommentRequest{}, + http.MethodPatch: &EditCommentRequest{}, + }), + withMethodCheck(http.MethodPost, http.MethodDelete, http.MethodPatch), + ) data := getSuccessData(t, svc, request) assert(t, data.Message, "Comment deleted successfully") - assert(t, data.Status, http.StatusOK) - }) - t.Run("Handles errors from Gitlab client", func(t *testing.T) { - request := makeRequest(t, http.MethodDelete, "/mr/comment", testCommentDeletionData) - svc := commentService{testProjectData, fakeCommentClient{testBase{errFromGitlab: true}}} - data := getFailData(t, svc, request) - checkErrorFromGitlab(t, data, "Could not delete comment") - }) - - t.Run("Handles non-200s from Gitlab client", func(t *testing.T) { - request := makeRequest(t, http.MethodDelete, "/mr/comment", testCommentDeletionData) - svc := commentService{testProjectData, fakeCommentClient{testBase{status: http.StatusSeeOther}}} - data := getFailData(t, svc, request) - checkNon200(t, data, "Could not delete comment", "/mr/comment") }) } @@ -103,22 +132,17 @@ func TestEditComment(t *testing.T) { var testEditCommentData = EditCommentRequest{Comment: "Some comment", NoteId: 3, DiscussionId: "abc123"} t.Run("Edits a comment", func(t *testing.T) { request := makeRequest(t, http.MethodPatch, "/mr/comment", testEditCommentData) - svc := commentService{testProjectData, fakeCommentClient{}} + svc := middleware( + commentService{testProjectData, fakeCommentClient{}}, + withMr(testProjectData, fakeMergeRequestLister{}), + withPayloadValidation(methodToPayload{ + http.MethodPost: &PostCommentRequest{}, + http.MethodDelete: &DeleteCommentRequest{}, + http.MethodPatch: &EditCommentRequest{}, + }), + withMethodCheck(http.MethodPost, http.MethodDelete, http.MethodPatch), + ) data := getSuccessData(t, svc, request) assert(t, data.Message, "Comment updated successfully") - assert(t, data.Status, http.StatusOK) - }) - t.Run("Handles errors from Gitlab client", func(t *testing.T) { - request := makeRequest(t, http.MethodPatch, "/mr/comment", testEditCommentData) - svc := commentService{testProjectData, fakeCommentClient{testBase{errFromGitlab: true}}} - data := getFailData(t, svc, request) - checkErrorFromGitlab(t, data, "Could not update comment") - }) - - t.Run("Handles non-200s from Gitlab client", func(t *testing.T) { - request := makeRequest(t, http.MethodPatch, "/mr/comment", testEditCommentData) - svc := commentService{testProjectData, fakeCommentClient{testBase{status: http.StatusSeeOther}}} - data := getFailData(t, svc, request) - checkNon200(t, data, "Could not update comment", "/mr/comment") }) } diff --git a/cmd/app/config.go b/cmd/app/config.go index 166ac391..5861b4b3 100644 --- a/cmd/app/config.go +++ b/cmd/app/config.go @@ -6,8 +6,10 @@ type PluginOptions struct { AuthToken string `json:"auth_token"` LogPath string `json:"log_path"` Debug struct { - Request bool `json:"go_request"` - Response bool `json:"go_response"` + Request bool `json:"request"` + Response bool `json:"response"` + GitlabRequest bool `json:"gitlab_request"` + GitlabResponse bool `json:"gitlab_response"` } `json:"debug"` ChosenTargetBranch *string `json:"chosen_target_branch,omitempty"` ConnectionSettings struct { diff --git a/cmd/app/create_mr.go b/cmd/app/create_mr.go index b8a318d7..76985be9 100644 --- a/cmd/app/create_mr.go +++ b/cmd/app/create_mr.go @@ -2,21 +2,19 @@ package app import ( "encoding/json" - "errors" "fmt" - "io" "net/http" "github.com/xanzy/go-gitlab" ) type CreateMrRequest struct { - Title string `json:"title"` + Title string `json:"title" validate:"required"` + TargetBranch string `json:"target_branch" validate:"required"` Description string `json:"description"` - TargetBranch string `json:"target_branch"` + TargetProjectID int `json:"forked_project_id,omitempty"` DeleteBranch bool `json:"delete_branch"` Squash bool `json:"squash"` - TargetProjectID int `json:"forked_project_id,omitempty"` } type MergeRequestCreator interface { @@ -29,36 +27,9 @@ type mergeRequestCreatorService struct { } /* createMr creates a merge request */ -func (a mergeRequestCreatorService) handler(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Methods", http.MethodGet) - if r.Method != http.MethodPost { - handleError(w, InvalidRequestError{}, "Expected POST", http.StatusMethodNotAllowed) - return - } - - body, err := io.ReadAll(r.Body) - if err != nil { - handleError(w, err, "Could not read request body", http.StatusBadRequest) - return - } - - var createMrRequest CreateMrRequest - err = json.Unmarshal(body, &createMrRequest) - if err != nil { - handleError(w, err, "Could not unmarshal request body", http.StatusBadRequest) - return - } +func (a mergeRequestCreatorService) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if createMrRequest.Title == "" { - handleError(w, errors.New("Title cannot be empty"), "Could not create MR", http.StatusBadRequest) - return - } - - if createMrRequest.TargetBranch == "" { - handleError(w, errors.New("Target branch cannot be empty"), "Could not create MR", http.StatusBadRequest) - return - } + createMrRequest := r.Context().Value(payload("payload")).(*CreateMrRequest) opts := gitlab.CreateMergeRequestOptions{ Title: &createMrRequest.Title, @@ -81,14 +52,11 @@ func (a mergeRequestCreatorService) handler(w http.ResponseWriter, r *http.Reque } if res.StatusCode >= 300 { - handleError(w, GenericError{endpoint: "/create_mr"}, "Could not create MR", res.StatusCode) + handleError(w, GenericError{r.URL.Path}, "Could not create MR", res.StatusCode) return } - response := SuccessResponse{ - Status: http.StatusOK, - Message: fmt.Sprintf("MR '%s' created", createMrRequest.Title), - } + response := SuccessResponse{Message: fmt.Sprintf("MR '%s' created", createMrRequest.Title)} w.WriteHeader(http.StatusOK) diff --git a/cmd/app/create_mr_test.go b/cmd/app/create_mr_test.go index 9e5a78de..e10bd7a7 100644 --- a/cmd/app/create_mr_test.go +++ b/cmd/app/create_mr_test.go @@ -29,30 +29,34 @@ func TestCreateMr(t *testing.T) { } t.Run("Creates an MR", func(t *testing.T) { request := makeRequest(t, http.MethodPost, "/create_mr", testCreateMrRequestData) - svc := mergeRequestCreatorService{testProjectData, fakeMergeCreatorClient{}} + svc := middleware( + mergeRequestCreatorService{testProjectData, fakeMergeCreatorClient{}}, + withPayloadValidation(methodToPayload{http.MethodPost: &CreateMrRequest{}}), + withMethodCheck(http.MethodPost), + ) data := getSuccessData(t, svc, request) assert(t, data.Message, "MR 'Some title' created") - assert(t, data.Status, http.StatusOK) - }) - - t.Run("Disallows non-POST methods", func(t *testing.T) { - request := makeRequest(t, http.MethodGet, "/create_mr", testCreateMrRequestData) - svc := mergeRequestCreatorService{testProjectData, fakeMergeCreatorClient{}} - data := getFailData(t, svc, request) - checkBadMethod(t, data, http.MethodPost) }) t.Run("Handles errors from Gitlab client", func(t *testing.T) { request := makeRequest(t, http.MethodPost, "/create_mr", testCreateMrRequestData) - svc := mergeRequestCreatorService{testProjectData, fakeMergeCreatorClient{testBase{errFromGitlab: true}}} - data := getFailData(t, svc, request) + svc := middleware( + mergeRequestCreatorService{testProjectData, fakeMergeCreatorClient{testBase{errFromGitlab: true}}}, + withPayloadValidation(methodToPayload{http.MethodPost: &CreateMrRequest{}}), + withMethodCheck(http.MethodPost), + ) + data, _ := getFailData(t, svc, request) checkErrorFromGitlab(t, data, "Could not create MR") }) t.Run("Handles non-200s from Gitlab client", func(t *testing.T) { request := makeRequest(t, http.MethodPost, "/create_mr", testCreateMrRequestData) - svc := mergeRequestCreatorService{testProjectData, fakeMergeCreatorClient{testBase{status: http.StatusSeeOther}}} - data := getFailData(t, svc, request) + svc := middleware( + mergeRequestCreatorService{testProjectData, fakeMergeCreatorClient{testBase{status: http.StatusSeeOther}}}, + withPayloadValidation(methodToPayload{http.MethodPost: &CreateMrRequest{}}), + withMethodCheck(http.MethodPost), + ) + data, _ := getFailData(t, svc, request) checkNon200(t, data, "Could not create MR", "/create_mr") }) @@ -60,21 +64,27 @@ func TestCreateMr(t *testing.T) { reqData := testCreateMrRequestData reqData.Title = "" request := makeRequest(t, http.MethodPost, "/create_mr", reqData) - svc := mergeRequestCreatorService{testProjectData, fakeMergeCreatorClient{}} - data := getFailData(t, svc, request) - assert(t, data.Status, http.StatusBadRequest) - assert(t, data.Message, "Could not create MR") - assert(t, data.Details, "Title cannot be empty") + svc := middleware( + mergeRequestCreatorService{testProjectData, fakeMergeCreatorClient{}}, + withPayloadValidation(methodToPayload{http.MethodPost: &CreateMrRequest{}}), + withMethodCheck(http.MethodPost), + ) + data, _ := getFailData(t, svc, request) + assert(t, data.Message, "Invalid payload") + assert(t, data.Details, "Title is required") }) t.Run("Handles missing target branch", func(t *testing.T) { reqData := testCreateMrRequestData reqData.TargetBranch = "" request := makeRequest(t, http.MethodPost, "/create_mr", reqData) - svc := mergeRequestCreatorService{testProjectData, fakeMergeCreatorClient{}} - data := getFailData(t, svc, request) - assert(t, data.Status, http.StatusBadRequest) - assert(t, data.Message, "Could not create MR") - assert(t, data.Details, "Target branch cannot be empty") + svc := middleware( + mergeRequestCreatorService{testProjectData, fakeMergeCreatorClient{}}, + withPayloadValidation(methodToPayload{http.MethodPost: &CreateMrRequest{}}), + withMethodCheck(http.MethodPost), + ) + data, _ := getFailData(t, svc, request) + assert(t, data.Message, "Invalid payload") + assert(t, data.Details, "TargetBranch is required") }) } diff --git a/cmd/app/draft_note_publisher.go b/cmd/app/draft_note_publisher.go index 57cfa033..548c5ca2 100644 --- a/cmd/app/draft_note_publisher.go +++ b/cmd/app/draft_note_publisher.go @@ -2,8 +2,6 @@ package app import ( "encoding/json" - "errors" - "io" "net/http" "github.com/xanzy/go-gitlab" @@ -19,38 +17,19 @@ type draftNotePublisherService struct { client DraftNotePublisher } -func (a draftNotePublisherService) handler(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - if r.Method != http.MethodPost { - w.Header().Set("Access-Control-Allow-Methods", http.MethodPost) - handleError(w, InvalidRequestError{}, "Expected POST", http.StatusMethodNotAllowed) - return - } - - body, err := io.ReadAll(r.Body) - if err != nil { - handleError(w, err, "Could not read request body", http.StatusBadRequest) - return - } - - defer r.Body.Close() - var draftNotePublishRequest DraftNotePublishRequest - err = json.Unmarshal(body, &draftNotePublishRequest) +type DraftNotePublishRequest struct { + Note int `json:"note,omitempty"` +} - if err != nil { - handleError(w, err, "Could not read JSON from request", http.StatusBadRequest) - return - } +func (a draftNotePublisherService) ServeHTTP(w http.ResponseWriter, r *http.Request) { + payload := r.Context().Value(payload("payload")).(*DraftNotePublishRequest) var res *gitlab.Response - if draftNotePublishRequest.PublishAll { - res, err = a.client.PublishAllDraftNotes(a.projectInfo.ProjectId, a.projectInfo.MergeId) + var err error + if payload.Note != 0 { + res, err = a.client.PublishDraftNote(a.projectInfo.ProjectId, a.projectInfo.MergeId, payload.Note) } else { - if draftNotePublishRequest.Note == 0 { - handleError(w, errors.New("No ID provided"), "Must provide Note ID", http.StatusBadRequest) - return - } - res, err = a.client.PublishDraftNote(a.projectInfo.ProjectId, a.projectInfo.MergeId, draftNotePublishRequest.Note) + res, err = a.client.PublishAllDraftNotes(a.projectInfo.ProjectId, a.projectInfo.MergeId) } if err != nil { @@ -59,15 +38,12 @@ func (a draftNotePublisherService) handler(w http.ResponseWriter, r *http.Reques } if res.StatusCode >= 300 { - handleError(w, GenericError{endpoint: "/mr/draft_notes/publish"}, "Could not publish dfaft note", res.StatusCode) + handleError(w, GenericError{r.URL.Path}, "Could not publish dfaft note", res.StatusCode) return } w.WriteHeader(http.StatusOK) - response := SuccessResponse{ - Message: "Draft note(s) published", - Status: http.StatusOK, - } + response := SuccessResponse{Message: "Draft note(s) published"} err = json.NewEncoder(w).Encode(response) if err != nil { diff --git a/cmd/app/draft_note_publisher_test.go b/cmd/app/draft_note_publisher_test.go index b53dda1f..36e3e4f0 100644 --- a/cmd/app/draft_note_publisher_test.go +++ b/cmd/app/draft_note_publisher_test.go @@ -19,56 +19,53 @@ func (f fakeDraftNotePublisher) PublishDraftNote(pid interface{}, mergeRequest i } func TestPublishDraftNote(t *testing.T) { - var testDraftNotePublishRequest = DraftNotePublishRequest{Note: 3, PublishAll: false} + var testDraftNotePublishRequest = DraftNotePublishRequest{Note: 3} t.Run("Publishes draft note", func(t *testing.T) { request := makeRequest(t, http.MethodPost, "/mr/draft_notes/publish", testDraftNotePublishRequest) - svc := draftNotePublisherService{testProjectData, fakeDraftNotePublisher{}} + svc := middleware( + draftNotePublisherService{testProjectData, fakeDraftNotePublisher{}}, + withMr(testProjectData, fakeMergeRequestLister{}), + withPayloadValidation(methodToPayload{http.MethodPost: &DraftNotePublishRequest{}}), + withMethodCheck(http.MethodPost), + ) data := getSuccessData(t, svc, request) - assert(t, data.Status, http.StatusOK) assert(t, data.Message, "Draft note(s) published") }) - t.Run("Disallows non-POST method", func(t *testing.T) { - request := makeRequest(t, http.MethodGet, "/mr/draft_notes/publish", testDraftNotePublishRequest) - svc := draftNotePublisherService{testProjectData, fakeDraftNotePublisher{}} - data := getFailData(t, svc, request) - checkBadMethod(t, data, http.MethodPost) - }) - t.Run("Handles bad ID", func(t *testing.T) { - badData := testDraftNotePublishRequest - badData.Note = 0 - request := makeRequest(t, http.MethodPost, "/mr/draft_notes/publish", badData) - svc := draftNotePublisherService{testProjectData, fakeDraftNotePublisher{}} - data := getFailData(t, svc, request) - assert(t, data.Status, http.StatusBadRequest) - assert(t, data.Message, "Must provide Note ID") - }) t.Run("Handles error from Gitlab", func(t *testing.T) { request := makeRequest(t, http.MethodPost, "/mr/draft_notes/publish", testDraftNotePublishRequest) - svc := draftNotePublisherService{testProjectData, fakeDraftNotePublisher{testBase{errFromGitlab: true}}} - data := getFailData(t, svc, request) + svc := middleware( + draftNotePublisherService{testProjectData, fakeDraftNotePublisher{testBase: testBase{errFromGitlab: true}}}, + withMr(testProjectData, fakeMergeRequestLister{}), + withPayloadValidation(methodToPayload{http.MethodPost: &DraftNotePublishRequest{}}), + withMethodCheck(http.MethodPost), + ) + data, _ := getFailData(t, svc, request) checkErrorFromGitlab(t, data, "Could not publish draft note(s)") }) } func TestPublishAllDraftNotes(t *testing.T) { - var testDraftNotePublishRequest = DraftNotePublishRequest{PublishAll: true} + var testDraftNotePublishRequest = DraftNotePublishRequest{} t.Run("Should publish all draft notes", func(t *testing.T) { request := makeRequest(t, http.MethodPost, "/mr/draft_notes/publish", testDraftNotePublishRequest) - svc := draftNotePublisherService{testProjectData, fakeDraftNotePublisher{}} + svc := middleware( + draftNotePublisherService{testProjectData, fakeDraftNotePublisher{}}, + withMr(testProjectData, fakeMergeRequestLister{}), + withPayloadValidation(methodToPayload{http.MethodPost: &DraftNotePublishRequest{}}), + withMethodCheck(http.MethodPost), + ) data := getSuccessData(t, svc, request) - assert(t, data.Status, http.StatusOK) assert(t, data.Message, "Draft note(s) published") }) - t.Run("Disallows non-POST method", func(t *testing.T) { - request := makeRequest(t, http.MethodGet, "/mr/draft_notes/publish", testDraftNotePublishRequest) - svc := draftNotePublisherService{testProjectData, fakeDraftNotePublisher{}} - data := getFailData(t, svc, request) - checkBadMethod(t, data, http.MethodPost) - }) t.Run("Handles error from Gitlab", func(t *testing.T) { request := makeRequest(t, http.MethodPost, "/mr/draft_notes/publish", testDraftNotePublishRequest) - svc := draftNotePublisherService{testProjectData, fakeDraftNotePublisher{testBase{errFromGitlab: true}}} - data := getFailData(t, svc, request) + svc := middleware( + draftNotePublisherService{testProjectData, fakeDraftNotePublisher{testBase: testBase{errFromGitlab: true}}}, + withMr(testProjectData, fakeMergeRequestLister{}), + withPayloadValidation(methodToPayload{http.MethodPost: &DraftNotePublishRequest{}}), + withMethodCheck(http.MethodPost), + ) + data, _ := getFailData(t, svc, request) checkErrorFromGitlab(t, data, "Could not publish draft note(s)") }) } diff --git a/cmd/app/draft_notes.go b/cmd/app/draft_notes.go index 6af8f0d7..6c248115 100644 --- a/cmd/app/draft_notes.go +++ b/cmd/app/draft_notes.go @@ -3,8 +3,6 @@ package app import ( "encoding/json" "errors" - "fmt" - "io" "net/http" "strconv" "strings" @@ -12,36 +10,15 @@ import ( "github.com/xanzy/go-gitlab" ) -/* The data coming from the client when creating a draft note is the same, +/* The data coming from the client when creating a draft note is the same as when they are creating a normal comment, but the Gitlab endpoints + resources we handle are different */ -type PostDraftNoteRequest struct { - Comment string `json:"comment"` - DiscussionId string `json:"discussion_id,omitempty"` - PositionData -} - -type UpdateDraftNoteRequest struct { - Note string `json:"note"` - Position gitlab.PositionOptions -} - -type DraftNotePublishRequest struct { - Note int `json:"note,omitempty"` - PublishAll bool `json:"publish_all"` -} - type DraftNoteResponse struct { SuccessResponse DraftNote *gitlab.DraftNote `json:"draft_note"` } -type ListDraftNotesResponse struct { - SuccessResponse - DraftNotes []*gitlab.DraftNote `json:"draft_notes"` -} - /* DraftNoteWithPosition is a draft comment with an (optional) position data value embedded in it. The position data will be non-nil for range-based draft comments. */ type DraftNoteWithPosition struct { PositionData PositionData @@ -64,7 +41,7 @@ type draftNoteService struct { } /* draftNoteHandler creates, edits, and deletes draft notes */ -func (a draftNoteService) handler(w http.ResponseWriter, r *http.Request) { +func (a draftNoteService) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") switch r.Method { case http.MethodGet: @@ -75,14 +52,16 @@ func (a draftNoteService) handler(w http.ResponseWriter, r *http.Request) { a.updateDraftNote(w, r) case http.MethodDelete: a.deleteDraftNote(w, r) - default: - w.Header().Set("Access-Control-Allow-Methods", fmt.Sprintf("%s, %s, %s, %s", http.MethodDelete, http.MethodPost, http.MethodPatch, http.MethodGet)) - handleError(w, InvalidRequestError{}, "Expected DELETE, GET, POST or PATCH", http.StatusMethodNotAllowed) } } +type ListDraftNotesResponse struct { + SuccessResponse + DraftNotes []*gitlab.DraftNote `json:"draft_notes"` +} + /* listDraftNotes lists all draft notes for the currently authenticated user */ -func (a draftNoteService) listDraftNotes(w http.ResponseWriter, _ *http.Request) { +func (a draftNoteService) listDraftNotes(w http.ResponseWriter, r *http.Request) { opt := gitlab.ListDraftNotesOptions{} draftNotes, res, err := a.client.ListDraftNotes(a.projectInfo.ProjectId, a.projectInfo.MergeId, &opt) @@ -93,17 +72,14 @@ func (a draftNoteService) listDraftNotes(w http.ResponseWriter, _ *http.Request) } if res.StatusCode >= 300 { - handleError(w, GenericError{endpoint: "/mr/draft_notes/"}, "Could not get draft notes", res.StatusCode) + handleError(w, GenericError{r.URL.Path}, "Could not get draft notes", res.StatusCode) return } w.WriteHeader(http.StatusOK) response := ListDraftNotesResponse{ - SuccessResponse: SuccessResponse{ - Message: "Draft notes fetched successfully", - Status: http.StatusOK, - }, - DraftNotes: draftNotes, + SuccessResponse: SuccessResponse{Message: "Draft notes fetched successfully"}, + DraftNotes: draftNotes, } err = json.NewEncoder(w).Encode(response) @@ -112,34 +88,27 @@ func (a draftNoteService) listDraftNotes(w http.ResponseWriter, _ *http.Request) } } +type PostDraftNoteRequest struct { + Comment string `json:"comment" validate:"required"` + DiscussionId string `json:"discussion_id,omitempty" validate:"required"` + PositionData // TODO: How to add validations to data from external package??? +} + /* postDraftNote creates a draft note */ func (a draftNoteService) postDraftNote(w http.ResponseWriter, r *http.Request) { - body, err := io.ReadAll(r.Body) - if err != nil { - handleError(w, err, "Could not read request body", http.StatusBadRequest) - return - } - - defer r.Body.Close() - - var postDraftNoteRequest PostDraftNoteRequest - err = json.Unmarshal(body, &postDraftNoteRequest) - if err != nil { - handleError(w, err, "Could not unmarshal data from request body", http.StatusBadRequest) - return - } + payload := r.Context().Value(payload("payload")).(*PostDraftNoteRequest) opt := gitlab.CreateDraftNoteOptions{ - Note: &postDraftNoteRequest.Comment, + Note: &payload.Comment, } // Draft notes can be posted in "response" to existing discussions - if postDraftNoteRequest.DiscussionId != "" { - opt.InReplyToDiscussionID = gitlab.Ptr(postDraftNoteRequest.DiscussionId) + if payload.DiscussionId != "" { + opt.InReplyToDiscussionID = gitlab.Ptr(payload.DiscussionId) } - if postDraftNoteRequest.FileName != "" { - draftNoteWithPosition := DraftNoteWithPosition{postDraftNoteRequest.PositionData} + if payload.FileName != "" { + draftNoteWithPosition := DraftNoteWithPosition{payload.PositionData} opt.Position = buildCommentPosition(draftNoteWithPosition) } @@ -151,17 +120,14 @@ func (a draftNoteService) postDraftNote(w http.ResponseWriter, r *http.Request) } if res.StatusCode >= 300 { - handleError(w, GenericError{endpoint: "/mr/draft_notes/"}, "Could not create draft note", res.StatusCode) + handleError(w, GenericError{r.URL.Path}, "Could not create draft note", res.StatusCode) return } w.WriteHeader(http.StatusOK) response := DraftNoteResponse{ - SuccessResponse: SuccessResponse{ - Message: "Draft note created successfully", - Status: http.StatusOK, - }, - DraftNote: draftNote, + SuccessResponse: SuccessResponse{Message: "Draft note created successfully"}, + DraftNote: draftNote, } err = json.NewEncoder(w).Encode(response) @@ -187,15 +153,12 @@ func (a draftNoteService) deleteDraftNote(w http.ResponseWriter, r *http.Request } if res.StatusCode >= 300 { - handleError(w, GenericError{endpoint: fmt.Sprintf("/mr/draft_notes/%d", id)}, "Could not delete draft note", res.StatusCode) + handleError(w, GenericError{r.URL.Path}, "Could not delete draft note", res.StatusCode) return } w.WriteHeader(http.StatusOK) - response := SuccessResponse{ - Message: "Draft note deleted", - Status: http.StatusOK, - } + response := SuccessResponse{Message: "Draft note deleted"} err = json.NewEncoder(w).Encode(response) if err != nil { @@ -203,6 +166,11 @@ func (a draftNoteService) deleteDraftNote(w http.ResponseWriter, r *http.Request } } +type UpdateDraftNoteRequest struct { + Note string `json:"note" validate:"required"` + Position gitlab.PositionOptions +} + /* updateDraftNote edits the text of a draft comment */ func (a draftNoteService) updateDraftNote(w http.ResponseWriter, r *http.Request) { suffix := strings.TrimPrefix(r.URL.Path, "/mr/draft_notes/") @@ -212,29 +180,16 @@ func (a draftNoteService) updateDraftNote(w http.ResponseWriter, r *http.Request return } - body, err := io.ReadAll(r.Body) - if err != nil { - handleError(w, err, "Could not read request body", http.StatusBadRequest) - return - } - - defer r.Body.Close() - - var updateDraftNoteRequest UpdateDraftNoteRequest - err = json.Unmarshal(body, &updateDraftNoteRequest) - if err != nil { - handleError(w, err, "Could not unmarshal data from request body", http.StatusBadRequest) - return - } + payload := r.Context().Value(payload("payload")).(*UpdateDraftNoteRequest) - if updateDraftNoteRequest.Note == "" { + if payload.Note == "" { handleError(w, errors.New("Draft note text missing"), "Must provide draft note text", http.StatusBadRequest) return } opt := gitlab.UpdateDraftNoteOptions{ - Note: &updateDraftNoteRequest.Note, - Position: &updateDraftNoteRequest.Position, + Note: &payload.Note, + Position: &payload.Position, } draftNote, res, err := a.client.UpdateDraftNote(a.projectInfo.ProjectId, a.projectInfo.MergeId, id, &opt) @@ -245,17 +200,14 @@ func (a draftNoteService) updateDraftNote(w http.ResponseWriter, r *http.Request } if res.StatusCode >= 300 { - handleError(w, GenericError{endpoint: fmt.Sprintf("/mr/draft_notes/%d", id)}, "Could not update draft note", res.StatusCode) + handleError(w, GenericError{r.URL.Path}, "Could not update draft note", res.StatusCode) return } w.WriteHeader(http.StatusOK) response := DraftNoteResponse{ - SuccessResponse: SuccessResponse{ - Message: "Draft note updated", - Status: http.StatusOK, - }, - DraftNote: draftNote, + SuccessResponse: SuccessResponse{Message: "Draft note updated"}, + DraftNote: draftNote, } err = json.NewEncoder(w).Encode(response) diff --git a/cmd/app/draft_notes_test.go b/cmd/app/draft_notes_test.go index 17b32599..3702a69e 100644 --- a/cmd/app/draft_notes_test.go +++ b/cmd/app/draft_notes_test.go @@ -42,74 +42,99 @@ func (f fakeDraftNoteManager) UpdateDraftNote(pid interface{}, mergeRequest int, func TestListDraftNotes(t *testing.T) { t.Run("Lists all draft notes", func(t *testing.T) { request := makeRequest(t, http.MethodGet, "/mr/draft_notes/", nil) - svc := draftNoteService{testProjectData, fakeDraftNoteManager{}} + svc := middleware( + draftNoteService{testProjectData, fakeDraftNoteManager{}}, + withMr(testProjectData, fakeMergeRequestLister{}), + withPayloadValidation(methodToPayload{ + http.MethodPost: &PostDraftNoteRequest{}, + http.MethodPatch: &UpdateDraftNoteRequest{}, + }), + withMethodCheck(http.MethodGet, http.MethodPost, http.MethodPatch, http.MethodDelete), + ) + data := getSuccessData(t, svc, request) - assert(t, data.Status, http.StatusOK) assert(t, data.Message, "Draft notes fetched successfully") }) t.Run("Handles error from Gitlab client", func(t *testing.T) { request := makeRequest(t, http.MethodGet, "/mr/draft_notes/", nil) - svc := draftNoteService{testProjectData, fakeDraftNoteManager{testBase{errFromGitlab: true}}} - data := getFailData(t, svc, request) + svc := middleware( + draftNoteService{testProjectData, fakeDraftNoteManager{testBase: testBase{errFromGitlab: true}}}, + withMr(testProjectData, fakeMergeRequestLister{}), + withPayloadValidation(methodToPayload{ + http.MethodPost: &PostDraftNoteRequest{}, + http.MethodPatch: &UpdateDraftNoteRequest{}, + }), + withMethodCheck(http.MethodGet, http.MethodPost, http.MethodPatch, http.MethodDelete), + ) + data, _ := getFailData(t, svc, request) checkErrorFromGitlab(t, data, "Could not get draft notes") }) t.Run("Handles non-200s from Gitlab client", func(t *testing.T) { request := makeRequest(t, http.MethodGet, "/mr/draft_notes/", nil) - svc := draftNoteService{testProjectData, fakeDraftNoteManager{testBase{status: http.StatusSeeOther}}} - data := getFailData(t, svc, request) + svc := middleware( + draftNoteService{testProjectData, fakeDraftNoteManager{testBase: testBase{status: http.StatusSeeOther}}}, + withMr(testProjectData, fakeMergeRequestLister{}), + withPayloadValidation(methodToPayload{ + http.MethodPost: &PostDraftNoteRequest{}, + http.MethodPatch: &UpdateDraftNoteRequest{}, + }), + withMethodCheck(http.MethodGet, http.MethodPost, http.MethodPatch, http.MethodDelete), + ) + data, _ := getFailData(t, svc, request) checkNon200(t, data, "Could not get draft notes", "/mr/draft_notes/") }) } func TestPostDraftNote(t *testing.T) { - var testPostDraftNoteRequestData = PostDraftNoteRequest{Comment: "Some comment"} + var testPostDraftNoteRequestData = PostDraftNoteRequest{ + Comment: "Some comment", + DiscussionId: "abc123", + } t.Run("Posts new draft note", func(t *testing.T) { request := makeRequest(t, http.MethodPost, "/mr/draft_notes/", testPostDraftNoteRequestData) - svc := draftNoteService{testProjectData, fakeDraftNoteManager{}} + svc := middleware( + draftNoteService{testProjectData, fakeDraftNoteManager{}}, + withMr(testProjectData, fakeMergeRequestLister{}), + withPayloadValidation(methodToPayload{ + http.MethodPost: &PostDraftNoteRequest{}, + http.MethodPatch: &UpdateDraftNoteRequest{}, + }), + withMethodCheck(http.MethodGet, http.MethodPost, http.MethodPatch, http.MethodDelete), + ) data := getSuccessData(t, svc, request) - assert(t, data.Status, http.StatusOK) assert(t, data.Message, "Draft note created successfully") }) - t.Run("Handles error from Gitlab client", func(t *testing.T) { - request := makeRequest(t, http.MethodPost, "/mr/draft_notes/", testPostDraftNoteRequestData) - svc := draftNoteService{testProjectData, fakeDraftNoteManager{testBase{errFromGitlab: true}}} - data := getFailData(t, svc, request) - checkErrorFromGitlab(t, data, "Could not create draft note") - }) - t.Run("Handles non-200s from Gitlab client", func(t *testing.T) { - request := makeRequest(t, http.MethodPost, "/mr/draft_notes/", testPostDraftNoteRequestData) - svc := draftNoteService{testProjectData, fakeDraftNoteManager{testBase{status: http.StatusSeeOther}}} - data := getFailData(t, svc, request) - checkNon200(t, data, "Could not create draft note", "/mr/draft_notes/") - }) } func TestDeleteDraftNote(t *testing.T) { t.Run("Deletes new draft note", func(t *testing.T) { request := makeRequest(t, http.MethodDelete, "/mr/draft_notes/3", nil) - svc := draftNoteService{testProjectData, fakeDraftNoteManager{}} + svc := middleware( + draftNoteService{testProjectData, fakeDraftNoteManager{}}, + withMr(testProjectData, fakeMergeRequestLister{}), + withPayloadValidation(methodToPayload{ + http.MethodPost: &PostDraftNoteRequest{}, + http.MethodPatch: &UpdateDraftNoteRequest{}, + }), + withMethodCheck(http.MethodGet, http.MethodPost, http.MethodPatch, http.MethodDelete), + ) data := getSuccessData(t, svc, request) - assert(t, data.Status, http.StatusOK) assert(t, data.Message, "Draft note deleted") }) - t.Run("Handles error from Gitlab client", func(t *testing.T) { - request := makeRequest(t, http.MethodDelete, "/mr/draft_notes/3", nil) - svc := draftNoteService{testProjectData, fakeDraftNoteManager{testBase{errFromGitlab: true}}} - data := getFailData(t, svc, request) - checkErrorFromGitlab(t, data, "Could not delete draft note") - }) - t.Run("Handles non-200s from Gitlab client", func(t *testing.T) { - request := makeRequest(t, http.MethodDelete, "/mr/draft_notes/3", nil) - svc := draftNoteService{testProjectData, fakeDraftNoteManager{testBase{status: http.StatusSeeOther}}} - data := getFailData(t, svc, request) - checkNon200(t, data, "Could not delete draft note", "/mr/draft_notes/3") - }) t.Run("Handles bad ID", func(t *testing.T) { request := makeRequest(t, http.MethodDelete, "/mr/draft_notes/blah", nil) - svc := draftNoteService{testProjectData, fakeDraftNoteManager{testBase{status: http.StatusSeeOther}}} - data := getFailData(t, svc, request) + svc := middleware( + draftNoteService{testProjectData, fakeDraftNoteManager{}}, + withMr(testProjectData, fakeMergeRequestLister{}), + withPayloadValidation(methodToPayload{ + http.MethodPost: &PostDraftNoteRequest{}, + http.MethodPatch: &UpdateDraftNoteRequest{}, + }), + withMethodCheck(http.MethodGet, http.MethodPost, http.MethodPatch, http.MethodDelete), + ) + data, status := getFailData(t, svc, request) assert(t, data.Message, "Could not parse draft note ID") - assert(t, data.Status, http.StatusBadRequest) + assert(t, status, http.StatusBadRequest) }) } @@ -117,37 +142,49 @@ func TestEditDraftNote(t *testing.T) { var testUpdateDraftNoteRequest = UpdateDraftNoteRequest{Note: "Some new note"} t.Run("Edits new draft note", func(t *testing.T) { request := makeRequest(t, http.MethodPatch, "/mr/draft_notes/3", testUpdateDraftNoteRequest) - svc := draftNoteService{testProjectData, fakeDraftNoteManager{}} + svc := middleware( + draftNoteService{testProjectData, fakeDraftNoteManager{}}, + withMr(testProjectData, fakeMergeRequestLister{}), + withPayloadValidation(methodToPayload{ + http.MethodPost: &PostDraftNoteRequest{}, + http.MethodPatch: &UpdateDraftNoteRequest{}, + }), + withMethodCheck(http.MethodGet, http.MethodPost, http.MethodPatch, http.MethodDelete), + ) data := getSuccessData(t, svc, request) - assert(t, data.Status, http.StatusOK) assert(t, data.Message, "Draft note updated") }) - t.Run("Handles error from Gitlab client", func(t *testing.T) { - request := makeRequest(t, http.MethodPatch, "/mr/draft_notes/3", testUpdateDraftNoteRequest) - svc := draftNoteService{testProjectData, fakeDraftNoteManager{testBase{errFromGitlab: true}}} - data := getFailData(t, svc, request) - checkErrorFromGitlab(t, data, "Could not update draft note") - }) - t.Run("Handles non-200s from Gitlab client", func(t *testing.T) { - request := makeRequest(t, http.MethodPatch, "/mr/draft_notes/3", testUpdateDraftNoteRequest) - svc := draftNoteService{testProjectData, fakeDraftNoteManager{testBase{status: http.StatusSeeOther}}} - data := getFailData(t, svc, request) - checkNon200(t, data, "Could not update draft note", "/mr/draft_notes/3") - }) t.Run("Handles bad ID", func(t *testing.T) { request := makeRequest(t, http.MethodPatch, "/mr/draft_notes/blah", testUpdateDraftNoteRequest) - svc := draftNoteService{testProjectData, fakeDraftNoteManager{testBase{status: http.StatusSeeOther}}} - data := getFailData(t, svc, request) + svc := middleware( + draftNoteService{testProjectData, fakeDraftNoteManager{}}, + withMr(testProjectData, fakeMergeRequestLister{}), + withPayloadValidation(methodToPayload{ + http.MethodPost: &PostDraftNoteRequest{}, + http.MethodPatch: &UpdateDraftNoteRequest{}, + }), + withMethodCheck(http.MethodGet, http.MethodPost, http.MethodPatch, http.MethodDelete), + ) + data, status := getFailData(t, svc, request) assert(t, data.Message, "Could not parse draft note ID") - assert(t, data.Status, http.StatusBadRequest) + assert(t, status, http.StatusBadRequest) }) t.Run("Handles empty note", func(t *testing.T) { requestData := testUpdateDraftNoteRequest requestData.Note = "" request := makeRequest(t, http.MethodPatch, "/mr/draft_notes/3", requestData) - svc := draftNoteService{testProjectData, fakeDraftNoteManager{testBase{status: http.StatusSeeOther}}} - data := getFailData(t, svc, request) - assert(t, data.Message, "Must provide draft note text") - assert(t, data.Status, http.StatusBadRequest) + svc := middleware( + draftNoteService{testProjectData, fakeDraftNoteManager{}}, + withMr(testProjectData, fakeMergeRequestLister{}), + withPayloadValidation(methodToPayload{ + http.MethodPost: &PostDraftNoteRequest{}, + http.MethodPatch: &UpdateDraftNoteRequest{}, + }), + withMethodCheck(http.MethodGet, http.MethodPost, http.MethodPatch, http.MethodDelete), + ) + data, status := getFailData(t, svc, request) + assert(t, data.Message, "Invalid payload") + assert(t, data.Details, "Note is required") + assert(t, status, http.StatusBadRequest) }) } diff --git a/cmd/app/emoji.go b/cmd/app/emoji.go index 979fce6f..4a535038 100644 --- a/cmd/app/emoji.go +++ b/cmd/app/emoji.go @@ -48,16 +48,13 @@ type emojiService struct { client EmojiManager } -func (a emojiService) handler(w http.ResponseWriter, r *http.Request) { +func (a emojiService) ServeHTTP(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") switch r.Method { case http.MethodPost: a.postEmojiOnNote(w, r) case http.MethodDelete: a.deleteEmojiFromNote(w, r) - default: - w.Header().Set("Access-Control-Allow-Methods", fmt.Sprintf("%s, %s", http.MethodDelete, http.MethodPost)) - handleError(w, InvalidRequestError{}, "Expected DELETE or POST", http.StatusMethodNotAllowed) } } @@ -87,15 +84,12 @@ func (a emojiService) deleteEmojiFromNote(w http.ResponseWriter, r *http.Request } if res.StatusCode >= 300 { - handleError(w, GenericError{endpoint: "/pipeline"}, "Could not delete awardable", res.StatusCode) + handleError(w, GenericError{r.URL.Path}, "Could not delete awardable", res.StatusCode) return } w.WriteHeader(http.StatusOK) - response := SuccessResponse{ - Message: "Emoji deleted", - Status: http.StatusOK, - } + response := SuccessResponse{Message: "Emoji deleted"} err = json.NewEncoder(w).Encode(response) if err != nil { @@ -131,17 +125,14 @@ func (a emojiService) postEmojiOnNote(w http.ResponseWriter, r *http.Request) { } if res.StatusCode >= 300 { - handleError(w, GenericError{endpoint: "/mr/awardable/note"}, "Could not post emoji", res.StatusCode) + handleError(w, GenericError{r.URL.Path}, "Could not post emoji", res.StatusCode) return } w.WriteHeader(http.StatusOK) response := CreateEmojiResponse{ - SuccessResponse: SuccessResponse{ - Message: "Merge requests retrieved", - Status: http.StatusOK, - }, - Emoji: awardEmoji, + SuccessResponse: SuccessResponse{Message: "Merge requests retrieved"}, + Emoji: awardEmoji, } err = json.NewEncoder(w).Encode(response) diff --git a/cmd/app/git/git.go b/cmd/app/git/git.go index 2398d99f..7307430b 100644 --- a/cmd/app/git/git.go +++ b/cmd/app/git/git.go @@ -59,7 +59,7 @@ func NewGitData(remote string, g GitManager) (GitData, error) { https://git@gitlab.com/namespace/subnamespace/dummy-test-repo.git git@git@gitlab.com:namespace/subnamespace/dummy-test-repo.git */ - re := regexp.MustCompile(`(?:^https?:\/\/|^ssh:\/\/|^git@)(?:[^\/:]+)(?::\d+)?[\/:](.*)\/([^\/]+?)(?:\.git)?$`) + re := regexp.MustCompile(`^(?:git@[^\/:]*|https?:\/\/[^\/]+|ssh:\/\/[^\/:]+)(?::\d+)?[\/:](.*)\/([^\/]+?)(?:\.git)?\/?$`) matches := re.FindStringSubmatch(url) if len(matches) != 3 { return GitData{}, fmt.Errorf("Invalid Git URL format: %s", url) diff --git a/cmd/app/git/git_test.go b/cmd/app/git/git_test.go index 60b496d8..e8864fb2 100644 --- a/cmd/app/git/git_test.go +++ b/cmd/app/git/git_test.go @@ -101,6 +101,13 @@ func TestExtractGitInfo_Success(t *testing.T) { projectName: "project-name", namespace: "namespace-1", }, + { + desc: "Project configured in HTTP and under a single folder without .git extension (with embedded credentials)", + remote: "http://username:password@custom-gitlab.com/namespace-1/project-name", + branch: "feature/abc", + projectName: "project-name", + namespace: "namespace-1", + }, { desc: "Project configured in HTTPS and under a single folder", remote: "https://custom-gitlab.com/namespace-1/project-name.git", @@ -108,6 +115,13 @@ func TestExtractGitInfo_Success(t *testing.T) { projectName: "project-name", namespace: "namespace-1", }, + { + desc: "Project configured in HTTPS and under a single folder (with embedded credentials)", + remote: "https://username:password@custom-gitlab.com/namespace-1/project-name.git", + branch: "feature/abc", + projectName: "project-name", + namespace: "namespace-1", + }, { desc: "Project configured in HTTPS and under a nested folder", remote: "https://custom-gitlab.com/namespace-1/namespace-2/project-name.git", @@ -115,6 +129,13 @@ func TestExtractGitInfo_Success(t *testing.T) { projectName: "project-name", namespace: "namespace-1/namespace-2", }, + { + desc: "Project configured in HTTPS and under a nested folder (with embedded credentials)", + remote: "https://username:password@custom-gitlab.com/namespace-1/namespace-2/project-name.git", + branch: "feature/abc", + projectName: "project-name", + namespace: "namespace-1/namespace-2", + }, { desc: "Project configured in HTTPS and under two nested folders", remote: "https://custom-gitlab.com/namespace-1/namespace-2/namespace-3/project-name.git", @@ -122,6 +143,13 @@ func TestExtractGitInfo_Success(t *testing.T) { projectName: "project-name", namespace: "namespace-1/namespace-2/namespace-3", }, + { + desc: "Project configured in HTTPS and under two nested folders (with embedded credentials)", + remote: "https://username:password@custom-gitlab.com/namespace-1/namespace-2/namespace-3/project-name.git", + branch: "feature/abc", + projectName: "project-name", + namespace: "namespace-1/namespace-2/namespace-3", + }, } for _, tC := range testCases { t.Run(tC.desc, func(t *testing.T) { diff --git a/cmd/app/info.go b/cmd/app/info.go index 8838efdc..ea1321ae 100644 --- a/cmd/app/info.go +++ b/cmd/app/info.go @@ -22,14 +22,7 @@ type infoService struct { } /* infoHandler fetches infomation about the current git project. The data returned here is used in many other API calls */ -func (a infoService) handler(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - if r.Method != http.MethodGet { - w.Header().Set("Access-Control-Allow-Methods", http.MethodGet) - handleError(w, InvalidRequestError{}, "Expected GET", http.StatusMethodNotAllowed) - return - } - +func (a infoService) ServeHTTP(w http.ResponseWriter, r *http.Request) { mr, res, err := a.client.GetMergeRequest(a.projectInfo.ProjectId, a.projectInfo.MergeId, &gitlab.GetMergeRequestsOptions{}) if err != nil { handleError(w, err, "Could not get project info", http.StatusInternalServerError) @@ -37,17 +30,14 @@ func (a infoService) handler(w http.ResponseWriter, r *http.Request) { } if res.StatusCode >= 300 { - handleError(w, GenericError{endpoint: "/mr/info"}, "Could not get project info", res.StatusCode) + handleError(w, GenericError{r.URL.Path}, "Could not get project info", res.StatusCode) return } w.WriteHeader(http.StatusOK) response := InfoResponse{ - SuccessResponse: SuccessResponse{ - Message: "Merge requests retrieved", - Status: http.StatusOK, - }, - Info: mr, + SuccessResponse: SuccessResponse{Message: "Merge requests retrieved"}, + Info: mr, } err = json.NewEncoder(w).Encode(response) diff --git a/cmd/app/info_test.go b/cmd/app/info_test.go index e2253b50..66dd90af 100644 --- a/cmd/app/info_test.go +++ b/cmd/app/info_test.go @@ -23,27 +23,29 @@ func (f fakeMergeRequestGetter) GetMergeRequest(pid interface{}, mergeRequest in func TestInfoHandler(t *testing.T) { t.Run("Returns normal information", func(t *testing.T) { request := makeRequest(t, http.MethodGet, "/mr/info", nil) - svc := infoService{testProjectData, fakeMergeRequestGetter{}} + svc := middleware( + infoService{testProjectData, fakeMergeRequestGetter{}}, + withMethodCheck(http.MethodGet), + ) data := getSuccessData(t, svc, request) assert(t, data.Message, "Merge requests retrieved") - assert(t, data.Status, http.StatusOK) - }) - t.Run("Disallows non-GET methods", func(t *testing.T) { - request := makeRequest(t, http.MethodPost, "/mr/info", nil) - svc := infoService{testProjectData, fakeMergeRequestGetter{}} - data := getFailData(t, svc, request) - checkBadMethod(t, data, http.MethodGet) }) t.Run("Handles errors from Gitlab client", func(t *testing.T) { request := makeRequest(t, http.MethodGet, "/mr/info", nil) - svc := infoService{testProjectData, fakeMergeRequestGetter{testBase{errFromGitlab: true}}} - data := getFailData(t, svc, request) + svc := middleware( + infoService{testProjectData, fakeMergeRequestGetter{testBase{errFromGitlab: true}}}, + withMethodCheck(http.MethodGet), + ) + data, _ := getFailData(t, svc, request) checkErrorFromGitlab(t, data, "Could not get project info") }) t.Run("Handles non-200s from Gitlab client", func(t *testing.T) { request := makeRequest(t, http.MethodGet, "/mr/info", nil) - svc := infoService{testProjectData, fakeMergeRequestGetter{testBase{status: http.StatusSeeOther}}} - data := getFailData(t, svc, request) + svc := middleware( + infoService{testProjectData, fakeMergeRequestGetter{testBase{status: http.StatusSeeOther}}}, + withMethodCheck(http.MethodGet), + ) + data, _ := getFailData(t, svc, request) checkNon200(t, data, "Could not get project info", "/mr/info") }) } diff --git a/cmd/app/job.go b/cmd/app/job.go index 4fc7abd9..ce7436e8 100644 --- a/cmd/app/job.go +++ b/cmd/app/job.go @@ -10,7 +10,7 @@ import ( ) type JobTraceRequest struct { - JobId int `json:"job_id"` + JobId int `json:"job_id" validate:"required"` } type JobTraceResponse struct { @@ -28,30 +28,11 @@ type traceFileService struct { } /* jobHandler returns a string that shows the output of a specific job run in a Gitlab pipeline */ -func (a traceFileService) handler(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - if r.Method != http.MethodGet { - w.Header().Set("Access-Control-Allow-Methods", http.MethodGet) - handleError(w, InvalidRequestError{}, "Expected GET", http.StatusMethodNotAllowed) - return - } - - body, err := io.ReadAll(r.Body) - if err != nil { - handleError(w, err, "Could not read request body", http.StatusBadRequest) - return - } - - defer r.Body.Close() +func (a traceFileService) ServeHTTP(w http.ResponseWriter, r *http.Request) { - var jobTraceRequest JobTraceRequest - err = json.Unmarshal(body, &jobTraceRequest) - if err != nil { - handleError(w, err, "Could not unmarshal data from request body", http.StatusBadRequest) - return - } + payload := r.Context().Value(payload("payload")).(*JobTraceRequest) - reader, res, err := a.client.GetTraceFile(a.projectInfo.ProjectId, jobTraceRequest.JobId) + reader, res, err := a.client.GetTraceFile(a.projectInfo.ProjectId, payload.JobId) if err != nil { handleError(w, err, "Could not get trace file for job", http.StatusInternalServerError) @@ -59,7 +40,7 @@ func (a traceFileService) handler(w http.ResponseWriter, r *http.Request) { } if res.StatusCode >= 300 { - handleError(w, GenericError{endpoint: "/job"}, "Could not get trace file for job", res.StatusCode) + handleError(w, GenericError{r.URL.Path}, "Could not get trace file for job", res.StatusCode) return } @@ -71,11 +52,8 @@ func (a traceFileService) handler(w http.ResponseWriter, r *http.Request) { } response := JobTraceResponse{ - SuccessResponse: SuccessResponse{ - Status: http.StatusOK, - Message: "Log file read", - }, - File: string(file), + SuccessResponse: SuccessResponse{Message: "Log file read"}, + File: string(file), } err = json.NewEncoder(w).Encode(response) diff --git a/cmd/app/job_test.go b/cmd/app/job_test.go index 37073707..f9465484 100644 --- a/cmd/app/job_test.go +++ b/cmd/app/job_test.go @@ -14,9 +14,9 @@ type fakeTraceFileGetter struct { testBase } -func getTraceFileData(t *testing.T, svc ServiceWithHandler, request *http.Request) JobTraceResponse { +func getTraceFileData(t *testing.T, svc http.Handler, request *http.Request) JobTraceResponse { res := httptest.NewRecorder() - svc.handler(res, request) + svc.ServeHTTP(res, request) var data JobTraceResponse err := json.Unmarshal(res.Body.Bytes(), &data) @@ -35,37 +35,37 @@ func (f fakeTraceFileGetter) GetTraceFile(pid interface{}, jobID int, options .. return re, resp, err } -// var jobId = 0 func TestJobHandler(t *testing.T) { t.Run("Should read a job trace file", func(t *testing.T) { - request := makeRequest(t, http.MethodGet, "/job", JobTraceRequest{}) - client := fakeTraceFileGetter{} - svc := traceFileService{testProjectData, client} + request := makeRequest(t, http.MethodGet, "/job", JobTraceRequest{JobId: 3}) + svc := middleware( + traceFileService{testProjectData, fakeTraceFileGetter{}}, + withPayloadValidation(methodToPayload{http.MethodGet: &JobTraceRequest{}}), + withMethodCheck(http.MethodGet), + ) data := getTraceFileData(t, svc, request) assert(t, data.Message, "Log file read") - assert(t, data.Status, http.StatusOK) assert(t, data.File, "Some data") }) - t.Run("Disallows non-GET methods", func(t *testing.T) { - request := makeRequest(t, http.MethodPost, "/job", JobTraceRequest{}) - client := fakeTraceFileGetter{} - svc := traceFileService{testProjectData, client} - data := getFailData(t, svc, request) - checkBadMethod(t, data, http.MethodGet) - }) t.Run("Handles errors from Gitlab client", func(t *testing.T) { - request := makeRequest(t, http.MethodGet, "/job", JobTraceRequest{}) - client := fakeTraceFileGetter{testBase{errFromGitlab: true}} - svc := traceFileService{testProjectData, client} - data := getFailData(t, svc, request) + request := makeRequest(t, http.MethodGet, "/job", JobTraceRequest{JobId: 2}) + svc := middleware( + traceFileService{testProjectData, fakeTraceFileGetter{testBase{errFromGitlab: true}}}, + withPayloadValidation(methodToPayload{http.MethodGet: &JobTraceRequest{}}), + withMethodCheck(http.MethodGet), + ) + data, _ := getFailData(t, svc, request) checkErrorFromGitlab(t, data, "Could not get trace file for job") }) t.Run("Handles non-200s from Gitlab client", func(t *testing.T) { - request := makeRequest(t, http.MethodGet, "/job", JobTraceRequest{}) - client := fakeTraceFileGetter{testBase{status: http.StatusSeeOther}} - svc := traceFileService{testProjectData, client} - data := getFailData(t, svc, request) + request := makeRequest(t, http.MethodGet, "/job", JobTraceRequest{JobId: 1}) + svc := middleware( + traceFileService{testProjectData, fakeTraceFileGetter{testBase{status: http.StatusSeeOther}}}, + withPayloadValidation(methodToPayload{http.MethodGet: &JobTraceRequest{}}), + withMethodCheck(http.MethodGet), + ) + data, _ := getFailData(t, svc, request) checkNon200(t, data, "Could not get trace file for job", "/job") }) } diff --git a/cmd/app/label.go b/cmd/app/label.go index d2c7b998..ca0ba426 100644 --- a/cmd/app/label.go +++ b/cmd/app/label.go @@ -2,7 +2,6 @@ package app import ( "encoding/json" - "fmt" "io" "net/http" @@ -39,20 +38,16 @@ type labelService struct { } /* labelsHandler adds or removes labels from a merge request, and returns all labels for the current project */ -func (a labelService) handler(w http.ResponseWriter, r *http.Request) { +func (a labelService) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: a.getLabels(w, r) case http.MethodPut: a.updateLabels(w, r) - default: - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Methods", fmt.Sprintf("%s, %s", http.MethodPut, http.MethodGet)) - handleError(w, InvalidRequestError{}, "Expected GET or PUT", http.StatusMethodNotAllowed) } } -func (a labelService) getLabels(w http.ResponseWriter, _ *http.Request) { +func (a labelService) getLabels(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json") labels, res, err := a.client.ListLabels(a.projectInfo.ProjectId, &gitlab.ListLabelsOptions{}) @@ -63,7 +58,7 @@ func (a labelService) getLabels(w http.ResponseWriter, _ *http.Request) { } if res.StatusCode >= 300 { - handleError(w, GenericError{endpoint: "/mr/label"}, "Could not modify merge request labels", res.StatusCode) + handleError(w, GenericError{r.URL.Path}, "Could not modify merge request labels", res.StatusCode) return } @@ -78,11 +73,8 @@ func (a labelService) getLabels(w http.ResponseWriter, _ *http.Request) { w.WriteHeader(http.StatusOK) response := LabelsRequestResponse{ - SuccessResponse: SuccessResponse{ - Message: "Labels updated", - Status: http.StatusOK, - }, - Labels: convertedLabels, + SuccessResponse: SuccessResponse{Message: "Labels updated"}, + Labels: convertedLabels, } err = json.NewEncoder(w).Encode(response) @@ -120,17 +112,14 @@ func (a labelService) updateLabels(w http.ResponseWriter, r *http.Request) { } if res.StatusCode >= 300 { - handleError(w, GenericError{endpoint: "/mr/label"}, "Could not modify merge request labels", res.StatusCode) + handleError(w, GenericError{r.URL.Path}, "Could not modify merge request labels", res.StatusCode) return } w.WriteHeader(http.StatusOK) response := LabelUpdateResponse{ - SuccessResponse: SuccessResponse{ - Message: "Labels updated", - Status: http.StatusOK, - }, - Labels: mr.Labels, + SuccessResponse: SuccessResponse{Message: "Labels updated"}, + Labels: mr.Labels, } err = json.NewEncoder(w).Encode(response) diff --git a/cmd/app/list_discussions.go b/cmd/app/list_discussions.go index c2a47b7d..a75134ce 100644 --- a/cmd/app/list_discussions.go +++ b/cmd/app/list_discussions.go @@ -1,7 +1,6 @@ package app import ( - "io" "net/http" "sort" "sync" @@ -21,7 +20,7 @@ func Contains[T comparable](elems []T, v T) bool { } type DiscussionsRequest struct { - Blacklist []string `json:"blacklist"` + Blacklist []string `json:"blacklist" validate:"required"` } type DiscussionsResponse struct { @@ -61,27 +60,9 @@ type discussionsListerService struct { listDiscussionsHandler lists all discusions for a given merge request, both those linked and unlinked to particular points in the code. The responses are sorted by date created, and blacklisted users are not included */ -func (a discussionsListerService) handler(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - if r.Method != http.MethodPost { - w.Header().Set("Access-Control-Allow-Methods", http.MethodPost) - handleError(w, InvalidRequestError{}, "Expected POST", http.StatusMethodNotAllowed) - return - } - - body, err := io.ReadAll(r.Body) +func (a discussionsListerService) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if err != nil { - handleError(w, err, "Could not read request body", http.StatusBadRequest) - } - - defer r.Body.Close() - - var requestBody DiscussionsRequest - err = json.Unmarshal(body, &requestBody) - if err != nil { - handleError(w, err, "Could not unmarshal request body", http.StatusBadRequest) - } + request := r.Context().Value(payload(payload("payload"))).(*DiscussionsRequest) mergeRequestDiscussionOptions := gitlab.ListMergeRequestDiscussionsOptions{ Page: 1, @@ -96,7 +77,7 @@ func (a discussionsListerService) handler(w http.ResponseWriter, r *http.Request } if res.StatusCode >= 300 { - handleError(w, GenericError{endpoint: "/mr/discussions/list"}, "Could not list discussions", res.StatusCode) + handleError(w, GenericError{r.URL.Path}, "Could not list discussions", res.StatusCode) return } @@ -106,7 +87,7 @@ func (a discussionsListerService) handler(w http.ResponseWriter, r *http.Request var linkedDiscussions []*gitlab.Discussion for _, discussion := range discussions { - if discussion.Notes == nil || len(discussion.Notes) == 0 || Contains(requestBody.Blacklist, discussion.Notes[0].Author.Username) { + if discussion.Notes == nil || len(discussion.Notes) == 0 || Contains(request.Blacklist, discussion.Notes[0].Author.Username) { continue } for _, note := range discussion.Notes { @@ -142,10 +123,7 @@ func (a discussionsListerService) handler(w http.ResponseWriter, r *http.Request w.WriteHeader(http.StatusOK) response := DiscussionsResponse{ - SuccessResponse: SuccessResponse{ - Message: "Discussions retrieved", - Status: http.StatusOK, - }, + SuccessResponse: SuccessResponse{Message: "Discussions retrieved"}, Discussions: linkedDiscussions, UnlinkedDiscussions: unlinkedDiscussions, Emojis: emojis, diff --git a/cmd/app/list_discussions_test.go b/cmd/app/list_discussions_test.go index ee4e0020..25d284c2 100644 --- a/cmd/app/list_discussions_test.go +++ b/cmd/app/list_discussions_test.go @@ -53,9 +53,9 @@ func (f fakeDiscussionsLister) ListMergeRequestAwardEmojiOnNote(pid interface{}, return []*gitlab.AwardEmoji{}, resp, err } -func getDiscussionsList(t *testing.T, svc ServiceWithHandler, request *http.Request) DiscussionsResponse { +func getDiscussionsList(t *testing.T, svc http.Handler, request *http.Request) DiscussionsResponse { res := httptest.NewRecorder() - svc.handler(res, request) + svc.ServeHTTP(res, request) var data DiscussionsResponse err := json.Unmarshal(res.Body.Bytes(), &data) @@ -67,46 +67,63 @@ func getDiscussionsList(t *testing.T, svc ServiceWithHandler, request *http.Requ func TestListDiscussions(t *testing.T) { t.Run("Returns sorted discussions", func(t *testing.T) { - request := makeRequest(t, http.MethodPost, "/mr/discussions/list", DiscussionsRequest{}) - svc := discussionsListerService{testProjectData, fakeDiscussionsLister{}} + request := makeRequest(t, http.MethodPost, "/mr/discussions/list", DiscussionsRequest{Blacklist: []string{}}) + svc := middleware( + discussionsListerService{testProjectData, fakeDiscussionsLister{}}, + withMr(testProjectData, fakeMergeRequestLister{}), + withPayloadValidation(methodToPayload{http.MethodPost: &DiscussionsRequest{}}), + withMethodCheck(http.MethodPost), + ) data := getDiscussionsList(t, svc, request) assert(t, data.Message, "Discussions retrieved") - assert(t, data.SuccessResponse.Status, http.StatusOK) assert(t, data.Discussions[0].Notes[0].Author.Username, "hcramer2") /* Sorting applied */ assert(t, data.Discussions[1].Notes[0].Author.Username, "hcramer") }) t.Run("Uses blacklist to filter unwanted authors", func(t *testing.T) { request := makeRequest(t, http.MethodPost, "/mr/discussions/list", DiscussionsRequest{Blacklist: []string{"hcramer"}}) - svc := discussionsListerService{testProjectData, fakeDiscussionsLister{}} + svc := middleware( + discussionsListerService{testProjectData, fakeDiscussionsLister{}}, + withMr(testProjectData, fakeMergeRequestLister{}), + withPayloadValidation(methodToPayload{http.MethodPost: &DiscussionsRequest{}}), + withMethodCheck(http.MethodPost), + ) data := getDiscussionsList(t, svc, request) assert(t, data.SuccessResponse.Message, "Discussions retrieved") - assert(t, data.SuccessResponse.Status, http.StatusOK) assert(t, len(data.Discussions), 1) assert(t, data.Discussions[0].Notes[0].Author.Username, "hcramer2") }) - t.Run("Disallows non-GET methods", func(t *testing.T) { - request := makeRequest(t, http.MethodGet, "/mr/discussions/list", DiscussionsRequest{}) - svc := discussionsListerService{testProjectData, fakeDiscussionsLister{}} - data := getFailData(t, svc, request) - checkBadMethod(t, data, http.MethodPost) - }) t.Run("Handles errors from Gitlab client", func(t *testing.T) { - request := makeRequest(t, http.MethodPost, "/mr/discussions/list", DiscussionsRequest{}) - svc := discussionsListerService{testProjectData, fakeDiscussionsLister{testBase: testBase{errFromGitlab: true}}} - data := getFailData(t, svc, request) + request := makeRequest(t, http.MethodPost, "/mr/discussions/list", DiscussionsRequest{Blacklist: []string{}}) + svc := middleware( + discussionsListerService{testProjectData, fakeDiscussionsLister{testBase: testBase{errFromGitlab: true}}}, + withMr(testProjectData, fakeMergeRequestLister{}), + withPayloadValidation(methodToPayload{http.MethodPost: &DiscussionsRequest{}}), + withMethodCheck(http.MethodPost), + ) + data, _ := getFailData(t, svc, request) checkErrorFromGitlab(t, data, "Could not list discussions") }) t.Run("Handles non-200s from Gitlab client", func(t *testing.T) { - request := makeRequest(t, http.MethodPost, "/mr/discussions/list", DiscussionsRequest{}) - svc := discussionsListerService{testProjectData, fakeDiscussionsLister{testBase: testBase{status: http.StatusSeeOther}}} - data := getFailData(t, svc, request) + request := makeRequest(t, http.MethodPost, "/mr/discussions/list", DiscussionsRequest{Blacklist: []string{}}) + svc := middleware( + discussionsListerService{testProjectData, fakeDiscussionsLister{testBase: testBase{status: http.StatusSeeOther}}}, + withMr(testProjectData, fakeMergeRequestLister{}), + withPayloadValidation(methodToPayload{http.MethodPost: &DiscussionsRequest{}}), + withMethodCheck(http.MethodPost), + ) + data, _ := getFailData(t, svc, request) checkNon200(t, data, "Could not list discussions", "/mr/discussions/list") }) t.Run("Handles error from emoji service", func(t *testing.T) { - request := makeRequest(t, http.MethodPost, "/mr/discussions/list", DiscussionsRequest{}) - svc := discussionsListerService{testProjectData, fakeDiscussionsLister{badEmojiResponse: true}} - data := getFailData(t, svc, request) + request := makeRequest(t, http.MethodPost, "/mr/discussions/list", DiscussionsRequest{Blacklist: []string{}}) + svc := middleware( + discussionsListerService{testProjectData, fakeDiscussionsLister{badEmojiResponse: true, testBase: testBase{}}}, + withMr(testProjectData, fakeMergeRequestLister{}), + withPayloadValidation(methodToPayload{http.MethodPost: &DiscussionsRequest{}}), + withMethodCheck(http.MethodPost), + ) + data, _ := getFailData(t, svc, request) assert(t, data.Message, "Could not fetch emojis") assert(t, data.Details, "Some error from emoji service") }) diff --git a/cmd/app/logging.go b/cmd/app/logging.go new file mode 100644 index 00000000..e7fda5d9 --- /dev/null +++ b/cmd/app/logging.go @@ -0,0 +1,96 @@ +package app + +import ( + "bytes" + "fmt" + "io" + "log" + "net/http" + "net/http/httputil" + "os" +) + +// LoggingServer is a wrapper around an http.Handler to log incoming requests and outgoing responses. +type LoggingServer struct { + handler http.Handler +} + +type LoggingResponseWriter struct { + statusCode int + body *bytes.Buffer + http.ResponseWriter +} + +func (l *LoggingResponseWriter) WriteHeader(statusCode int) { + l.statusCode = statusCode +} + +func (l *LoggingResponseWriter) Write(b []byte) (int, error) { + l.body.Write(b) + return l.ResponseWriter.Write(b) +} + +// Logs the request, calls the original handler on the ServeMux, then logs the response +func (l LoggingServer) ServeHTTP(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json") + if pluginOptions.Debug.Request { + logRequest("REQUEST TO GO SERVER", r) + } + lrw := &LoggingResponseWriter{ResponseWriter: w, body: &bytes.Buffer{}} + l.handler.ServeHTTP(lrw, r) + resp := &http.Response{ + Status: http.StatusText(lrw.statusCode), + StatusCode: lrw.statusCode, + Body: io.NopCloser(bytes.NewBuffer(lrw.body.Bytes())), // Use the captured body + ContentLength: int64(lrw.body.Len()), + Header: lrw.Header(), + Request: r, + } + if pluginOptions.Debug.Response { + logResponse("RESPONSE FROM GO SERVER", resp) + } +} + +func logRequest(prefix string, r *http.Request) { + file := openLogFile() + defer file.Close() + token := r.Header.Get("Private-Token") + r.Header.Set("Private-Token", "REDACTED") + res, err := httputil.DumpRequest(r, true) + if err != nil { + log.Fatalf("Error dumping request: %v", err) + os.Exit(1) + } + r.Header.Set("Private-Token", token) + _, err = file.Write([]byte(fmt.Sprintf("\n-- %s --\n%s\n", prefix, res))) //nolint:all +} + +func logResponse(prefix string, r *http.Response) { + file := openLogFile() + defer file.Close() + + res, err := httputil.DumpResponse(r, true) + if err != nil { + log.Fatalf("Error dumping response: %v", err) + os.Exit(1) + } + + _, err = file.Write([]byte(fmt.Sprintf("\n-- %s --\n%s\n", prefix, res))) //nolint:all +} + +func openLogFile() *os.File { + file, err := os.OpenFile(pluginOptions.LogPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) + if err != nil { + if os.IsNotExist(err) { + log.Printf("Log file %s does not exist", pluginOptions.LogPath) + } else if os.IsPermission(err) { + log.Printf("Permission denied for log file %s", pluginOptions.LogPath) + } else { + log.Printf("Error opening log file %s: %v", pluginOptions.LogPath, err) + } + + os.Exit(1) + } + + return file +} diff --git a/cmd/app/members.go b/cmd/app/members.go index 9002ac0b..d8293a4a 100644 --- a/cmd/app/members.go +++ b/cmd/app/members.go @@ -22,13 +22,7 @@ type projectMemberService struct { } /* projectMembersHandler returns all members of the current Gitlab project */ -func (a projectMemberService) handler(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - if r.Method != http.MethodGet { - w.Header().Set("Access-Control-Allow-Methods", http.MethodGet) - handleError(w, InvalidRequestError{}, "Expected GET", http.StatusMethodNotAllowed) - return - } +func (a projectMemberService) ServeHTTP(w http.ResponseWriter, r *http.Request) { projectMemberOptions := gitlab.ListProjectMembersOptions{ ListOptions: gitlab.ListOptions{ @@ -44,18 +38,15 @@ func (a projectMemberService) handler(w http.ResponseWriter, r *http.Request) { } if res.StatusCode >= 300 { - handleError(w, GenericError{endpoint: "/project/members"}, "Could not retrieve project members", res.StatusCode) + handleError(w, GenericError{r.URL.Path}, "Could not retrieve project members", res.StatusCode) return } w.WriteHeader(http.StatusOK) response := ProjectMembersResponse{ - SuccessResponse: SuccessResponse{ - Status: http.StatusOK, - Message: "Project members retrieved", - }, - ProjectMembers: projectMembers, + SuccessResponse: SuccessResponse{Message: "Project members retrieved"}, + ProjectMembers: projectMembers, } err = json.NewEncoder(w).Encode(response) diff --git a/cmd/app/members_test.go b/cmd/app/members_test.go index cdf10efe..e641b1ce 100644 --- a/cmd/app/members_test.go +++ b/cmd/app/members_test.go @@ -22,27 +22,29 @@ func (f fakeMemberLister) ListAllProjectMembers(pid interface{}, opt *gitlab.Lis func TestMembersHandler(t *testing.T) { t.Run("Returns project members", func(t *testing.T) { request := makeRequest(t, http.MethodGet, "/project/members", nil) - svc := projectMemberService{testProjectData, fakeMemberLister{}} + svc := middleware( + projectMemberService{testProjectData, fakeMemberLister{}}, + withMethodCheck(http.MethodGet), + ) data := getSuccessData(t, svc, request) - assert(t, data.Status, http.StatusOK) assert(t, data.Message, "Project members retrieved") }) - t.Run("Disallows non-GET methods", func(t *testing.T) { - request := makeRequest(t, http.MethodPost, "/project/members", nil) - svc := projectMemberService{testProjectData, fakeMemberLister{}} - data := getFailData(t, svc, request) - checkBadMethod(t, data, http.MethodGet) - }) t.Run("Handles error from Gitlab client", func(t *testing.T) { request := makeRequest(t, http.MethodGet, "/project/members", nil) - svc := projectMemberService{testProjectData, fakeMemberLister{testBase{errFromGitlab: true}}} - data := getFailData(t, svc, request) + svc := middleware( + projectMemberService{testProjectData, fakeMemberLister{testBase{errFromGitlab: true}}}, + withMethodCheck(http.MethodGet), + ) + data, _ := getFailData(t, svc, request) checkErrorFromGitlab(t, data, "Could not retrieve project members") }) t.Run("Handles non-200s from Gitlab client", func(t *testing.T) { request := makeRequest(t, http.MethodGet, "/project/members", nil) - svc := projectMemberService{testProjectData, fakeMemberLister{testBase{status: http.StatusSeeOther}}} - data := getFailData(t, svc, request) + svc := middleware( + projectMemberService{testProjectData, fakeMemberLister{testBase{status: http.StatusSeeOther}}}, + withMethodCheck(http.MethodGet), + ) + data, _ := getFailData(t, svc, request) checkNon200(t, data, "Could not retrieve project members", "/project/members") }) } diff --git a/cmd/app/merge_mr.go b/cmd/app/merge_mr.go index 89245cb4..872ac548 100644 --- a/cmd/app/merge_mr.go +++ b/cmd/app/merge_mr.go @@ -2,16 +2,15 @@ package app import ( "encoding/json" - "io" "net/http" "github.com/xanzy/go-gitlab" ) type AcceptMergeRequestRequest struct { - Squash bool `json:"squash"` - SquashMessage string `json:"squash_message"` DeleteBranch bool `json:"delete_branch"` + SquashMessage string `json:"squash_message"` + Squash bool `json:"squash"` } type MergeRequestAccepter interface { @@ -24,34 +23,16 @@ type mergeRequestAccepterService struct { } /* acceptAndMergeHandler merges a given merge request into the target branch */ -func (a mergeRequestAccepterService) handler(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Methods", http.MethodGet) - if r.Method != http.MethodPost { - handleError(w, InvalidRequestError{}, "Expected POST", http.StatusMethodNotAllowed) - return - } - - body, err := io.ReadAll(r.Body) - if err != nil { - handleError(w, err, "Could not read request body", http.StatusBadRequest) - return - } - - var acceptAndMergeRequest AcceptMergeRequestRequest - err = json.Unmarshal(body, &acceptAndMergeRequest) - if err != nil { - handleError(w, err, "Could not unmarshal request body", http.StatusBadRequest) - return - } +func (a mergeRequestAccepterService) ServeHTTP(w http.ResponseWriter, r *http.Request) { + payload := r.Context().Value(payload("payload")).(*AcceptMergeRequestRequest) opts := gitlab.AcceptMergeRequestOptions{ - Squash: &acceptAndMergeRequest.Squash, - ShouldRemoveSourceBranch: &acceptAndMergeRequest.DeleteBranch, + Squash: &payload.Squash, + ShouldRemoveSourceBranch: &payload.DeleteBranch, } - if acceptAndMergeRequest.SquashMessage != "" { - opts.SquashCommitMessage = &acceptAndMergeRequest.SquashMessage + if payload.SquashMessage != "" { + opts.SquashCommitMessage = &payload.SquashMessage } _, res, err := a.client.AcceptMergeRequest(a.projectInfo.ProjectId, a.projectInfo.MergeId, &opts) @@ -62,14 +43,11 @@ func (a mergeRequestAccepterService) handler(w http.ResponseWriter, r *http.Requ } if res.StatusCode >= 300 { - handleError(w, GenericError{endpoint: "/mr/merge"}, "Could not merge MR", res.StatusCode) + handleError(w, GenericError{r.URL.Path}, "Could not merge MR", res.StatusCode) return } - response := SuccessResponse{ - Status: http.StatusOK, - Message: "MR merged successfully", - } + response := SuccessResponse{Message: "MR merged successfully"} w.WriteHeader(http.StatusOK) diff --git a/cmd/app/merge_mr_test.go b/cmd/app/merge_mr_test.go index 1adabba4..9c48cf41 100644 --- a/cmd/app/merge_mr_test.go +++ b/cmd/app/merge_mr_test.go @@ -24,27 +24,41 @@ func TestAcceptAndMergeHandler(t *testing.T) { var testAcceptMergeRequestPayload = AcceptMergeRequestRequest{Squash: false, SquashMessage: "Squash me!", DeleteBranch: false} t.Run("Accepts and merges a merge request", func(t *testing.T) { request := makeRequest(t, http.MethodPost, "/mr/merge", testAcceptMergeRequestPayload) - svc := mergeRequestAccepterService{testProjectData, fakeMergeRequestAccepter{}} + svc := middleware( + mergeRequestAccepterService{testProjectData, fakeMergeRequestAccepter{}}, + withMr(testProjectData, fakeMergeRequestLister{}), + withPayloadValidation(methodToPayload{ + http.MethodPost: &AcceptMergeRequestRequest{}, + }), + withMethodCheck(http.MethodPost), + ) data := getSuccessData(t, svc, request) assert(t, data.Message, "MR merged successfully") - assert(t, data.Status, http.StatusOK) - }) - t.Run("Disallows non-POST methods", func(t *testing.T) { - request := makeRequest(t, http.MethodPut, "/mr/merge", testAcceptMergeRequestPayload) - svc := mergeRequestAccepterService{testProjectData, fakeMergeRequestAccepter{}} - data := getFailData(t, svc, request) - checkBadMethod(t, data, http.MethodPost) }) t.Run("Handles errors from Gitlab client", func(t *testing.T) { request := makeRequest(t, http.MethodPost, "/mr/merge", testAcceptMergeRequestPayload) - svc := mergeRequestAccepterService{testProjectData, fakeMergeRequestAccepter{testBase{errFromGitlab: true}}} - data := getFailData(t, svc, request) + svc := middleware( + mergeRequestAccepterService{testProjectData, fakeMergeRequestAccepter{testBase{errFromGitlab: true}}}, + withMr(testProjectData, fakeMergeRequestLister{}), + withPayloadValidation(methodToPayload{ + http.MethodPost: &AcceptMergeRequestRequest{}, + }), + withMethodCheck(http.MethodPost), + ) + data, _ := getFailData(t, svc, request) checkErrorFromGitlab(t, data, "Could not merge MR") }) t.Run("Handles non-200s from Gitlab", func(t *testing.T) { request := makeRequest(t, http.MethodPost, "/mr/merge", testAcceptMergeRequestPayload) - svc := mergeRequestAccepterService{testProjectData, fakeMergeRequestAccepter{testBase{status: http.StatusSeeOther}}} - data := getFailData(t, svc, request) + svc := middleware( + mergeRequestAccepterService{testProjectData, fakeMergeRequestAccepter{testBase{status: http.StatusSeeOther}}}, + withMr(testProjectData, fakeMergeRequestLister{}), + withPayloadValidation(methodToPayload{ + http.MethodPost: &AcceptMergeRequestRequest{}, + }), + withMethodCheck(http.MethodPost), + ) + data, _ := getFailData(t, svc, request) checkNon200(t, data, "Could not merge MR", "/mr/merge") }) } diff --git a/cmd/app/merge_requests.go b/cmd/app/merge_requests.go index bc7d1ab1..4fd58b6d 100644 --- a/cmd/app/merge_requests.go +++ b/cmd/app/merge_requests.go @@ -3,7 +3,6 @@ package app import ( "encoding/json" "errors" - "io" "net/http" "github.com/xanzy/go-gitlab" @@ -23,37 +22,20 @@ type mergeRequestListerService struct { client MergeRequestLister } -func (a mergeRequestListerService) handler(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - if r.Method != http.MethodPost { - w.Header().Set("Access-Control-Allow-Methods", http.MethodPost) - handleError(w, InvalidRequestError{}, "Expected POST", http.StatusMethodNotAllowed) - return - } +// Lists all merge requests in Gitlab according to the provided filters +func (a mergeRequestListerService) ServeHTTP(w http.ResponseWriter, r *http.Request) { - body, err := io.ReadAll(r.Body) - if err != nil { - handleError(w, err, "Could not read request body", http.StatusBadRequest) - return - } - - defer r.Body.Close() - var listMergeRequestRequest gitlab.ListProjectMergeRequestsOptions - err = json.Unmarshal(body, &listMergeRequestRequest) - if err != nil { - handleError(w, err, "Could not read JSON from request", http.StatusBadRequest) - return - } + payload := r.Context().Value(payload("payload")).(*gitlab.ListProjectMergeRequestsOptions) - if listMergeRequestRequest.State == nil { - listMergeRequestRequest.State = gitlab.Ptr("opened") + if payload.State == nil { + payload.State = gitlab.Ptr("opened") } - if listMergeRequestRequest.Scope == nil { - listMergeRequestRequest.Scope = gitlab.Ptr("all") + if payload.Scope == nil { + payload.Scope = gitlab.Ptr("all") } - mergeRequests, res, err := a.client.ListProjectMergeRequests(a.projectInfo.ProjectId, &listMergeRequestRequest) + mergeRequests, res, err := a.client.ListProjectMergeRequests(a.projectInfo.ProjectId, payload) if err != nil { handleError(w, err, "Failed to list merge requests", http.StatusInternalServerError) @@ -61,7 +43,7 @@ func (a mergeRequestListerService) handler(w http.ResponseWriter, r *http.Reques } if res.StatusCode >= 300 { - handleError(w, GenericError{endpoint: "/merge_requests"}, "Failed to list merge requests", res.StatusCode) + handleError(w, GenericError{r.URL.Path}, "Failed to list merge requests", res.StatusCode) return } @@ -72,11 +54,8 @@ func (a mergeRequestListerService) handler(w http.ResponseWriter, r *http.Reques w.WriteHeader(http.StatusOK) response := ListMergeRequestResponse{ - SuccessResponse: SuccessResponse{ - Message: "Merge requests fetched successfully", - Status: http.StatusOK, - }, - MergeRequests: mergeRequests, + SuccessResponse: SuccessResponse{Message: "Merge requests fetched successfully"}, + MergeRequests: mergeRequests, } err = json.NewEncoder(w).Encode(response) diff --git a/cmd/app/merge_requests_by_username.go b/cmd/app/merge_requests_by_username.go index 4f64448a..4a097566 100644 --- a/cmd/app/merge_requests_by_username.go +++ b/cmd/app/merge_requests_by_username.go @@ -4,7 +4,6 @@ import ( "encoding/json" "errors" "fmt" - "io" "net/http" "sync" @@ -21,42 +20,15 @@ type mergeRequestListerByUsernameService struct { } type MergeRequestByUsernameRequest struct { - UserId int `json:"user_id"` - Username string `json:"username"` + UserId int `json:"user_id" validate:"required"` + Username string `json:"username" validate:"required"` State string `json:"state,omitempty"` } -func (a mergeRequestListerByUsernameService) handler(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - if r.Method != http.MethodPost { - w.Header().Set("Access-Control-Allow-Methods", http.MethodPost) - handleError(w, InvalidRequestError{}, "Expected POST", http.StatusMethodNotAllowed) - return - } - - body, err := io.ReadAll(r.Body) - if err != nil { - handleError(w, err, "Could not read request body", http.StatusBadRequest) - return - } - - defer r.Body.Close() - var request MergeRequestByUsernameRequest - err = json.Unmarshal(body, &request) - if err != nil { - handleError(w, err, "Could not read JSON from request", http.StatusBadRequest) - return - } - - if request.Username == "" { - handleError(w, errors.New("username is a required payload field"), "username is required", http.StatusBadRequest) - return - } +// Returns a list of merge requests where the given username/id is either an assignee, reviewer, or author +func (a mergeRequestListerByUsernameService) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if request.UserId == 0 { - handleError(w, errors.New("user_id is a required payload field"), "user_id is required", http.StatusBadRequest) - return - } + request := r.Context().Value(payload("payload")).(*MergeRequestByUsernameRequest) if request.State == "" { request.State = "opened" @@ -133,14 +105,11 @@ func (a mergeRequestListerByUsernameService) handler(w http.ResponseWriter, r *h w.WriteHeader(http.StatusOK) response := ListMergeRequestResponse{ - SuccessResponse: SuccessResponse{ - Message: fmt.Sprintf("Merge requests fetched for %s", request.Username), - Status: http.StatusOK, - }, - MergeRequests: mergeRequests, + SuccessResponse: SuccessResponse{Message: fmt.Sprintf("Merge requests fetched for %s", request.Username)}, + MergeRequests: mergeRequests, } - err = json.NewEncoder(w).Encode(response) + err := json.NewEncoder(w).Encode(response) if err != nil { handleError(w, err, "Could not encode response", http.StatusInternalServerError) } diff --git a/cmd/app/merge_requests_by_username_test.go b/cmd/app/merge_requests_by_username_test.go index 34631a0b..a6c2d010 100644 --- a/cmd/app/merge_requests_by_username_test.go +++ b/cmd/app/merge_requests_by_username_test.go @@ -30,58 +30,81 @@ func TestListMergeRequestByUsername(t *testing.T) { var testListMrsByUsernamePayload = MergeRequestByUsernameRequest{Username: "hcramer", UserId: 1234, State: "opened"} t.Run("Gets merge requests by username", func(t *testing.T) { request := makeRequest(t, http.MethodPost, "/merge_requests_by_username", testListMrsByUsernamePayload) - svc := mergeRequestListerByUsernameService{testProjectData, fakeMergeRequestListerByUsername{}} + svc := middleware( + mergeRequestListerByUsernameService{testProjectData, fakeMergeRequestListerByUsername{}}, + withPayloadValidation(methodToPayload{http.MethodPost: &MergeRequestByUsernameRequest{}}), + withMethodCheck(http.MethodPost), + ) data := getSuccessData(t, svc, request) assert(t, data.Message, "Merge requests fetched for hcramer") - assert(t, data.Status, http.StatusOK) }) t.Run("Should handle no merge requests", func(t *testing.T) { request := makeRequest(t, http.MethodPost, "/merge_requests_by_username", testListMrsByUsernamePayload) - svc := mergeRequestListerByUsernameService{testProjectData, fakeMergeRequestListerByUsername{emptyResponse: true}} - data := getFailData(t, svc, request) + svc := middleware( + mergeRequestListerByUsernameService{testProjectData, fakeMergeRequestListerByUsername{emptyResponse: true}}, + withPayloadValidation(methodToPayload{http.MethodPost: &MergeRequestByUsernameRequest{}}), + withMethodCheck(http.MethodPost), + ) + data, status := getFailData(t, svc, request) assert(t, data.Message, "No MRs found") assert(t, data.Details, "hcramer did not have any MRs") - assert(t, data.Status, http.StatusNotFound) + assert(t, status, http.StatusNotFound) }) t.Run("Should require username", func(t *testing.T) { missingUsernamePayload := testListMrsByUsernamePayload missingUsernamePayload.Username = "" request := makeRequest(t, http.MethodPost, "/merge_requests_by_username", missingUsernamePayload) - svc := mergeRequestListerByUsernameService{testProjectData, fakeMergeRequestListerByUsername{}} - data := getFailData(t, svc, request) - assert(t, data.Message, "username is required") - assert(t, data.Details, "username is a required payload field") - assert(t, data.Status, http.StatusBadRequest) + svc := middleware( + mergeRequestListerByUsernameService{testProjectData, fakeMergeRequestListerByUsername{}}, + withPayloadValidation(methodToPayload{http.MethodPost: &MergeRequestByUsernameRequest{}}), + withMethodCheck(http.MethodPost), + ) + data, status := getFailData(t, svc, request) + assert(t, data.Message, "Invalid payload") + assert(t, data.Details, "Username is required") + assert(t, status, http.StatusBadRequest) }) t.Run("Should require User ID for assignee call", func(t *testing.T) { missingUsernamePayload := testListMrsByUsernamePayload missingUsernamePayload.UserId = 0 request := makeRequest(t, http.MethodPost, "/merge_requests_by_username", missingUsernamePayload) - svc := mergeRequestListerByUsernameService{testProjectData, fakeMergeRequestListerByUsername{}} - data := getFailData(t, svc, request) - assert(t, data.Message, "user_id is required") - assert(t, data.Details, "user_id is a required payload field") - assert(t, data.Status, http.StatusBadRequest) + svc := middleware( + mergeRequestListerByUsernameService{testProjectData, fakeMergeRequestListerByUsername{}}, + withPayloadValidation(methodToPayload{http.MethodPost: &MergeRequestByUsernameRequest{}}), + withMethodCheck(http.MethodPost), + ) + data, status := getFailData(t, svc, request) + assert(t, data.Message, "Invalid payload") + assert(t, data.Details, "UserId is required") + assert(t, status, http.StatusBadRequest) }) t.Run("Should handle error from Gitlab", func(t *testing.T) { request := makeRequest(t, http.MethodPost, "/merge_requests_by_username", testListMrsByUsernamePayload) - svc := mergeRequestListerByUsernameService{testProjectData, fakeMergeRequestListerByUsername{testBase: testBase{errFromGitlab: true}}} - data := getFailData(t, svc, request) + svc := middleware( + mergeRequestListerByUsernameService{testProjectData, fakeMergeRequestListerByUsername{testBase: testBase{errFromGitlab: true}}}, + withPayloadValidation(methodToPayload{http.MethodPost: &MergeRequestByUsernameRequest{}}), + withMethodCheck(http.MethodPost), + ) + data, status := getFailData(t, svc, request) assert(t, data.Message, "An error occurred") assert(t, data.Details, strings.Repeat("Some error from Gitlab; ", 3)) - assert(t, data.Status, http.StatusInternalServerError) + assert(t, status, http.StatusInternalServerError) }) t.Run("Handles non-200 from Gitlab", func(t *testing.T) { request := makeRequest(t, http.MethodPost, "/merge_requests_by_username", testListMrsByUsernamePayload) - svc := mergeRequestListerByUsernameService{testProjectData, fakeMergeRequestListerByUsername{testBase: testBase{status: http.StatusSeeOther}}} - data := getFailData(t, svc, request) + svc := middleware( + mergeRequestListerByUsernameService{testProjectData, fakeMergeRequestListerByUsername{testBase: testBase{status: http.StatusSeeOther}}}, + withPayloadValidation(methodToPayload{http.MethodPost: &MergeRequestByUsernameRequest{}}), + withMethodCheck(http.MethodPost), + ) + data, status := getFailData(t, svc, request) assert(t, data.Message, "An error occurred") assert(t, data.Details, strings.Repeat("An error occurred on the /merge_requests_by_username endpoint; ", 3)) - assert(t, data.Status, http.StatusInternalServerError) + assert(t, status, http.StatusInternalServerError) }) } diff --git a/cmd/app/merge_requests_test.go b/cmd/app/merge_requests_test.go index 79020ce8..5f8cd1c3 100644 --- a/cmd/app/merge_requests_test.go +++ b/cmd/app/merge_requests_test.go @@ -10,6 +10,7 @@ import ( type fakeMergeRequestLister struct { testBase emptyResponse bool + multipleMrs bool } func (f fakeMergeRequestLister) ListProjectMergeRequests(pid interface{}, opt *gitlab.ListProjectMergeRequestsOptions, options ...gitlab.RequestOptionFunc) ([]*gitlab.MergeRequest, *gitlab.Response, error) { @@ -22,6 +23,10 @@ func (f fakeMergeRequestLister) ListProjectMergeRequests(pid interface{}, opt *g return []*gitlab.MergeRequest{}, resp, err } + if f.multipleMrs { + return []*gitlab.MergeRequest{{IID: 10}, {IID: 11}}, resp, err + } + return []*gitlab.MergeRequest{{IID: 10}}, resp, err } @@ -29,30 +34,45 @@ func TestMergeRequestHandler(t *testing.T) { var testListMergeRequestsRequest = gitlab.ListProjectMergeRequestsOptions{} t.Run("Should fetch merge requests", func(t *testing.T) { request := makeRequest(t, http.MethodPost, "/merge_requests", testListMergeRequestsRequest) - svc := mergeRequestListerService{testProjectData, fakeMergeRequestLister{}} + svc := middleware( + mergeRequestListerService{testProjectData, fakeMergeRequestLister{}}, + withPayloadValidation(methodToPayload{http.MethodPost: &gitlab.ListProjectMergeRequestsOptions{}}), + withMethodCheck(http.MethodPost), + ) data := getSuccessData(t, svc, request) - assert(t, data.Status, http.StatusOK) assert(t, data.Message, "Merge requests fetched successfully") }) t.Run("Handles error from Gitlab client", func(t *testing.T) { request := makeRequest(t, http.MethodPost, "/merge_requests", testListMergeRequestsRequest) - svc := mergeRequestListerService{testProjectData, fakeMergeRequestLister{testBase: testBase{errFromGitlab: true}}} - data := getFailData(t, svc, request) + svc := middleware( + mergeRequestListerService{testProjectData, fakeMergeRequestLister{testBase: testBase{errFromGitlab: true}}}, + withPayloadValidation(methodToPayload{http.MethodPost: &gitlab.ListProjectMergeRequestsOptions{}}), + withMethodCheck(http.MethodPost), + ) + data, status := getFailData(t, svc, request) checkErrorFromGitlab(t, data, "Failed to list merge requests") - assert(t, data.Status, http.StatusInternalServerError) + assert(t, status, http.StatusInternalServerError) }) t.Run("Handles non-200s from Gitlab client", func(t *testing.T) { request := makeRequest(t, http.MethodPost, "/merge_requests", testListMergeRequestsRequest) - svc := mergeRequestListerService{testProjectData, fakeMergeRequestLister{testBase: testBase{status: http.StatusSeeOther}}} - data := getFailData(t, svc, request) + svc := middleware( + mergeRequestListerService{testProjectData, fakeMergeRequestLister{testBase: testBase{status: http.StatusSeeOther}}}, + withPayloadValidation(methodToPayload{http.MethodPost: &gitlab.ListProjectMergeRequestsOptions{}}), + withMethodCheck(http.MethodPost), + ) + data, status := getFailData(t, svc, request) checkNon200(t, data, "Failed to list merge requests", "/merge_requests") - assert(t, data.Status, http.StatusSeeOther) + assert(t, status, http.StatusSeeOther) }) t.Run("Should handle not having any merge requests with 404", func(t *testing.T) { request := makeRequest(t, http.MethodPost, "/merge_requests", testListMergeRequestsRequest) - svc := mergeRequestListerService{testProjectData, fakeMergeRequestLister{emptyResponse: true}} - data := getFailData(t, svc, request) + svc := middleware( + mergeRequestListerService{testProjectData, fakeMergeRequestLister{emptyResponse: true}}, + withPayloadValidation(methodToPayload{http.MethodPost: &gitlab.ListProjectMergeRequestsOptions{}}), + withMethodCheck(http.MethodPost), + ) + data, status := getFailData(t, svc, request) assert(t, data.Message, "No merge requests found") - assert(t, data.Status, http.StatusNotFound) + assert(t, status, http.StatusNotFound) }) } diff --git a/cmd/app/middleware.go b/cmd/app/middleware.go new file mode 100644 index 00000000..ae991911 --- /dev/null +++ b/cmd/app/middleware.go @@ -0,0 +1,173 @@ +package app + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + + "github.com/go-playground/validator/v10" + "github.com/xanzy/go-gitlab" +) + +type mw func(http.Handler) http.Handler + +type payload string + +// Wraps a series of middleware around the base handler. Functions are called from bottom to top. +// The middlewares should call the serveHTTP method on their http.Handler argument to pass along the request. +func middleware(h http.Handler, middlewares ...mw) http.HandlerFunc { + for _, middleware := range middlewares { + h = middleware(h) + } + return h.ServeHTTP +} + +var validate = validator.New() + +type methodToPayload map[string]any + +type validatorMiddleware struct { + validate *validator.Validate + methodToPayload methodToPayload +} + +// Validates the fields in a payload and then attaches the validated payload to the request context so that +// subsequent handlers can use it. +func (p validatorMiddleware) handle(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + + if p.methodToPayload[r.Method] == nil { // If no payload to validate for this method type... + next.ServeHTTP(w, r) + return + } + + body, err := io.ReadAll(r.Body) + if err != nil { + handleError(w, err, "Could not read request body", http.StatusBadRequest) + return + } + + pl := p.methodToPayload[r.Method] + err = json.Unmarshal(body, &pl) + + if err != nil { + handleError(w, err, "Could not parse JSON request body", http.StatusBadRequest) + return + } + + err = p.validate.Struct(pl) + if err != nil { + switch err := err.(type) { + case validator.ValidationErrors: + handleError(w, formatValidationErrors(err), "Invalid payload", http.StatusBadRequest) + return + case *validator.InvalidValidationError: + handleError(w, err, "Invalid validation error", http.StatusInternalServerError) + return + } + } + + // Pass the parsed data so we don't have to re-parse it in the handler + ctx := context.WithValue(r.Context(), payload(payload("payload")), pl) + r = r.WithContext(ctx) + + next.ServeHTTP(w, r) + }) +} + +func withPayloadValidation(mtp methodToPayload) mw { + return validatorMiddleware{validate: validate, methodToPayload: mtp}.handle +} + +type withMrMiddleware struct { + data data + client MergeRequestLister +} + +// Gets the current merge request ID and attaches it to the projectInfo +func (m withMrMiddleware) handle(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // If the merge request is already attached, skip the middleware logic + if m.data.projectInfo.MergeId == 0 { + options := gitlab.ListProjectMergeRequestsOptions{ + Scope: gitlab.Ptr("all"), + SourceBranch: &m.data.gitInfo.BranchName, + TargetBranch: pluginOptions.ChosenTargetBranch, + } + + mergeRequests, _, err := m.client.ListProjectMergeRequests(m.data.projectInfo.ProjectId, &options) + if err != nil { + handleError(w, fmt.Errorf("Failed to list merge requests: %w", err), "Failed to list merge requests", http.StatusInternalServerError) + return + } + + if len(mergeRequests) == 0 { + err := fmt.Errorf("Branch '%s' does not have any merge requests", m.data.gitInfo.BranchName) + handleError(w, err, "No MRs Found", http.StatusNotFound) + return + } + + if len(mergeRequests) > 1 { + err := errors.New("Please call gitlab.choose_merge_request()") + handleError(w, err, "Multiple MRs found", http.StatusBadRequest) + return + } + + mergeIdInt := mergeRequests[0].IID + m.data.projectInfo.MergeId = mergeIdInt + } + + // Call the next handler if middleware succeeds + next.ServeHTTP(w, r) + }) +} + +// Att +func withMr(data data, client MergeRequestLister) mw { + return withMrMiddleware{data, client}.handle +} + +type methodMiddleware struct { + methods []string +} + +func (m methodMiddleware) handle(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + method := r.Method + for _, acceptableMethod := range m.methods { + if method == acceptableMethod { + next.ServeHTTP(w, r) + return + } + } + + w.Header().Set("Access-Control-Allow-Methods", http.MethodPut) + handleError(w, InvalidRequestError{fmt.Sprintf("Expected: %s", strings.Join(m.methods, "; "))}, "Invalid request type", http.StatusMethodNotAllowed) + }) +} + +func withMethodCheck(methods ...string) mw { + return methodMiddleware{methods: methods}.handle +} + +// Helper function to format validation errors into more readable strings +func formatValidationErrors(errors validator.ValidationErrors) error { + var s strings.Builder + for i, e := range errors { + if i > 0 { + s.WriteString("; ") + } + switch e.Tag() { + case "required": + s.WriteString(fmt.Sprintf("%s is required", e.Field())) + default: + s.WriteString(fmt.Sprintf("The field '%s' failed on validation on the '%s' tag", e.Field(), e.Tag())) + } + } + + return fmt.Errorf(s.String()) +} diff --git a/cmd/app/middleware_test.go b/cmd/app/middleware_test.go new file mode 100644 index 00000000..6e9afdfb --- /dev/null +++ b/cmd/app/middleware_test.go @@ -0,0 +1,114 @@ +package app + +import ( + "encoding/json" + "net/http" + "testing" + + "github.com/harrisoncramer/gitlab.nvim/cmd/app/git" +) + +type FakePayload struct { + Foo string `json:"foo" validate:"required"` +} + +type fakeHandler struct{} + +func (f fakeHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + data := SuccessResponse{Message: "Some message"} + j, _ := json.Marshal(data) + w.Write(j) // nolint + +} + +func TestMethodMiddleware(t *testing.T) { + t.Run("Fails a bad method", func(t *testing.T) { + request := makeRequest(t, http.MethodGet, "/foo", nil) + mw := withMethodCheck(http.MethodPost) + handler := middleware(fakeHandler{}, mw) + data, status := getFailData(t, handler, request) + assert(t, data.Message, "Invalid request type") + assert(t, data.Details, "Expected: POST") + assert(t, status, http.StatusMethodNotAllowed) + }) + t.Run("Fails bad method with multiple", func(t *testing.T) { + request := makeRequest(t, http.MethodGet, "/foo", nil) + mw := withMethodCheck(http.MethodPost, http.MethodPatch) + handler := middleware(fakeHandler{}, mw) + data, status := getFailData(t, handler, request) + assert(t, data.Message, "Invalid request type") + assert(t, data.Details, "Expected: POST; PATCH") + assert(t, status, http.StatusMethodNotAllowed) + }) + t.Run("Allows ok method through", func(t *testing.T) { + request := makeRequest(t, http.MethodGet, "/foo", nil) + mw := withMethodCheck(http.MethodGet) + handler := middleware(fakeHandler{}, mw) + data := getSuccessData(t, handler, request) + assert(t, data.Message, "Some message") + }) +} + +func TestWithMrMiddleware(t *testing.T) { + t.Run("Loads an MR ID into the projectInfo", func(t *testing.T) { + request := makeRequest(t, http.MethodGet, "/foo", nil) + d := data{ + projectInfo: &ProjectInfo{}, + gitInfo: &git.GitData{BranchName: "foo"}, + } + mw := withMr(d, fakeMergeRequestLister{}) + handler := middleware(fakeHandler{}, mw) + getSuccessData(t, handler, request) + if d.projectInfo.MergeId != 10 { + t.FailNow() + } + }) + t.Run("Handles when there are no MRs", func(t *testing.T) { + request := makeRequest(t, http.MethodGet, "/foo", nil) + d := data{ + projectInfo: &ProjectInfo{}, + gitInfo: &git.GitData{BranchName: "foo"}, + } + mw := withMr(d, fakeMergeRequestLister{emptyResponse: true}) + handler := middleware(fakeHandler{}, mw) + data, status := getFailData(t, handler, request) + assert(t, status, http.StatusNotFound) + assert(t, data.Message, "No MRs Found") + assert(t, data.Details, "Branch 'foo' does not have any merge requests") + }) + t.Run("Handles when there are too many MRs", func(t *testing.T) { + request := makeRequest(t, http.MethodGet, "/foo", nil) + d := data{ + projectInfo: &ProjectInfo{}, + gitInfo: &git.GitData{BranchName: "foo"}, + } + mw := withMr(d, fakeMergeRequestLister{multipleMrs: true}) + handler := middleware(fakeHandler{}, mw) + data, status := getFailData(t, handler, request) + assert(t, status, http.StatusBadRequest) + assert(t, data.Message, "Multiple MRs found") + assert(t, data.Details, "Please call gitlab.choose_merge_request()") + }) +} + +func TestValidatorMiddleware(t *testing.T) { + t.Run("Should error with missing field", func(t *testing.T) { + request := makeRequest(t, http.MethodPost, "/foo", FakePayload{}) // No Foo field + data, status := getFailData(t, middleware( + fakeHandler{}, + withPayloadValidation(methodToPayload{http.MethodPost: &FakePayload{}}), + ), request) + assert(t, data.Message, "Invalid payload") + assert(t, data.Details, "Foo is required") + assert(t, status, http.StatusBadRequest) + }) + t.Run("Should allow valid payload through", func(t *testing.T) { + request := makeRequest(t, http.MethodPost, "/foo", FakePayload{Foo: "Some payload"}) + data := getSuccessData(t, middleware( + fakeHandler{}, + withPayloadValidation(methodToPayload{http.MethodPost: &FakePayload{}}), + ), request) + assert(t, data.Message, "Some message") + }) +} diff --git a/cmd/app/pipeline.go b/cmd/app/pipeline.go index 587c1e0c..8640613a 100644 --- a/cmd/app/pipeline.go +++ b/cmd/app/pipeline.go @@ -43,16 +43,12 @@ type pipelineService struct { pipelineHandler fetches information about the current pipeline, and retriggers a pipeline run. For more detailed information about a given job in a pipeline, see the jobHandler function */ -func (a pipelineService) handler(w http.ResponseWriter, r *http.Request) { +func (a pipelineService) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch r.Method { case http.MethodGet: a.GetPipelineAndJobs(w, r) case http.MethodPost: a.RetriggerPipeline(w, r) - default: - w.Header().Set("Content-Type", "application/json") - w.Header().Set("Access-Control-Allow-Methods", fmt.Sprintf("%s, %s", http.MethodGet, http.MethodPost)) - handleError(w, InvalidRequestError{}, "Expected GET or POST", http.StatusMethodNotAllowed) } } @@ -100,7 +96,7 @@ func (a pipelineService) GetPipelineAndJobs(w http.ResponseWriter, r *http.Reque } if pipeline == nil { - handleError(w, GenericError{endpoint: "/pipeline"}, fmt.Sprintf("No pipeline found for %s branch", a.gitInfo.BranchName), http.StatusInternalServerError) + handleError(w, GenericError{r.URL.Path}, fmt.Sprintf("No pipeline found for %s branch", a.gitInfo.BranchName), http.StatusInternalServerError) return } @@ -112,16 +108,13 @@ func (a pipelineService) GetPipelineAndJobs(w http.ResponseWriter, r *http.Reque } if res.StatusCode >= 300 { - handleError(w, GenericError{endpoint: "/pipeline"}, "Could not get pipeline jobs", res.StatusCode) + handleError(w, GenericError{r.URL.Path}, "Could not get pipeline jobs", res.StatusCode) return } w.WriteHeader(http.StatusOK) response := GetPipelineAndJobsResponse{ - SuccessResponse: SuccessResponse{ - Status: http.StatusOK, - Message: "Pipeline retrieved", - }, + SuccessResponse: SuccessResponse{Message: "Pipeline retrieved"}, Pipeline: PipelineWithJobs{ LatestPipeline: pipeline, Jobs: jobs, @@ -153,17 +146,14 @@ func (a pipelineService) RetriggerPipeline(w http.ResponseWriter, r *http.Reques } if res.StatusCode >= 300 { - handleError(w, GenericError{endpoint: "/pipeline"}, "Could not retrigger pipeline", res.StatusCode) + handleError(w, GenericError{r.URL.Path}, "Could not retrigger pipeline", res.StatusCode) return } w.WriteHeader(http.StatusOK) response := RetriggerPipelineResponse{ - SuccessResponse: SuccessResponse{ - Message: "Pipeline retriggered", - Status: http.StatusOK, - }, - LatestPipeline: pipeline, + SuccessResponse: SuccessResponse{Message: "Pipeline retriggered"}, + LatestPipeline: pipeline, } err = json.NewEncoder(w).Encode(response) diff --git a/cmd/app/pipeline_test.go b/cmd/app/pipeline_test.go index d8fcecb9..db354946 100644 --- a/cmd/app/pipeline_test.go +++ b/cmd/app/pipeline_test.go @@ -38,27 +38,29 @@ func (f fakePipelineManager) RetryPipelineBuild(pid interface{}, pipeline int, o func TestPipelineGetter(t *testing.T) { t.Run("Gets all pipeline jobs", func(t *testing.T) { request := makeRequest(t, http.MethodGet, "/pipeline", nil) - svc := pipelineService{testProjectData, fakePipelineManager{}, FakeGitManager{}} + svc := middleware( + pipelineService{testProjectData, fakePipelineManager{}, FakeGitManager{}}, + withMethodCheck(http.MethodGet), + ) data := getSuccessData(t, svc, request) assert(t, data.Message, "Pipeline retrieved") - assert(t, data.Status, http.StatusOK) - }) - t.Run("Disallows non-GET, non-POST methods", func(t *testing.T) { - request := makeRequest(t, http.MethodPatch, "/pipeline", nil) - svc := pipelineService{testProjectData, fakePipelineManager{}, FakeGitManager{}} - data := getFailData(t, svc, request) - checkBadMethod(t, data, http.MethodGet, http.MethodPost) }) t.Run("Handles errors from Gitlab client", func(t *testing.T) { request := makeRequest(t, http.MethodGet, "/pipeline", nil) - svc := pipelineService{testProjectData, fakePipelineManager{testBase{errFromGitlab: true}}, FakeGitManager{}} - data := getFailData(t, svc, request) + svc := middleware( + pipelineService{testProjectData, fakePipelineManager{testBase{errFromGitlab: true}}, FakeGitManager{}}, + withMethodCheck(http.MethodGet), + ) + data, _ := getFailData(t, svc, request) checkErrorFromGitlab(t, data, "Failed to get latest pipeline for some-branch branch") }) t.Run("Handles non-200s from Gitlab client", func(t *testing.T) { request := makeRequest(t, http.MethodGet, "/pipeline", nil) - svc := pipelineService{testProjectData, fakePipelineManager{testBase: testBase{status: http.StatusSeeOther}}, FakeGitManager{}} - data := getFailData(t, svc, request) + svc := middleware( + pipelineService{testProjectData, fakePipelineManager{testBase{status: http.StatusSeeOther}}, FakeGitManager{}}, + withMethodCheck(http.MethodGet), + ) + data, _ := getFailData(t, svc, request) assert(t, data.Message, "Failed to get latest pipeline for some-branch branch") // Expected, we treat this as an error }) } @@ -66,21 +68,29 @@ func TestPipelineGetter(t *testing.T) { func TestPipelineTrigger(t *testing.T) { t.Run("Retriggers pipeline", func(t *testing.T) { request := makeRequest(t, http.MethodPost, "/pipeline/trigger/3", nil) - svc := pipelineService{testProjectData, fakePipelineManager{}, FakeGitManager{}} + svc := middleware( + pipelineService{testProjectData, fakePipelineManager{}, FakeGitManager{}}, + withMethodCheck(http.MethodPost), + ) data := getSuccessData(t, svc, request) assert(t, data.Message, "Pipeline retriggered") - assert(t, data.Status, http.StatusOK) }) t.Run("Handles errors from Gitlab client", func(t *testing.T) { request := makeRequest(t, http.MethodPost, "/pipeline/trigger/3", nil) - svc := pipelineService{testProjectData, fakePipelineManager{testBase{errFromGitlab: true}}, FakeGitManager{}} - data := getFailData(t, svc, request) + svc := middleware( + pipelineService{testProjectData, fakePipelineManager{testBase{errFromGitlab: true}}, FakeGitManager{}}, + withMethodCheck(http.MethodPost), + ) + data, _ := getFailData(t, svc, request) checkErrorFromGitlab(t, data, "Could not retrigger pipeline") }) t.Run("Handles non-200s from Gitlab client", func(t *testing.T) { request := makeRequest(t, http.MethodPost, "/pipeline/trigger/3", nil) - svc := pipelineService{testProjectData, fakePipelineManager{testBase: testBase{status: http.StatusSeeOther}}, FakeGitManager{}} - data := getFailData(t, svc, request) - checkNon200(t, data, "Could not retrigger pipeline", "/pipeline") + svc := middleware( + pipelineService{testProjectData, fakePipelineManager{testBase{status: http.StatusSeeOther}}, FakeGitManager{}}, + withMethodCheck(http.MethodPost), + ) + data, _ := getFailData(t, svc, request) + checkNon200(t, data, "Could not retrigger pipeline", "/pipeline/trigger/3") }) } diff --git a/cmd/app/reply.go b/cmd/app/reply.go index eda9c80a..38deed86 100644 --- a/cmd/app/reply.go +++ b/cmd/app/reply.go @@ -2,7 +2,6 @@ package app import ( "encoding/json" - "io" "net/http" "time" @@ -10,8 +9,8 @@ import ( ) type ReplyRequest struct { - DiscussionId string `json:"discussion_id"` - Reply string `json:"reply"` + DiscussionId string `json:"discussion_id" validate:"required"` + Reply string `json:"reply" validate:"required"` IsDraft bool `json:"is_draft"` } @@ -30,28 +29,8 @@ type replyService struct { } /* replyHandler sends a reply to a note or comment */ -func (a replyService) handler(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - if r.Method != http.MethodPost { - w.Header().Set("Access-Control-Allow-Methods", http.MethodPost) - handleError(w, InvalidRequestError{}, "Expected POST", http.StatusMethodNotAllowed) - return - } - - body, err := io.ReadAll(r.Body) - if err != nil { - handleError(w, err, "Could not read request body", http.StatusBadRequest) - return - } - - defer r.Body.Close() - var replyRequest ReplyRequest - err = json.Unmarshal(body, &replyRequest) - - if err != nil { - handleError(w, err, "Could not read JSON from request", http.StatusBadRequest) - return - } +func (a replyService) ServeHTTP(w http.ResponseWriter, r *http.Request) { + replyRequest := r.Context().Value(payload("payload")).(*ReplyRequest) now := time.Now() options := gitlab.AddMergeRequestDiscussionNoteOptions{ @@ -67,17 +46,14 @@ func (a replyService) handler(w http.ResponseWriter, r *http.Request) { } if res.StatusCode >= 300 { - handleError(w, GenericError{endpoint: "/mr/reply"}, "Could not leave reply", res.StatusCode) + handleError(w, GenericError{r.URL.Path}, "Could not leave reply", res.StatusCode) return } w.WriteHeader(http.StatusOK) response := ReplyResponse{ - SuccessResponse: SuccessResponse{ - Message: "Replied to comment", - Status: http.StatusOK, - }, - Note: note, + SuccessResponse: SuccessResponse{Message: "Replied to comment"}, + Note: note, } err = json.NewEncoder(w).Encode(response) diff --git a/cmd/app/reply_test.go b/cmd/app/reply_test.go index a57c2716..27d4416e 100644 --- a/cmd/app/reply_test.go +++ b/cmd/app/reply_test.go @@ -24,22 +24,36 @@ func TestReplyHandler(t *testing.T) { var testReplyRequest = ReplyRequest{DiscussionId: "abc123", Reply: "Some Reply", IsDraft: false} t.Run("Sends a reply", func(t *testing.T) { request := makeRequest(t, http.MethodPost, "/mr/reply", testReplyRequest) - svc := replyService{testProjectData, fakeReplyManager{}} + svc := middleware( + replyService{testProjectData, fakeReplyManager{}}, + withMr(testProjectData, fakeMergeRequestLister{}), + withPayloadValidation(methodToPayload{http.MethodPost: &ReplyRequest{}}), + withMethodCheck(http.MethodPost), + ) data := getSuccessData(t, svc, request) assert(t, data.Message, "Replied to comment") - assert(t, data.Status, http.StatusOK) }) t.Run("Handles errors from Gitlab client", func(t *testing.T) { request := makeRequest(t, http.MethodPost, "/mr/reply", testReplyRequest) - svc := replyService{testProjectData, fakeReplyManager{testBase{errFromGitlab: true}}} - data := getFailData(t, svc, request) + svc := middleware( + replyService{testProjectData, fakeReplyManager{testBase{errFromGitlab: true}}}, + withMr(testProjectData, fakeMergeRequestLister{}), + withPayloadValidation(methodToPayload{http.MethodPost: &ReplyRequest{}}), + withMethodCheck(http.MethodPost), + ) + data, _ := getFailData(t, svc, request) checkErrorFromGitlab(t, data, "Could not leave reply") }) t.Run("Handles non-200s from Gitlab client", func(t *testing.T) { request := makeRequest(t, http.MethodPost, "/mr/reply", testReplyRequest) - svc := replyService{testProjectData, fakeReplyManager{testBase{status: http.StatusSeeOther}}} - data := getFailData(t, svc, request) + svc := middleware( + replyService{testProjectData, fakeReplyManager{testBase{status: http.StatusSeeOther}}}, + withMr(testProjectData, fakeMergeRequestLister{}), + withPayloadValidation(methodToPayload{http.MethodPost: &ReplyRequest{}}), + withMethodCheck(http.MethodPost), + ) + data, _ := getFailData(t, svc, request) checkNon200(t, data, "Could not leave reply", "/mr/reply") }) } diff --git a/cmd/app/resolve_discussion.go b/cmd/app/resolve_discussion.go index 89284574..2b55acc0 100644 --- a/cmd/app/resolve_discussion.go +++ b/cmd/app/resolve_discussion.go @@ -3,17 +3,11 @@ package app import ( "encoding/json" "fmt" - "io" "net/http" "github.com/xanzy/go-gitlab" ) -type DiscussionResolveRequest struct { - DiscussionID string `json:"discussion_id"` - Resolved bool `json:"resolved"` -} - type DiscussionResolver interface { ResolveMergeRequestDiscussion(pid interface{}, mergeRequest int, discussion string, opt *gitlab.ResolveMergeRequestDiscussionOptions, options ...gitlab.RequestOptionFunc) (*gitlab.Discussion, *gitlab.Response, error) } @@ -23,40 +17,24 @@ type discussionsResolutionService struct { client DiscussionResolver } -/* discussionsResolveHandler sets a discussion to be "resolved" or not resolved, depending on the payload */ -func (a discussionsResolutionService) handler(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - if r.Method != http.MethodPut { - w.Header().Set("Access-Control-Allow-Methods", http.MethodPut) - handleError(w, InvalidRequestError{}, "Expected PUT", http.StatusMethodNotAllowed) - return - } - - body, err := io.ReadAll(r.Body) - if err != nil { - handleError(w, err, "Could not read request body", http.StatusBadRequest) - return - } - - defer r.Body.Close() - - var resolveDiscussionRequest DiscussionResolveRequest - err = json.Unmarshal(body, &resolveDiscussionRequest) +type DiscussionResolveRequest struct { + DiscussionID string `json:"discussion_id" validate:"required"` + Resolved bool `json:"resolved"` +} - if err != nil { - handleError(w, err, "Could not read JSON from request", http.StatusBadRequest) - return - } +/* discussionsResolveHandler sets a discussion to be "resolved" or not resolved, depending on the payload */ +func (a discussionsResolutionService) ServeHTTP(w http.ResponseWriter, r *http.Request) { + payload := r.Context().Value(payload("payload")).(*DiscussionResolveRequest) _, res, err := a.client.ResolveMergeRequestDiscussion( a.projectInfo.ProjectId, a.projectInfo.MergeId, - resolveDiscussionRequest.DiscussionID, - &gitlab.ResolveMergeRequestDiscussionOptions{Resolved: &resolveDiscussionRequest.Resolved}, + payload.DiscussionID, + &gitlab.ResolveMergeRequestDiscussionOptions{Resolved: &payload.Resolved}, ) friendlyName := "unresolve" - if resolveDiscussionRequest.Resolved { + if payload.Resolved { friendlyName = "resolve" } @@ -66,15 +44,12 @@ func (a discussionsResolutionService) handler(w http.ResponseWriter, r *http.Req } if res.StatusCode >= 300 { - handleError(w, GenericError{endpoint: "/mr/discussions/resolve"}, fmt.Sprintf("Could not %s discussion", friendlyName), res.StatusCode) + handleError(w, GenericError{r.URL.Path}, fmt.Sprintf("Could not %s discussion", friendlyName), res.StatusCode) return } w.WriteHeader(http.StatusOK) - response := SuccessResponse{ - Message: fmt.Sprintf("Discussion %sd", friendlyName), - Status: http.StatusOK, - } + response := SuccessResponse{Message: fmt.Sprintf("Discussion %sd", friendlyName)} err = json.NewEncoder(w).Encode(response) if err != nil { diff --git a/cmd/app/resolve_discussion_test.go b/cmd/app/resolve_discussion_test.go new file mode 100644 index 00000000..18918e1a --- /dev/null +++ b/cmd/app/resolve_discussion_test.go @@ -0,0 +1,84 @@ +package app + +import ( + "net/http" + "testing" + + "github.com/xanzy/go-gitlab" +) + +type fakeDiscussionResolver struct { + testBase +} + +func (f fakeDiscussionResolver) ResolveMergeRequestDiscussion(pid interface{}, mergeRequest int, discussion string, opt *gitlab.ResolveMergeRequestDiscussionOptions, options ...gitlab.RequestOptionFunc) (*gitlab.Discussion, *gitlab.Response, error) { + resp, err := f.handleGitlabError() + if err != nil { + return nil, nil, err + } + + return &gitlab.Discussion{}, resp, err +} + +func TestResolveDiscussion(t *testing.T) { + var testResolveMergeRequestPayload = DiscussionResolveRequest{ + DiscussionID: "abc123", + Resolved: true, + } + + t.Run("Resolves a discussion", func(t *testing.T) { + svc := middleware( + discussionsResolutionService{testProjectData, fakeDiscussionResolver{}}, + withMr(testProjectData, fakeMergeRequestLister{}), + withPayloadValidation(methodToPayload{http.MethodPut: &DiscussionResolveRequest{}}), + withMethodCheck(http.MethodPut), + ) + request := makeRequest(t, http.MethodPut, "/mr/discussions/resolve", testResolveMergeRequestPayload) + data := getSuccessData(t, svc, request) + assert(t, data.Message, "Discussion resolved") + }) + + t.Run("Unresolves a discussion", func(t *testing.T) { + payload := testResolveMergeRequestPayload + payload.Resolved = false + svc := middleware( + discussionsResolutionService{testProjectData, fakeDiscussionResolver{}}, + withMr(testProjectData, fakeMergeRequestLister{}), + withPayloadValidation(methodToPayload{http.MethodPut: &DiscussionResolveRequest{}}), + withMethodCheck(http.MethodPut), + ) + request := makeRequest(t, http.MethodPut, "/mr/discussions/resolve", payload) + data := getSuccessData(t, svc, request) + assert(t, data.Message, "Discussion unresolved") + }) + + t.Run("Requires a discussion ID", func(t *testing.T) { + payload := testResolveMergeRequestPayload + payload.DiscussionID = "" + svc := middleware( + discussionsResolutionService{testProjectData, fakeDiscussionResolver{}}, + withMr(testProjectData, fakeMergeRequestLister{}), + withPayloadValidation(methodToPayload{http.MethodPut: &DiscussionResolveRequest{}}), + withMethodCheck(http.MethodPut), + ) + request := makeRequest(t, http.MethodPut, "/mr/discussions/resolve", payload) + data, status := getFailData(t, svc, request) + assert(t, data.Message, "Invalid payload") + assert(t, data.Details, "DiscussionID is required") + assert(t, status, http.StatusBadRequest) + }) + + t.Run("Handles error from Gitlab", func(t *testing.T) { + svc := middleware( + discussionsResolutionService{testProjectData, fakeDiscussionResolver{testBase: testBase{errFromGitlab: true}}}, + withMr(testProjectData, fakeMergeRequestLister{}), + withPayloadValidation(methodToPayload{http.MethodPut: &DiscussionResolveRequest{}}), + withMethodCheck(http.MethodPut), + ) + request := makeRequest(t, http.MethodPut, "/mr/discussions/resolve", testResolveMergeRequestPayload) + data, status := getFailData(t, svc, request) + assert(t, data.Message, "Could not resolve discussion") + assert(t, data.Details, "Some error from Gitlab") + assert(t, status, http.StatusInternalServerError) + }) +} diff --git a/cmd/app/response_types.go b/cmd/app/response_types.go index 2fb3af04..c34fa0a2 100644 --- a/cmd/app/response_types.go +++ b/cmd/app/response_types.go @@ -7,12 +7,10 @@ import ( type ErrorResponse struct { Message string `json:"message"` Details string `json:"details"` - Status int `json:"status"` } type SuccessResponse struct { Message string `json:"message"` - Status int `json:"status"` } type GenericError struct { @@ -23,8 +21,8 @@ func (e GenericError) Error() string { return fmt.Sprintf("An error occurred on the %s endpoint", e.endpoint) } -type InvalidRequestError struct{} +type InvalidRequestError struct{ msg string } func (e InvalidRequestError) Error() string { - return "Invalid request type" + return e.msg } diff --git a/cmd/app/reviewer.go b/cmd/app/reviewer.go index 1ffd6a0e..7b98bdee 100644 --- a/cmd/app/reviewer.go +++ b/cmd/app/reviewer.go @@ -2,14 +2,13 @@ package app import ( "encoding/json" - "io" "net/http" "github.com/xanzy/go-gitlab" ) type ReviewerUpdateRequest struct { - Ids []int `json:"ids"` + Ids []int `json:"ids" validate:"required"` } type ReviewerUpdateResponse struct { @@ -32,31 +31,11 @@ type reviewerService struct { } /* reviewersHandler adds or removes reviewers from an MR */ -func (a reviewerService) handler(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - if r.Method != http.MethodPut { - w.Header().Set("Access-Control-Allow-Methods", http.MethodPut) - handleError(w, InvalidRequestError{}, "Expected PUT", http.StatusMethodNotAllowed) - return - } - - body, err := io.ReadAll(r.Body) - if err != nil { - handleError(w, err, "Could not read request body", http.StatusBadRequest) - return - } - - defer r.Body.Close() - var reviewerUpdateRequest ReviewerUpdateRequest - err = json.Unmarshal(body, &reviewerUpdateRequest) - - if err != nil { - handleError(w, err, "Could not read JSON from request", http.StatusBadRequest) - return - } +func (a reviewerService) ServeHTTP(w http.ResponseWriter, r *http.Request) { + payload := r.Context().Value(payload("payload")).(*ReviewerUpdateRequest) mr, res, err := a.client.UpdateMergeRequest(a.projectInfo.ProjectId, a.projectInfo.MergeId, &gitlab.UpdateMergeRequestOptions{ - ReviewerIDs: &reviewerUpdateRequest.Ids, + ReviewerIDs: &payload.Ids, }) if err != nil { @@ -65,17 +44,14 @@ func (a reviewerService) handler(w http.ResponseWriter, r *http.Request) { } if res.StatusCode >= 300 { - handleError(w, GenericError{endpoint: "/mr/reviewer"}, "Could not modify merge request reviewers", res.StatusCode) + handleError(w, GenericError{r.URL.Path}, "Could not modify merge request reviewers", res.StatusCode) return } w.WriteHeader(http.StatusOK) response := ReviewerUpdateResponse{ - SuccessResponse: SuccessResponse{ - Message: "Reviewers updated", - Status: http.StatusOK, - }, - Reviewers: mr.Reviewers, + SuccessResponse: SuccessResponse{Message: "Reviewers updated"}, + Reviewers: mr.Reviewers, } err = json.NewEncoder(w).Encode(response) diff --git a/cmd/app/revisions.go b/cmd/app/revisions.go index f423f656..16e68226 100644 --- a/cmd/app/revisions.go +++ b/cmd/app/revisions.go @@ -25,13 +25,7 @@ type revisionsService struct { revisionsHandler gets revision information about the current MR. This data is not used directly but is a precursor API call for other functionality */ -func (a revisionsService) handler(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - if r.Method != http.MethodGet { - w.Header().Set("Access-Control-Allow-Methods", http.MethodGet) - handleError(w, InvalidRequestError{}, "Expected GET", http.StatusMethodNotAllowed) - return - } +func (a revisionsService) ServeHTTP(w http.ResponseWriter, r *http.Request) { versionInfo, res, err := a.client.GetMergeRequestDiffVersions(a.projectInfo.ProjectId, a.projectInfo.MergeId, &gitlab.GetMergeRequestDiffVersionsOptions{}) if err != nil { @@ -40,17 +34,14 @@ func (a revisionsService) handler(w http.ResponseWriter, r *http.Request) { } if res.StatusCode >= 300 { - handleError(w, GenericError{endpoint: "/mr/revisions"}, "Could not get diff version info", res.StatusCode) + handleError(w, GenericError{r.URL.Path}, "Could not get diff version info", res.StatusCode) return } w.WriteHeader(http.StatusOK) response := RevisionsResponse{ - SuccessResponse: SuccessResponse{ - Message: "Revisions fetched successfully", - Status: http.StatusOK, - }, - Revisions: versionInfo, + SuccessResponse: SuccessResponse{Message: "Revisions fetched successfully"}, + Revisions: versionInfo, } err = json.NewEncoder(w).Encode(response) diff --git a/cmd/app/revoke.go b/cmd/app/revoke.go index c065b88f..ae45f805 100644 --- a/cmd/app/revoke.go +++ b/cmd/app/revoke.go @@ -17,13 +17,7 @@ type mergeRequestRevokerService struct { } /* revokeHandler revokes approval for the current merge request */ -func (a mergeRequestRevokerService) handler(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - if r.Method != http.MethodPost { - w.Header().Set("Access-Control-Allow-Methods", http.MethodPost) - handleError(w, InvalidRequestError{}, "Expected POST", http.StatusMethodNotAllowed) - return - } +func (a mergeRequestRevokerService) ServeHTTP(w http.ResponseWriter, r *http.Request) { res, err := a.client.UnapproveMergeRequest(a.projectInfo.ProjectId, a.projectInfo.MergeId, nil, nil) @@ -33,15 +27,12 @@ func (a mergeRequestRevokerService) handler(w http.ResponseWriter, r *http.Reque } if res.StatusCode >= 300 { - handleError(w, GenericError{endpoint: "/mr/revoke"}, "Could not revoke approval", res.StatusCode) + handleError(w, GenericError{r.URL.Path}, "Could not revoke approval", res.StatusCode) return } w.WriteHeader(http.StatusOK) - response := SuccessResponse{ - Message: "Success! Revoked MR approval", - Status: http.StatusOK, - } + response := SuccessResponse{Message: "Success! Revoked MR approval"} err = json.NewEncoder(w).Encode(response) if err != nil { diff --git a/cmd/app/server.go b/cmd/app/server.go index 7357a1bc..69b9c60d 100644 --- a/cmd/app/server.go +++ b/cmd/app/server.go @@ -76,7 +76,7 @@ type data struct { type optFunc func(a *data) error -func CreateRouter(gitlabClient *Client, projectInfo *ProjectInfo, s ShutdownHandler, optFuncs ...optFunc) *http.ServeMux { +func CreateRouter(gitlabClient *Client, projectInfo *ProjectInfo, s ShutdownHandler, optFuncs ...optFunc) http.Handler { m := http.NewServeMux() d := data{ @@ -92,37 +92,149 @@ func CreateRouter(gitlabClient *Client, projectInfo *ProjectInfo, s ShutdownHand } } - m.HandleFunc("/mr/approve", withMr(mergeRequestApproverService{d, gitlabClient}, d, gitlabClient)) - m.HandleFunc("/mr/comment", withMr(commentService{d, gitlabClient}, d, gitlabClient)) - m.HandleFunc("/mr/merge", withMr(mergeRequestAccepterService{d, gitlabClient}, d, gitlabClient)) - m.HandleFunc("/mr/discussions/list", withMr(discussionsListerService{d, gitlabClient}, d, gitlabClient)) - m.HandleFunc("/mr/discussions/resolve", withMr(discussionsResolutionService{d, gitlabClient}, d, gitlabClient)) - m.HandleFunc("/mr/info", withMr(infoService{d, gitlabClient}, d, gitlabClient)) - m.HandleFunc("/mr/assignee", withMr(assigneesService{d, gitlabClient}, d, gitlabClient)) - m.HandleFunc("/mr/summary", withMr(summaryService{d, gitlabClient}, d, gitlabClient)) - m.HandleFunc("/mr/reviewer", withMr(reviewerService{d, gitlabClient}, d, gitlabClient)) - m.HandleFunc("/mr/revisions", withMr(revisionsService{d, gitlabClient}, d, gitlabClient)) - m.HandleFunc("/mr/reply", withMr(replyService{d, gitlabClient}, d, gitlabClient)) - m.HandleFunc("/mr/label", withMr(labelService{d, gitlabClient}, d, gitlabClient)) - m.HandleFunc("/mr/revoke", withMr(mergeRequestRevokerService{d, gitlabClient}, d, gitlabClient)) - m.HandleFunc("/mr/awardable/note/", withMr(emojiService{d, gitlabClient}, d, gitlabClient)) - m.HandleFunc("/mr/draft_notes/", withMr(draftNoteService{d, gitlabClient}, d, gitlabClient)) - m.HandleFunc("/mr/draft_notes/publish", withMr(draftNotePublisherService{d, gitlabClient}, d, gitlabClient)) - - m.HandleFunc("/pipeline", pipelineService{d, gitlabClient, git.Git{}}.handler) - m.HandleFunc("/pipeline/trigger/", pipelineService{d, gitlabClient, git.Git{}}.handler) - m.HandleFunc("/users/me", meService{d, gitlabClient}.handler) - m.HandleFunc("/attachment", attachmentService{data: d, client: gitlabClient, fileReader: attachmentReader{}}.handler) - m.HandleFunc("/create_mr", mergeRequestCreatorService{d, gitlabClient}.handler) - m.HandleFunc("/job", traceFileService{d, gitlabClient}.handler) - m.HandleFunc("/project/members", projectMemberService{d, gitlabClient}.handler) - m.HandleFunc("/merge_requests", mergeRequestListerService{d, gitlabClient}.handler) - m.HandleFunc("/merge_requests_by_username", mergeRequestListerByUsernameService{d, gitlabClient}.handler) + m.HandleFunc("/mr/approve", middleware( + mergeRequestApproverService{d, gitlabClient}, // These functions are called from bottom to top... + withMr(d, gitlabClient), + withMethodCheck(http.MethodPost), + )) + m.HandleFunc("/mr/comment", middleware( + commentService{d, gitlabClient}, + withMr(d, gitlabClient), + withPayloadValidation(methodToPayload{ + http.MethodPost: &PostCommentRequest{}, + http.MethodDelete: &DeleteCommentRequest{}, + http.MethodPatch: &EditCommentRequest{}, + }), + withMethodCheck(http.MethodPost, http.MethodDelete, http.MethodPatch), + )) + m.HandleFunc("/mr/merge", middleware( + mergeRequestAccepterService{d, gitlabClient}, + withMr(d, gitlabClient), + withPayloadValidation(methodToPayload{http.MethodPost: &AcceptMergeRequestRequest{}}), + withMethodCheck(http.MethodPost), + )) + m.HandleFunc("/mr/discussions/list", middleware( + discussionsListerService{d, gitlabClient}, + withMr(d, gitlabClient), + withPayloadValidation(methodToPayload{http.MethodPost: &DiscussionsRequest{}}), + withMethodCheck(http.MethodPost), + )) + m.HandleFunc("/mr/discussions/resolve", middleware( + discussionsResolutionService{d, gitlabClient}, + withMr(d, gitlabClient), + withPayloadValidation(methodToPayload{http.MethodPut: &DiscussionResolveRequest{}}), + withMethodCheck(http.MethodPut), + )) + m.HandleFunc("/mr/info", middleware( + infoService{d, gitlabClient}, + withMr(d, gitlabClient), + withMethodCheck(http.MethodGet), + )) + m.HandleFunc("/mr/assignee", middleware( + assigneesService{d, gitlabClient}, + withMr(d, gitlabClient), + withPayloadValidation(methodToPayload{http.MethodPut: &AssigneeUpdateRequest{}}), + withMethodCheck(http.MethodPut), + )) + m.HandleFunc("/mr/summary", middleware( + summaryService{d, gitlabClient}, + withMr(d, gitlabClient), + withPayloadValidation(methodToPayload{http.MethodPut: &SummaryUpdateRequest{}}), + withMethodCheck(http.MethodPut), + )) + m.HandleFunc("/mr/reviewer", middleware( + reviewerService{d, gitlabClient}, + withMr(d, gitlabClient), + withPayloadValidation(methodToPayload{http.MethodPut: &ReviewerUpdateRequest{}}), + withMethodCheck(http.MethodPut), + )) + m.HandleFunc("/mr/revisions", middleware( + revisionsService{d, gitlabClient}, + withMr(d, gitlabClient), + withMethodCheck(http.MethodGet), + )) + m.HandleFunc("/mr/reply", middleware( + replyService{d, gitlabClient}, + withMr(d, gitlabClient), + withPayloadValidation(methodToPayload{http.MethodPost: &ReplyRequest{}}), + withMethodCheck(http.MethodPost), + )) + m.HandleFunc("/mr/label", middleware( + labelService{d, gitlabClient}, + withMr(d, gitlabClient), + )) + m.HandleFunc("/mr/revoke", middleware( + mergeRequestRevokerService{d, gitlabClient}, + withMethodCheck(http.MethodPost), + withMr(d, gitlabClient), + )) + m.HandleFunc("/mr/awardable/note/", middleware( + emojiService{d, gitlabClient}, + withMethodCheck(http.MethodPost, http.MethodDelete), + withMr(d, gitlabClient), + )) + m.HandleFunc("/mr/draft_notes/", middleware( + draftNoteService{d, gitlabClient}, + withMr(d, gitlabClient), + withPayloadValidation(methodToPayload{ + http.MethodPost: &PostDraftNoteRequest{}, + http.MethodPatch: &UpdateDraftNoteRequest{}, + }), + withMethodCheck(http.MethodGet, http.MethodPost, http.MethodPatch, http.MethodDelete), + )) + m.HandleFunc("/mr/draft_notes/publish", middleware( + draftNotePublisherService{d, gitlabClient}, + withMr(d, gitlabClient), + withPayloadValidation(methodToPayload{http.MethodPost: &DraftNotePublishRequest{}}), + withMethodCheck(http.MethodPost), + )) + + m.HandleFunc("/pipeline", middleware( + pipelineService{d, gitlabClient, git.Git{}}, + withMethodCheck(http.MethodGet), + )) + m.HandleFunc("/pipeline/trigger/", middleware( + pipelineService{d, gitlabClient, git.Git{}}, + withMethodCheck(http.MethodPost), + )) + m.HandleFunc("/users/me", middleware( + meService{d, gitlabClient}, + withMethodCheck(http.MethodGet), + )) + m.HandleFunc("/attachment", middleware( + attachmentService{data: d, client: gitlabClient, fileReader: attachmentReader{}}, + withPayloadValidation(methodToPayload{http.MethodPost: &AttachmentRequest{}}), + withMethodCheck(http.MethodPost), + )) + m.HandleFunc("/create_mr", middleware( + mergeRequestCreatorService{d, gitlabClient}, + withPayloadValidation(methodToPayload{http.MethodPost: &CreateMrRequest{}}), + withMethodCheck(http.MethodPost), + )) + m.HandleFunc("/job", middleware( + traceFileService{d, gitlabClient}, + withPayloadValidation(methodToPayload{http.MethodGet: &JobTraceRequest{}}), + withMethodCheck(http.MethodGet), + )) + m.HandleFunc("/project/members", middleware( + projectMemberService{d, gitlabClient}, + withMethodCheck(http.MethodGet), + )) + m.HandleFunc("/merge_requests", middleware( + mergeRequestListerService{d, gitlabClient}, + withPayloadValidation(methodToPayload{http.MethodPost: &gitlab.ListProjectMergeRequestsOptions{}}), // TODO: How to validate external object + withMethodCheck(http.MethodPost), + )) + m.HandleFunc("/merge_requests_by_username", middleware( + mergeRequestListerByUsernameService{d, gitlabClient}, + withPayloadValidation(methodToPayload{http.MethodPost: &MergeRequestByUsernameRequest{}}), + withMethodCheck(http.MethodPost), + )) m.HandleFunc("/shutdown", s.shutdownHandler) m.Handle("/ping", http.HandlerFunc(pingHandler)) - return m + return LoggingServer{handler: m} } /* Used to check whether the server has started yet */ @@ -155,45 +267,3 @@ func createListener() (l net.Listener) { return l } - -type ServiceWithHandler interface { - handler(http.ResponseWriter, *http.Request) -} - -/* withMr is a Middlware that gets the current merge request ID and attaches it to the projectInfo */ -func withMr(svc ServiceWithHandler, c data, client MergeRequestLister) http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - // If the merge request is already attached, skip the middleware logic - if c.projectInfo.MergeId == 0 { - options := gitlab.ListProjectMergeRequestsOptions{ - Scope: gitlab.Ptr("all"), - SourceBranch: &c.gitInfo.BranchName, - TargetBranch: pluginOptions.ChosenTargetBranch, - } - - mergeRequests, _, err := client.ListProjectMergeRequests(c.projectInfo.ProjectId, &options) - if err != nil { - handleError(w, fmt.Errorf("Failed to list merge requests: %w", err), "Failed to list merge requests", http.StatusInternalServerError) - return - } - - if len(mergeRequests) == 0 { - err := fmt.Errorf("No merge requests found for branch '%s'", c.gitInfo.BranchName) - handleError(w, err, "No merge requests found", http.StatusBadRequest) - return - } - - if len(mergeRequests) > 1 { - err := errors.New("Please call gitlab.choose_merge_request()") - handleError(w, err, "Multiple MRs found", http.StatusBadRequest) - return - } - - mergeIdInt := mergeRequests[0].IID - c.projectInfo.MergeId = mergeIdInt - } - - // Call the next handler if middleware succeeds - svc.handler(w, r) - } -} diff --git a/cmd/app/shutdown.go b/cmd/app/shutdown.go index 9f1a6d6e..5f4bd305 100644 --- a/cmd/app/shutdown.go +++ b/cmd/app/shutdown.go @@ -69,10 +69,7 @@ func (s shutdown) shutdownHandler(w http.ResponseWriter, r *http.Request) { } w.WriteHeader(http.StatusOK) - response := SuccessResponse{ - Message: text, - Status: http.StatusOK, - } + response := SuccessResponse{Message: text} err = json.NewEncoder(w).Encode(response) if err != nil { diff --git a/cmd/app/summary.go b/cmd/app/summary.go index 411747a6..a345135c 100644 --- a/cmd/app/summary.go +++ b/cmd/app/summary.go @@ -2,15 +2,14 @@ package app import ( "encoding/json" - "io" "net/http" "github.com/xanzy/go-gitlab" ) type SummaryUpdateRequest struct { + Title string `json:"title" validate:"required"` Description string `json:"description"` - Title string `json:"title"` } type SummaryUpdateResponse struct { @@ -23,33 +22,13 @@ type summaryService struct { client MergeRequestUpdater } -func (a summaryService) handler(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") +func (a summaryService) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if r.Method != http.MethodPut { - w.Header().Set("Access-Control-Allow-Methods", http.MethodPut) - handleError(w, InvalidRequestError{}, "Expected PUT", http.StatusMethodNotAllowed) - return - } - - body, err := io.ReadAll(r.Body) - if err != nil { - handleError(w, err, "Could not read request body", http.StatusBadRequest) - return - } - - defer r.Body.Close() - var SummaryUpdateRequest SummaryUpdateRequest - err = json.Unmarshal(body, &SummaryUpdateRequest) - - if err != nil { - handleError(w, err, "Could not read JSON from request", http.StatusBadRequest) - return - } + payload := r.Context().Value(payload("payload")).(*SummaryUpdateRequest) mr, res, err := a.client.UpdateMergeRequest(a.projectInfo.ProjectId, a.projectInfo.MergeId, &gitlab.UpdateMergeRequestOptions{ - Description: &SummaryUpdateRequest.Description, - Title: &SummaryUpdateRequest.Title, + Description: &payload.Description, + Title: &payload.Title, }) if err != nil { @@ -58,18 +37,15 @@ func (a summaryService) handler(w http.ResponseWriter, r *http.Request) { } if res.StatusCode >= 300 { - handleError(w, GenericError{endpoint: "/summary"}, "Could not edit merge request summary", res.StatusCode) + handleError(w, GenericError{r.URL.Path}, "Could not edit merge request summary", res.StatusCode) return } w.WriteHeader(http.StatusOK) response := SummaryUpdateResponse{ - SuccessResponse: SuccessResponse{ - Message: "Summary updated", - Status: http.StatusOK, - }, - MergeRequest: mr, + SuccessResponse: SuccessResponse{Message: "Summary updated"}, + MergeRequest: mr, } err = json.NewEncoder(w).Encode(response) diff --git a/cmd/app/test_helpers.go b/cmd/app/test_helpers.go index 29b5903d..98ec8d2b 100644 --- a/cmd/app/test_helpers.go +++ b/cmd/app/test_helpers.go @@ -8,7 +8,6 @@ import ( "io" "net/http" "net/http/httptest" - "strings" "testing" "github.com/harrisoncramer/gitlab.nvim/cmd/app/git" @@ -63,9 +62,9 @@ var testProjectData = data{ }, } -func getSuccessData(t *testing.T, svc ServiceWithHandler, request *http.Request) SuccessResponse { +func getSuccessData(t *testing.T, svc http.Handler, request *http.Request) SuccessResponse { res := httptest.NewRecorder() - svc.handler(res, request) + svc.ServeHTTP(res, request) var data SuccessResponse err := json.Unmarshal(res.Body.Bytes(), &data) @@ -75,16 +74,16 @@ func getSuccessData(t *testing.T, svc ServiceWithHandler, request *http.Request) return data } -func getFailData(t *testing.T, svc ServiceWithHandler, request *http.Request) ErrorResponse { +func getFailData(t *testing.T, svc http.Handler, request *http.Request) (errResponse ErrorResponse, status int) { res := httptest.NewRecorder() - svc.handler(res, request) + svc.ServeHTTP(res, request) var data ErrorResponse err := json.Unmarshal(res.Body.Bytes(), &data) if err != nil { t.Error(err) } - return data + return data, res.Result().StatusCode } type testBase struct { @@ -105,22 +104,12 @@ func (f *testBase) handleGitlabError() (*gitlab.Response, error) { func checkErrorFromGitlab(t *testing.T, data ErrorResponse, msg string) { t.Helper() - assert(t, data.Status, http.StatusInternalServerError) assert(t, data.Message, msg) assert(t, data.Details, errorFromGitlab.Error()) } -func checkBadMethod(t *testing.T, data ErrorResponse, methods ...string) { - t.Helper() - assert(t, data.Status, http.StatusMethodNotAllowed) - assert(t, data.Details, "Invalid request type") - expectedMethods := strings.Join(methods, " or ") - assert(t, data.Message, fmt.Sprintf("Expected %s", expectedMethods)) -} - func checkNon200(t *testing.T, data ErrorResponse, msg, endpoint string) { t.Helper() - assert(t, data.Status, http.StatusSeeOther) assert(t, data.Message, msg) assert(t, data.Details, fmt.Sprintf("An error occurred on the %s endpoint", endpoint)) } diff --git a/cmd/app/user.go b/cmd/app/user.go index ae2c6c58..da31a9ac 100644 --- a/cmd/app/user.go +++ b/cmd/app/user.go @@ -21,13 +21,7 @@ type meService struct { client MeGetter } -func (a meService) handler(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Content-Type", "application/json") - if r.Method != http.MethodGet { - w.Header().Set("Access-Control-Allow-Methods", http.MethodGet) - handleError(w, InvalidRequestError{}, "Expected GET", http.StatusMethodNotAllowed) - return - } +func (a meService) ServeHTTP(w http.ResponseWriter, r *http.Request) { user, res, err := a.client.CurrentUser() @@ -42,11 +36,8 @@ func (a meService) handler(w http.ResponseWriter, r *http.Request) { } response := UserResponse{ - SuccessResponse: SuccessResponse{ - Message: "User fetched successfully", - Status: http.StatusOK, - }, - User: user, + SuccessResponse: SuccessResponse{Message: "User fetched successfully"}, + User: user, } err = json.NewEncoder(w).Encode(response) diff --git a/doc/gitlab.nvim.txt b/doc/gitlab.nvim.txt index 58cc8d0a..2a529e00 100644 --- a/doc/gitlab.nvim.txt +++ b/doc/gitlab.nvim.txt @@ -152,7 +152,12 @@ you call this function with no values the defaults will be used: port = nil, -- The port of the Go server, which runs in the background, if omitted or `nil` the port will be chosen automatically log_path = vim.fn.stdpath("cache") .. "/gitlab.nvim.log", -- Log path for the Go server config_path = nil, -- Custom path for `.gitlab.nvim` file, please read the "Connecting to Gitlab" section - debug = { go_request = false, go_response = false }, -- Which values to log + debug = { + request = false, -- Requests to/from Go server + response = false, + gitlab_request = false, -- Requests to/from Gitlab + gitlab_response = false, + }, attachment_dir = nil, -- The local directory for files (see the "summary" section) reviewer_settings = { jump_with_no_diagnostics = false, -- Jump to last position in discussion tree if true, otherwise stay in reviewer and show warning. diff --git a/go.mod b/go.mod index 6410a147..e16b85bd 100644 --- a/go.mod +++ b/go.mod @@ -3,16 +3,24 @@ module github.com/harrisoncramer/gitlab.nvim go 1.19 require ( + github.com/go-playground/validator/v10 v10.22.1 github.com/hashicorp/go-retryablehttp v0.7.7 github.com/xanzy/go-gitlab v0.108.0 ) require ( + github.com/gabriel-vasile/mimetype v1.4.3 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect github.com/golang/protobuf v1.5.3 // indirect github.com/google/go-querystring v1.1.0 // indirect github.com/hashicorp/go-cleanhttp v0.5.2 // indirect - golang.org/x/net v0.8.0 // indirect + github.com/leodido/go-urn v1.4.0 // indirect + golang.org/x/crypto v0.19.0 // indirect + golang.org/x/net v0.21.0 // indirect golang.org/x/oauth2 v0.6.0 // indirect + golang.org/x/sys v0.20.0 // indirect + golang.org/x/text v0.14.0 // indirect golang.org/x/time v0.3.0 // indirect google.golang.org/appengine v1.6.7 // indirect google.golang.org/protobuf v1.29.1 // indirect diff --git a/go.sum b/go.sum index f00dd864..5df86b54 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,14 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/fatih/color v1.16.0 h1:zmkK9Ngbjj+K0yRhTVONQh1p/HknKYSlNT+vZCzyokM= +github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= +github.com/gabriel-vasile/mimetype v1.4.3/go.mod h1:d8uq/6HKRL6CGdk+aubisF/M5GcPfT7nKyLpA0lbSSk= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.22.1 h1:40JcKH+bBNGFczGuoBYgX4I6m/i27HYW8P9FDk5PbgA= +github.com/go-playground/validator/v10 v10.22.1/go.mod h1:dbuPbCMFw/DrkbEynArYaCwl3amGuJotoKCe95atGMM= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg= @@ -14,22 +23,29 @@ github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/S github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k= github.com/hashicorp/go-retryablehttp v0.7.7 h1:C8hUCYzor8PIfXHa4UrZkU4VvK8o9ISHxT2Q8+VepXU= github.com/hashicorp/go-retryablehttp v0.7.7/go.mod h1:pkQpWZeYWskR+D1tR2O5OcBFOxfA7DoAO6xtkuQnHTk= +github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= +github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= -github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/xanzy/go-gitlab v0.108.0 h1:IEvEUWFR5G1seslRhJ8gC//INiIUqYXuSUoBd7/gFKE= github.com/xanzy/go-gitlab v0.108.0/go.mod h1:wKNKh3GkYDMOsGmnfuX+ITCmDuSDWFO0G+C4AygL9RY= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo= +golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks= -golang.org/x/net v0.8.0 h1:Zrh2ngAOFYneWTAIAPethzeaQLuHwhuBkuV6ZiRnUaQ= -golang.org/x/net v0.8.0/go.mod h1:QVkue5JL9kW//ek3r6jTKnTFis1tRmNAW2P1shuFdJc= +golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= +golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/oauth2 v0.6.0 h1:Lh8GPgSKBfWSwFvtuWOfeI3aAAnbXTSutYxJiOJFgIw= golang.org/x/oauth2 v0.6.0/go.mod h1:ycmewcwgD4Rpr3eZJLSB4Kyyljb3qDh40vJ8STE5HKw= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.20.0 h1:Od9JTbYCk261bKm4M/mw7AklTlFYIa0bIp9BgSm1S8Y= +golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/time v0.3.0 h1:rg5rLMjNzMS1RkNLzCG38eapWhnYLFYXDXj2gOlr8j4= golang.org/x/time v0.3.0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/lua/gitlab/actions/comment.lua b/lua/gitlab/actions/comment.lua index 4723a364..daadd3be 100644 --- a/lua/gitlab/actions/comment.lua +++ b/lua/gitlab/actions/comment.lua @@ -3,6 +3,7 @@ --- to this module the data required to make the API calls local Popup = require("nui.popup") local Layout = require("nui.layout") +local diffview_lib = require("diffview.lib") local state = require("gitlab.state") local job = require("gitlab.job") local u = require("gitlab.utils") @@ -153,17 +154,45 @@ end ---@class LayoutOpts ---@field ranged boolean ----@field discussion_id string|nil ---@field unlinked boolean +---@field discussion_id string|nil ---This function sets up the layout and popups needed to create a comment, note and ---multi-line comment. It also sets up the basic keybindings for switching between ---window panes, and for the non-primary sections. ----@param opts LayoutOpts|nil ----@return NuiLayout +---@param opts LayoutOpts +---@return NuiLayout|nil M.create_comment_layout = function(opts) - if opts == nil then - opts = {} + if opts.unlinked ~= true then + -- Check that diffview is initialized + if reviewer.tabnr == nil then + u.notify("Reviewer must be initialized first", vim.log.levels.ERROR) + return + end + + -- Check that Diffview is the current view + local view = diffview_lib.get_current_view() + if view == nil then + u.notify("Comments should be left in the reviewer pane", vim.log.levels.ERROR) + return + end + + -- Check that we are in the diffview tab + local tabnr = vim.api.nvim_get_current_tabpage() + if tabnr ~= reviewer.tabnr then + u.notify("Line location can only be determined within reviewer window", vim.log.levels.ERROR) + return + end + + -- Check that we are hovering over the code + local filetype = vim.bo[0].filetype + if filetype == "DiffviewFiles" or filetype == "gitlab" then + u.notify( + "Comments can only be left on the code. To leave unlinked comments, use gitlab.create_note() instead", + vim.log.levels.ERROR + ) + return + end end local title = opts.discussion_id and "Reply" or "Comment" @@ -229,7 +258,8 @@ M.create_comment = function() if err ~= nil then return end - local is_modified = vim.api.nvim_buf_get_option(0, "modified") + + local is_modified = vim.bo[0].modified if state.settings.reviewer_settings.diffview.imply_local and (is_modified or not has_clean_tree) then u.notify( "Cannot leave comments on changed files. \n Please stash all local changes or push them to the feature branch.", @@ -243,7 +273,9 @@ M.create_comment = function() end local layout = M.create_comment_layout({ ranged = false, unlinked = false }) - layout:mount() + if layout ~= nil then + layout:mount() + end end --- This function will open a multi-line comment popup in order to create a multi-line comment @@ -257,14 +289,18 @@ M.create_multiline_comment = function() end local layout = M.create_comment_layout({ ranged = true, unlinked = false }) - layout:mount() + if layout ~= nil then + layout:mount() + end end --- This function will open a a popup to create a "note" (e.g. unlinked comment) --- on the changed/updated line in the current MR M.create_note = function() local layout = M.create_comment_layout({ ranged = false, unlinked = true }) - layout:mount() + if layout ~= nil then + layout:mount() + end end ---Given the current visually selected area of text, builds text to fill in the @@ -319,7 +355,9 @@ M.create_comment_suggestion = function() local suggestion_lines, range_length = build_suggestion() local layout = M.create_comment_layout({ ranged = range_length > 0, unlinked = false }) - layout:mount() + if layout ~= nil then + layout:mount() + end vim.schedule(function() if suggestion_lines then vim.api.nvim_buf_set_lines(M.comment_popup.bufnr, 0, -1, false, suggestion_lines) diff --git a/lua/gitlab/actions/draft_notes/init.lua b/lua/gitlab/actions/draft_notes/init.lua index 1122b967..a8716750 100755 --- a/lua/gitlab/actions/draft_notes/init.lua +++ b/lua/gitlab/actions/draft_notes/init.lua @@ -84,7 +84,7 @@ end ---Publishes all draft notes and comments. Re-renders all discussion views. M.confirm_publish_all_drafts = function() - local body = { publish_all = true } + local body = {} job.run_job("/mr/draft_notes/publish", "POST", body, function(data) u.notify(data.message, vim.log.levels.INFO) state.DRAFT_NOTES = {} @@ -109,7 +109,7 @@ M.confirm_publish_draft = function(tree) ---@type integer local note_id = note_node.is_root and root_node.id or note_node.id - local body = { note = note_id, publish_all = false } + local body = { note = note_id } job.run_job("/mr/draft_notes/publish", "POST", body, function(data) u.notify(data.message, vim.log.levels.INFO) diff --git a/lua/gitlab/annotations.lua b/lua/gitlab/annotations.lua index 82a675b9..e4993498 100644 --- a/lua/gitlab/annotations.lua +++ b/lua/gitlab/annotations.lua @@ -226,6 +226,8 @@ ---@class DebugSettings: table ---@field go_request? boolean -- Log the requests to Gitlab sent by the Go server ---@field go_response? boolean -- Log the responses received from Gitlab to the Go server +---@field request? boolean -- Log the requests to the Go server +---@field response? boolean -- Log the responses from the Go server ---@class PopupSettings: table ---@field width? string -- The width of the popup, by default "40%" diff --git a/lua/gitlab/job.lua b/lua/gitlab/job.lua index 2ca4e52e..7f5f4d8e 100644 --- a/lua/gitlab/job.lua +++ b/lua/gitlab/job.lua @@ -26,6 +26,8 @@ M.run_job = function(endpoint, method, body, callback) return end local data_ok, data = pcall(vim.json.decode, output) + + -- Failing to unmarshal JSON if not data_ok then local msg = string.format("Failed to parse JSON from %s endpoint", endpoint) if type(output) == "string" then @@ -34,17 +36,22 @@ M.run_job = function(endpoint, method, body, callback) u.notify(string.format(msg, endpoint, output), vim.log.levels.WARN) return end + + -- If JSON provided, handle success or error cases if data ~= nil then - local status = (tonumber(data.status) >= 200 and tonumber(data.status) < 300) and "success" or "error" - if status == "success" and callback ~= nil then - callback(data) - elseif status == "success" then + if data.details == nil then + if callback then + callback(data) + return + end local message = string.format("%s", data.message) u.notify(message, vim.log.levels.INFO) - else - local message = string.format("%s: %s", data.message, data.details) - u.notify(message, vim.log.levels.ERROR) + return end + + -- Handle error case + local message = string.format("%s: %s", data.message, data.details) + u.notify(message, vim.log.levels.ERROR) end end, 0) end, diff --git a/lua/gitlab/reviewer/init.lua b/lua/gitlab/reviewer/init.lua index e561162b..9ef37ea2 100644 --- a/lua/gitlab/reviewer/init.lua +++ b/lua/gitlab/reviewer/init.lua @@ -67,11 +67,11 @@ M.open = function() end if state.INFO.state == "closed" then - u.notify(string.format("This MR was closed on %s", u.format_date(state.INFO.closed_at)), vim.log.levels.WARN) + u.notify(string.format("This MR was closed %s", u.time_since(state.INFO.closed_at)), vim.log.levels.WARN) end if state.INFO.state == "merged" then - u.notify(string.format("This MR was merged on %s", u.format_date(state.INFO.merged_at)), vim.log.levels.WARN) + u.notify(string.format("This MR was merged %s", u.time_since(state.INFO.merged_at)), vim.log.levels.WARN) end if state.settings.discussion_diagnostic ~= nil or state.settings.discussion_sign ~= nil then @@ -151,25 +151,7 @@ end ---other modules such as the comment module to create line codes or set diagnostics ---@return DiffviewInfo | nil M.get_reviewer_data = function() - if M.tabnr == nil then - u.notify("Diffview reviewer must be initialized first", vim.log.levels.ERROR) - return - end - - -- Check if we are in the diffview tab - local tabnr = vim.api.nvim_get_current_tabpage() - if tabnr ~= M.tabnr then - u.notify("Line location can only be determined within reviewer window", vim.log.levels.ERROR) - return - end - - -- Check if we are in the diffview buffer local view = diffview_lib.get_current_view() - if view == nil then - u.notify("Could not find Diffview view", vim.log.levels.ERROR) - return - end - local layout = view.cur_layout local old_win = u.get_window_id_by_buffer_id(layout.a.file.bufnr) local new_win = u.get_window_id_by_buffer_id(layout.b.file.bufnr) @@ -321,7 +303,7 @@ local set_keymaps = function(bufnr, keymaps) if keymaps.reviewer.create_comment ~= false then -- Set keymap for repeated operator keybinding vim.keymap.set("o", keymaps.reviewer.create_comment, function() - vim.api.nvim_cmd({ cmd = "normal", bang = true, args = { tostring(vim.v.count1) .. "j" } }, {}) + vim.api.nvim_cmd({ cmd = "normal", bang = true, args = { tostring(vim.v.count1) .. "$" } }, {}) end, { buffer = bufnr, desc = "Create comment for [count] lines", @@ -351,7 +333,7 @@ local set_keymaps = function(bufnr, keymaps) if keymaps.reviewer.create_suggestion ~= false then -- Set keymap for repeated operator keybinding vim.keymap.set("o", keymaps.reviewer.create_suggestion, function() - vim.api.nvim_cmd({ cmd = "normal", bang = true, args = { tostring(vim.v.count1) .. "j" } }, {}) + vim.api.nvim_cmd({ cmd = "normal", bang = true, args = { tostring(vim.v.count1) .. "$" } }, {}) end, { buffer = bufnr, desc = "Create suggestion for [count] lines", diff --git a/lua/gitlab/state.lua b/lua/gitlab/state.lua index 2007fe32..a6349b50 100644 --- a/lua/gitlab/state.lua +++ b/lua/gitlab/state.lua @@ -47,8 +47,10 @@ M.settings = { file_separator = u.path_separator, port = nil, -- choose random port debug = { - go_request = false, - go_response = false, + request = false, + response = false, + gitlab_request = false, + gitlab_response = false, }, log_path = (vim.fn.stdpath("cache") .. "/gitlab.nvim.log"), config_path = nil,