diff --git a/cmd/agent/commands/grpc.go b/cmd/agent/commands/grpc.go index 2eba9fb7..2dfc6f97 100644 --- a/cmd/agent/commands/grpc.go +++ b/cmd/agent/commands/grpc.go @@ -42,6 +42,13 @@ func BuildGrpcCmd(env runtime.Environment, config *agent.Config) *cobra.Command return fmt.Errorf("upstream host cannot be localhost when running in transparent mode") } + agent, err := agent.BuildAndStart(env, config) + if err != nil { + return fmt.Errorf("initializing agent: %w", err) + } + + defer agent.Stop() + listenAddress := net.JoinHostPort("", fmt.Sprint(port)) upstreamAddress := net.JoinHostPort(upstreamHost, fmt.Sprint(targetPort)) @@ -80,8 +87,6 @@ func BuildGrpcCmd(env runtime.Environment, config *agent.Config) *cobra.Command return err } - agent := agent.BuildAgent(env, config) - return agent.ApplyDisruption(cmd.Context(), disruptor, duration) }, } diff --git a/cmd/agent/commands/http.go b/cmd/agent/commands/http.go index fe7a776e..a65f4b1c 100644 --- a/cmd/agent/commands/http.go +++ b/cmd/agent/commands/http.go @@ -41,6 +41,13 @@ func BuildHTTPCmd(env runtime.Environment, config *agent.Config) *cobra.Command return fmt.Errorf("upstream host cannot be localhost when running in transparent mode") } + agent, err := agent.BuildAndStart(env, config) + if err != nil { + return fmt.Errorf("initializing agent: %w", err) + } + + defer agent.Stop() + listenAddress := net.JoinHostPort("", fmt.Sprint(port)) upstreamAddress := "http://" + net.JoinHostPort(upstreamHost, fmt.Sprint(targetPort)) @@ -79,8 +86,6 @@ func BuildHTTPCmd(env runtime.Environment, config *agent.Config) *cobra.Command return err } - agent := agent.BuildAgent(env, config) - return agent.ApplyDisruption(cmd.Context(), disruptor, duration) }, } diff --git a/pkg/agent/agent.go b/pkg/agent/agent.go index 38ccd603..d0b5560d 100644 --- a/pkg/agent/agent.go +++ b/pkg/agent/agent.go @@ -4,6 +4,8 @@ package agent import ( "context" "fmt" + "io" + "os" "syscall" "time" @@ -19,48 +21,51 @@ type Config struct { // Agent maintains the state required for executing an agent command type Agent struct { - env runtime.Environment - config *Config + env runtime.Environment + sc <-chan os.Signal + profileCloser io.Closer } -// BuildAgent builds a instance of an agent -func BuildAgent(env runtime.Environment, config *Config) *Agent { - return &Agent{ - env: env, - config: config, +// BuildAndStart creates and starts a new instance of an agent. +// Returned agent is guaranteed to be unique in the environment it is running, and will handle signals sent to the +// process. +// Callers must Stop the returned agent at the end of its lifecycle. +func BuildAndStart(env runtime.Environment, config *Config) (*Agent, error) { + a := &Agent{ + env: env, } + + if err := a.start(config); err != nil { + a.Stop() // Stop any initialized component if initialization failed. + return nil, err + } + + return a, nil } -// ApplyDisruption applies a disruption to the target -func (r *Agent) ApplyDisruption(ctx context.Context, disruptor protocol.Disruptor, duration time.Duration) error { - sc := r.env.Signal().Notify(syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP) - defer func() { - r.env.Signal().Reset() - }() +func (a *Agent) start(config *Config) error { + a.sc = a.env.Signal().Notify(syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP) - acquired, err := r.env.Lock().Acquire() + acquired, err := a.env.Lock().Acquire() if err != nil { return fmt.Errorf("could not acquire process lock: %w", err) } + if !acquired { return fmt.Errorf("another instance of the agent is already running") } - defer func() { - _ = r.env.Lock().Release() - }() - // start profiler - profiler, err := r.env.Profiler().Start(ctx, *r.config.Profiler) + a.profileCloser, err = a.env.Profiler().Start(*config.Profiler) if err != nil { return fmt.Errorf("could not create profiler %w", err) } - // ensure the profiler is closed even if there's an error executing the command - defer func() { - _ = profiler.Close() - }() + return nil +} +// ApplyDisruption applies a disruption to the target +func (a *Agent) ApplyDisruption(ctx context.Context, disruptor protocol.Disruptor, duration time.Duration) error { // set context for command ctx, cancel := context.WithCancel(ctx) @@ -83,7 +88,17 @@ func (r *Agent) ApplyDisruption(ctx context.Context, disruptor protocol.Disrupto return ctx.Err() case err := <-cc: return err - case s := <-sc: + case s := <-a.sc: return fmt.Errorf("received signal %q", s) } } + +// Stop stops a running agent: It releases +func (a *Agent) Stop() { + a.env.Signal().Reset() + _ = a.env.Lock().Release() + + if a.profileCloser != nil { + _ = a.profileCloser.Close() + } +} diff --git a/pkg/agent/agent_test.go b/pkg/agent/agent_test.go index 4af75425..baa1ff53 100644 --- a/pkg/agent/agent_test.go +++ b/pkg/agent/agent_test.go @@ -63,7 +63,12 @@ func Test_CancelContext(t *testing.T) { t.Parallel() env := runtime.NewFakeRuntime(tc.args, tc.vars) - agent := BuildAgent(env, tc.config) + agent, err := BuildAndStart(env, tc.config) + if err != nil { + t.Fatalf("starting agent: %v", err) + } + + defer agent.Stop() ctx, cancel := context.WithCancel(context.Background()) go func() { @@ -72,7 +77,7 @@ func Test_CancelContext(t *testing.T) { }() disruptor := &FakeProtocolDisruptor{} - err := agent.ApplyDisruption(ctx, disruptor, tc.delay) + err = agent.ApplyDisruption(ctx, disruptor, tc.delay) if !errors.Is(err, tc.expected) { t.Errorf("expected %v got %v", tc.err, err) } @@ -126,7 +131,12 @@ func Test_Signals(t *testing.T) { t.Parallel() env := runtime.NewFakeRuntime(tc.args, tc.vars) - agent := BuildAgent(env, tc.config) + agent, err := BuildAndStart(env, tc.config) + if err != nil { + t.Fatalf("starting agent: %v", err) + } + + defer agent.Stop() go func() { time.Sleep(1 * time.Second) @@ -136,7 +146,7 @@ func Test_Signals(t *testing.T) { }() disruptor := &FakeProtocolDisruptor{} - err := agent.ApplyDisruption(context.TODO(), disruptor, tc.delay) + err = agent.ApplyDisruption(context.TODO(), disruptor, tc.delay) if tc.expectErr && err == nil { t.Errorf("should had failed") return