diff --git a/context/context.go b/context/context.go new file mode 100644 index 0000000..b3b1f4e --- /dev/null +++ b/context/context.go @@ -0,0 +1,21 @@ +package context + +import ( + "context" + "fmt" + "net/http" +) + +type Store interface { + Fetch(ctx context.Context) (string, error) +} + +func Server(store Store) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + data, err := store.Fetch(r.Context()) + if err != nil { + return + } + fmt.Fprint(w, data) + } +} diff --git a/context/context_test.go b/context/context_test.go new file mode 100644 index 0000000..da89fcf --- /dev/null +++ b/context/context_test.go @@ -0,0 +1,97 @@ +package context + +import ( + "context" + "errors" + "log" + "net/http" + "net/http/httptest" + "testing" + "time" +) + +type SpyStore struct { + response string + t *testing.T +} + +func (s *SpyStore) Fetch(ctx context.Context) (string, error) { + data := make(chan string, 1) + + go func() { + var result string + for _, c := range s.response { + select { + case <-ctx.Done(): + log.Println("spy store got cancelled") + return + default: + time.Sleep(10 * time.Millisecond) + result += string(c) + } + } + data <- result + }() + + select { + case <-ctx.Done(): + return "", ctx.Err() + case res := <-data: + return res, nil + } +} + +type SpyResponseWriter struct { + written bool +} + +func (s *SpyResponseWriter) Header() http.Header { + s.written = true + return nil +} + +func (s *SpyResponseWriter) Write([]byte) (int, error) { + s.written = true + return 0, errors.New("not implemented") +} + +func (s *SpyResponseWriter) WriteHeader(statusCode int) { + s.written = true +} + +func TestServer(t *testing.T) { + t.Run("returns data from store", func(t *testing.T) { + data := "hello, world" + store := &SpyStore{response: data, t: t} + srv := Server(store) + + request := httptest.NewRequest(http.MethodGet, "/", nil) + response := httptest.NewRecorder() + + srv.ServeHTTP(response, request) + + if response.Body.String() != data { + t.Errorf("got %v, want %v", response.Body.String(), data) + } + }) + + t.Run("tells store to cancel work if request is cancelled", func(t *testing.T) { + data := "hello, world" + store := &SpyStore{response: data, t: t} + svr := Server(store) + + request := httptest.NewRequest(http.MethodGet, "/", nil) + + cancellingCtx, cancel := context.WithCancel(request.Context()) + time.AfterFunc(5*time.Millisecond, cancel) + request = request.WithContext(cancellingCtx) + + response := &SpyResponseWriter{} + + svr.ServeHTTP(response, request) + + if response.written { + t.Error("a response should not have been written") + } + }) +}