From 8f46cf0ce86021bbd8aa6ac3785a3c24fc4be206 Mon Sep 17 00:00:00 2001 From: Ben Marini Date: Mon, 2 Apr 2018 15:40:59 -0700 Subject: [PATCH] Ensure context arg is the same as request Context --- httpx/middleware/header.go | 2 ++ httpx/middleware/header_test.go | 5 +++++ httpx/middleware/logger.go | 3 +++ httpx/middleware/opentracing.go | 1 + httpx/middleware/reporter.go | 2 ++ httpx/middleware/request_id.go | 3 +++ httpx/middleware/request_id_test.go | 8 +++++++- 7 files changed, 23 insertions(+), 1 deletion(-) diff --git a/httpx/middleware/header.go b/httpx/middleware/header.go index 123b34e3..68a9252a 100644 --- a/httpx/middleware/header.go +++ b/httpx/middleware/header.go @@ -28,6 +28,8 @@ func (h *Header) ServeHTTPContext(ctx context.Context, w http.ResponseWriter, r value := e(r) ctx = httpx.WithHeader(ctx, h.key, value) + r = r.WithContext(ctx) + return h.handler.ServeHTTPContext(ctx, w, r) } diff --git a/httpx/middleware/header_test.go b/httpx/middleware/header_test.go index 94c7fb3f..386542d0 100644 --- a/httpx/middleware/header_test.go +++ b/httpx/middleware/header_test.go @@ -6,6 +6,7 @@ import ( "testing" "context" + "github.com/remind101/pkg/httpx" ) @@ -24,7 +25,11 @@ func TestHeader(t *testing.T) { m := ExtractHeader( httpx.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) error { data := httpx.Header(ctx, tt.key) + if got, want := data, tt.val; got != want { + t.Fatalf("%s => %s; want %s", tt.key, got, want) + } + data = httpx.Header(r.Context(), tt.key) if got, want := data, tt.val; got != want { t.Fatalf("%s => %s; want %s", tt.key, got, want) } diff --git a/httpx/middleware/logger.go b/httpx/middleware/logger.go index fc878c96..72383264 100644 --- a/httpx/middleware/logger.go +++ b/httpx/middleware/logger.go @@ -43,7 +43,10 @@ func LogTo(h httpx.Handler, g loggerGenerator) httpx.Handler { func InsertLogger(h httpx.Handler, g loggerGenerator) httpx.Handler { return httpx.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) error { l := g(ctx, r) + ctx = logger.WithLogger(ctx, l) + r = r.WithContext(ctx) + return h.ServeHTTPContext(ctx, w, r) }) } diff --git a/httpx/middleware/opentracing.go b/httpx/middleware/opentracing.go index 84ee279b..c9910140 100644 --- a/httpx/middleware/opentracing.go +++ b/httpx/middleware/opentracing.go @@ -45,6 +45,7 @@ func (h *OpentracingTracer) ServeHTTPContext(ctx context.Context, w http.Respons defer span.Finish() ctx = opentracing.ContextWithSpan(ctx, span) + r = r.WithContext(ctx) rw := NewResponseWriter(w) reqErr := h.handler.ServeHTTPContext(ctx, rw, r) diff --git a/httpx/middleware/reporter.go b/httpx/middleware/reporter.go index d0f3e21b..67040bcd 100644 --- a/httpx/middleware/reporter.go +++ b/httpx/middleware/reporter.go @@ -26,6 +26,8 @@ func (m *Reporter) ServeHTTPContext(ctx context.Context, w http.ResponseWriter, // Add the request id to reporter context. ctx = errors.WithInfo(ctx, "request_id", httpx.RequestID(ctx)) + r = r.WithContext(ctx) + return m.handler.ServeHTTPContext(ctx, w, r) } diff --git a/httpx/middleware/request_id.go b/httpx/middleware/request_id.go index 83c75574..569acfaa 100644 --- a/httpx/middleware/request_id.go +++ b/httpx/middleware/request_id.go @@ -4,6 +4,7 @@ import ( "net/http" "context" + "github.com/remind101/pkg/httpx" ) @@ -39,5 +40,7 @@ func (h *RequestID) ServeHTTPContext(ctx context.Context, w http.ResponseWriter, requestID := e(r) ctx = httpx.WithRequestID(ctx, requestID) + r = r.WithContext(ctx) + return h.handler.ServeHTTPContext(ctx, w, r) } diff --git a/httpx/middleware/request_id_test.go b/httpx/middleware/request_id_test.go index 5f0b83d4..8b880b72 100644 --- a/httpx/middleware/request_id_test.go +++ b/httpx/middleware/request_id_test.go @@ -5,8 +5,9 @@ import ( "net/http/httptest" "testing" - "github.com/remind101/pkg/httpx" "context" + + "github.com/remind101/pkg/httpx" ) func TestRequestID(t *testing.T) { @@ -23,7 +24,12 @@ func TestRequestID(t *testing.T) { m := &RequestID{ handler: httpx.HandlerFunc(func(ctx context.Context, w http.ResponseWriter, r *http.Request) error { requestID := httpx.RequestID(ctx) + if got, want := requestID, tt.id; got != want { + t.Fatalf("RequestID => %s; want %s", got, want) + } + // From request.Context() + requestID = httpx.RequestID(r.Context()) if got, want := requestID, tt.id; got != want { t.Fatalf("RequestID => %s; want %s", got, want) }