Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rate limit per JWT claim #17

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ go 1.13
require (
github.com/devopsfaith/krakend v0.0.0-20190930092458-9e6fc3784eca
github.com/gin-gonic/gin v1.4.0
github.com/google/go-cmp v0.5.5 // indirect
github.com/json-iterator/go v1.1.8 // indirect
github.com/juju/ratelimit v1.0.1
github.com/tidwall/gjson v1.7.4
golang.org/x/time v0.0.0-20191024005414-555d28b269f0
)
8 changes: 8 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ github.com/golang/protobuf v1.0.0 h1:lsek0oXi8iFE9L+EXARyHIjU5rlWIhhTkjDz3vHhWWQ
github.com/golang/protobuf v1.0.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/golang/protobuf v1.3.1 h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg=
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/gorilla/context v0.0.0-20160226214623-1ea25387ff6f/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg=
github.com/gorilla/mux v1.6.1/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs=
Expand All @@ -43,6 +44,12 @@ github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/tidwall/gjson v1.7.4 h1:19cchw8FOxkG5mdLRkGf9jqIqEyqdZhPqW60XfyFxk8=
github.com/tidwall/gjson v1.7.4/go.mod h1:5/xDoumyyDNerp2U36lyolv46b3uF/9Bu6OfyQ9GImk=
github.com/tidwall/match v1.0.3 h1:FQUVvBImDutD8wJLN6c5eMzWtjgONK9MwIBCOrUJKeE=
github.com/tidwall/match v1.0.3/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM=
github.com/tidwall/pretty v1.1.0 h1:K3hMW5epkdAVwibsQEfR/7Zj0Qgt4DxtNumTq/VloO8=
github.com/tidwall/pretty v1.1.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
github.com/ugorji/go v0.0.0-20180112141927-9831f2c3ac10 h1:4zp+5ElNBLy5qmaDFrbVDolQSOtPmquw+W6EMNEpi+k=
github.com/ugorji/go v0.0.0-20180112141927-9831f2c3ac10/go.mod h1:hnLbHMwcvSihnDhEfx2/BzKp2xb0Y+ErdfYcrs9tkJQ=
github.com/ugorji/go v1.1.4 h1:j4s+tAvLfL3bZyefP2SEWmhBzmuIlH/eqNuPdFPgngw=
Expand All @@ -61,6 +68,7 @@ golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5h
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0 h1:/5xXl8Y5W96D+TtHSlonuFqGHIWVuyCkGJLwGh9JJFs=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/go-playground/assert.v1 v1.2.1 h1:xoYuJVE7KT85PYWrN730RguIQO0ePzVRfFMXadIrXTM=
Expand Down
22 changes: 22 additions & 0 deletions juju/example/krakend.json
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,28 @@
],
"extra_config": {
"github.com/devopsfaith/krakend-ratelimit/juju/router": {
"tierConfiguration": {
"jwtClaim": "tier",
"duration": "1m",
"tiers": [
{
"name": "unlimited",
"limit": 0
},
{
"name": "gold",
"limit": 50
},
{
"name": "silver",
"limit": 20
},
{
"name": "bronze",
"limit": 5
}
]
},
"maxRate": 50,
"clientMaxRate": 5,
"strategy": "ip"
Expand Down
17 changes: 17 additions & 0 deletions juju/juju.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ package juju
import (
"context"

"time"

"github.com/juju/ratelimit"

krakendrate "github.com/devopsfaith/krakend-ratelimit"
Expand All @@ -19,6 +21,10 @@ func NewLimiter(maxRate float64, capacity int64) Limiter {
return Limiter{ratelimit.NewBucketWithRate(maxRate, capacity)}
}

func NewLimiterDuration(fillInterval time.Duration, capacity int64) Limiter {
return Limiter{ratelimit.NewBucketWithQuantum(fillInterval, capacity, capacity)}
}

// Limiter is a simple wrapper over the ratelimit.Bucket struct
type Limiter struct {
limiter *ratelimit.Bucket
Expand All @@ -37,7 +43,18 @@ func NewLimiterStore(maxRate float64, capacity int64, backend krakendrate.Backen
}
}

func NewLimiterDurationStore(fillInterval time.Duration, capacity int64, backend krakendrate.Backend) krakendrate.LimiterStore {
f := func() interface{} { return NewLimiterDuration(fillInterval, capacity) }
return func(t string) krakendrate.Limiter {
return backend.Load(t, f).(Limiter)
}
}

// NewMemoryStore returns a LimiterStore using the memory backend
func NewMemoryStore(maxRate float64, capacity int64) krakendrate.LimiterStore {
return NewLimiterStore(maxRate, capacity, krakendrate.DefaultShardedMemoryBackend(context.Background()))
}

func NewMemoryDurationStore(fillInterval time.Duration, capacity int64) krakendrate.LimiterStore {
return NewLimiterDurationStore(fillInterval, capacity, krakendrate.DefaultShardedMemoryBackend(context.Background()))
}
81 changes: 81 additions & 0 deletions juju/router/gin/gin.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package gin

import (
"encoding/base64"
"log"
"net"
"net/http"
"strings"
"time"

"github.com/devopsfaith/krakend/config"
"github.com/devopsfaith/krakend/proxy"
Expand All @@ -13,6 +16,8 @@ import (
krakendrate "github.com/devopsfaith/krakend-ratelimit"
"github.com/devopsfaith/krakend-ratelimit/juju"
"github.com/devopsfaith/krakend-ratelimit/juju/router"

"github.com/tidwall/gjson"
)

// HandlerFactory is the out-of-the-box basic ratelimit handler factory using the default krakend endpoint
Expand Down Expand Up @@ -40,6 +45,14 @@ func NewRateLimiterMw(next krakendgin.HandlerFactory) krakendgin.HandlerFactory
handlerFunc = NewHeaderLimiterMw(cfg.Key, float64(cfg.ClientMaxRate), cfg.ClientMaxRate)(handlerFunc)
}
}
if cfg.TierConfiguration != nil {
duration, err := time.ParseDuration(cfg.TierConfiguration.Duration)
if err != nil {
log.Printf("%s => Tier Configuration will be ignored.", err)
} else {
handlerFunc = NewJwtClaimLimiterMw(cfg.TierConfiguration, duration)(handlerFunc)
}
}
return handlerFunc
}
}
Expand Down Expand Up @@ -78,6 +91,18 @@ func NewIpLimiterWithKeyMw(header string, maxRate float64, capacity int64) Endpo
return NewTokenLimiterMw(NewIPTokenExtractor(header), juju.NewMemoryStore(maxRate, capacity))
}

func NewJwtClaimLimiterMw(tierConfiguration *router.TierConfiguration, fillInterval time.Duration) EndpointMw {
var stores = map[string]krakendrate.LimiterStore{}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of using a simple map, why don't you go with a sharded store (https://github.com/devopsfaith/krakend-ratelimit/blob/master/krakendrate.go#L52) ?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll look into it.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hi, there is a map without a mutex. There is a good implementation on the link above.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, you commented on an outdated version. What do you mean? I'm using a ShardedMemoryBackend.

var capacities = map[string]int64{}
for _, tier := range tierConfiguration.Tiers {
if tier.Limit > 0 {
stores[tier.Name] = nil
capacities[tier.Name] = tier.Limit
}
}
return NewTokenLimiterPerPlanMw(JwtClaimTokenExtractor(tierConfiguration.JwtClaim), fillInterval, stores, capacities)
}

// TokenExtractor defines the interface of the functions to use in order to extract a token for each request
type TokenExtractor func(*gin.Context) string

Expand Down Expand Up @@ -120,3 +145,59 @@ func NewTokenLimiterMw(tokenExtractor TokenExtractor, limiterStore krakendrate.L
}
}
}

func NewTokenLimiterPerPlanMw(tokenExtractor TokenExtractor, fillInterval time.Duration, mapLimiterStore map[string]krakendrate.LimiterStore, mapCapacities map[string]int64) EndpointMw {
return func(next gin.HandlerFunc) gin.HandlerFunc {
return func(c *gin.Context) {
tokenKey := tokenExtractor(c)
if tokenKey == "" {
c.AbortWithError(http.StatusTooManyRequests, krakendrate.ErrLimited)
return
}
tierName := strings.Split(tokenKey, "-")[0]
_, tierNameExists := mapLimiterStore[tierName]
brunobastosg marked this conversation as resolved.
Show resolved Hide resolved
if tierNameExists {
if mapLimiterStore[tierName] == nil {
mapLimiterStore[tierName] = juju.NewMemoryDurationStore(fillInterval, mapCapacities[tierName])
}
if !mapLimiterStore[tierName](tokenKey).Allow() {
c.AbortWithError(http.StatusTooManyRequests, krakendrate.ErrLimited)
return
}
}
next(c)
}
}
}

func JwtClaimTokenExtractor(jwtClaimName string) TokenExtractor {
brunobastosg marked this conversation as resolved.
Show resolved Hide resolved
return func(c *gin.Context) string {
bearer := c.Request.Header.Get("Authorization")
if bearer != "" && strings.Count(bearer, ".") == 2 {
start := strings.Index(bearer, ".")
end := strings.LastIndex(bearer, ".")
rawPayload, err := base64.RawStdEncoding.DecodeString(bearer[start+1 : end])
if err != nil {
log.Println("Invalid JWT payload (not Base64)")
return ""
}
jsonPayload := string(rawPayload)
if !gjson.Valid(jsonPayload) {
log.Println("Invalid JWT payload (not JSON)")
return ""
}
jwtClaim := gjson.Get(jsonPayload, jwtClaimName)
sub := gjson.Get(jsonPayload, "sub")
if !jwtClaim.Exists() {
log.Printf("Claim '%s' not found in payload", jwtClaimName)
return ""
}
if !sub.Exists() {
log.Println("Claim 'sub' not found in payload")
return ""
}
return jwtClaim.String() + "-" + sub.String()
}
return ""
}
}
35 changes: 31 additions & 4 deletions juju/router/router.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ and http://en.wikipedia.org/wiki/Token_bucket for more details.
package router

import (
"encoding/json"
"fmt"

"github.com/devopsfaith/krakend/config"
Expand All @@ -32,10 +33,22 @@ const Namespace = "github.com/devopsfaith/krakend-ratelimit/juju/router"

// Config is the custom config struct containing the params for the router middlewares
type Config struct {
MaxRate int64
Strategy string
ClientMaxRate int64
Key string
MaxRate int64
Strategy string
ClientMaxRate int64
Key string
TierConfiguration *TierConfiguration
}

type TierConfiguration struct {
JwtClaim string
Duration string
Tiers []Tier
}

type Tier struct {
Name string
Limit int64
}

// ZeroCfg is the zero value for the Config struct
Expand Down Expand Up @@ -79,5 +92,19 @@ func ConfigGetter(e config.ExtraConfig) interface{} {
if v, ok := tmp["key"]; ok {
cfg.Key = fmt.Sprintf("%v", v)
}
if v, ok := tmp["tierConfiguration"]; ok {
jsonbody, err := json.Marshal(v)
if err != nil {
fmt.Println(err)
return ZeroCfg
}

tierConfiguration := TierConfiguration{}
if err := json.Unmarshal(jsonbody, &tierConfiguration); err != nil {
fmt.Println(err)
return ZeroCfg
}
cfg.TierConfiguration = &tierConfiguration
}
return cfg
}