diff --git a/internal/http1parser/request_test.go b/internal/http1parser/request_test.go index 3ed9d8e3..8fda4748 100644 --- a/internal/http1parser/request_test.go +++ b/internal/http1parser/request_test.go @@ -1,8 +1,13 @@ package http1parser_test import ( + "bufio" "bytes" + "fmt" "io" + "net/http" + "net/url" + "strings" "testing" "github.com/elazarl/goproxy/internal/http1parser" @@ -80,3 +85,118 @@ func TestMultipleNonCanonicalRequests(t *testing.T) { assert.True(t, parser.IsEOF()) } + +// reqTest is inspired by https://github.com/golang/go/blob/master/src/net/http/readrequest_test.go +type reqTest struct { + Raw string + Req *http.Request + Body string + Trailer http.Header + Error string +} + +var noError = "" +var noBodyStr = "" +var noTrailer http.Header = nil + +var reqTests = []reqTest{ + // Baseline test; All Request fields included for template use + { + "GET http://www.techcrunch.com/ HTTP/1.1\r\n" + + "Host: www.techcrunch.com\r\n" + + "user-agent: Fake\r\n" + + "Accept: text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8\r\n" + + "Accept-Language: en-us,en;q=0.5\r\n" + + "Accept-Encoding: gzip,deflate\r\n" + + "Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.7\r\n" + + "Keep-Alive: 300\r\n" + + "Content-Length: 7\r\n" + + "Proxy-Connection: keep-alive\r\n\r\n" + + "abcdef\n???", + &http.Request{ + Method: "GET", + URL: &url.URL{ + Scheme: "http", + Host: "www.techcrunch.com", + Path: "/", + }, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: http.Header{ + "Accept": {"text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8"}, + "Accept-Language": {"en-us,en;q=0.5"}, + "Accept-Encoding": {"gzip,deflate"}, + "Accept-Charset": {"ISO-8859-1,utf-8;q=0.7,*;q=0.7"}, + "Keep-Alive": {"300"}, + "Proxy-Connection": {"keep-alive"}, + "Content-Length": {"7"}, + "user-agent": {"Fake"}, + }, + Close: false, + ContentLength: 7, + Host: "www.techcrunch.com", + RequestURI: "http://www.techcrunch.com/", + }, + "abcdef\n", + noTrailer, + noError, + }, + + // GET request with no body (the normal case) + { + "GET / HTTP/1.1\r\n" + + "Host: foo.com\r\n\r\n", + &http.Request{ + Method: "GET", + URL: &url.URL{ + Path: "/", + }, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: http.Header{}, + Close: false, + ContentLength: 0, + Host: "foo.com", + RequestURI: "/", + }, + noBodyStr, + noTrailer, + noError, + }, +} + +func TestReadRequest(t *testing.T) { + for i := range reqTests { + tt := &reqTests[i] + + testName := fmt.Sprintf("Test %d (%q)", i, tt.Raw) + t.Run(testName, func(t *testing.T) { + r := bufio.NewReader(strings.NewReader(tt.Raw)) + parser := http1parser.NewRequestReader(true, r) + req, err := parser.ReadRequest() + if err != nil && err.Error() == tt.Error { + // Test finished, we expected an error + return + } + require.NoError(t, err) + + // Check request equality (excluding body) + rbody := req.Body + req.Body = nil + assert.Equal(t, tt.Req, req) + + // Check if the two bodies match + var bodyString string + if rbody != nil { + data, err := io.ReadAll(rbody) + require.NoError(t, err) + bodyString = string(data) + _ = rbody.Close() + } + assert.Equal(t, tt.Body, bodyString) + assert.Equal(t, tt.Trailer, req.Trailer) + }) + } +}