diff --git a/CHANGELOG.md b/CHANGELOG.md index 242b0fdd1d..6ec52440b8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,12 @@ # Changelog +## [1.30.2](https://github.com/rudderlabs/rudder-server/compare/v1.30.1...v1.30.2) (2024-07-24) + + +### Bug Fixes + +* introduce config for get requests in webhook sources ([#4927](https://github.com/rudderlabs/rudder-server/issues/4927)) ([2494e69](https://github.com/rudderlabs/rudder-server/commit/2494e69cb64b0374f9cbfadcd25f2504c6e813de)) + ## [1.30.1](https://github.com/rudderlabs/rudder-server/compare/v1.30.0...v1.30.1) (2024-07-24) diff --git a/gateway/webhook/setup.go b/gateway/webhook/setup.go index 5e7e749b9a..f56df783b6 100644 --- a/gateway/webhook/setup.go +++ b/gateway/webhook/setup.go @@ -9,6 +9,7 @@ import ( "time" "github.com/hashicorp/go-retryablehttp" + "github.com/samber/lo" "github.com/rudderlabs/rudder-go-kit/config" @@ -55,6 +56,14 @@ func Setup(gwHandle Gateway, transformerFeaturesService transformer.FeaturesServ maxTransformerProcess := config.GetIntVar(64, 1, "Gateway.webhook.maxTransformerProcess") // Parse all query params from sources mentioned in this list webhook.config.sourceListForParsingParams = config.GetStringSliceVar([]string{"Shopify", "adjust"}, "Gateway.webhook.sourceListForParsingParams") + + webhook.config.forwardGetRequestForSrcMap = lo.SliceToMap( + config.GetStringSliceVar([]string{"adjust"}, "Gateway.webhook.forwardGetRequestForSrcs"), + func(item string) (string, struct{}) { + return item, struct{}{} + }, + ) + // lowercasing the strings in sourceListForParsingParams for i, s := range webhook.config.sourceListForParsingParams { webhook.config.sourceListForParsingParams[i] = strings.ToLower(s) diff --git a/gateway/webhook/webhook.go b/gateway/webhook/webhook.go index 9d31afeb9d..2a4e96217c 100644 --- a/gateway/webhook/webhook.go +++ b/gateway/webhook/webhook.go @@ -69,6 +69,7 @@ type HandleT struct { webhookBatchTimeout config.ValueLoader[time.Duration] maxWebhookBatchSize config.ValueLoader[int] sourceListForParsingParams []string + forwardGetRequestForSrcMap map[string]struct{} } } @@ -104,6 +105,11 @@ func (webhook *HandleT) failRequest(w http.ResponseWriter, r *http.Request, reas http.Error(w, reason, statusCode) } +func (wb *HandleT) IsGetAndNotAllow(reqMethod, sourceDefName string) bool { + _, ok := wb.config.forwardGetRequestForSrcMap[sourceDefName] + return reqMethod == http.MethodGet && !ok +} + func (webhook *HandleT) RequestHandler(w http.ResponseWriter, r *http.Request) { reqType := r.Context().Value(gwtypes.CtxParamCallType).(string) arctx := r.Context().Value(gwtypes.CtxParamAuthRequestContext).(*gwtypes.AuthRequestContext) @@ -114,7 +120,7 @@ func (webhook *HandleT) RequestHandler(w http.ResponseWriter, r *http.Request) { var postFrom url.Values var multipartForm *multipart.Form - if r.Method == "GET" { + if webhook.IsGetAndNotAllow(r.Method, sourceDefName) { return } contentType := r.Header.Get("Content-Type") diff --git a/gateway/webhook/webhook_test.go b/gateway/webhook/webhook_test.go index 082d4aae64..528bdbab5e 100644 --- a/gateway/webhook/webhook_test.go +++ b/gateway/webhook/webhook_test.go @@ -519,3 +519,52 @@ func TestPrepareRequestBody(t *testing.T) { }) } } + +func TestAllowGetReqForWebhookSrc(t *testing.T) { + cases := []struct { + name string + forwardGetRequestForSrcMap map[string]struct{} + method string + srcDef string + expected bool + }{ + { + name: "should allow get request for adjust", + method: http.MethodGet, + forwardGetRequestForSrcMap: map[string]struct{}{ + "adjust": {}, + }, + srcDef: "adjust", + expected: false, + }, + { + name: "should allow post request for adjust", + method: http.MethodPost, + forwardGetRequestForSrcMap: map[string]struct{}{ + "adjust": {}, + }, + srcDef: "adjust", + expected: false, + }, + { + name: "should not allow get request for shopify", + forwardGetRequestForSrcMap: map[string]struct{}{ + "adjust": {}, + "customerio": {}, + }, + method: http.MethodGet, + srcDef: "Shopify", + expected: true, + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + wbh := HandleT{} + wbh.config.forwardGetRequestForSrcMap = tc.forwardGetRequestForSrcMap + + isGetAndNotAllow := wbh.IsGetAndNotAllow(tc.method, tc.srcDef) + require.Equal(t, tc.expected, isGetAndNotAllow) + }) + } +}