Skip to content

Commit

Permalink
feat(fixer v2): yaml, update replacer (#122)
Browse files Browse the repository at this point in the history
# Description

- Added yaml handler
- preserve ident context when apply replacer
  • Loading branch information
notJoon authored Feb 4, 2025
1 parent 0e77bd5 commit dfaaf49
Show file tree
Hide file tree
Showing 4 changed files with 327 additions and 8 deletions.
49 changes: 42 additions & 7 deletions fixer_v2/replacer.go
Original file line number Diff line number Diff line change
@@ -1,26 +1,51 @@
package fixerv2

// ReplaceAll replaces all occurrences in the subject that match the pattern with the replacement template
func ReplaceAll(patternNodes []Node, replacementNodes []Node, subject string) string {
import "strings"

type Replacer struct {
patternNodes []Node
replacementNodes []Node
baseIndent string
}

func NewReplacer(pattern, replacement []Node) *Replacer {
return &Replacer{
patternNodes: pattern,
replacementNodes: replacement,
}
}

func (r *Replacer) ReplaceAll(subject string) string {
result := ""
pos := 0

for {
found, matchStart, matchEnd, captures := findNextMatch(patternNodes, subject, pos)
found, matchStart, matchEnd, captures := findNextMatch(r.patternNodes, subject, pos)
if !found {
result += subject[pos:]
break
}

result += subject[pos:matchStart]
result += applyReplacement(replacementNodes, captures)

// Update base indent when match is found
if idx := strings.LastIndex(subject[:matchStart], "\n"); idx != -1 {
r.baseIndent = subject[idx+1 : matchStart]
} else {
r.baseIndent = ""
}

rawRepl := r.applyReplacement(captures)
adjusted := r.adjustIndent(rawRepl)
result += adjusted
pos = matchEnd
}
return result
}

// applyReplacement generates a replacement string using the replacement template AST and capture map
func applyReplacement(replacementNodes []Node, captures map[string]string) string {
func (r *Replacer) applyReplacement(captures map[string]string) string {
result := ""
for _, node := range replacementNodes {
for _, node := range r.replacementNodes {
switch n := node.(type) {
case LiteralNode:
result += n.Value
Expand All @@ -32,3 +57,13 @@ func applyReplacement(replacementNodes []Node, captures map[string]string) strin
}
return result
}

func (r *Replacer) adjustIndent(repl string) string {
lines := strings.Split(repl, "\n")
for i := 1; i < len(lines); i++ {
if !strings.HasPrefix(lines[i], r.baseIndent) {
lines[i] = r.baseIndent + lines[i]
}
}
return strings.Join(lines, "\n")
}
92 changes: 91 additions & 1 deletion fixer_v2/replacer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,95 @@ func TestReplacer(t *testing.T) {
subjectStr: "func test() {\n return x + 1\n}",
expected: "func test() {\n return (x + 1)\n}",
},
{
name: "error handling improvement",
patternStr: "if err != nil { return err }",
replacementStr: "if err != nil { return fmt.Errorf(\"failed to process: %w\", err) }",
subjectStr: "func process() error {\n if err != nil { return err }\n}",
expected: "func process() error {\n if err != nil { return fmt.Errorf(\"failed to process: %w\", err) }\n}",
},
{
name: "context with cancel",
patternStr: "ctx, _ := context.WithTimeout(:[parent], :[duration])",
replacementStr: "ctx, cancel := context.WithTimeout(:[parent], :[duration])\ndefer cancel()",
subjectStr: "func process() {\n ctx, _ := context.WithTimeout(context.Background(), 5*time.Second)\n}",
expected: "func process() {\n ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)\n defer cancel()\n}",
},
{
name: "mutex lock with defer unlock",
patternStr: "mu.Lock()\n:[code]",
replacementStr: "mu.Lock()\ndefer mu.Unlock()\n:[code]",
subjectStr: "func process() {\n mu.Lock()\n data.Write()\n}",
expected: "func process() {\n mu.Lock()\n defer mu.Unlock()\n data.Write()\n}",
},
{
name: "error variable naming",
patternStr: ":[type], err := :[func]()",
replacementStr: ":[type], :[type]Err := :[func]()",
subjectStr: "user, err := getUser()",
expected: "user, userErr := getUser()",
},
{
name: "http error handling",
patternStr: "http.Error(:[w], err.Error(), :[code])",
replacementStr: "http.Error(:[w], fmt.Sprintf(\"internal error: %v\", err), :[code])",
subjectStr: "http.Error(w, err.Error(), http.StatusInternalServerError)",
expected: "http.Error(w, fmt.Sprintf(\"internal error: %v\", err), http.StatusInternalServerError)",
},
{
name: "nested if statements",
patternStr: "if :[cond1] {\n if :[cond2] {\n :[code]\n }\n}",
replacementStr: "if :[cond1] && :[cond2] {\n :[code]\n}",
subjectStr: "if x > 0 {\n if y < 10 {\n process()\n }\n}",
expected: "if x > 0 && y < 10 {\n process()\n}",
},
{
name: "unnecessary else block",
patternStr: "if :[cond] {\n :[code1]\n} else {\n :[code2]\n}",
replacementStr: "if :[cond] {\n :[code1]\n}",
subjectStr: "if x > 0 {\n process()\n} else {\n log.Println(\"error\")\n}",
expected: "if x > 0 {\n process()\n}",
},
{
name: "channel close with defer",
patternStr: "ch := make(chan :[type])",
replacementStr: "ch := make(chan :[type])\ndefer close(ch)",
subjectStr: "func process() {\n ch := make(chan int)\n}",
expected: "func process() {\n ch := make(chan int)\n defer close(ch)\n}",
},
{
name: "multiple errors in one line",
patternStr: "err1, err2 := :[func1](), :[func2]()",
replacementStr: "err1Res, err2Res := :[func1](), :[func2]()",
subjectStr: "err1, err2 := readFile(), writeFile()",
expected: "err1Res, err2Res := readFile(), writeFile()",
},
{
name: "whitespace handling",
patternStr: "if :[cond] {",
replacementStr: "if :[cond] {",
subjectStr: "func test() {\n if x > 0 {\n}",
expected: "func test() {\n if x > 0 {\n}",
},
{
name: "complex nested replacement",
patternStr: "for :[i] := range :[slice] {\n if :[cond] {\n :[code]\n }\n}",
replacementStr: "for :[i] := range :[slice] {\n switch {\n case :[cond]:\n :[code]\n }\n}",
subjectStr: "for i := range items {\n if items[i].Valid {\n process(items[i])\n }\n}",
expected: "for i := range items {\n switch {\n case items[i].Valid:\n process(items[i])\n }\n}",
},
// TODO (@notJoon): If capacity is not provided, it will be replaced with an empty string (ex: make([]int, 0, )).
// To solve this, we need add an arbitrary default value like I did here, or analyze the context
// to find an appropriate value from other lines of code.
// The context must be located in the same or higher scope and should be bounded by the lines
// proceeding the current line.
{
name: "slice capacity",
patternStr: "make([]:[type], 0)",
replacementStr: "make([]:[type], 0, 10)",
subjectStr: "func test() {\n data := make([]int, 0)\n}",
expected: "func test() {\n data := make([]int, 0, 10)\n}",
},
}

for _, tt := range tests {
Expand All @@ -71,7 +160,8 @@ func TestReplacer(t *testing.T) {
t.Fatalf("replacement parse error: %v", err)
}

result := ReplaceAll(patternNodes, replacementNodes, tt.subjectStr)
repl := NewReplacer(patternNodes, replacementNodes)
result := repl.ReplaceAll(tt.subjectStr)

if result != tt.expected {
t.Errorf("replaceAll() = %q, want %q", result, tt.expected)
Expand Down
64 changes: 64 additions & 0 deletions fixer_v2/yaml.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package fixerv2

import (
"log"
"os"

"gopkg.in/yaml.v3"
)

type FixRule struct {
Name string `yaml:"name"`
Pattern string `yaml:"pattern"`
Replacement string `yaml:"replacement"`
}

type RulesConfig struct {
Rules []FixRule `yaml:"rules"`
}

func Load(path string) ([]FixRule, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, err
}
var cfg RulesConfig
if err := yaml.Unmarshal(data, &cfg); err != nil {
return nil, err
}
return cfg.Rules, nil
}

func Apply(subject string, rules []FixRule) string {
result := subject
for _, rule := range rules {
pat, err := Lex(rule.Pattern)
if err != nil {
log.Printf("failed to parse pattern %q: %v", rule.Pattern, err)
continue
}
nodes, err := Parse(pat)
if err != nil {
log.Printf("failed to parse pattern %q: %v", rule.Pattern, err)
continue
}
replacement, err := Lex(rule.Replacement)
if err != nil {
log.Printf("failed to parse replacement %q: %v", rule.Replacement, err)
continue
}
replacementNodes, err := Parse(replacement)
if err != nil {
log.Printf("failed to parse replacement %q: %v", rule.Replacement, err)
continue
}

repl := NewReplacer(nodes, replacementNodes)
newResult := repl.ReplaceAll(result)
if newResult != result {
log.Printf("applied rule %q: %q -> %q", rule.Name, result, newResult)
}
result = newResult
}
return result
}
130 changes: 130 additions & 0 deletions fixer_v2/yaml_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
package fixerv2

import (
"os"
"testing"
)

func TestLoadFixRules(t *testing.T) {
tests := []struct {
name string
yamlContent string
wantRules []FixRule
wantErr bool
}{
{
name: "valid rules",
yamlContent: `
rules:
- name: simple replacement
pattern: ":[name]"
replacement: "Hello, :[name]!"
- name: arithmetic replacement
pattern: ":[lhs] + :[rhs]"
replacement: ":[rhs] - :[lhs]"
`,
wantRules: []FixRule{
{
Name: "simple replacement",
Pattern: ":[name]",
Replacement: "Hello, :[name]!",
},
{
Name: "arithmetic replacement",
Pattern: ":[lhs] + :[rhs]",
Replacement: ":[rhs] - :[lhs]",
},
},
wantErr: false,
},
{
name: "invalid yaml",
yamlContent: `
rules:
- name: missing colon
pattern ":[abc]"
replacement: "Should fail"
`,
wantRules: nil,
wantErr: true,
},
{
name: "golang lint rules",
yamlContent: `
rules:
- name: if err handling
pattern: "if err != nil { return err }"
replacement: "if err != nil { return fmt.Errorf(\"failed to process: %w\", err) }"
- name: context timeout
pattern: "context.WithTimeout(:[ctx], :[duration])"
replacement: "context.WithTimeout(:[ctx], :[duration])\ndefer cancel()"
- name: slice capacity
pattern: "make([]:[type], 0)"
replacement: "make([]:[type], 0, :[capacity])"
`,
wantRules: []FixRule{
{
Name: "if err handling",
Pattern: "if err != nil { return err }",
Replacement: "if err != nil { return fmt.Errorf(\"failed to process: %w\", err) }",
},
{
Name: "context timeout",
Pattern: "context.WithTimeout(:[ctx], :[duration])",
Replacement: "context.WithTimeout(:[ctx], :[duration])\ndefer cancel()",
},
{
Name: "slice capacity",
Pattern: "make([]:[type], 0)",
Replacement: "make([]:[type], 0, :[capacity])",
},
},
wantErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tmpfile, err := os.CreateTemp("", "rules_*.yaml")
if err != nil {
t.Fatalf("TempFile error: %v", err)
}
defer os.Remove(tmpfile.Name())

if _, err := tmpfile.Write([]byte(tt.yamlContent)); err != nil {
t.Fatalf("Write error: %v", err)
}
if err := tmpfile.Close(); err != nil {
t.Fatalf("Close error: %v", err)
}

rules, err := Load(tmpfile.Name())
if tt.wantErr {
if err == nil {
t.Fatalf("expected error but got none")
}
return
}
if err != nil {
t.Fatalf("unexpected error: %v", err)
}

if len(rules) != len(tt.wantRules) {
t.Fatalf("expected %d rule(s), got %d", len(tt.wantRules), len(rules))
}

for i, want := range tt.wantRules {
got := rules[i]
if got.Name != want.Name {
t.Errorf("rule[%d] Name: got %q, want %q", i, got.Name, want.Name)
}
if got.Pattern != want.Pattern {
t.Errorf("rule[%d] Pattern: got %q, want %q", i, got.Pattern, want.Pattern)
}
if got.Replacement != want.Replacement {
t.Errorf("rule[%d] Replacement: got %q, want %q", i, got.Replacement, want.Replacement)
}
}
})
}
}

0 comments on commit dfaaf49

Please sign in to comment.