From 6afd6768ea64e4a27a98aa16fb981376cba48731 Mon Sep 17 00:00:00 2001 From: Adrian Hesketh Date: Mon, 30 Dec 2024 20:39:14 +0000 Subject: [PATCH] fix(breaking): trigger rebuild on changes to *.go files, fixes #646 (#1026) --- .version | 2 +- cmd/templ/generatecmd/cmd.go | 25 +++++-- cmd/templ/generatecmd/main.go | 7 +- cmd/templ/generatecmd/main_test.go | 72 +++++++++++++++++++++ cmd/templ/generatecmd/watcher/watch.go | 53 +++++++-------- cmd/templ/generatecmd/watcher/watch_test.go | 50 ++++++++------ cmd/templ/main.go | 2 + get-version/main.go | 2 +- 8 files changed, 153 insertions(+), 60 deletions(-) diff --git a/.version b/.version index e91a1a916..5158f20d0 100644 --- a/.version +++ b/.version @@ -1 +1 @@ -0.2.806 \ No newline at end of file +0.2.808 \ No newline at end of file diff --git a/cmd/templ/generatecmd/cmd.go b/cmd/templ/generatecmd/cmd.go index edbf0c413..3f1707481 100644 --- a/cmd/templ/generatecmd/cmd.go +++ b/cmd/templ/generatecmd/cmd.go @@ -9,6 +9,7 @@ import ( "net/url" "path" "path/filepath" + "regexp" "runtime" "sync" "sync/atomic" @@ -25,7 +26,9 @@ import ( "github.com/fsnotify/fsnotify" ) -func NewGenerate(log *slog.Logger, args Arguments) (g *Generate) { +const defaultWatchPattern = `(.+\.go$)|(.+\.templ$)|(.+_templ\.txt$)` + +func NewGenerate(log *slog.Logger, args Arguments) (g *Generate, err error) { g = &Generate{ Log: log, Args: &args, @@ -33,12 +36,20 @@ func NewGenerate(log *slog.Logger, args Arguments) (g *Generate) { if g.Args.WorkerCount == 0 { g.Args.WorkerCount = runtime.NumCPU() } - return g + if g.Args.WatchPattern == "" { + g.Args.WatchPattern = defaultWatchPattern + } + g.WatchPattern, err = regexp.Compile(g.Args.WatchPattern) + if err != nil { + return nil, fmt.Errorf("failed to compile watch pattern %q: %w", g.Args.WatchPattern, err) + } + return g, nil } type Generate struct { - Log *slog.Logger - Args *Arguments + Log *slog.Logger + Args *Arguments + WatchPattern *regexp.Regexp } type GenerationEvent struct { @@ -143,7 +154,7 @@ func (cmd Generate) Run(ctx context.Context) (err error) { slog.String("path", cmd.Args.Path), slog.Bool("devMode", cmd.Args.Watch), ) - if err := watcher.WalkFiles(ctx, cmd.Args.Path, events); err != nil { + if err := watcher.WalkFiles(ctx, cmd.Args.Path, cmd.WatchPattern, events); err != nil { cmd.Log.Error("WalkFiles failed, exiting", slog.Any("error", err)) errs <- FatalError{Err: fmt.Errorf("failed to walk files: %w", err)} return @@ -153,7 +164,7 @@ func (cmd Generate) Run(ctx context.Context) (err error) { return } cmd.Log.Info("Watching files") - rw, err := watcher.Recursive(ctx, cmd.Args.Path, events, errs) + rw, err := watcher.Recursive(ctx, cmd.Args.Path, cmd.WatchPattern, events, errs) if err != nil { cmd.Log.Error("Recursive watcher setup failed, exiting", slog.Any("error", err)) errs <- FatalError{Err: fmt.Errorf("failed to setup recursive watcher: %w", err)} @@ -187,7 +198,7 @@ func (cmd Generate) Run(ctx context.Context) (err error) { cmd.Args.Lazy, ) errorCount.Store(0) - if err := watcher.WalkFiles(ctx, cmd.Args.Path, events); err != nil { + if err := watcher.WalkFiles(ctx, cmd.Args.Path, cmd.WatchPattern, events); err != nil { cmd.Log.Error("Post dev mode WalkFiles failed", slog.Any("error", err)) errs <- FatalError{Err: fmt.Errorf("failed to walk files: %w", err)} return diff --git a/cmd/templ/generatecmd/main.go b/cmd/templ/generatecmd/main.go index 538f820a4..4d2cbf5a2 100644 --- a/cmd/templ/generatecmd/main.go +++ b/cmd/templ/generatecmd/main.go @@ -13,6 +13,7 @@ type Arguments struct { FileWriter FileWriterFunc Path string Watch bool + WatchPattern string OpenBrowser bool Command string ProxyBind string @@ -30,5 +31,9 @@ type Arguments struct { } func Run(ctx context.Context, log *slog.Logger, args Arguments) (err error) { - return NewGenerate(log, args).Run(ctx) + g, err := NewGenerate(log, args) + if err != nil { + return err + } + return g.Run(ctx) } diff --git a/cmd/templ/generatecmd/main_test.go b/cmd/templ/generatecmd/main_test.go index 9b0b10914..62c5c2094 100644 --- a/cmd/templ/generatecmd/main_test.go +++ b/cmd/templ/generatecmd/main_test.go @@ -6,6 +6,7 @@ import ( "log/slog" "os" "path" + "regexp" "testing" "github.com/a-h/templ/cmd/templ/testproject" @@ -42,3 +43,74 @@ func TestGenerate(t *testing.T) { } }) } + +func TestDefaultWatchPattern(t *testing.T) { + tests := []struct { + name string + input string + matches bool + }{ + { + name: "empty file names do not match", + input: "", + matches: false, + }, + { + name: "*_templ.txt matches, Windows", + input: `C:\Users\adrian\github.com\a-h\templ\cmd\templ\testproject\strings_templ.txt`, + matches: true, + }, + { + name: "*_templ.txt matches, Unix", + input: "/Users/adrian/github.com/a-h/templ/cmd/templ/testproject/strings_templ.txt", + matches: true, + }, + { + name: "*.templ files match, Windows", + input: `C:\Users\adrian\github.com\a-h\templ\cmd\templ\testproject\templates.templ`, + matches: true, + }, + { + name: "*.templ files match, Unix", + input: "/Users/adrian/github.com/a-h/templ/cmd/templ/testproject/templates.templ", + matches: true, + }, + { + name: "*_templ.go files match, Windows", + input: `C:\Users\adrian\github.com\a-h\templ\cmd\templ\testproject\templates_templ.go`, + matches: true, + }, + { + name: "*_templ.go files match, Unix", + input: "/Users/adrian/github.com/a-h/templ/cmd/templ/testproject/templates_templ.go", + matches: true, + }, + { + name: "*.go files match, Windows", + input: `C:\Users\adrian\github.com\a-h\templ\cmd\templ\testproject\templates.go`, + matches: true, + }, + { + name: "*.go files match, Unix", + input: "/Users/adrian/github.com/a-h/templ/cmd/templ/testproject/templates.go", + matches: true, + }, + { + name: "*.css files do not match", + input: "/Users/adrian/github.com/a-h/templ/cmd/templ/testproject/templates.css", + matches: false, + }, + } + wpRegexp, err := regexp.Compile(defaultWatchPattern) + if err != nil { + t.Fatalf("failed to compile default watch pattern: %v", err) + } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + t.Parallel() + if wpRegexp.MatchString(test.input) != test.matches { + t.Fatalf("expected match of %q to be %v", test.input, test.matches) + } + }) + } +} diff --git a/cmd/templ/generatecmd/watcher/watch.go b/cmd/templ/generatecmd/watcher/watch.go index f97149c8d..866560340 100644 --- a/cmd/templ/generatecmd/watcher/watch.go +++ b/cmd/templ/generatecmd/watcher/watch.go @@ -6,6 +6,7 @@ import ( "os" "path" "path/filepath" + "regexp" "strings" "sync" "time" @@ -16,6 +17,7 @@ import ( func Recursive( ctx context.Context, path string, + watchPattern *regexp.Regexp, out chan fsnotify.Event, errors chan error, ) (w *RecursiveWatcher, err error) { @@ -23,20 +25,25 @@ func Recursive( if err != nil { return nil, err } - w = &RecursiveWatcher{ - ctx: ctx, - w: fsnw, - Events: out, - Errors: errors, - timers: make(map[timerKey]*time.Timer), - } + w = NewRecursiveWatcher(ctx, fsnw, watchPattern, out, errors) go w.loop() return w, w.Add(path) } +func NewRecursiveWatcher(ctx context.Context, w *fsnotify.Watcher, watchPattern *regexp.Regexp, events chan fsnotify.Event, errors chan error) *RecursiveWatcher { + return &RecursiveWatcher{ + ctx: ctx, + w: w, + WatchPattern: watchPattern, + Events: events, + Errors: errors, + timers: make(map[timerKey]*time.Timer), + } +} + // WalkFiles walks the file tree rooted at path, sending a Create event for each // file it encounters. -func WalkFiles(ctx context.Context, path string, out chan fsnotify.Event) (err error) { +func WalkFiles(ctx context.Context, path string, watchPattern *regexp.Regexp, out chan fsnotify.Event) (err error) { rootPath := path fileSystem := os.DirFS(rootPath) return fs.WalkDir(fileSystem, ".", func(path string, info os.DirEntry, err error) error { @@ -50,7 +57,7 @@ func WalkFiles(ctx context.Context, path string, out chan fsnotify.Event) (err e if info.IsDir() && shouldSkipDir(absPath) { return filepath.SkipDir } - if !shouldIncludeFile(absPath) { + if !watchPattern.MatchString(absPath) { return nil } out <- fsnotify.Event{ @@ -61,26 +68,14 @@ func WalkFiles(ctx context.Context, path string, out chan fsnotify.Event) (err e }) } -func shouldIncludeFile(name string) bool { - if strings.HasSuffix(name, ".templ") { - return true - } - if strings.HasSuffix(name, "_templ.go") { - return true - } - if strings.HasSuffix(name, "_templ.txt") { - return true - } - return false -} - type RecursiveWatcher struct { - ctx context.Context - w *fsnotify.Watcher - Events chan fsnotify.Event - Errors chan error - timerMu sync.Mutex - timers map[timerKey]*time.Timer + ctx context.Context + w *fsnotify.Watcher + WatchPattern *regexp.Regexp + Events chan fsnotify.Event + Errors chan error + timerMu sync.Mutex + timers map[timerKey]*time.Timer } type timerKey struct { @@ -114,7 +109,7 @@ func (w *RecursiveWatcher) loop() { } } // Only notify on templ related files. - if !shouldIncludeFile(event.Name) { + if !w.WatchPattern.MatchString(event.Name) { continue } tk := timerKeyFromEvent(event) diff --git a/cmd/templ/generatecmd/watcher/watch_test.go b/cmd/templ/generatecmd/watcher/watch_test.go index e90180da7..39a560bb3 100644 --- a/cmd/templ/generatecmd/watcher/watch_test.go +++ b/cmd/templ/generatecmd/watcher/watch_test.go @@ -2,6 +2,8 @@ package watcher import ( "context" + "fmt" + "regexp" "testing" "time" @@ -10,14 +12,16 @@ import ( func TestWatchDebouncesDuplicates(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) - rw := &RecursiveWatcher{ - ctx: ctx, - w: &fsnotify.Watcher{ - Events: make(chan fsnotify.Event), - }, - Events: make(chan fsnotify.Event, 2), - timers: make(map[timerKey]*time.Timer), + w := &fsnotify.Watcher{ + Events: make(chan fsnotify.Event), } + events := make(chan fsnotify.Event, 2) + errors := make(chan error) + watchPattern, err := regexp.Compile(".*") + if err != nil { + t.Fatal(fmt.Errorf("failed to compile watch pattern: %w", err)) + } + rw := NewRecursiveWatcher(ctx, w, watchPattern, events, errors) go func() { rw.w.Events <- fsnotify.Event{Name: "test.templ"} rw.w.Events <- fsnotify.Event{Name: "test.templ"} @@ -60,14 +64,16 @@ func TestWatchDoesNotDebounceDifferentEvents(t *testing.T) { } for _, test := range tests { ctx, cancel := context.WithCancel(context.Background()) - rw := &RecursiveWatcher{ - ctx: ctx, - w: &fsnotify.Watcher{ - Events: make(chan fsnotify.Event), - }, - Events: make(chan fsnotify.Event, 2), - timers: make(map[timerKey]*time.Timer), + w := &fsnotify.Watcher{ + Events: make(chan fsnotify.Event), } + events := make(chan fsnotify.Event, 2) + errors := make(chan error) + watchPattern, err := regexp.Compile(".*") + if err != nil { + t.Fatal(fmt.Errorf("failed to compile watch pattern: %w", err)) + } + rw := NewRecursiveWatcher(ctx, w, watchPattern, events, errors) go func() { rw.w.Events <- test.event1 rw.w.Events <- test.event2 @@ -93,14 +99,16 @@ func TestWatchDoesNotDebounceDifferentEvents(t *testing.T) { func TestWatchDoesNotDebounceSeparateEvents(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) - rw := &RecursiveWatcher{ - ctx: ctx, - w: &fsnotify.Watcher{ - Events: make(chan fsnotify.Event), - }, - Events: make(chan fsnotify.Event, 2), - timers: make(map[timerKey]*time.Timer), + w := &fsnotify.Watcher{ + Events: make(chan fsnotify.Event), + } + events := make(chan fsnotify.Event, 2) + errors := make(chan error) + watchPattern, err := regexp.Compile(".*") + if err != nil { + t.Fatal(fmt.Errorf("failed to compile watch pattern: %w", err)) } + rw := NewRecursiveWatcher(ctx, w, watchPattern, events, errors) go func() { rw.w.Events <- fsnotify.Event{Name: "test.templ"} <-time.After(200 * time.Millisecond) diff --git a/cmd/templ/main.go b/cmd/templ/main.go index c265c24a8..e24041a00 100644 --- a/cmd/templ/main.go +++ b/cmd/templ/main.go @@ -157,6 +157,8 @@ Args: Set to true to include the current time in the generated code. -watch Set to true to watch the path for changes and regenerate code. + -watch-pattern + Set the regexp pattern of files that will be watched for changes. (default: '(.+\.go$)|(.+\.templ$)|(.+_templ\.txt$)') -cmd Set the command to run after generating code. -proxy diff --git a/get-version/main.go b/get-version/main.go index 86d471290..f538f9f31 100644 --- a/get-version/main.go +++ b/get-version/main.go @@ -25,5 +25,5 @@ func main() { } // The current commit isn't the one we're about to commit. - fmt.Printf("0.2.%d", count+1) + fmt.Printf("0.3.%d", count+1) }