From 7c026b6991d496bd0de3dc611cdb9b6be4984f24 Mon Sep 17 00:00:00 2001 From: Noah Kreiger <32901937+nkreiger@users.noreply.github.com> Date: Tue, 30 Apr 2024 17:26:57 -0400 Subject: [PATCH] Allow Context to Configure Default Timeout (#2) * allow context to override the default timeout Signed-off-by: Noah Kreiger --- v2/protocol/http/options.go | 32 +++++++ v2/protocol/http/options_test.go | 114 ++++++++++++++++++++++++- v2/protocol/http/protocol.go | 23 +++++ v2/protocol/http/protocol_lifecycle.go | 4 +- v2/protocol/http/protocol_test.go | 3 + 5 files changed, 172 insertions(+), 4 deletions(-) diff --git a/v2/protocol/http/options.go b/v2/protocol/http/options.go index 6582af3ea..359095004 100644 --- a/v2/protocol/http/options.go +++ b/v2/protocol/http/options.go @@ -83,6 +83,38 @@ func WithShutdownTimeout(timeout time.Duration) Option { } } +// WithReadTimeout overwrites the default read timeout (600s) of the http +// server. The specified timeout must not be negative. A timeout of 0 disables +// read timeouts in the http server. +func WithReadTimeout(timeout time.Duration) Option { + return func(p *Protocol) error { + if p == nil { + return fmt.Errorf("http read timeout option can not set nil protocol") + } + if timeout < 0 { + return fmt.Errorf("http read timeout must not be negative") + } + p.readTimeout = &timeout + return nil + } +} + +// WithWriteTimeout overwrites the default write timeout (600s) of the http +// server. The specified timeout must not be negative. A timeout of 0 disables +// write timeouts in the http server. +func WithWriteTimeout(timeout time.Duration) Option { + return func(p *Protocol) error { + if p == nil { + return fmt.Errorf("http write timeout option can not set nil protocol") + } + if timeout < 0 { + return fmt.Errorf("http write timeout must not be negative") + } + p.writeTimeout = &timeout + return nil + } +} + func checkListen(p *Protocol, prefix string) error { switch { case p.listener.Load() != nil: diff --git a/v2/protocol/http/options_test.go b/v2/protocol/http/options_test.go index fd0af7fcf..21cb841d2 100644 --- a/v2/protocol/http/options_test.go +++ b/v2/protocol/http/options_test.go @@ -315,6 +315,106 @@ func TestWithShutdownTimeout(t *testing.T) { } } +func TestWithReadTimeout(t *testing.T) { + expected := time.Minute * 4 + testCases := map[string]struct { + t *Protocol + timeout time.Duration + want *Protocol + wantErr string + }{ + "valid timeout": { + t: &Protocol{}, + timeout: time.Minute * 4, + want: &Protocol{ + readTimeout: &expected, + }, + }, + "negative timeout": { + t: &Protocol{}, + timeout: -1, + wantErr: "http read timeout must not be negative", + }, + "nil protocol": { + wantErr: "http read timeout option can not set nil protocol", + }, + } + for n, tc := range testCases { + t.Run(n, func(t *testing.T) { + + err := tc.t.applyOptions(WithReadTimeout(tc.timeout)) + + if tc.wantErr != "" || err != nil { + var gotErr string + if err != nil { + gotErr = err.Error() + } + if diff := cmp.Diff(tc.wantErr, gotErr); diff != "" { + t.Errorf("unexpected error (-want, +got) = %v", diff) + } + return + } + + got := tc.t + + if diff := cmp.Diff(tc.want, got, + cmpopts.IgnoreUnexported(Protocol{})); diff != "" { + t.Errorf("unexpected (-want, +got) = %v", diff) + } + }) + } +} + +func TestWithWriteTimeout(t *testing.T) { + expected := time.Minute * 4 + + testCases := map[string]struct { + t *Protocol + timeout time.Duration + want *Protocol + wantErr string + }{ + "valid timeout": { + t: &Protocol{}, + timeout: time.Minute * 4, + want: &Protocol{ + writeTimeout: &expected, + }, + }, + "negative timeout": { + t: &Protocol{}, + timeout: -1, + wantErr: "http write timeout must not be negative", + }, + "nil protocol": { + wantErr: "http write timeout option can not set nil protocol", + }, + } + for n, tc := range testCases { + t.Run(n, func(t *testing.T) { + + err := tc.t.applyOptions(WithWriteTimeout(tc.timeout)) + + if tc.wantErr != "" || err != nil { + var gotErr string + if err != nil { + gotErr = err.Error() + } + if diff := cmp.Diff(tc.wantErr, gotErr); diff != "" { + t.Errorf("unexpected error (-want, +got) = %v", diff) + } + return + } + + got := tc.t + + if diff := cmp.Diff(tc.want, got, + cmpopts.IgnoreUnexported(Protocol{})); diff != "" { + t.Errorf("unexpected (-want, +got) = %v", diff) + } + }) + } +} func TestWithPort(t *testing.T) { testCases := map[string]struct { t *Protocol @@ -389,9 +489,19 @@ func forceClose(tr *Protocol) { } func TestWithPort0(t *testing.T) { + noReadWriteTimeout := time.Duration(0) + testCases := map[string]func() (*Protocol, error){ - "WithPort0": func() (*Protocol, error) { return New(WithPort(0)) }, - "SetPort0": func() (*Protocol, error) { return &Protocol{Port: 0}, nil }, + "WithPort0": func() (*Protocol, error) { + return New(WithPort(0)) + }, + "SetPort0": func() (*Protocol, error) { + return &Protocol{ + Port: 0, + readTimeout: &noReadWriteTimeout, + writeTimeout: &noReadWriteTimeout, + }, nil + }, } for name, f := range testCases { t.Run(name, func(t *testing.T) { diff --git a/v2/protocol/http/protocol.go b/v2/protocol/http/protocol.go index 7ee3b8fe1..18bd604a6 100644 --- a/v2/protocol/http/protocol.go +++ b/v2/protocol/http/protocol.go @@ -70,6 +70,18 @@ type Protocol struct { // If 0, DefaultShutdownTimeout is used. ShutdownTimeout time.Duration + // readTimeout defines the http.Server ReadTimeout It is the maximum duration + // for reading the entire request, including the body. If not overwritten by an + // option, the default value (600s) is used + readTimeout *time.Duration + + // writeTimeout defines the http.Server WriteTimeout It is the maximum duration + // before timing out writes of the response. It is reset whenever a new + // request's header is read. Like ReadTimeout, it does not let Handlers make + // decisions on a per-request basis. If not overwritten by an option, the + // default value (600s) is used + writeTimeout *time.Duration + // Port is the port configured to bind the receiver to. Defaults to 8080. // If you want to know the effective port you're listening to, use GetListeningPort() Port int @@ -116,6 +128,17 @@ func New(opts ...Option) (*Protocol, error) { p.ShutdownTimeout = DefaultShutdownTimeout } + // use default timeout from abuse protection value + defaultTimeout := DefaultTimeout + + if p.readTimeout == nil { + p.readTimeout = &defaultTimeout + } + + if p.writeTimeout == nil { + p.writeTimeout = &defaultTimeout + } + if p.isRetriableFunc == nil { p.isRetriableFunc = defaultIsRetriableFunc } diff --git a/v2/protocol/http/protocol_lifecycle.go b/v2/protocol/http/protocol_lifecycle.go index 04ef96915..7551c31c5 100644 --- a/v2/protocol/http/protocol_lifecycle.go +++ b/v2/protocol/http/protocol_lifecycle.go @@ -40,8 +40,8 @@ func (p *Protocol) OpenInbound(ctx context.Context) error { p.server = &http.Server{ Addr: listener.Addr().String(), Handler: attachMiddleware(p.Handler, p.middleware), - ReadTimeout: DefaultTimeout, - WriteTimeout: DefaultTimeout, + ReadTimeout: *p.readTimeout, + WriteTimeout: *p.writeTimeout, } // Shutdown diff --git a/v2/protocol/http/protocol_test.go b/v2/protocol/http/protocol_test.go index 818ef60c2..4014989e6 100644 --- a/v2/protocol/http/protocol_test.go +++ b/v2/protocol/http/protocol_test.go @@ -26,6 +26,7 @@ import ( func TestNew(t *testing.T) { dst := DefaultShutdownTimeout + ot := DefaultTimeout testCases := map[string]struct { opts []Option @@ -36,6 +37,8 @@ func TestNew(t *testing.T) { want: &Protocol{ Client: http.DefaultClient, ShutdownTimeout: dst, + readTimeout: &ot, + writeTimeout: &ot, Port: -1, }, },