diff --git a/middleware_accesstoken.go b/middleware_accesstoken.go index 8e0e4c5..00bb9a2 100644 --- a/middleware_accesstoken.go +++ b/middleware_accesstoken.go @@ -7,12 +7,14 @@ import ( ) type accessTokens struct { - headerName string - tokens []string + paramName string + tokens []string + getFunc func(string, *http.Request) string + missingMessage string } /* -NewMiddlewareAccessToken creates a new handler to verify access tokens in a rye chain. +NewMiddlewareAccessToken creates a new handler to verify access tokens passed as a header. Example usage: @@ -23,19 +25,60 @@ Example usage: })).Methods("POST") */ func NewMiddlewareAccessToken(headerName string, tokens []string) func(rw http.ResponseWriter, req *http.Request) *Response { + return newAccessTokenHandler(headerName, tokens, "header") +} + +/* +NewMiddlewareAccessQueryToken creates a new handler to verify access tokens passed as a query parameter. + +Example usage: + + routes.Handle("/some/route", a.Dependencies.MWHandler.Handle( + []rye.Handler{ + rye.NewMiddlewareAccessQueryToken(queryParamName, []string{token1, token2}), + yourHandler, + })).Methods("POST") +*/ +func NewMiddlewareAccessQueryToken(queryParamName string, tokens []string) func(rw http.ResponseWriter, req *http.Request) *Response { + return newAccessTokenHandler(queryParamName, tokens, "query") +} + +func newAccessTokenHandler(name string, tokens []string, tokenType string) func(rw http.ResponseWriter, req *http.Request) *Response { a := &accessTokens{ - headerName: headerName, - tokens: tokens, + paramName: name, + tokens: tokens, + } + + switch tokenType { + + case "query": + a.getFunc = func(s string, r *http.Request) string { + q, ok := r.URL.Query()[s] + if !ok { + return "" + } + + return q[0] + } + a.missingMessage = fmt.Sprintf("No access token found; ensure you pass the '%s' parameter", name) + + default: + // default to using the header + a.getFunc = func(s string, r *http.Request) string { + return r.Header.Get(s) + } + a.missingMessage = fmt.Sprintf("No access token found; ensure you pass '%s' in header", name) } + return a.handle } func (a *accessTokens) handle(rw http.ResponseWriter, r *http.Request) *Response { - token := r.Header.Get(a.headerName) + token := a.getFunc(a.paramName, r) if token == "" { return &Response{ - Err: fmt.Errorf("No access token found; ensure you pass '%s' in header", a.headerName), + Err: errors.New(a.missingMessage), StatusCode: http.StatusUnauthorized, } } diff --git a/middleware_accesstoken_test.go b/middleware_accesstoken_test.go index ea752b0..09a8022 100644 --- a/middleware_accesstoken_test.go +++ b/middleware_accesstoken_test.go @@ -1,8 +1,10 @@ package rye import ( + "fmt" "net/http" "net/http/httptest" + "net/url" . "github.com/onsi/ginkgo" . "github.com/onsi/gomega" @@ -14,31 +16,40 @@ var _ = Describe("AccessToken Middleware", func() { request *http.Request response *httptest.ResponseRecorder - tokenHeaderName = "at-hname" - token1, token2 string + testHandler func(http.ResponseWriter, *http.Request) *Response + + token1, token2 string ) BeforeEach(func() { response = httptest.NewRecorder() - request = &http.Request{ - Header: map[string][]string{}, - } token1 = "test1" token2 = "test2" }) - Describe("handle", func() { + Context("header token", func() { + var ( + tokenHeaderName = "at-hname" + ) + + BeforeEach(func() { + testHandler = NewMiddlewareAccessToken(tokenHeaderName, []string{token1, token2}) + request = &http.Request{ + Header: map[string][]string{}, + } + }) + Context("when a valid token is used", func() { It("should return nil", func() { request.Header.Add(tokenHeaderName, token1) - resp := NewMiddlewareAccessToken(tokenHeaderName, []string{token1, token2})(response, request) + resp := testHandler(response, request) Expect(resp).To(BeNil()) }) It("should return nil", func() { request.Header.Add(tokenHeaderName, token2) - resp := NewMiddlewareAccessToken(tokenHeaderName, []string{token1, token2})(response, request) + resp := testHandler(response, request) Expect(resp).To(BeNil()) }) }) @@ -46,7 +57,7 @@ var _ = Describe("AccessToken Middleware", func() { Context("when an invalid token is used", func() { It("should return an error", func() { request.Header.Add(tokenHeaderName, "blah") - resp := NewMiddlewareAccessToken(tokenHeaderName, []string{token1, token2})(response, request) + resp := testHandler(response, request) Expect(resp).ToNot(BeNil()) Expect(resp.Err).To(HaveOccurred()) Expect(resp.Error()).To(ContainSubstring("invalid access token")) @@ -56,7 +67,7 @@ var _ = Describe("AccessToken Middleware", func() { Context("when no token header exists", func() { It("should return an error", func() { - resp := NewMiddlewareAccessToken(tokenHeaderName, []string{token1, token2})(response, request) + resp := testHandler(response, request) Expect(resp).ToNot(BeNil()) Expect(resp.Err).To(HaveOccurred()) Expect(resp.Error()).To(ContainSubstring("No access token found")) @@ -67,7 +78,7 @@ var _ = Describe("AccessToken Middleware", func() { Context("when token header is blank", func() { It("should return an error", func() { request.Header.Add(tokenHeaderName, "") - resp := NewMiddlewareAccessToken(tokenHeaderName, []string{token1, token2})(response, request) + resp := testHandler(response, request) Expect(resp).ToNot(BeNil()) Expect(resp.Err).To(HaveOccurred()) Expect(resp.Error()).To(ContainSubstring("No access token found")) @@ -75,4 +86,109 @@ var _ = Describe("AccessToken Middleware", func() { }) }) }) + + Context("query param token", func() { + var ( + qParamName string + qParams string + ) + + BeforeEach(func() { + qParamName = "token" + testHandler = NewMiddlewareAccessQueryToken(qParamName, []string{token1, token2}) + }) + + JustBeforeEach(func() { + u, err := url.Parse(fmt.Sprintf("http://doesntmatter.io/blah?%s", qParams)) + Expect(err).ToNot(HaveOccurred()) + + request = &http.Request{ + URL: u, + } + }) + + Context("when a valid token is used", func() { + BeforeEach(func() { + qParams = fmt.Sprintf("%s=%s", qParamName, token1) + }) + + It("should return nil", func() { + resp := testHandler(response, request) + Expect(resp).To(BeNil()) + }) + }) + + Context("when the other valid token is used", func() { + BeforeEach(func() { + qParams = fmt.Sprintf("%s=%s", qParamName, token2) + }) + + It("should return nil", func() { + resp := testHandler(response, request) + Expect(resp).To(BeNil()) + }) + }) + + Context("when an invalid token is used", func() { + BeforeEach(func() { + qParams = fmt.Sprintf("%s=blah", qParamName) + }) + + It("should return an error", func() { + resp := testHandler(response, request) + Expect(resp).ToNot(BeNil()) + Expect(resp.Err).To(HaveOccurred()) + Expect(resp.Error()).To(ContainSubstring("invalid access token")) + Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized)) + }) + }) + + Context("when no token param exists", func() { + BeforeEach(func() { + qParams = "something=else" + }) + + It("should return an error", func() { + resp := testHandler(response, request) + Expect(resp).ToNot(BeNil()) + Expect(resp.Err).To(HaveOccurred()) + Expect(resp.Error()).To(ContainSubstring("No access token found")) + Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized)) + }) + }) + + Context("when token param is blank", func() { + BeforeEach(func() { + qParams = fmt.Sprintf("%s=''", qParamName) + }) + + It("should return an error", func() { + resp := testHandler(response, request) + Expect(resp).ToNot(BeNil()) + Expect(resp.Err).To(HaveOccurred()) + Expect(resp.Error()).To(ContainSubstring("invalid access token")) + Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized)) + }) + }) + + Context("when no query params", func() { + JustBeforeEach(func() { + u, err := url.Parse("http://doesntmatter.io/blah") + Expect(err).ToNot(HaveOccurred()) + + request = &http.Request{ + URL: u, + } + }) + + It("should return an error", func() { + resp := testHandler(response, request) + Expect(resp).ToNot(BeNil()) + Expect(resp.Err).To(HaveOccurred()) + Expect(resp.Error()).To(ContainSubstring("No access token found")) + Expect(resp.StatusCode).To(Equal(http.StatusUnauthorized)) + }) + }) + + }) })