diff --git a/context/context_test.go b/context/context_test.go index fa5efd7..44d8120 100644 --- a/context/context_test.go +++ b/context/context_test.go @@ -11,6 +11,7 @@ import ( type SpyStore struct { response string cancelled bool + t *testing.T } func (s *SpyStore) Fetch() string { @@ -22,30 +23,25 @@ func (s *SpyStore) Cancel() { s.cancelled = true } -func TestServer(t *testing.T) { - t.Run("tells store to cancel work if request is cancelled", func(t *testing.T) { - data := "hello, world" - store := &SpyStore{response: data} - svr := Server(store) - - request := httptest.NewRequest(http.MethodGet, "/", nil) - - cancellingCtx, cancel := context.WithCancel(request.Context()) - time.AfterFunc(5*time.Millisecond, cancel) - request = request.WithContext(cancellingCtx) - - response := httptest.NewRecorder() +func (s *SpyStore) assertWasCancelled() { + s.t.Helper() - svr.ServeHTTP(response, request) + if !s.cancelled { + s.t.Error("store was not told to cancel") + } +} - if !store.cancelled { - t.Error("store was not told to cancel") - } - }) +func (s *SpyStore) assertWasNotCancelled() { + s.t.Helper() + if s.cancelled { + s.t.Error("store was told to cancel") + } +} +func TestServer(t *testing.T) { t.Run("returns data from store", func(t *testing.T) { data := "hello, world" - store := &SpyStore{response: data} + store := &SpyStore{response: data, t: t} srv := Server(store) request := httptest.NewRequest(http.MethodGet, "/", nil) @@ -57,8 +53,24 @@ func TestServer(t *testing.T) { t.Errorf("got %v, want %v", response.Body.String(), data) } - if store.cancelled { - t.Errorf("it should not have cancelled the store") - } + store.assertWasNotCancelled() + }) + + t.Run("tells store to cancel work if request is cancelled", func(t *testing.T) { + data := "hello, world" + store := &SpyStore{response: data, t: t} + svr := Server(store) + + request := httptest.NewRequest(http.MethodGet, "/", nil) + + cancellingCtx, cancel := context.WithCancel(request.Context()) + time.AfterFunc(5*time.Millisecond, cancel) + request = request.WithContext(cancellingCtx) + + response := httptest.NewRecorder() + + svr.ServeHTTP(response, request) + + store.assertWasCancelled() }) }