Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve watch mode to reduce recompilation #366

Merged
merged 12 commits into from
Jan 7, 2024
21 changes: 6 additions & 15 deletions benchmarks/templ/template_templ.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

179 changes: 147 additions & 32 deletions cmd/templ/generatecmd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bufio"
"bytes"
"context"
"crypto/sha256"
_ "embed"
"errors"
"fmt"
Expand Down Expand Up @@ -56,6 +57,8 @@ var defaultWorkerCount = runtime.NumCPU()

func Run(w io.Writer, args Arguments) (err error) {
ctx, cancel := context.WithCancel(context.Background())
watchCtx, watchCancel := context.WithCancel(context.Background())

signalChan := make(chan os.Signal, 1)
signal.Notify(signalChan, os.Interrupt)
defer func() {
Expand All @@ -67,29 +70,44 @@ func Run(w io.Writer, args Arguments) (err error) {
_ = http.ListenAndServe(fmt.Sprintf("localhost:%d", args.PPROFPort), nil)
}()
}

go func() {
select {
case <-signalChan: // First signal, cancel context.
fmt.Fprintln(w, "\nCancelling...")
err = run.Stop()
if err != nil {
fmt.Fprintf(w, "Error killing command: %v\n", err)
watching := args.Watch
for {
select {
case <-signalChan: // First signal, cancel context.
if watching {
fmt.Println("Stopping watch operation...")
watchCancel()
continue
}

if ctx.Err() != nil {
fmt.Fprintln(w, "\nHARD EXIT")
os.Exit(2) // hard exit
continue
}

fmt.Fprintln(w, "\nCancelling...")
cancel()

case <-ctx.Done():
break
}
cancel()
case <-ctx.Done():
}
<-signalChan // Second signal, hard exit.
os.Exit(2)
}()
err = runCmd(ctx, w, args)

err = runCmd(ctx, watchCtx, w, args)
if errors.Is(err, context.Canceled) {
return nil
}

return err
}

func runCmd(ctx context.Context, w io.Writer, args Arguments) (err error) {
start := time.Now()
func runCmd(ctx, watchCtx context.Context, w io.Writer, args Arguments) error {
var err error

if args.Watch && args.FileName != "" {
return fmt.Errorf("cannot watch a single file, remove the -f or -watch flag")
}
Expand All @@ -101,7 +119,7 @@ func runCmd(ctx context.Context, w io.Writer, args Arguments) (err error) {
opts = append(opts, generator.WithTimestamp(time.Now()))
}
if args.FileName != "" {
return processSingleFile(ctx, w, "", args.FileName, args.GenerateSourceMapVisualisations, opts)
return processSingleFile(ctx, w, "", args.FileName, nil, args.GenerateSourceMapVisualisations, opts)
}
var target *url.URL
if args.Proxy != "" {
Expand All @@ -120,7 +138,7 @@ func runCmd(ctx context.Context, w io.Writer, args Arguments) (err error) {
if !path.IsAbs(args.Path) {
args.Path, err = filepath.Abs(args.Path)
if err != nil {
return
return err
}
}

Expand All @@ -129,15 +147,36 @@ func runCmd(ctx context.Context, w io.Writer, args Arguments) (err error) {
p = proxy.New(args.ProxyPort, target)
}
fmt.Fprintln(w, "Processing path:", args.Path)

if args.Watch {
err = generateWatched(watchCtx, w, args, opts, p)
if err != nil && !errors.Is(err, context.Canceled) {
return err
}
}

return generateProduction(ctx, w, args, opts, p)
}

func generateWatched(ctx context.Context, w io.Writer, args Arguments, opts []generator.GenerateOpt, p *proxy.Handler) error {
fmt.Fprintln(w, "Generating dev code:", args.Path)
start := time.Now()

bo := backoff.NewExponentialBackOff()
bo.InitialInterval = time.Millisecond * 500
bo.MaxInterval = time.Second * 3
bo.MaxElapsedTime = 0

var firstRunComplete bool
fileNameToLastModTime := make(map[string]time.Time)
fileNameToHash := make(map[string][sha256.Size]byte)

for !firstRunComplete || args.Watch {
changesFound, errs := processChanges(ctx, w, fileNameToLastModTime, args.Path, args.GenerateSourceMapVisualisations, opts, args.WorkerCount, args.KeepOrphanedFiles)
changesFound, errs := processChanges(
ctx, w,
fileNameToLastModTime, fileNameToHash,
args.Path, args.GenerateSourceMapVisualisations,
opts, args.WorkerCount, true, args.KeepOrphanedFiles)
if len(errs) > 0 {
if errors.Is(errs[0], context.Canceled) {
return errs[0]
Expand Down Expand Up @@ -179,20 +218,58 @@ func runCmd(ctx context.Context, w io.Writer, args Arguments) (err error) {
}()
}
}
if err = checkTemplVersion(args.Path); err != nil {
if err := checkTemplVersion(args.Path); err != nil {
logWarning(w, "templ version check failed: %v\n", err)
err = nil
}

if firstRunComplete {
if changesFound > 0 {
bo.Reset()
}
time.Sleep(bo.NextBackOff())
}

firstRunComplete = true
start = time.Now()
}
return err

return nil
}

func generateProduction(ctx context.Context, w io.Writer, args Arguments, opts []generator.GenerateOpt, p *proxy.Handler) error {
fmt.Fprintln(w, "Generating production code:", args.Path)
start := time.Now()

changesFound, errs := processChanges(
ctx, w, nil, nil,
args.Path, args.GenerateSourceMapVisualisations,
opts, args.WorkerCount, false, args.KeepOrphanedFiles)
if len(errs) > 0 {
if errors.Is(errs[0], context.Canceled) {
return errs[0]
}
logError(w, "Error processing path: %v\n", errors.Join(errs...))
}

if changesFound > 0 {
if len(errs) > 0 {
logError(w, "Generated code for %d templates with %d errors in %s\n", changesFound, len(errs), time.Since(start))
} else {
logSuccess(w, "Generated code for %d templates with %d errors in %s\n", changesFound, len(errs), time.Since(start))
}
if args.Command != "" {
fmt.Fprintf(w, "Executing command: %s\n", args.Command)
if _, err := run.Run(ctx, args.Path, args.Command); err != nil {
fmt.Fprintf(w, "Error starting command: %v\n", err)
}
}
}

if err := checkTemplVersion(args.Path); err != nil {
logWarning(w, "templ version check failed: %v\n", err)
}

return nil
}

func shouldSkipDir(dir string) bool {
Expand All @@ -210,10 +287,18 @@ func shouldSkipDir(dir string) bool {
return false
}

func processChanges(ctx context.Context, stdout io.Writer, fileNameToLastModTime map[string]time.Time, path string, generateSourceMapVisualisations bool, opts []generator.GenerateOpt, maxWorkerCount int, keepOrphanedFiles bool) (changesFound int, errs []error) {
func processChanges(ctx context.Context, stdout io.Writer, fileNameToLastModTime map[string]time.Time, hashes map[string][sha256.Size]byte, path string, generateSourceMapVisualisations bool, opts []generator.GenerateOpt, maxWorkerCount int, watching, keepOrphanedFiles bool) (changesFound int, errs []error) {
sem := make(chan struct{}, maxWorkerCount)
var wg sync.WaitGroup

if watching {
opts = append(opts, generator.WithExtractStrings())
}

if fileNameToLastModTime == nil {
fileNameToLastModTime = make(map[string]time.Time)
}

err := filepath.WalkDir(path, func(fileName string, info os.DirEntry, err error) error {
if err != nil {
return err
Expand All @@ -227,19 +312,25 @@ func processChanges(ctx context.Context, stdout io.Writer, fileNameToLastModTime
if info.IsDir() {
return nil
}
if !keepOrphanedFiles && strings.HasSuffix(fileName, "_templ.go") {

orphaned := !keepOrphanedFiles && strings.HasSuffix(fileName, "_templ.go")
if orphaned {
// Make sure the generated file is orphaned
// by checking if the corresponding .templ file exists.
if _, err := os.Stat(strings.TrimSuffix(fileName, "_templ.go") + ".templ"); err == nil {
// The .templ file exists, so we don't delete the generated file.
return nil
orphaned = false
}
}

devTextFile := !watching && strings.HasSuffix(fileName, "_templ.txt")
if orphaned || devTextFile {
if err = os.Remove(fileName); err != nil {
return fmt.Errorf("failed to remove file: %w", err)
}
logWarning(stdout, "Deleted orphaned file %q\n", fileName)
logWarning(stdout, "Deleted file %q\n", fileName)
return nil
}

if strings.HasSuffix(fileName, ".templ") {
lastModTime := fileNameToLastModTime[fileName]
fileInfo, err := info.Info()
Expand All @@ -255,7 +346,7 @@ func processChanges(ctx context.Context, stdout io.Writer, fileNameToLastModTime
wg.Add(1)
go func() {
defer wg.Done()
if err := processSingleFile(ctx, stdout, path, fileName, generateSourceMapVisualisations, opts); err != nil {
if err := processSingleFile(ctx, stdout, path, fileName, hashes, generateSourceMapVisualisations, opts); err != nil {
errs = append(errs, err)
}
<-sem
Expand Down Expand Up @@ -291,9 +382,9 @@ func openURL(w io.Writer, url string) error {

// processSingleFile generates Go code for a single template.
// If a basePath is provided, the filename included in error messages is relative to it.
func processSingleFile(ctx context.Context, stdout io.Writer, basePath, fileName string, generateSourceMapVisualisations bool, opts []generator.GenerateOpt) (err error) {
func processSingleFile(ctx context.Context, stdout io.Writer, basePath, fileName string, hashes map[string][sha256.Size]byte, generateSourceMapVisualisations bool, opts []generator.GenerateOpt) (err error) {
start := time.Now()
diag, err := generate(ctx, basePath, fileName, generateSourceMapVisualisations, opts)
diag, err := generate(ctx, basePath, fileName, hashes, generateSourceMapVisualisations, opts)
if err != nil {
return err
}
Expand All @@ -320,11 +411,15 @@ func printDiagnostics(w io.Writer, fileName string, diags []parser.Diagnostic) {

// generate Go code for a single template.
// If a basePath is provided, the filename included in error messages is relative to it.
func generate(ctx context.Context, basePath, fileName string, generateSourceMapVisualisations bool, opts []generator.GenerateOpt) (diagnostics []parser.Diagnostic, err error) {
func generate(ctx context.Context, basePath, fileName string, hashes map[string][sha256.Size]byte, generateSourceMapVisualisations bool, opts []generator.GenerateOpt) (diagnostics []parser.Diagnostic, err error) {
if err = ctx.Err(); err != nil {
return
}

if hashes == nil {
hashes = make(map[string][sha256.Size]byte)
}

t, err := parser.Parse(fileName)
if err != nil {
return nil, fmt.Errorf("%s parsing error: %w", fileName, err)
Expand All @@ -338,18 +433,35 @@ func generate(ctx context.Context, basePath, fileName string, generateSourceMapV
}

var b bytes.Buffer
sourceMap, err := generator.Generate(t, &b, append(opts, generator.WithFileName(errorMessageFileName))...)
sourceMap, literals, err := generator.Generate(t, &b, append(opts, generator.WithFileName(errorMessageFileName))...)
if err != nil {
return nil, fmt.Errorf("%s generation error: %w", fileName, err)
}

data, err := format.Source(b.Bytes())
formattedGoCode, err := format.Source(b.Bytes())
if err != nil {
return nil, fmt.Errorf("%s source formatting error: %w", fileName, err)
}

if err = os.WriteFile(targetFileName, data, 0644); err != nil {
return nil, fmt.Errorf("%s write file error: %w", targetFileName, err)
// Hash output, and write out the file if the goCodeHash has changed.
goCodeHash := sha256.Sum256(formattedGoCode)
if hashes[targetFileName] != goCodeHash {
if err = os.WriteFile(targetFileName, formattedGoCode, 0o644); err != nil {
return nil, fmt.Errorf("failed to write target file %q: %w", targetFileName, err)
}
hashes[targetFileName] = goCodeHash
}

// Add the txt file if it has changed.
if len(literals) > 0 {
txtFileName := strings.TrimSuffix(fileName, ".templ") + "_templ.txt"
txtHash := sha256.Sum256([]byte(literals))
if hashes[txtFileName] != txtHash {
if err = os.WriteFile(txtFileName, []byte(literals), 0o644); err != nil {
return nil, fmt.Errorf("failed to write string literal file %q: %w", txtFileName, err)
}
hashes[txtFileName] = txtHash
}
}

if generateSourceMapVisualisations {
Expand Down Expand Up @@ -397,12 +509,15 @@ func generateSourceMapVisualisation(ctx context.Context, templFileName, goFileNa
func logError(w io.Writer, format string, a ...any) {
logWithDecoration(w, "✗", color.FgRed, format, a...)
}

func logWarning(w io.Writer, format string, a ...any) {
logWithDecoration(w, "!", color.FgYellow, format, a...)
}

func logSuccess(w io.Writer, format string, a ...any) {
logWithDecoration(w, "✓", color.FgGreen, format, a...)
}

func logWithDecoration(w io.Writer, decoration string, col color.Attribute, format string, a ...any) {
color.New(col).Fprintf(w, "(%s) ", decoration)
fmt.Fprintf(w, format, a...)
Expand Down
Loading