From 96b6b5d1a7681302c71bf20ccafea1a766d75b79 Mon Sep 17 00:00:00 2001 From: Lee ByeongJun Date: Thu, 18 Jul 2024 16:15:13 +0900 Subject: [PATCH] code suggestion (#8) --- formatter/fmt_test.go | 9 ++ formatter/general.go | 11 +-- formatter/unnecessary_else.go | 60 +++++++++++-- internal/fixer.go | 157 ++++++++++++++++++++++++--------- internal/fixer_test.go | 159 +++++++++------------------------- testdata/main.gno | 5 +- 6 files changed, 224 insertions(+), 177 deletions(-) diff --git a/formatter/fmt_test.go b/formatter/fmt_test.go index eb5edf0..f90507c 100644 --- a/formatter/fmt_test.go +++ b/formatter/fmt_test.go @@ -191,6 +191,15 @@ func TestFormatIssuesWithArrows_UnnecessaryElse(t *testing.T) { | ~~~~~~~~~~~~~~~~~~~~ | unnecessary else block +Suggestion: +4 | if condition { +5 | return true +6 | } +7 | return false + +Note: Unnecessary 'else' block removed. +The code inside the 'else' block has been moved outside, as it will only be executed when the 'if' condition is false. + ` result := FormatIssuesWithArrows(issues, sourceCode) diff --git a/formatter/general.go b/formatter/general.go index 1dc9f0b..81fea78 100644 --- a/formatter/general.go +++ b/formatter/general.go @@ -11,11 +11,12 @@ import ( const tabWidth = 8 var ( - errorStyle = color.New(color.FgRed, color.Bold) - ruleStyle = color.New(color.FgYellow, color.Bold) - fileStyle = color.New(color.FgCyan, color.Bold) - lineStyle = color.New(color.FgBlue, color.Bold) - messageStyle = color.New(color.FgRed, color.Bold) + errorStyle = color.New(color.FgRed, color.Bold) + ruleStyle = color.New(color.FgYellow, color.Bold) + fileStyle = color.New(color.FgCyan, color.Bold) + lineStyle = color.New(color.FgBlue, color.Bold) + messageStyle = color.New(color.FgRed, color.Bold) + suggestionStyle = color.New(color.FgGreen, color.Bold) ) // GeneralIssueFormatter is a formatter for general lint issues. diff --git a/formatter/unnecessary_else.go b/formatter/unnecessary_else.go index b82bda3..edc44e1 100644 --- a/formatter/unnecessary_else.go +++ b/formatter/unnecessary_else.go @@ -16,20 +16,23 @@ func (f *UnnecessaryElseFormatter) Format( ) string { var result strings.Builder ifStartLine, elseEndLine := issue.Start.Line-2, issue.End.Line - maxLineNumberStr := fmt.Sprintf("%d", elseEndLine) - padding := strings.Repeat(" ", len(maxLineNumberStr)-1) + code := strings.Join(snippet.Lines, "\n") + problemSnippet := internal.ExtractSnippet(issue, code, ifStartLine-1, elseEndLine-1) + suggestion, err := internal.RemoveUnnecessaryElse(problemSnippet) + if err != nil { + suggestion = problemSnippet + } + + maxLineNumWidth := calculateMaxLineNumWidth(elseEndLine) + padding := strings.Repeat(" ", maxLineNumWidth-1) result.WriteString(lineStyle.Sprintf(" %s|\n", padding)) - maxLen := 0 + maxLen := calculateMaxLineLength(snippet.Lines, ifStartLine, elseEndLine) for i := ifStartLine; i <= elseEndLine; i++ { - if len(snippet.Lines[i-1]) > maxLen { - maxLen = len(snippet.Lines[i-1]) - } line := expandTabs(snippet.Lines[i-1]) - lineNumberStr := fmt.Sprintf("%d", i) - linePadding := strings.Repeat(" ", len(maxLineNumberStr)-len(lineNumberStr)) - result.WriteString(lineStyle.Sprintf("%s%s | ", linePadding, lineNumberStr)) + lineNumberStr := fmt.Sprintf("%*d", maxLineNumWidth, i) + result.WriteString(lineStyle.Sprintf("%s | ", lineNumberStr)) result.WriteString(line + "\n") } @@ -38,5 +41,44 @@ func (f *UnnecessaryElseFormatter) Format( result.WriteString(lineStyle.Sprintf(" %s| ", padding)) result.WriteString(messageStyle.Sprintf("%s\n\n", issue.Message)) + result.WriteString(formatSuggestion(issue, suggestion, ifStartLine)) + result.WriteString("\n") + + return result.String() +} + +func calculateMaxLineNumWidth(endLine int) int { + return len(fmt.Sprintf("%d", endLine)) +} + +func calculateMaxLineLength(lines []string, start, end int) int { + maxLen := 0 + for i := start - 1; i < end; i++ { + if len(lines[i]) > maxLen { + maxLen = len(lines[i]) + } + } + return maxLen +} + +func formatSuggestion(issue internal.Issue, improvedSnippet string, startLine int) string { + var result strings.Builder + lines := strings.Split(improvedSnippet, "\n") + maxLineNumWidth := calculateMaxLineNumWidth(issue.End.Line) + + result.WriteString(suggestionStyle.Sprint("Suggestion:\n")) + + for i, line := range lines { + lineNum := fmt.Sprintf("%*d", maxLineNumWidth, startLine+i) + result.WriteString(lineStyle.Sprintf("%s | ", lineNum)) + result.WriteString(fmt.Sprintln(line)) + } + + // Add a note explaining the improvement + result.WriteString("\n") + result.WriteString(suggestionStyle.Sprint("Note: ")) + result.WriteString("Unnecessary 'else' block removed.\n") + result.WriteString("The code inside the 'else' block has been moved outside, as it will only be executed when the 'if' condition is false.\n") + return result.String() } diff --git a/internal/fixer.go b/internal/fixer.go index a2104c1..0207c2a 100644 --- a/internal/fixer.go +++ b/internal/fixer.go @@ -1,7 +1,6 @@ package internal import ( - "bytes" "go/ast" "go/format" "go/parser" @@ -9,73 +8,147 @@ import ( "strings" ) -// TODO: Must flattening the nested unnecessary if-else blocks. +func RemoveUnnecessaryElse(snippet string) (string, error) { + wrappedSnippet := "package main\nfunc main() {\n" + snippet + "\n}" -// improveCode refactors the input source code and returns the formatted version. -func improveCode(src []byte) (string, error) { fset := token.NewFileSet() - file, err := parser.ParseFile(fset, "", src, parser.ParseComments) + file, err := parser.ParseFile(fset, "", wrappedSnippet, parser.ParseComments) if err != nil { return "", err } - err = refactorAST(file) + var funcBody *ast.BlockStmt + ast.Inspect(file, func(n ast.Node) bool { + if fd, ok := n.(*ast.FuncDecl); ok { + funcBody = fd.Body + return false + } + return true + }) + + removeUnnecessaryElseRecursive(funcBody) + + var buf strings.Builder + err = format.Node(&buf, fset, funcBody) if err != nil { return "", err } - return formatSource(fset, file) + result := cleanUpResult(buf.String()) + + return result, nil } -// refactorAST processes the AST to modify specific patterns. -func refactorAST(file *ast.File) error { - ast.Inspect(file, func(n ast.Node) bool { - ifStmt, ok := n.(*ast.IfStmt) - if !ok || ifStmt.Else == nil { - return true - } +func cleanUpResult(result string) string { + result = strings.TrimSpace(result) + result = strings.TrimPrefix(result, "{") + result = strings.TrimSuffix(result, "}") + result = strings.TrimSpace(result) + + lines := strings.Split(result, "\n") + for i, line := range lines { + lines[i] = strings.TrimPrefix(line, "\t") + } + return strings.Join(lines, "\n") +} - blockStmt, ok := ifStmt.Else.(*ast.BlockStmt) - if !ok || len(ifStmt.Body.List) == 0 { - return true +func removeUnnecessaryElseRecursive(node ast.Node) { + ast.Inspect(node, func(n ast.Node) bool { + if ifStmt, ok := n.(*ast.IfStmt); ok { + processIfStmt(ifStmt, node) + removeUnnecessaryElseRecursive(ifStmt.Body) + if ifStmt.Else != nil { + removeUnnecessaryElseRecursive(ifStmt.Else) + } + return false } + return true + }) +} - _, isReturn := ifStmt.Body.List[len(ifStmt.Body.List)-1].(*ast.ReturnStmt) - if !isReturn { - return true +func processIfStmt(ifStmt *ast.IfStmt, node ast.Node) { + if ifStmt.Else != nil && endsWithReturn(ifStmt.Body) { + parent := findParentBlockStmt(node, ifStmt) + if parent != nil { + switch elseBody := ifStmt.Else.(type) { + case *ast.BlockStmt: + insertStatementsAfter(parent, ifStmt, elseBody.List) + case *ast.IfStmt: + insertStatementsAfter(parent, ifStmt, []ast.Stmt{elseBody}) + } + ifStmt.Else = nil + } + } else if ifStmt.Else != nil { + if elseIfStmt, ok := ifStmt.Else.(*ast.IfStmt); ok && endsWithReturn(elseIfStmt.Body) { + processIfStmt(elseIfStmt, ifStmt) } + } +} - mergeElseIntoIf(file, ifStmt, blockStmt) - ifStmt.Else = nil +func endsWithReturn(block *ast.BlockStmt) bool { + if len(block.List) == 0 { + return false + } + _, isReturn := block.List[len(block.List)-1].(*ast.ReturnStmt) + return isReturn +} +func findParentBlockStmt(root ast.Node, child ast.Node) *ast.BlockStmt { + var parent *ast.BlockStmt + ast.Inspect(root, func(n ast.Node) bool { + if n == child { + return false + } + if block, ok := n.(*ast.BlockStmt); ok { + for _, stmt := range block.List { + if stmt == child { + parent = block + return false + } + } + } return true }) - return nil + return parent } -// mergeElseIntoIf merges the statements of an 'else' block into the enclosing function body. -func mergeElseIntoIf(file *ast.File, ifStmt *ast.IfStmt, blockStmt *ast.BlockStmt) { - for _, list := range file.Decls { - decl, ok := list.(*ast.FuncDecl) - if !ok { - continue - } - for i, stmt := range decl.Body.List { - if ifStmt != stmt { - continue - } - decl.Body.List = append(decl.Body.List[:i+1], append(blockStmt.List, decl.Body.List[i+1:]...)...) +func insertStatementsAfter(block *ast.BlockStmt, target ast.Stmt, stmts []ast.Stmt) { + for i, stmt := range block.List { + if stmt == target { + block.List = append(block.List[:i+1], append(stmts, block.List[i+1:]...)...) break } } } -// formatSource formats the AST back to source code. -func formatSource(fset *token.FileSet, file *ast.File) (string, error) { - var buf bytes.Buffer - err := format.Node(&buf, fset, file) - if err != nil { - return "", err +func ExtractSnippet(issue Issue, code string, startLine, endLine int) string { + lines := strings.Split(code, "\n") + + // ensure we don't go out of bounds + if startLine < 0 { + startLine = 0 } - return strings.TrimRight(buf.String(), "\n"), nil + if endLine > len(lines) { + endLine = len(lines) + } + + // extract the relevant lines + snippet := lines[startLine:endLine] + + // trim any leading empty lines + for len(snippet) > 0 && strings.TrimSpace(snippet[0]) == "" { + snippet = snippet[1:] + } + + // ensure the last line is included if it's a closing brace + if endLine < len(lines) && strings.TrimSpace(lines[endLine]) == "}" { + snippet = append(snippet, lines[endLine]) + } + + // trim any trailing empty lines + for len(snippet) > 0 && strings.TrimSpace(snippet[len(snippet)-1]) == "" { + snippet = snippet[:len(snippet)-1] + } + + return strings.Join(snippet, "\n") } diff --git a/internal/fixer_test.go b/internal/fixer_test.go index e7d5c77..8d0eb2a 100644 --- a/internal/fixer_test.go +++ b/internal/fixer_test.go @@ -7,7 +7,7 @@ import ( "github.com/stretchr/testify/require" ) -func TestImproveCode(t *testing.T) { +func TestRemoveUnnecessaryElse(t *testing.T) { testCases := []struct { name string input string @@ -15,137 +15,62 @@ func TestImproveCode(t *testing.T) { }{ { name: "don't need to modify", - input: `package main - -func foo(x bool) int { - if x { - println("x") - } else { - println("hello") - } + input: `if x { + println("x") +} else { + println("hello") }`, - expected: `package main - -func foo(x bool) int { - if x { - println("x") - } else { - println("hello") - } + expected: `if x { + println("x") +} else { + println("hello") }`, }, { - name: "Remove unnecessary else", - input: ` -package main - -func unnecessaryElse() bool { - if condition { - return true - } else { - return false - } -}`, - expected: `package main - -func unnecessaryElse() bool { - if condition { - return true - } - return false - + name: "remove unnecessary else", + input: `if x { + return 1 +} else { + return 2 }`, + expected: `if x { + return 1 +} +return 2`, }, { - name: "Keep necessary else", - input: ` -package main - -func necessaryElse() int { - if condition { - return 1 - } else { - doSomething() - return 2 - } -}`, - expected: `package main - -func necessaryElse() int { - if condition { - return 1 + name: "nested if else", + input: `if x { + return 1 +} +if z { + println("x") +} else { + if y { + return 2 + } else { + return 3 } - doSomething() - return 2 +} +`, + expected: `if x { + return 1 +} +if z { + println("x") +} else { + if y { + return 2 + } + return 3 }`, }, - // { - // name: "Multiple unnecessary else", - // input: ` - // package main - - // func multipleUnnecessaryElse() int { - // if condition1 { - // return 1 - // } else { - // if condition2 { - // return 2 - // } else { - // return 3 - // } - // } - // }`, - // expected: `package main - - // func multipleUnnecessaryElse() int { - // if condition1 { - // return 1 - // } - // if condition2 { - // return 2 - // } - // return 3 - // } - // `, - // }, - // { - // name: "Mixed necessary and unnecessary else", - // input: ` - // package main - - // func mixedElse() int { - // if condition1 { - // return 1 - // } else { - // if condition2 { - // doSomething() - // return 2 - // } else { - // return 3 - // } - // } - // }`, - // expected: `package main - - // func mixedElse() int { - // if condition1 { - // return 1 - // } else { - // if condition2 { - // doSomething() - // return 2 - // } - // return 3 - // } - // } - // `, - // }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - improved, err := improveCode([]byte(tc.input)) + improved, err := RemoveUnnecessaryElse(tc.input) require.NoError(t, err) assert.Equal(t, tc.expected, improved, "Improved code does not match expected output") }) diff --git a/testdata/main.gno b/testdata/main.gno index 23193f3..f0210df 100644 --- a/testdata/main.gno +++ b/testdata/main.gno @@ -6,10 +6,7 @@ func foo(x, y bool) int { if x { return 1 } else { - if y { - return 2 - } - return 3 + return 2 } }