From c53c94ec87860d1ada758759e36e55f6f5822009 Mon Sep 17 00:00:00 2001 From: Bruno Bastos Guimaraes Date: Fri, 23 Apr 2021 17:07:29 -0300 Subject: [PATCH 1/6] Rate limit per JWT claim --- go.mod | 2 + go.sum | 8 ++++ juju/example/krakend.json | 22 +++++++++++ juju/juju.go | 17 ++++++++ juju/router/gin/gin.go | 81 +++++++++++++++++++++++++++++++++++++++ juju/router/router.go | 35 +++++++++++++++-- 6 files changed, 161 insertions(+), 4 deletions(-) diff --git a/go.mod b/go.mod index dcc365e..67fd338 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,9 @@ go 1.13 require ( github.com/devopsfaith/krakend v0.0.0-20190930092458-9e6fc3784eca github.com/gin-gonic/gin v1.4.0 + github.com/google/go-cmp v0.5.5 // indirect github.com/json-iterator/go v1.1.8 // indirect github.com/juju/ratelimit v1.0.1 + github.com/tidwall/gjson v1.7.4 golang.org/x/time v0.0.0-20191024005414-555d28b269f0 ) diff --git a/go.sum b/go.sum index 639b366..6736cc6 100644 --- a/go.sum +++ b/go.sum @@ -17,6 +17,7 @@ github.com/golang/protobuf v1.0.0 h1:lsek0oXi8iFE9L+EXARyHIjU5rlWIhhTkjDz3vHhWWQ github.com/golang/protobuf v1.0.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1 h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/gorilla/context v0.0.0-20160226214623-1ea25387ff6f/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/mux v1.6.1/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= @@ -43,6 +44,12 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/tidwall/gjson v1.7.4 h1:19cchw8FOxkG5mdLRkGf9jqIqEyqdZhPqW60XfyFxk8= +github.com/tidwall/gjson v1.7.4/go.mod h1:5/xDoumyyDNerp2U36lyolv46b3uF/9Bu6OfyQ9GImk= +github.com/tidwall/match v1.0.3 h1:FQUVvBImDutD8wJLN6c5eMzWtjgONK9MwIBCOrUJKeE= +github.com/tidwall/match v1.0.3/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.1.0 h1:K3hMW5epkdAVwibsQEfR/7Zj0Qgt4DxtNumTq/VloO8= +github.com/tidwall/pretty v1.1.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= github.com/ugorji/go v0.0.0-20180112141927-9831f2c3ac10 h1:4zp+5ElNBLy5qmaDFrbVDolQSOtPmquw+W6EMNEpi+k= github.com/ugorji/go v0.0.0-20180112141927-9831f2c3ac10/go.mod h1:hnLbHMwcvSihnDhEfx2/BzKp2xb0Y+ErdfYcrs9tkJQ= github.com/ugorji/go v1.1.4 h1:j4s+tAvLfL3bZyefP2SEWmhBzmuIlH/eqNuPdFPgngw= @@ -61,6 +68,7 @@ golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/go-playground/assert.v1 v1.2.1 h1:xoYuJVE7KT85PYWrN730RguIQO0ePzVRfFMXadIrXTM= diff --git a/juju/example/krakend.json b/juju/example/krakend.json index 0995178..0f35b1c 100644 --- a/juju/example/krakend.json +++ b/juju/example/krakend.json @@ -37,6 +37,28 @@ ], "extra_config": { "github.com/devopsfaith/krakend-ratelimit/juju/router": { + "tierConfiguration": { + "jwtClaim": "tier", + "duration": "1m", + "tiers": [ + { + "name": "unlimited", + "limit": 0 + }, + { + "name": "gold", + "limit": 50 + }, + { + "name": "silver", + "limit": 20 + }, + { + "name": "bronze", + "limit": 5 + } + ] + }, "maxRate": 50, "clientMaxRate": 5, "strategy": "ip" diff --git a/juju/juju.go b/juju/juju.go index b4cfcd3..f81f7c1 100644 --- a/juju/juju.go +++ b/juju/juju.go @@ -9,6 +9,8 @@ package juju import ( "context" + "time" + "github.com/juju/ratelimit" krakendrate "github.com/devopsfaith/krakend-ratelimit" @@ -19,6 +21,10 @@ func NewLimiter(maxRate float64, capacity int64) Limiter { return Limiter{ratelimit.NewBucketWithRate(maxRate, capacity)} } +func NewLimiterDuration(fillInterval time.Duration, capacity int64) Limiter { + return Limiter{ratelimit.NewBucketWithQuantum(fillInterval, capacity, capacity)} +} + // Limiter is a simple wrapper over the ratelimit.Bucket struct type Limiter struct { limiter *ratelimit.Bucket @@ -37,7 +43,18 @@ func NewLimiterStore(maxRate float64, capacity int64, backend krakendrate.Backen } } +func NewLimiterDurationStore(fillInterval time.Duration, capacity int64, backend krakendrate.Backend) krakendrate.LimiterStore { + f := func() interface{} { return NewLimiterDuration(fillInterval, capacity) } + return func(t string) krakendrate.Limiter { + return backend.Load(t, f).(Limiter) + } +} + // NewMemoryStore returns a LimiterStore using the memory backend func NewMemoryStore(maxRate float64, capacity int64) krakendrate.LimiterStore { return NewLimiterStore(maxRate, capacity, krakendrate.DefaultShardedMemoryBackend(context.Background())) } + +func NewMemoryDurationStore(fillInterval time.Duration, capacity int64) krakendrate.LimiterStore { + return NewLimiterDurationStore(fillInterval, capacity, krakendrate.DefaultShardedMemoryBackend(context.Background())) +} diff --git a/juju/router/gin/gin.go b/juju/router/gin/gin.go index 54f591a..59020f9 100644 --- a/juju/router/gin/gin.go +++ b/juju/router/gin/gin.go @@ -1,9 +1,12 @@ package gin import ( + "encoding/base64" + "log" "net" "net/http" "strings" + "time" "github.com/devopsfaith/krakend/config" "github.com/devopsfaith/krakend/proxy" @@ -13,6 +16,8 @@ import ( krakendrate "github.com/devopsfaith/krakend-ratelimit" "github.com/devopsfaith/krakend-ratelimit/juju" "github.com/devopsfaith/krakend-ratelimit/juju/router" + + "github.com/tidwall/gjson" ) // HandlerFactory is the out-of-the-box basic ratelimit handler factory using the default krakend endpoint @@ -40,6 +45,14 @@ func NewRateLimiterMw(next krakendgin.HandlerFactory) krakendgin.HandlerFactory handlerFunc = NewHeaderLimiterMw(cfg.Key, float64(cfg.ClientMaxRate), cfg.ClientMaxRate)(handlerFunc) } } + if cfg.TierConfiguration != nil { + duration, err := time.ParseDuration(cfg.TierConfiguration.Duration) + if err != nil { + log.Printf("%s => Tier Configuration will be ignored.", err) + } else { + handlerFunc = NewJwtClaimLimiterMw(cfg.TierConfiguration, duration)(handlerFunc) + } + } return handlerFunc } } @@ -78,6 +91,18 @@ func NewIpLimiterWithKeyMw(header string, maxRate float64, capacity int64) Endpo return NewTokenLimiterMw(NewIPTokenExtractor(header), juju.NewMemoryStore(maxRate, capacity)) } +func NewJwtClaimLimiterMw(tierConfiguration *router.TierConfiguration, fillInterval time.Duration) EndpointMw { + var stores = map[string]krakendrate.LimiterStore{} + var capacities = map[string]int64{} + for _, tier := range tierConfiguration.Tiers { + if tier.Limit > 0 { + stores[tier.Name] = nil + capacities[tier.Name] = tier.Limit + } + } + return NewTokenLimiterPerPlanMw(JwtClaimTokenExtractor(tierConfiguration.JwtClaim), fillInterval, stores, capacities) +} + // TokenExtractor defines the interface of the functions to use in order to extract a token for each request type TokenExtractor func(*gin.Context) string @@ -120,3 +145,59 @@ func NewTokenLimiterMw(tokenExtractor TokenExtractor, limiterStore krakendrate.L } } } + +func NewTokenLimiterPerPlanMw(tokenExtractor TokenExtractor, fillInterval time.Duration, mapLimiterStore map[string]krakendrate.LimiterStore, mapCapacities map[string]int64) EndpointMw { + return func(next gin.HandlerFunc) gin.HandlerFunc { + return func(c *gin.Context) { + tokenKey := tokenExtractor(c) + if tokenKey == "" { + c.AbortWithError(http.StatusTooManyRequests, krakendrate.ErrLimited) + return + } + tierName := strings.Split(tokenKey, "-")[0] + _, tierNameExists := mapLimiterStore[tierName] + if tierNameExists { + if mapLimiterStore[tierName] == nil { + mapLimiterStore[tierName] = juju.NewMemoryDurationStore(fillInterval, mapCapacities[tierName]) + } + if !mapLimiterStore[tierName](tokenKey).Allow() { + c.AbortWithError(http.StatusTooManyRequests, krakendrate.ErrLimited) + return + } + } + next(c) + } + } +} + +func JwtClaimTokenExtractor(jwtClaimName string) TokenExtractor { + return func(c *gin.Context) string { + bearer := c.Request.Header.Get("Authorization") + if bearer != "" && strings.Count(bearer, ".") == 2 { + start := strings.Index(bearer, ".") + end := strings.LastIndex(bearer, ".") + rawPayload, err := base64.RawStdEncoding.DecodeString(bearer[start+1 : end]) + if err != nil { + log.Println("Invalid JWT payload (not Base64)") + return "" + } + jsonPayload := string(rawPayload) + if !gjson.Valid(jsonPayload) { + log.Println("Invalid JWT payload (not JSON)") + return "" + } + jwtClaim := gjson.Get(jsonPayload, jwtClaimName) + sub := gjson.Get(jsonPayload, "sub") + if !jwtClaim.Exists() { + log.Printf("Claim '%s' not found in payload", jwtClaimName) + return "" + } + if !sub.Exists() { + log.Println("Claim 'sub' not found in payload") + return "" + } + return jwtClaim.String() + "-" + sub.String() + } + return "" + } +} diff --git a/juju/router/router.go b/juju/router/router.go index 432b601..276904e 100644 --- a/juju/router/router.go +++ b/juju/router/router.go @@ -22,6 +22,7 @@ and http://en.wikipedia.org/wiki/Token_bucket for more details. package router import ( + "encoding/json" "fmt" "github.com/devopsfaith/krakend/config" @@ -32,10 +33,22 @@ const Namespace = "github.com/devopsfaith/krakend-ratelimit/juju/router" // Config is the custom config struct containing the params for the router middlewares type Config struct { - MaxRate int64 - Strategy string - ClientMaxRate int64 - Key string + MaxRate int64 + Strategy string + ClientMaxRate int64 + Key string + TierConfiguration *TierConfiguration +} + +type TierConfiguration struct { + JwtClaim string + Duration string + Tiers []Tier +} + +type Tier struct { + Name string + Limit int64 } // ZeroCfg is the zero value for the Config struct @@ -79,5 +92,19 @@ func ConfigGetter(e config.ExtraConfig) interface{} { if v, ok := tmp["key"]; ok { cfg.Key = fmt.Sprintf("%v", v) } + if v, ok := tmp["tierConfiguration"]; ok { + jsonbody, err := json.Marshal(v) + if err != nil { + fmt.Println(err) + return ZeroCfg + } + + tierConfiguration := TierConfiguration{} + if err := json.Unmarshal(jsonbody, &tierConfiguration); err != nil { + fmt.Println(err) + return ZeroCfg + } + cfg.TierConfiguration = &tierConfiguration + } return cfg } From 04034a7b89333f3fa3686c400e640f6d07a1b40c Mon Sep 17 00:00:00 2001 From: Bruno Bastos Guimaraes Date: Tue, 18 May 2021 09:13:58 -0300 Subject: [PATCH 2/6] Using headers as token for rate limit per tier --- go.mod | 2 -- go.sum | 8 ------ juju/example/krakend.json | 7 ++++- juju/router/gin/gin.go | 56 +++++++++++--------------------------- juju/router/router.go | 7 +++-- juju/router/router_test.go | 3 ++ 6 files changed, 29 insertions(+), 54 deletions(-) diff --git a/go.mod b/go.mod index 67fd338..dcc365e 100644 --- a/go.mod +++ b/go.mod @@ -5,9 +5,7 @@ go 1.13 require ( github.com/devopsfaith/krakend v0.0.0-20190930092458-9e6fc3784eca github.com/gin-gonic/gin v1.4.0 - github.com/google/go-cmp v0.5.5 // indirect github.com/json-iterator/go v1.1.8 // indirect github.com/juju/ratelimit v1.0.1 - github.com/tidwall/gjson v1.7.4 golang.org/x/time v0.0.0-20191024005414-555d28b269f0 ) diff --git a/go.sum b/go.sum index 6736cc6..639b366 100644 --- a/go.sum +++ b/go.sum @@ -17,7 +17,6 @@ github.com/golang/protobuf v1.0.0 h1:lsek0oXi8iFE9L+EXARyHIjU5rlWIhhTkjDz3vHhWWQ github.com/golang/protobuf v1.0.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.1 h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= github.com/gorilla/context v0.0.0-20160226214623-1ea25387ff6f/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= github.com/gorilla/mux v1.6.1/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= @@ -44,12 +43,6 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/tidwall/gjson v1.7.4 h1:19cchw8FOxkG5mdLRkGf9jqIqEyqdZhPqW60XfyFxk8= -github.com/tidwall/gjson v1.7.4/go.mod h1:5/xDoumyyDNerp2U36lyolv46b3uF/9Bu6OfyQ9GImk= -github.com/tidwall/match v1.0.3 h1:FQUVvBImDutD8wJLN6c5eMzWtjgONK9MwIBCOrUJKeE= -github.com/tidwall/match v1.0.3/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= -github.com/tidwall/pretty v1.1.0 h1:K3hMW5epkdAVwibsQEfR/7Zj0Qgt4DxtNumTq/VloO8= -github.com/tidwall/pretty v1.1.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= github.com/ugorji/go v0.0.0-20180112141927-9831f2c3ac10 h1:4zp+5ElNBLy5qmaDFrbVDolQSOtPmquw+W6EMNEpi+k= github.com/ugorji/go v0.0.0-20180112141927-9831f2c3ac10/go.mod h1:hnLbHMwcvSihnDhEfx2/BzKp2xb0Y+ErdfYcrs9tkJQ= github.com/ugorji/go v1.1.4 h1:j4s+tAvLfL3bZyefP2SEWmhBzmuIlH/eqNuPdFPgngw= @@ -68,7 +61,6 @@ golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= -golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/go-playground/assert.v1 v1.2.1 h1:xoYuJVE7KT85PYWrN730RguIQO0ePzVRfFMXadIrXTM= diff --git a/juju/example/krakend.json b/juju/example/krakend.json index 0f35b1c..365d27d 100644 --- a/juju/example/krakend.json +++ b/juju/example/krakend.json @@ -7,6 +7,10 @@ "endpoints": [ { "endpoint": "/showrss/{id}", + "headers_to_pass": [ + "x-user", + "x-tier" + ], "backend": [ { "host": [ @@ -38,7 +42,8 @@ "extra_config": { "github.com/devopsfaith/krakend-ratelimit/juju/router": { "tierConfiguration": { - "jwtClaim": "tier", + "headerTier": "x-tier", + "headerUser": "x-user", "duration": "1m", "tiers": [ { diff --git a/juju/router/gin/gin.go b/juju/router/gin/gin.go index 59020f9..7a7523a 100644 --- a/juju/router/gin/gin.go +++ b/juju/router/gin/gin.go @@ -1,7 +1,6 @@ package gin import ( - "encoding/base64" "log" "net" "net/http" @@ -16,8 +15,6 @@ import ( krakendrate "github.com/devopsfaith/krakend-ratelimit" "github.com/devopsfaith/krakend-ratelimit/juju" "github.com/devopsfaith/krakend-ratelimit/juju/router" - - "github.com/tidwall/gjson" ) // HandlerFactory is the out-of-the-box basic ratelimit handler factory using the default krakend endpoint @@ -30,7 +27,7 @@ func NewRateLimiterMw(next krakendgin.HandlerFactory) krakendgin.HandlerFactory handlerFunc := next(remote, p) cfg := router.ConfigGetter(remote.ExtraConfig).(router.Config) - if cfg == router.ZeroCfg || (cfg.MaxRate <= 0 && cfg.ClientMaxRate <= 0) { + if cfg == router.ZeroCfg || (cfg.MaxRate <= 0 && cfg.ClientMaxRate <= 0 || cfg.TierConfiguration == nil) { return handlerFunc } @@ -50,7 +47,7 @@ func NewRateLimiterMw(next krakendgin.HandlerFactory) krakendgin.HandlerFactory if err != nil { log.Printf("%s => Tier Configuration will be ignored.", err) } else { - handlerFunc = NewJwtClaimLimiterMw(cfg.TierConfiguration, duration)(handlerFunc) + handlerFunc = NewTierLimiterMw(cfg.TierConfiguration, duration)(handlerFunc) } } return handlerFunc @@ -91,7 +88,7 @@ func NewIpLimiterWithKeyMw(header string, maxRate float64, capacity int64) Endpo return NewTokenLimiterMw(NewIPTokenExtractor(header), juju.NewMemoryStore(maxRate, capacity)) } -func NewJwtClaimLimiterMw(tierConfiguration *router.TierConfiguration, fillInterval time.Duration) EndpointMw { +func NewTierLimiterMw(tierConfiguration *router.TierConfiguration, fillInterval time.Duration) EndpointMw { var stores = map[string]krakendrate.LimiterStore{} var capacities = map[string]int64{} for _, tier := range tierConfiguration.Tiers { @@ -100,7 +97,7 @@ func NewJwtClaimLimiterMw(tierConfiguration *router.TierConfiguration, fillInter capacities[tier.Name] = tier.Limit } } - return NewTokenLimiterPerPlanMw(JwtClaimTokenExtractor(tierConfiguration.JwtClaim), fillInterval, stores, capacities) + return NewTokenLimiterPerTierMw(HeadersTokenExtractor([]string{tierConfiguration.HeaderTier, tierConfiguration.HeaderUser}), fillInterval, stores, capacities) } // TokenExtractor defines the interface of the functions to use in order to extract a token for each request @@ -128,6 +125,17 @@ func HeaderTokenExtractor(header string) TokenExtractor { return func(c *gin.Context) string { return c.Request.Header.Get(header) } } +// HeadersTokenExtractor returns a TokenExtractor that looks for the values of the designed headers +func HeadersTokenExtractor(headers []string) TokenExtractor { + return func(c *gin.Context) string { + var headerValues = make([]string, len(headers)) + for i, header := range headers { + headerValues[i] = c.Request.Header.Get(header) + } + return strings.Join(headerValues, "-") + } +} + // NewTokenLimiterMw returns a token based ratelimiting endpoint middleware with the received TokenExtractor and LimiterStore func NewTokenLimiterMw(tokenExtractor TokenExtractor, limiterStore krakendrate.LimiterStore) EndpointMw { return func(next gin.HandlerFunc) gin.HandlerFunc { @@ -146,7 +154,7 @@ func NewTokenLimiterMw(tokenExtractor TokenExtractor, limiterStore krakendrate.L } } -func NewTokenLimiterPerPlanMw(tokenExtractor TokenExtractor, fillInterval time.Duration, mapLimiterStore map[string]krakendrate.LimiterStore, mapCapacities map[string]int64) EndpointMw { +func NewTokenLimiterPerTierMw(tokenExtractor TokenExtractor, fillInterval time.Duration, mapLimiterStore map[string]krakendrate.LimiterStore, mapCapacities map[string]int64) EndpointMw { return func(next gin.HandlerFunc) gin.HandlerFunc { return func(c *gin.Context) { tokenKey := tokenExtractor(c) @@ -169,35 +177,3 @@ func NewTokenLimiterPerPlanMw(tokenExtractor TokenExtractor, fillInterval time.D } } } - -func JwtClaimTokenExtractor(jwtClaimName string) TokenExtractor { - return func(c *gin.Context) string { - bearer := c.Request.Header.Get("Authorization") - if bearer != "" && strings.Count(bearer, ".") == 2 { - start := strings.Index(bearer, ".") - end := strings.LastIndex(bearer, ".") - rawPayload, err := base64.RawStdEncoding.DecodeString(bearer[start+1 : end]) - if err != nil { - log.Println("Invalid JWT payload (not Base64)") - return "" - } - jsonPayload := string(rawPayload) - if !gjson.Valid(jsonPayload) { - log.Println("Invalid JWT payload (not JSON)") - return "" - } - jwtClaim := gjson.Get(jsonPayload, jwtClaimName) - sub := gjson.Get(jsonPayload, "sub") - if !jwtClaim.Exists() { - log.Printf("Claim '%s' not found in payload", jwtClaimName) - return "" - } - if !sub.Exists() { - log.Println("Claim 'sub' not found in payload") - return "" - } - return jwtClaim.String() + "-" + sub.String() - } - return "" - } -} diff --git a/juju/router/router.go b/juju/router/router.go index 276904e..a4b3772 100644 --- a/juju/router/router.go +++ b/juju/router/router.go @@ -41,9 +41,10 @@ type Config struct { } type TierConfiguration struct { - JwtClaim string - Duration string - Tiers []Tier + HeaderTier string + HeaderUser string + Duration string + Tiers []Tier } type Tier struct { diff --git a/juju/router/router_test.go b/juju/router/router_test.go index 7230b58..06c2f50 100644 --- a/juju/router/router_test.go +++ b/juju/router/router_test.go @@ -30,4 +30,7 @@ func TestConfigGetter(t *testing.T) { if cfg.Key != "" { t.Errorf("wrong value for Key. Want: '', have: %s", cfg.Key) } + if cfg.TierConfiguration != nil { + t.Errorf("wrong value for Key. Want: '', have: %+v", cfg.TierConfiguration) + } } From 84a3f290826787e2b66402598427e6a0027922c9 Mon Sep 17 00:00:00 2001 From: Bruno Bastos Guimaraes Date: Tue, 18 May 2021 15:24:40 -0300 Subject: [PATCH 3/6] Using sharded store instead of map --- juju/router/gin/gin.go | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/juju/router/gin/gin.go b/juju/router/gin/gin.go index 7a7523a..7543c49 100644 --- a/juju/router/gin/gin.go +++ b/juju/router/gin/gin.go @@ -1,6 +1,7 @@ package gin import ( + "context" "log" "net" "net/http" @@ -89,15 +90,13 @@ func NewIpLimiterWithKeyMw(header string, maxRate float64, capacity int64) Endpo } func NewTierLimiterMw(tierConfiguration *router.TierConfiguration, fillInterval time.Duration) EndpointMw { - var stores = map[string]krakendrate.LimiterStore{} - var capacities = map[string]int64{} + var storesPerTier = krakendrate.NewShardedMemoryBackend(context.Background(), 256, fillInterval, krakendrate.PseudoFNV64a) for _, tier := range tierConfiguration.Tiers { if tier.Limit > 0 { - stores[tier.Name] = nil - capacities[tier.Name] = tier.Limit + storesPerTier.Store(tier.Name, juju.NewMemoryDurationStore(fillInterval, tier.Limit)) } } - return NewTokenLimiterPerTierMw(HeadersTokenExtractor([]string{tierConfiguration.HeaderTier, tierConfiguration.HeaderUser}), fillInterval, stores, capacities) + return NewTokenLimiterPerTierMw(HeadersTokenExtractor([]string{tierConfiguration.HeaderTier, tierConfiguration.HeaderUser}), fillInterval, storesPerTier) } // TokenExtractor defines the interface of the functions to use in order to extract a token for each request @@ -154,7 +153,8 @@ func NewTokenLimiterMw(tokenExtractor TokenExtractor, limiterStore krakendrate.L } } -func NewTokenLimiterPerTierMw(tokenExtractor TokenExtractor, fillInterval time.Duration, mapLimiterStore map[string]krakendrate.LimiterStore, mapCapacities map[string]int64) EndpointMw { +func NewTokenLimiterPerTierMw(tokenExtractor TokenExtractor, fillInterval time.Duration, storesPerTier *krakendrate.ShardedMemoryBackend) EndpointMw { + var noResult = func() interface{} { return nil } return func(next gin.HandlerFunc) gin.HandlerFunc { return func(c *gin.Context) { tokenKey := tokenExtractor(c) @@ -162,16 +162,16 @@ func NewTokenLimiterPerTierMw(tokenExtractor TokenExtractor, fillInterval time.D c.AbortWithError(http.StatusTooManyRequests, krakendrate.ErrLimited) return } - tierName := strings.Split(tokenKey, "-")[0] - _, tierNameExists := mapLimiterStore[tierName] - if tierNameExists { - if mapLimiterStore[tierName] == nil { - mapLimiterStore[tierName] = juju.NewMemoryDurationStore(fillInterval, mapCapacities[tierName]) - } - if !mapLimiterStore[tierName](tokenKey).Allow() { + tokenKeyParts := strings.Split(tokenKey, "-") + tierName, user := tokenKeyParts[0], tokenKeyParts[1] + tierLimiter := storesPerTier.Load(tierName, noResult) + if tierLimiter != nil { + if !tierLimiter.(krakendrate.LimiterStore)(user).Allow() { c.AbortWithError(http.StatusTooManyRequests, krakendrate.ErrLimited) return } + } else { + log.Printf("Tier %s does not exist.", tierName) } next(c) } From 3eca5ca7344652b371796235163db9464f9e5570 Mon Sep 17 00:00:00 2001 From: Bruno Bastos Guimaraes Date: Tue, 18 May 2021 17:10:01 -0300 Subject: [PATCH 4/6] Refactor to allow IP as client identifier in tier configuration --- juju/example/krakend.json | 3 ++- juju/router/gin/gin.go | 45 ++++++++++++++++++++++++++++++--------- juju/router/router.go | 3 ++- 3 files changed, 39 insertions(+), 12 deletions(-) diff --git a/juju/example/krakend.json b/juju/example/krakend.json index 365d27d..1e22176 100644 --- a/juju/example/krakend.json +++ b/juju/example/krakend.json @@ -43,7 +43,8 @@ "github.com/devopsfaith/krakend-ratelimit/juju/router": { "tierConfiguration": { "headerTier": "x-tier", - "headerUser": "x-user", + "strategy": "header", + "key": "x-user", "duration": "1m", "tiers": [ { diff --git a/juju/router/gin/gin.go b/juju/router/gin/gin.go index 7543c49..f536b9e 100644 --- a/juju/router/gin/gin.go +++ b/juju/router/gin/gin.go @@ -44,9 +44,12 @@ func NewRateLimiterMw(next krakendgin.HandlerFactory) krakendgin.HandlerFactory } } if cfg.TierConfiguration != nil { + strategy := strings.ToLower(cfg.TierConfiguration.Strategy) duration, err := time.ParseDuration(cfg.TierConfiguration.Duration) if err != nil { log.Printf("%s => Tier Configuration will be ignored.", err) + } else if strategy != "ip" && strategy != "header" { + log.Printf("%s is not a valid strategy => Tier Configuration will be ignored", strategy) } else { handlerFunc = NewTierLimiterMw(cfg.TierConfiguration, duration)(handlerFunc) } @@ -89,14 +92,19 @@ func NewIpLimiterWithKeyMw(header string, maxRate float64, capacity int64) Endpo return NewTokenLimiterMw(NewIPTokenExtractor(header), juju.NewMemoryStore(maxRate, capacity)) } +// NewIpLimiterWithKeyMw creates a token ratelimiter using the IP/header of the request and tier name as a token func NewTierLimiterMw(tierConfiguration *router.TierConfiguration, fillInterval time.Duration) EndpointMw { - var storesPerTier = krakendrate.NewShardedMemoryBackend(context.Background(), 256, fillInterval, krakendrate.PseudoFNV64a) + var storesPerTier = krakendrate.NewShardedMemoryBackend(context.Background(), 2, fillInterval, krakendrate.PseudoFNV64a) for _, tier := range tierConfiguration.Tiers { if tier.Limit > 0 { storesPerTier.Store(tier.Name, juju.NewMemoryDurationStore(fillInterval, tier.Limit)) } } - return NewTokenLimiterPerTierMw(HeadersTokenExtractor([]string{tierConfiguration.HeaderTier, tierConfiguration.HeaderUser}), fillInterval, storesPerTier) + return NewTokenLimiterPerTierMw( + NewConcatTokenExtractor(tierConfiguration.HeaderTier, strings.ToLower(tierConfiguration.Strategy), tierConfiguration.Key), + fillInterval, + storesPerTier, + ) } // TokenExtractor defines the interface of the functions to use in order to extract a token for each request @@ -124,17 +132,33 @@ func HeaderTokenExtractor(header string) TokenExtractor { return func(c *gin.Context) string { return c.Request.Header.Get(header) } } -// HeadersTokenExtractor returns a TokenExtractor that looks for the values of the designed headers -func HeadersTokenExtractor(headers []string) TokenExtractor { +// ConcatTokenExtractor returns a TokenExtractor that concatenates all passed token extractors +func ConcatTokenExtractor(tokenExtractors []TokenExtractor) TokenExtractor { return func(c *gin.Context) string { - var headerValues = make([]string, len(headers)) - for i, header := range headers { - headerValues[i] = c.Request.Header.Get(header) + var tokenValues = make([]string, len(tokenExtractors)) + for i, tokenExtractor := range tokenExtractors { + tokenValues[i] = tokenExtractor(c) } - return strings.Join(headerValues, "-") + return strings.Join(tokenValues, "-") } } +// NewConcatTokenExtractor generates a ConcatTokenExtractor using ip or header extractors depending on the strategy +func NewConcatTokenExtractor(headerTier string, strategy string, key string) TokenExtractor { + var tierTokenExtractor = HeaderTokenExtractor(headerTier) + var clientIdentifierTokenExtractor TokenExtractor + if strategy == "ip" { + if key == "" { + clientIdentifierTokenExtractor = IPTokenExtractor + } else { + clientIdentifierTokenExtractor = NewIPTokenExtractor(key) + } + } else if strategy == "header" { + clientIdentifierTokenExtractor = HeaderTokenExtractor(key) + } + return ConcatTokenExtractor([]TokenExtractor{tierTokenExtractor, clientIdentifierTokenExtractor}) +} + // NewTokenLimiterMw returns a token based ratelimiting endpoint middleware with the received TokenExtractor and LimiterStore func NewTokenLimiterMw(tokenExtractor TokenExtractor, limiterStore krakendrate.LimiterStore) EndpointMw { return func(next gin.HandlerFunc) gin.HandlerFunc { @@ -153,6 +177,7 @@ func NewTokenLimiterMw(tokenExtractor TokenExtractor, limiterStore krakendrate.L } } +// NewTokenLimiterPerTierMw returns a token based ratelimiting endpoint middleware with the received TokenExtractor and different LimiterStores per tier func NewTokenLimiterPerTierMw(tokenExtractor TokenExtractor, fillInterval time.Duration, storesPerTier *krakendrate.ShardedMemoryBackend) EndpointMw { var noResult = func() interface{} { return nil } return func(next gin.HandlerFunc) gin.HandlerFunc { @@ -163,10 +188,10 @@ func NewTokenLimiterPerTierMw(tokenExtractor TokenExtractor, fillInterval time.D return } tokenKeyParts := strings.Split(tokenKey, "-") - tierName, user := tokenKeyParts[0], tokenKeyParts[1] + tierName, clientIdentifier := tokenKeyParts[0], tokenKeyParts[1] tierLimiter := storesPerTier.Load(tierName, noResult) if tierLimiter != nil { - if !tierLimiter.(krakendrate.LimiterStore)(user).Allow() { + if !tierLimiter.(krakendrate.LimiterStore)(clientIdentifier).Allow() { c.AbortWithError(http.StatusTooManyRequests, krakendrate.ErrLimited) return } diff --git a/juju/router/router.go b/juju/router/router.go index a4b3772..e908f5c 100644 --- a/juju/router/router.go +++ b/juju/router/router.go @@ -41,8 +41,9 @@ type Config struct { } type TierConfiguration struct { + Strategy string + Key string HeaderTier string - HeaderUser string Duration string Tiers []Tier } From cc6d030f04c79aa44667fe19efe37d2080cccec7 Mon Sep 17 00:00:00 2001 From: Bruno Bastos Guimaraes Date: Wed, 19 May 2021 08:55:44 -0300 Subject: [PATCH 5/6] Fixing test --- juju/router/router_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/juju/router/router_test.go b/juju/router/router_test.go index 06c2f50..4385399 100644 --- a/juju/router/router_test.go +++ b/juju/router/router_test.go @@ -31,6 +31,6 @@ func TestConfigGetter(t *testing.T) { t.Errorf("wrong value for Key. Want: '', have: %s", cfg.Key) } if cfg.TierConfiguration != nil { - t.Errorf("wrong value for Key. Want: '', have: %+v", cfg.TierConfiguration) + t.Errorf("wrong value for TierConfiguration. Want: , have: %+v", cfg.TierConfiguration) } } From 9c8332d7d2d0883808446ba4326be196a22393be Mon Sep 17 00:00:00 2001 From: Bruno Bastos Guimaraes Date: Wed, 19 May 2021 09:32:57 -0300 Subject: [PATCH 6/6] Added tests for the different tier limiters --- juju/router/gin/gin_test.go | 102 ++++++++++++++++++++++++++++++++++++ 1 file changed, 102 insertions(+) diff --git a/juju/router/gin/gin_test.go b/juju/router/gin/gin_test.go index 9d334b0..29c6ae8 100644 --- a/juju/router/gin/gin_test.go +++ b/juju/router/gin/gin_test.go @@ -69,6 +69,108 @@ func TestNewRateLimiterMw_DefaultIP(t *testing.T) { testRateLimiterMw(t, rd, cfg) } +func TestNewRateLimiterMw_TierCustomHeader(t *testing.T) { + headerTier := "X-Tier" + headerUser := "X-User" + + cfg := &config.EndpointConfig{ + ExtraConfig: map[string]interface{}{ + router.Namespace: map[string]interface{}{ + "tierConfiguration": map[string]interface{}{ + "headerTier": headerTier, + "strategy": "header", + "key": headerUser, + "duration": "1s", + "tiers": []map[string]interface{}{ + { + "name": "tier1", + "limit": 100, + }, + { + "name": "tier2", + "limit": 200, + }, + }, + }, + }, + }, + } + + rd := func(req *http.Request) { + req.Header.Add(headerTier, "tier1") + req.Header.Add(headerUser, "1234567890") + } + + testRateLimiterMw(t, rd, cfg) +} + +func TestNewRateLimiterMw_TierDefaultIP(t *testing.T) { + headerTier := "X-Tier" + + cfg := &config.EndpointConfig{ + ExtraConfig: map[string]interface{}{ + router.Namespace: map[string]interface{}{ + "tierConfiguration": map[string]interface{}{ + "headerTier": headerTier, + "strategy": "ip", + "duration": "1s", + "tiers": []map[string]interface{}{ + { + "name": "tier1", + "limit": 100, + }, + { + "name": "tier2", + "limit": 200, + }, + }, + }, + }, + }, + } + + rd := func(req *http.Request) { + req.Header.Add(headerTier, "tier1") + } + + testRateLimiterMw(t, rd, cfg) +} + +func TestNewRateLimiterMw_TierCustomHeaderIP(t *testing.T) { + headerTier := "X-Tier" + headerIP := "X-Custom-Forwarded-For" + + cfg := &config.EndpointConfig{ + ExtraConfig: map[string]interface{}{ + router.Namespace: map[string]interface{}{ + "tierConfiguration": map[string]interface{}{ + "headerTier": headerTier, + "strategy": "ip", + "key": headerIP, + "duration": "1s", + "tiers": []map[string]interface{}{ + { + "name": "tier1", + "limit": 100, + }, + { + "name": "tier2", + "limit": 200, + }, + }, + }, + }, + }, + } + + rd := func(req *http.Request) { + req.Header.Add(headerTier, "tier1") + req.Header.Add(headerIP, "1.1.1.1,2.2.2.2,3.3.3.3") + } + + testRateLimiterMw(t, rd, cfg) +} + type requestDecorator func(*http.Request) func testRateLimiterMw(t *testing.T, rd requestDecorator, cfg *config.EndpointConfig) {