Skip to content

Commit

Permalink
Merge pull request #28 from sendgrid/debug_signal_handling
Browse files Browse the repository at this point in the history
Explicitly catch and propagate signals to subcommands #minor
  • Loading branch information
Michael Robinson authored Sep 10, 2020
2 parents 953b95d + 9788a87 commit eea4a1a
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 12 deletions.
15 changes: 12 additions & 3 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
GO_VERSION ?= 1.12.1
GO_VERSION ?= 1.15
GO_CI_VERSION = v1.31.0
BINARIES = aws-env
WD ?= $(shell pwd)
NAMESPACE=sendgrid
Expand Down Expand Up @@ -87,8 +88,16 @@ lint:
-w /code \
golang:$(GO_VERSION) \
sh -c "\
curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | sh -s -- -b /go/bin v1.16.0 && \
golangci-lint -v run --enable-all -D gochecknoglobals"
curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | sh -s -- -b /go/bin $(GO_CI_VERSION) && \
golangci-lint -v run --exclude-use-default=false --deadline 5m \
-E golint \
-E gosec \
-E unconvert \
-E unparam \
-E gocyclo \
-E misspell \
-E gocritic \
-E maligned"

.PHONY: release
release:
Expand Down
14 changes: 7 additions & 7 deletions awsenv/file_replacer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ func TestFileReplacer_ReplaceAll_multiple(t *testing.T) {
defer cleanup()

// read content before the change
oldContent, err := ioutil.ReadFile(fileName)
oldContent, err := ioutil.ReadFile(fileName) //nolint: gosec
require.NoError(t, err)
require.Equal(t, sampleCnfFile2, string(oldContent))

Expand All @@ -128,7 +128,7 @@ mysql_users:
}
)
`
f, err := ioutil.ReadFile(fileName)
f, err := ioutil.ReadFile(fileName) //nolint: gosec
require.NoError(t, err)

require.Equal(t, expectedContent, string(f))
Expand All @@ -140,7 +140,7 @@ func TestFileReplacer_ReplaceAll_with_commas(t *testing.T) {
defer cleanup()

// read content before the change
oldContent, err := ioutil.ReadFile(fileName)
oldContent, err := ioutil.ReadFile(fileName) //nolint: gosec
require.NoError(t, err)
require.Equal(t, sampleCnfFile3, string(oldContent))

Expand All @@ -167,7 +167,7 @@ mysql_users:
}
)
`
f, err := ioutil.ReadFile(fileName)
f, err := ioutil.ReadFile(fileName) //nolint: gosec
require.NoError(t, err)

require.Equal(t, expectedContent, string(f))
Expand All @@ -179,7 +179,7 @@ func TestFileReplacer_ReplaceAll_with_multiple_occurrences(t *testing.T) {
defer cleanup()

// read content before the change
oldContent, err := ioutil.ReadFile(fileName)
oldContent, err := ioutil.ReadFile(fileName) //nolint: gosec
require.NoError(t, err)
require.Equal(t, sampleCnfFile4, string(oldContent))

Expand Down Expand Up @@ -208,7 +208,7 @@ mysql_users:
}
)
`
f, err := ioutil.ReadFile(fileName)
f, err := ioutil.ReadFile(fileName) //nolint: gosec
require.NoError(t, err)

require.Equal(t, expectedContent, string(f))
Expand All @@ -234,5 +234,5 @@ func writeTempFile(contents string) (string, func()) {
log.Fatal(err)
}

return tmpfile.Name(), func() { os.Remove(fName) }
return tmpfile.Name(), func() { os.Remove(fName) } //nolint: errcheck,gosec
}
42 changes: 40 additions & 2 deletions cmd/aws-env/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"fmt"
"os"
"os/exec"
"os/signal"

"github.com/sendgrid/aws-env/awsenv"
v1 "github.com/sendgrid/aws-env/awsenv/v1"
Expand Down Expand Up @@ -88,7 +89,7 @@ func initApp() *cli.App {
cli.BoolFlag{
Name: "ecs",
EnvVar: "AWS_ENV_ECS",
Usage: "Enable ECS mode, using the default credential provider to support ECS",
Usage: "enable ECS mode, using the default credential provider to support ECS",
Destination: &ecs,
},
}
Expand Down Expand Up @@ -159,6 +160,7 @@ func run(c *cli.Context) error {
if fileName != "" {
return fileReplacement(ssmClient)
}

return envReplacement(c, ssmClient)
}

Expand Down Expand Up @@ -214,7 +216,43 @@ func invoke(r *awsenv.Replacer, prog string, args []string) error {
cmd.Stdin = os.Stdin
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
return cmd.Run()

// in order to make sure that we catch and propagate signals correctly, we need
// to decouple starting the command and waiting for it to complete, so we can
// send signals as it runs
err = cmd.Start()
if err != nil {
log.WithError(err).Error("failed to start child process")
return err
}

// wait for the command to finish
errCh := make(chan error, 1)
go func() {
errCh <- cmd.Wait()
close(errCh)
}()
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh)

for {
select {
case sig := <-sigCh:
// this errror case only seems possible if the OS has released the process
// or if it isn't started. So we _should_ be able to break
if err := cmd.Process.Signal(sig); err != nil {
log.WithError(err).WithField("signal", sig).Error("error sending signal")
return err
}
case err := <-errCh:
// the command finished.
if err != nil {
log.WithError(err).Error("command failed")
return err
}
return nil
}
}
}

func main() {
Expand Down

0 comments on commit eea4a1a

Please sign in to comment.