diff --git a/service/aiproxy/common/gin.go b/service/aiproxy/common/gin.go index 113617b4d77..eb2c95f0123 100644 --- a/service/aiproxy/common/gin.go +++ b/service/aiproxy/common/gin.go @@ -3,6 +3,7 @@ package common import ( "bytes" "context" + "errors" "fmt" "io" "net/http" @@ -13,6 +14,31 @@ import ( type RequestBodyKey struct{} +const ( + MaxRequestBodySize = 1024 * 1024 * 50 // 50MB +) + +func LimitReader(r io.Reader, n int64) io.Reader { return &LimitedReader{r, n} } + +type LimitedReader struct { + R io.Reader + N int64 +} + +var ErrLimitedReaderExceeded = errors.New("limited reader exceeded") + +func (l *LimitedReader) Read(p []byte) (n int, err error) { + if l.N <= 0 { + return 0, ErrLimitedReaderExceeded + } + if int64(len(p)) > l.N { + p = p[0:l.N] + } + n, err = l.R.Read(p) + l.N -= int64(n) + return +} + func GetRequestBody(req *http.Request) ([]byte, error) { requestBody := req.Context().Value(RequestBodyKey{}) if requestBody != nil { @@ -27,8 +53,17 @@ func GetRequestBody(req *http.Request) ([]byte, error) { } }() if req.ContentLength <= 0 || req.Header.Get("Content-Type") != "application/json" { - buf, err = io.ReadAll(req.Body) + buf, err = io.ReadAll(LimitReader(req.Body, MaxRequestBodySize)) + if err != nil { + if errors.Is(err, ErrLimitedReaderExceeded) { + return nil, fmt.Errorf("request body too large, max: %d", MaxRequestBodySize) + } + return nil, fmt.Errorf("request body read failed: %w", err) + } } else { + if req.ContentLength > MaxRequestBodySize { + return nil, fmt.Errorf("request body too large: %d, max: %d", req.ContentLength, MaxRequestBodySize) + } buf = make([]byte, req.ContentLength) _, err = io.ReadFull(req.Body, buf) } diff --git a/service/aiproxy/common/image/image.go b/service/aiproxy/common/image/image.go index 3ab4e5df7cd..a584b8b84b6 100644 --- a/service/aiproxy/common/image/image.go +++ b/service/aiproxy/common/image/image.go @@ -20,6 +20,7 @@ import ( "strings" // import webp decoder + "github.com/labring/sealos/service/aiproxy/common" _ "golang.org/x/image/webp" ) @@ -56,6 +57,10 @@ func GetImageSizeFromURL(url string) (width int, height int, err error) { return img.Width, img.Height, nil } +const ( + MaxImageSize = 1024 * 1024 * 5 // 5MB +) + func GetImageFromURL(ctx context.Context, url string) (string, string, error) { // Check if the URL is a data URL matches := dataURLPattern.FindStringSubmatch(url) @@ -82,8 +87,17 @@ func GetImageFromURL(ctx context.Context, url string) (string, string, error) { } var buf []byte if resp.ContentLength <= 0 { - buf, err = io.ReadAll(resp.Body) + buf, err = io.ReadAll(common.LimitReader(resp.Body, MaxImageSize)) + if err != nil { + if errors.Is(err, common.ErrLimitedReaderExceeded) { + return "", "", fmt.Errorf("image too large, max: %d", MaxImageSize) + } + return "", "", fmt.Errorf("image read failed: %w", err) + } } else { + if resp.ContentLength > MaxImageSize { + return "", "", fmt.Errorf("image too large: %d, max: %d", resp.ContentLength, MaxImageSize) + } buf = make([]byte, resp.ContentLength) _, err = io.ReadFull(resp.Body, buf) }