From 54adc8b476c647211c17efbf0e6ccb42479259fb Mon Sep 17 00:00:00 2001
From: Florian Leskovsek <florian@leskovsek.dev>
Date: Wed, 3 Jul 2024 17:54:48 +0200
Subject: [PATCH] Validate JSON body before unmarshalling into input type

---
 request/jsonbody.go      | 30 ++++++++++++------------------
 request/jsonbody_test.go |  2 +-
 2 files changed, 13 insertions(+), 19 deletions(-)

diff --git a/request/jsonbody.go b/request/jsonbody.go
index 788cab3..e332359 100644
--- a/request/jsonbody.go
+++ b/request/jsonbody.go
@@ -36,34 +36,28 @@ func decodeJSONBody(readJSON func(rd io.Reader, v interface{}) error, tolerateFo
 			return nil
 		}
 
-		var (
-			rd io.Reader = r.Body
-			b  *bytes.Buffer
-		)
+		var b *bytes.Buffer
 
-		validate := validator != nil && validator.HasConstraints(rest.ParamInBody)
-
-		if validate {
-			b = bufPool.Get().(*bytes.Buffer) //nolint:errcheck // bufPool is configured to provide *bytes.Buffer.
-			defer bufPool.Put(b)
+		b = bufPool.Get().(*bytes.Buffer) //nolint:errcheck // bufPool is configured to provide *bytes.Buffer.
+		defer bufPool.Put(b)
+		b.Reset()
 
-			b.Reset()
-			rd = io.TeeReader(r.Body, b)
+		// Read body into buffer.
+		if _, err := b.ReadFrom(r.Body); err != nil {
+			return err
 		}
 
-		err := readJSON(rd, &input)
-		if err != nil {
-			return fmt.Errorf("failed to decode json: %w", err)
-		}
+		validate := validator != nil && validator.HasConstraints(rest.ParamInBody)
 
-		if validator != nil && validate {
-			err = validator.ValidateJSONBody(b.Bytes())
+		if validate {
+			// Perform validation before unmarshalling into input object.
+			err := validator.ValidateJSONBody(b.Bytes())
 			if err != nil {
 				return err
 			}
 		}
 
-		return nil
+		return readJSON(b, input)
 	}
 }
 
diff --git a/request/jsonbody_test.go b/request/jsonbody_test.go
index 3380b62..59cb55c 100644
--- a/request/jsonbody_test.go
+++ b/request/jsonbody_test.go
@@ -81,7 +81,7 @@ func Test_decodeJSONBody_unmarshalFailed(t *testing.T) {
 	var i []int
 
 	err = decodeJSONBody(readJSON, false)(req, &i, nil)
-	assert.EqualError(t, err, "failed to decode json: json: cannot unmarshal number into Go value of type []int")
+	assert.EqualError(t, err, "json: cannot unmarshal number into Go value of type []int")
 }
 
 func Test_decodeJSONBody_validateFailed(t *testing.T) {