diff --git a/gateway/mw_api_rate_limit.go b/gateway/mw_api_rate_limit.go index 6471845ce4b..619ae7d64de 100644 --- a/gateway/mw_api_rate_limit.go +++ b/gateway/mw_api_rate_limit.go @@ -1,6 +1,7 @@ package gateway import ( + "fmt" "net/http" "strconv" @@ -54,7 +55,7 @@ func (k *RateLimitForAPI) getSession(r *http.Request) *user.SessionState { if ok { if limits := spec.RateLimit; limits.Valid() { // track per-endpoint with a hash of the path - keyname := k.keyName + "-" + storage.HashStr(limits.Path) + keyname := k.keyName + "-" + storage.HashStr(fmt.Sprintf("%s:%s", limits.Method, limits.Path)) session := &user.SessionState{ Rate: limits.Rate, diff --git a/tests/rate/per_api_limit_test.go b/tests/rate/per_api_limit_test.go index 3935bf9f30f..e36bca1a9f9 100644 --- a/tests/rate/per_api_limit_test.go +++ b/tests/rate/per_api_limit_test.go @@ -2,9 +2,11 @@ package rate_test import ( "encoding/json" - "fmt" + "net/http" "testing" + "github.com/TykTechnologies/tyk/apidef" + "github.com/stretchr/testify/assert" . "github.com/TykTechnologies/tyk/gateway" @@ -12,7 +14,7 @@ import ( "github.com/TykTechnologies/tyk/test" ) -func buildPathRateLimitAPI(tb testing.TB, gw *Gateway, pathName string, rate, per int64) { +func buildPathRateLimitAPI(tb testing.TB, gw *Gateway, per int64, rateLimits []apidef.RateLimitMeta) { tb.Helper() gw.BuildAndLoadAPI(func(spec *APISpec) { @@ -23,39 +25,34 @@ func buildPathRateLimitAPI(tb testing.TB, gw *Gateway, pathName string, rate, pe spec.GlobalRateLimit.Per = float64(per) version := spec.VersionData.Versions["v1"] - versionJSON := []byte(fmt.Sprintf(`{ - "use_extended_paths": true, - "extended_paths": { - "rate_limit": [{ - "method": "GET", - "rate": %d, - "per": %d - }] - } - }`, rate, per)) - err := json.Unmarshal(versionJSON, &version) - assert.NoError(tb, err) - - version.ExtendedPaths.RateLimit[0].Path = pathName + version.UseExtendedPaths = true + version.ExtendedPaths.RateLimit = rateLimits spec.VersionData.Versions["v1"] = version }) } -func testRateLimit(tb testing.TB, ts *Test, testPath string, want int) { +func testRateLimit(tb testing.TB, ts *Test, testPath string, testMethod string, want int) { tb.Helper() // single request _, _ = ts.Run(tb, test.TestCase{ - Path: "/ratelimit" + testPath, - BodyMatch: fmt.Sprintf(`"Url":"%s"`, testPath), + Path: "/ratelimit" + testPath, + Method: testMethod, + BodyMatchFunc: func(bytes []byte) bool { + res := map[string]any{} + err := json.Unmarshal(bytes, &res) + assert.NoError(tb, err) + return assert.Equal(tb, testPath, res["Url"]) && assert.Equal(tb, testMethod, res["Method"]) + }, }) // and 50 more var ok, failed int = 1, 0 for i := 0; i < 50; i++ { res, err := ts.Run(tb, test.TestCase{ - Path: "/ratelimit" + testPath, + Path: "/ratelimit" + testPath, + Method: testMethod, }) assert.NoError(tb, err) @@ -82,8 +79,16 @@ func TestPerAPILimit(t *testing.T) { forPath := "/" + uuid.New() testPath := "/miss" - buildPathRateLimitAPI(t, ts.Gw, forPath, 30, 60) - testRateLimit(t, ts, testPath, 15) + rateLimits := []apidef.RateLimitMeta{ + { + Method: http.MethodGet, + Path: forPath, + Rate: 30, + Per: 60, + }, + } + buildPathRateLimitAPI(t, ts.Gw, 60, rateLimits) + testRateLimit(t, ts, testPath, http.MethodGet, 15) }) t.Run("hit per-endpoint rate limit", func(t *testing.T) { @@ -93,7 +98,40 @@ func TestPerAPILimit(t *testing.T) { forPath := "/" + uuid.New() testPath := forPath - buildPathRateLimitAPI(t, ts.Gw, forPath, 30, 60) - testRateLimit(t, ts, testPath, 30) + rateLimits := []apidef.RateLimitMeta{ + { + Method: http.MethodGet, + Path: forPath, + Rate: 30, + Per: 60, + }, + } + buildPathRateLimitAPI(t, ts.Gw, 60, rateLimits) + testRateLimit(t, ts, testPath, http.MethodGet, 30) + }) + + t.Run("[TT-12990][regression] hit per-endpoint per-method rate limit", func(t *testing.T) { + ts := StartTest(nil) + defer ts.Close() + + forPath := "/anything/" + uuid.New() + testPath := forPath + rateLimits := []apidef.RateLimitMeta{ + { + Method: http.MethodGet, + Path: forPath, + Rate: 20, + Per: 60, + }, + { + Method: http.MethodPost, + Path: forPath, + Rate: 30, + Per: 60, + }, + } + buildPathRateLimitAPI(t, ts.Gw, 60, rateLimits) + testRateLimit(t, ts, testPath, http.MethodGet, 20) + testRateLimit(t, ts, testPath, http.MethodPost, 30) }) }