Skip to content

Commit

Permalink
Deny requests that have characteristics of web browsers (#44)
Browse files Browse the repository at this point in the history
  • Loading branch information
patricksanders authored Feb 3, 2021
1 parent 811e1a6 commit 1405b82
Show file tree
Hide file tree
Showing 6 changed files with 268 additions and 15 deletions.
4 changes: 2 additions & 2 deletions cmd/ecs_credential_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ func runEcsMetadata(cmd *cobra.Command, args []string) error {

router := mux.NewRouter()
router.HandleFunc("/healthcheck", handlers.HealthcheckHandler)
router.HandleFunc("/ecs/{role:.*}", handlers.MetaDataServiceMiddleware(handlers.ECSMetadataServiceCredentialsHandler))
router.HandleFunc("/{path:.*}", handlers.MetaDataServiceMiddleware(handlers.CustomHandler))
router.HandleFunc("/ecs/{role:.*}", handlers.CredentialServiceMiddleware(handlers.ECSMetadataServiceCredentialsHandler))
router.HandleFunc("/{path:.*}", handlers.CredentialServiceMiddleware(handlers.CustomHandler))

go func() {
log.Info("Starting weep ECS meta-data service...")
Expand Down
18 changes: 9 additions & 9 deletions cmd/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,15 +66,15 @@ func runMetadata(cmd *cobra.Command, args []string) error {

router := mux.NewRouter()
router.HandleFunc("/healthcheck", handlers.HealthcheckHandler)
router.HandleFunc("/{version}/", handlers.MetaDataServiceMiddleware(handlers.BaseVersionHandler))
router.HandleFunc("/{version}/api/token", handlers.MetaDataServiceMiddleware(handlers.TokenHandler)).Methods("PUT")
router.HandleFunc("/{version}/meta-data", handlers.MetaDataServiceMiddleware(handlers.BaseHandler))
router.HandleFunc("/{version}/meta-data/", handlers.MetaDataServiceMiddleware(handlers.BaseHandler))
router.HandleFunc("/{version}/meta-data/iam/info", handlers.MetaDataServiceMiddleware(handlers.IamInfoHandler))
router.HandleFunc("/{version}/meta-data/iam/security-credentials/", handlers.MetaDataServiceMiddleware(handlers.RoleHandler))
router.HandleFunc("/{version}/meta-data/iam/security-credentials/{role}", handlers.MetaDataServiceMiddleware(handlers.CredentialsHandler))
router.HandleFunc("/{version}/dynamic/instance-identity/document", handlers.MetaDataServiceMiddleware(handlers.InstanceIdentityDocumentHandler))
router.HandleFunc("/{path:.*}", handlers.MetaDataServiceMiddleware(handlers.CustomHandler))
router.HandleFunc("/{version}/", handlers.CredentialServiceMiddleware(handlers.BaseVersionHandler))
router.HandleFunc("/{version}/api/token", handlers.CredentialServiceMiddleware(handlers.TokenHandler)).Methods("PUT")
router.HandleFunc("/{version}/meta-data", handlers.CredentialServiceMiddleware(handlers.BaseHandler))
router.HandleFunc("/{version}/meta-data/", handlers.CredentialServiceMiddleware(handlers.BaseHandler))
router.HandleFunc("/{version}/meta-data/iam/info", handlers.CredentialServiceMiddleware(handlers.IamInfoHandler))
router.HandleFunc("/{version}/meta-data/iam/security-credentials/", handlers.CredentialServiceMiddleware(handlers.RoleHandler))
router.HandleFunc("/{version}/meta-data/iam/security-credentials/{role}", handlers.CredentialServiceMiddleware(handlers.CredentialsHandler))
router.HandleFunc("/{version}/dynamic/instance-identity/document", handlers.CredentialServiceMiddleware(handlers.InstanceIdentityDocumentHandler))
router.HandleFunc("/{path:.*}", handlers.CredentialServiceMiddleware(handlers.CustomHandler))

go func() {
log.Info("Starting weep meta-data service...")
Expand Down
2 changes: 1 addition & 1 deletion handlers/ecsCredentialsHandler.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ func ECSMetadataServiceCredentialsHandler(w http.ResponseWriter, r *http.Request
assume, err := parseAssumeRoleQuery(r)
if err != nil {
log.Error(err)
util.WriteError(w, http.StatusBadRequest, err.Error())
util.WriteError(w, err.Error(), http.StatusBadRequest)
return
}
vars := mux.Vars(r)
Expand Down
56 changes: 55 additions & 1 deletion handlers/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,20 @@ import (
"math/rand"
"net/http"
"strconv"
"strings"

"github.com/netflix/weep/util"

"github.com/netflix/weep/metadata"
log "github.com/sirupsen/logrus"
)

func MetaDataServiceMiddleware(next http.HandlerFunc) http.HandlerFunc {
// CredentialServiceMiddleware is a convenience wrapper that chains BrowserFilterMiddleware and AWSHeaderMiddleware
func CredentialServiceMiddleware(next http.HandlerFunc) http.HandlerFunc {
return BrowserFilterMiddleware(AWSHeaderMiddleware(next))
}

func AWSHeaderMiddleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {

w.Header().Set("ETag", strconv.FormatInt(rand.Int63n(10000000000), 10))
Expand Down Expand Up @@ -53,3 +61,49 @@ func MetaDataServiceMiddleware(next http.HandlerFunc) http.HandlerFunc {
next.ServeHTTP(w, r)
}
}

// allowedHosts is a map used to look up Host headers for the purpose of rejecting requests
// for hosts that are not allowed
var allowedHosts = map[string]bool{
"": true, // Empty or no host header, could be curl or similar
"127.0.0.1": true, // localhost
"169.254.169.254": true, // IMDS IP
}

// BrowserFilterMiddleware is a middleware designed mitigate risks related to DNS rebinding,
// cross site request forgery, and any other traffic from a well behaved modern web browser
func BrowserFilterMiddleware(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
// Check User-Agent
// If User-Agent has Mozilla in it, this is almost certainly a browser request
userAgent := r.Header.Get("User-Agent")
userAgent = strings.ToLower(userAgent)
if strings.Contains(userAgent, "mozilla") {
log.Warn("bad user-agent detected")
util.WriteError(w, "forbidden", http.StatusForbidden)
return
}

// Check for Referrer or Origin header
// These also indicate a likely browser request
if referrer := r.Header.Get("Referrer"); referrer != "" {
log.Warn("referrer detected")
util.WriteError(w, "forbidden", http.StatusForbidden)
return
}
if origin := r.Header.Get("Origin"); origin != "" {
log.Warn("origin detected")
util.WriteError(w, "forbidden", http.StatusForbidden)
return
}

// Check host header
// This should only be 127.0.0.1, 169.254.169.254, or nothing
if host := r.Header.Get("Host"); !allowedHosts[host] {
log.Warn("bad host detected")
util.WriteError(w, "forbidden", http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
}
}
199 changes: 199 additions & 0 deletions handlers/middleware_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
package handlers

import (
"net/http"
"net/http/httptest"
"testing"
)

// TestBrowserFilterMiddleware ensures 403 Forbidden is returned for all requests that look like
// they came from a web browser.
func TestBrowserFilterMiddleware(t *testing.T) {
cases := []struct {
Description string
HeaderName string
HeaderValue string
ExpectedStatus int
}{
{
Description: "valid request",
HeaderName: "User-Agent",
HeaderValue: "boto3/foo",
ExpectedStatus: http.StatusOK,
},
{
Description: "Mozilla in user-agent",
HeaderName: "User-Agent",
HeaderValue: "Mozilla",
ExpectedStatus: http.StatusForbidden,
},
{
Description: "mozilla in user-agent",
HeaderName: "User-Agent",
HeaderValue: "mozilla",
ExpectedStatus: http.StatusForbidden,
},
{
Description: "referrer header set",
HeaderName: "Referrer",
HeaderValue: "anything",
ExpectedStatus: http.StatusForbidden,
},
{
Description: "origin header set",
HeaderName: "Origin",
HeaderValue: "anything",
ExpectedStatus: http.StatusForbidden,
},
{
Description: "host header not in allowlist",
HeaderName: "Host",
HeaderValue: "netflix.com",
ExpectedStatus: http.StatusForbidden,
},
{
Description: "host header in allowlist (127.0.0.1)",
HeaderName: "Host",
HeaderValue: "127.0.0.1",
ExpectedStatus: http.StatusOK,
},
{
Description: "host header in allowlist (169.254.169.254)",
HeaderName: "Host",
HeaderValue: "127.0.0.1",
ExpectedStatus: http.StatusOK,
},
}
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
for i, tc := range cases {
t.Logf("test case %d: %s", i, tc.Description)
bfmHandler := BrowserFilterMiddleware(nextHandler)
req := httptest.NewRequest("GET", "http://localhost", nil)
req.Header.Add(tc.HeaderName, tc.HeaderValue)
rec := httptest.NewRecorder()
bfmHandler.ServeHTTP(rec, req)
if rec.Code != tc.ExpectedStatus {
t.Errorf("%s failed: got status %d, expected %d", tc.Description, rec.Code, tc.ExpectedStatus)
continue
}
}
}

// TestAWSHeaderMiddleware checks for headers added for consumption by AWS SDKs
func TestAWSHeaderMiddleware(t *testing.T) {
description := "aws header middleware"
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
t.Logf("test case: %s", description)
bfmHandler := CredentialServiceMiddleware(nextHandler)
req := httptest.NewRequest("GET", "http://localhost", nil)
rec := httptest.NewRecorder()
bfmHandler.ServeHTTP(rec, req)
if rec.Code != http.StatusOK {
t.Errorf("%s failed: got status %d, expected %d", description, rec.Code, http.StatusOK)
}
if etag := rec.Header().Get("ETag"); etag == "" {
t.Errorf("%s failed: ETag header not set", description)
}
if lastModified := rec.Header().Get("Last-Modified"); lastModified == "" {
t.Errorf("%s failed: Last-Modified header not set", description)
}
if server := rec.Header().Get("Server"); server != "EC2ws" {
t.Errorf("%s failed: got Server header %s, expected %s", description, server, "EC2ws")
}
if contentType := rec.Header().Get("Content-Type"); contentType != "text/plain" {
t.Errorf("%s failed: got Content-Type header %s, expected %s", description, contentType, "text/plain")
}
}

// TestCredentialServiceMiddleware is a superset of TestBrowserFilterMiddleware and TestAWSHeaderMiddleware
// since CredentialServiceMiddleware is a chain of BrowserFilterMiddleware and AWSHeaderMiddleware
func TestCredentialServiceMiddleware(t *testing.T) {
cases := []struct {
Description string
HeaderName string
HeaderValue string
ExpectedStatus int
}{
{
Description: "valid request",
HeaderName: "User-Agent",
HeaderValue: "boto3/foo",
ExpectedStatus: http.StatusOK,
},
{
Description: "Mozilla in user-agent",
HeaderName: "User-Agent",
HeaderValue: "Mozilla",
ExpectedStatus: http.StatusForbidden,
},
{
Description: "mozilla in user-agent",
HeaderName: "User-Agent",
HeaderValue: "mozilla",
ExpectedStatus: http.StatusForbidden,
},
{
Description: "referrer header set",
HeaderName: "Referrer",
HeaderValue: "anything",
ExpectedStatus: http.StatusForbidden,
},
{
Description: "origin header set",
HeaderName: "Origin",
HeaderValue: "anything",
ExpectedStatus: http.StatusForbidden,
},
{
Description: "host header not in allowlist",
HeaderName: "Host",
HeaderValue: "netflix.com",
ExpectedStatus: http.StatusForbidden,
},
{
Description: "host header in allowlist (127.0.0.1)",
HeaderName: "Host",
HeaderValue: "127.0.0.1",
ExpectedStatus: http.StatusOK,
},
{
Description: "host header in allowlist (169.254.169.254)",
HeaderName: "Host",
HeaderValue: "127.0.0.1",
ExpectedStatus: http.StatusOK,
},
}
nextHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
})
for i, tc := range cases {
t.Logf("test case %d: %s", i, tc.Description)
bfmHandler := CredentialServiceMiddleware(nextHandler)
req := httptest.NewRequest("GET", "http://localhost", nil)
req.Header.Add(tc.HeaderName, tc.HeaderValue)
rec := httptest.NewRecorder()
bfmHandler.ServeHTTP(rec, req)
if rec.Code != tc.ExpectedStatus {
t.Errorf("%s failed: got status %d, expected %d", tc.Description, rec.Code, tc.ExpectedStatus)
continue
}
if rec.Code == http.StatusOK {
if etag := rec.Header().Get("ETag"); etag == "" {
t.Errorf("%s failed: ETag header not set", tc.Description)
}
if lastModified := rec.Header().Get("Last-Modified"); lastModified == "" {
t.Errorf("%s failed: Last-Modified header not set", tc.Description)
}
if server := rec.Header().Get("Server"); server != "EC2ws" {
t.Errorf("%s failed: got Server header %s, expected %s", tc.Description, server, "EC2ws")
}
if contentType := rec.Header().Get("Content-Type"); contentType != "text/plain" {
t.Errorf("%s failed: got Content-Type header %s, expected %s", tc.Description, contentType, "text/plain")
}
}
}
}
4 changes: 2 additions & 2 deletions util/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ type AwsArn struct {
}

type ErrorResponse struct {
Error string
Error string `json:"error"`
}

func validate(arn string, pieces []string) error {
Expand Down Expand Up @@ -90,7 +90,7 @@ func FileExists(path string) bool {
}

// WriteError writes a status code and JSON-formatted error to the provided http.ResponseWriter.
func WriteError(w http.ResponseWriter, status int, message string) {
func WriteError(w http.ResponseWriter, message string, status int) {
log.Debugf("writing HTTP error response: %s", message)
resp := ErrorResponse{Error: message}
respBytes, err := json.Marshal(resp)
Expand Down

0 comments on commit 1405b82

Please sign in to comment.