diff --git a/Makefile b/Makefile index 11ec1d9..e5a129e 100644 --- a/Makefile +++ b/Makefile @@ -12,7 +12,7 @@ all: test .PHONY: test test: - export GO111MODULE=on; go test . -v + export GO111MODULE=on; go test . -race -count=1 .PHONY: godoc godoc: diff --git a/httphelpers.go b/httphelpers.go index b0be3b6..6aa3ab3 100644 --- a/httphelpers.go +++ b/httphelpers.go @@ -6,7 +6,6 @@ import ( "encoding/json" "fmt" "io" - "io/ioutil" "mime/multipart" "net/http" "net/url" @@ -131,6 +130,7 @@ func (f *urlEncodedPayload) getPayloadBuffer() (*bytes.Buffer, error) { for _, keyVal := range f.Values { data.Add(keyVal.key, keyVal.value) } + return bytes.NewBufferString(data.Encode()), nil } @@ -177,6 +177,7 @@ func (f *formDataPayload) getPayloadBuffer() (*bytes.Buffer, error) { for _, keyVal := range f.Values { if tmp, err := writer.CreateFormField(keyVal.key); err == nil { + // TODO(DE-1139): handle error: tmp.Write([]byte(keyVal.value)) } else { return nil, err @@ -186,7 +187,9 @@ func (f *formDataPayload) getPayloadBuffer() (*bytes.Buffer, error) { for _, file := range f.Files { if tmp, err := writer.CreateFormFile(file.key, path.Base(file.value)); err == nil { if fp, err := os.Open(file.value); err == nil { + // TODO(DE-1139): defer in a loop: defer fp.Close() + // TODO(DE-1139): handle error: io.Copy(tmp, fp) } else { return nil, err @@ -198,7 +201,9 @@ func (f *formDataPayload) getPayloadBuffer() (*bytes.Buffer, error) { for _, file := range f.ReadClosers { if tmp, err := writer.CreateFormFile(file.key, file.name); err == nil { + // TODO(DE-1139): defer in a loop: defer file.value.Close() + // TODO(DE-1139): handle error: io.Copy(tmp, file.value) } else { return nil, err @@ -208,6 +213,7 @@ func (f *formDataPayload) getPayloadBuffer() (*bytes.Buffer, error) { for _, buff := range f.Buffers { if tmp, err := writer.CreateFormFile(buff.key, buff.name); err == nil { r := bytes.NewReader(buff.value) + // TODO(DE-1139): handle error: io.Copy(tmp, r) } else { return nil, err @@ -221,8 +227,10 @@ func (f *formDataPayload) getPayloadBuffer() (*bytes.Buffer, error) { func (f *formDataPayload) getContentType() string { if f.contentType == "" { + // TODO(DE-1139): handle error: f.getPayloadBuffer() } + return f.contentType } @@ -234,23 +242,23 @@ func (r *httpRequest) addHeader(name, value string) { } func (r *httpRequest) makeGetRequest(ctx context.Context) (*httpResponse, error) { - return r.makeRequest(ctx, "GET", nil) + return r.makeRequest(ctx, http.MethodGet, nil) } func (r *httpRequest) makePostRequest(ctx context.Context, payload payload) (*httpResponse, error) { - return r.makeRequest(ctx, "POST", payload) + return r.makeRequest(ctx, http.MethodPost, payload) } func (r *httpRequest) makePutRequest(ctx context.Context, payload payload) (*httpResponse, error) { - return r.makeRequest(ctx, "PUT", payload) + return r.makeRequest(ctx, http.MethodPut, payload) } func (r *httpRequest) makeDeleteRequest(ctx context.Context) (*httpResponse, error) { - return r.makeRequest(ctx, "DELETE", nil) + return r.makeRequest(ctx, http.MethodDelete, nil) } func (r *httpRequest) NewRequest(ctx context.Context, method string, payload payload) (*http.Request, error) { - url, err := r.generateUrlWithParameters() + uri, err := r.generateUrlWithParameters() if err != nil { return nil, err } @@ -263,13 +271,12 @@ func (r *httpRequest) NewRequest(ctx context.Context, method string, payload pay } else { body = nil } - req, err := http.NewRequest(method, url, body) + + req, err := http.NewRequestWithContext(ctx, method, uri, body) if err != nil { return nil, err } - req = req.WithContext(ctx) - if payload != nil && payload.getContentType() != "" { req.Header.Add("Content-Type", payload.getContentType()) } @@ -286,6 +293,7 @@ func (r *httpRequest) NewRequest(ctx context.Context, method string, payload pay } req.Header.Add(header, value) } + return req, nil } @@ -305,42 +313,43 @@ func (r *httpRequest) makeRequest(ctx context.Context, method string, payload pa } } - response := httpResponse{} - resp, err := r.Client.Do(req) - if resp != nil { - response.Code = resp.StatusCode - } if err != nil { - if urlErr, ok := err.(*url.Error); ok { - if urlErr.Err == io.EOF { - return nil, errors.Wrap(err, "remote server prematurely closed connection") - } + var urlErr *url.Error + if errors.As(err, &urlErr) && urlErr != nil && errors.Is(urlErr.Err, io.EOF) { + return nil, errors.Wrap(err, "remote server prematurely closed connection") } + return nil, errors.Wrap(err, "while making http request") } defer resp.Body.Close() - responseBody, err := ioutil.ReadAll(resp.Body) + + response := httpResponse{ + Code: resp.StatusCode, + } + + responseBody, err := io.ReadAll(resp.Body) if err != nil { return nil, errors.Wrap(err, "while reading response body") } response.Data = responseBody + return &response, nil } func (r *httpRequest) generateUrlWithParameters() (string, error) { - url, err := url.Parse(r.URL) + uri, err := url.Parse(r.URL) if err != nil { return "", err } - if !validURL.MatchString(url.Path) { - return "", errors.New(`BaseAPI must end with a /v1, /v2, /v3 or /v4; setBaseAPI("https://host/v3")`) + if !validURL.MatchString(uri.Path) { + return "", errors.New(`APIBase must end with a /v1, /v2, /v3 or /v4; SetAPIBase("https://host/v3")`) } - q := url.Query() + q := uri.Query() if r.Parameters != nil && len(r.Parameters) > 0 { for name, values := range r.Parameters { for _, value := range values { @@ -348,9 +357,9 @@ func (r *httpRequest) generateUrlWithParameters() (string, error) { } } } - url.RawQuery = q.Encode() + uri.RawQuery = q.Encode() - return url.String(), nil + return uri.String(), nil } func (r *httpRequest) curlString(req *http.Request, p payload) string { @@ -384,5 +393,6 @@ func (r *httpRequest) curlString(req *http.Request, p payload) string { } } } + return strings.Join(parts, " ") } diff --git a/mailgun_test.go b/mailgun_test.go index 9c343ec..0c5c013 100644 --- a/mailgun_test.go +++ b/mailgun_test.go @@ -10,6 +10,7 @@ import ( "github.com/facebookgo/ensure" "github.com/mailgun/mailgun-go/v4" + "github.com/stretchr/testify/assert" ) const domain = "valid-mailgun-domain" @@ -33,8 +34,7 @@ func TestInvalidBaseAPI(t *testing.T) { ctx := context.Background() _, err := mg.GetDomain(ctx, "unknown.domain") - ensure.NotNil(t, err) - ensure.DeepEqual(t, err.Error(), `BaseAPI must end with a /v1, /v2, /v3 or /v4; setBaseAPI("https://host/v3")`) + assert.EqualError(t, err, `APIBase must end with a /v1, /v2, /v3 or /v4; SetAPIBase("https://host/v3")`) } func TestValidBaseAPI(t *testing.T) {