Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DE-1144 refactor httphelpers #334

Merged
merged 20 commits into from
Oct 29, 2024
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ all: test

.PHONY: test
test:
export GO111MODULE=on; go test . -v
export GO111MODULE=on; go test . -race -count=1

.PHONY: godoc
godoc:
Expand Down
60 changes: 35 additions & 25 deletions httphelpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"encoding/json"
"fmt"
"io"
"io/ioutil"
"mime/multipart"
"net/http"
"net/url"
Expand Down Expand Up @@ -131,6 +130,7 @@ func (f *urlEncodedPayload) getPayloadBuffer() (*bytes.Buffer, error) {
for _, keyVal := range f.Values {
data.Add(keyVal.key, keyVal.value)
}

return bytes.NewBufferString(data.Encode()), nil
}

Expand Down Expand Up @@ -177,6 +177,7 @@ func (f *formDataPayload) getPayloadBuffer() (*bytes.Buffer, error) {

for _, keyVal := range f.Values {
if tmp, err := writer.CreateFormField(keyVal.key); err == nil {
// TODO(DE-1139): handle error:
tmp.Write([]byte(keyVal.value))
} else {
return nil, err
Expand All @@ -186,7 +187,9 @@ func (f *formDataPayload) getPayloadBuffer() (*bytes.Buffer, error) {
for _, file := range f.Files {
if tmp, err := writer.CreateFormFile(file.key, path.Base(file.value)); err == nil {
if fp, err := os.Open(file.value); err == nil {
// TODO(DE-1139): defer in a loop:
defer fp.Close()
// TODO(DE-1139): handle error:
io.Copy(tmp, fp)
} else {
return nil, err
Expand All @@ -198,7 +201,9 @@ func (f *formDataPayload) getPayloadBuffer() (*bytes.Buffer, error) {

for _, file := range f.ReadClosers {
if tmp, err := writer.CreateFormFile(file.key, file.name); err == nil {
// TODO(DE-1139): defer in a loop:
defer file.value.Close()
// TODO(DE-1139): handle error:
io.Copy(tmp, file.value)
} else {
return nil, err
Expand All @@ -208,6 +213,7 @@ func (f *formDataPayload) getPayloadBuffer() (*bytes.Buffer, error) {
for _, buff := range f.Buffers {
if tmp, err := writer.CreateFormFile(buff.key, buff.name); err == nil {
r := bytes.NewReader(buff.value)
// TODO(DE-1139): handle error:
io.Copy(tmp, r)
} else {
return nil, err
Expand All @@ -221,8 +227,10 @@ func (f *formDataPayload) getPayloadBuffer() (*bytes.Buffer, error) {

func (f *formDataPayload) getContentType() string {
if f.contentType == "" {
// TODO(DE-1139): handle error:
f.getPayloadBuffer()
}

return f.contentType
}

Expand All @@ -234,23 +242,23 @@ func (r *httpRequest) addHeader(name, value string) {
}

func (r *httpRequest) makeGetRequest(ctx context.Context) (*httpResponse, error) {
return r.makeRequest(ctx, "GET", nil)
return r.makeRequest(ctx, http.MethodGet, nil)
}

func (r *httpRequest) makePostRequest(ctx context.Context, payload payload) (*httpResponse, error) {
return r.makeRequest(ctx, "POST", payload)
return r.makeRequest(ctx, http.MethodPost, payload)
}

func (r *httpRequest) makePutRequest(ctx context.Context, payload payload) (*httpResponse, error) {
return r.makeRequest(ctx, "PUT", payload)
return r.makeRequest(ctx, http.MethodPut, payload)
}

func (r *httpRequest) makeDeleteRequest(ctx context.Context) (*httpResponse, error) {
return r.makeRequest(ctx, "DELETE", nil)
return r.makeRequest(ctx, http.MethodDelete, nil)
}

func (r *httpRequest) NewRequest(ctx context.Context, method string, payload payload) (*http.Request, error) {
url, err := r.generateUrlWithParameters()
uri, err := r.generateUrlWithParameters()
if err != nil {
return nil, err
}
Expand All @@ -263,13 +271,12 @@ func (r *httpRequest) NewRequest(ctx context.Context, method string, payload pay
} else {
body = nil
}
req, err := http.NewRequest(method, url, body)

req, err := http.NewRequestWithContext(ctx, method, uri, body)
if err != nil {
return nil, err
}

req = req.WithContext(ctx)

if payload != nil && payload.getContentType() != "" {
req.Header.Add("Content-Type", payload.getContentType())
}
Expand All @@ -286,6 +293,7 @@ func (r *httpRequest) NewRequest(ctx context.Context, method string, payload pay
}
req.Header.Add(header, value)
}

return req, nil
}

Expand All @@ -305,52 +313,53 @@ func (r *httpRequest) makeRequest(ctx context.Context, method string, payload pa
}
}

response := httpResponse{}

resp, err := r.Client.Do(req)
if resp != nil {
response.Code = resp.StatusCode
}
if err != nil {
if urlErr, ok := err.(*url.Error); ok {
if urlErr.Err == io.EOF {
return nil, errors.Wrap(err, "remote server prematurely closed connection")
}
var urlErr *url.Error
if errors.As(err, &urlErr) && urlErr != nil && errors.Is(urlErr.Err, io.EOF) {
return nil, errors.Wrap(err, "remote server prematurely closed connection")
}

return nil, errors.Wrap(err, "while making http request")
}

defer resp.Body.Close()
responseBody, err := ioutil.ReadAll(resp.Body)

response := httpResponse{
Code: resp.StatusCode,
}

responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return nil, errors.Wrap(err, "while reading response body")
}

response.Data = responseBody

return &response, nil
}

func (r *httpRequest) generateUrlWithParameters() (string, error) {
url, err := url.Parse(r.URL)
uri, err := url.Parse(r.URL)
if err != nil {
return "", err
}

if !validURL.MatchString(url.Path) {
return "", errors.New(`BaseAPI must end with a /v1, /v2, /v3 or /v4; setBaseAPI("https://host/v3")`)
if !validURL.MatchString(uri.Path) {
return "", errors.New(`APIBase must end with a /v1, /v2, /v3 or /v4; SetAPIBase("https://host/v3")`)
}

q := url.Query()
q := uri.Query()
if r.Parameters != nil && len(r.Parameters) > 0 {
for name, values := range r.Parameters {
for _, value := range values {
q.Add(name, value)
}
}
}
url.RawQuery = q.Encode()
uri.RawQuery = q.Encode()

return url.String(), nil
return uri.String(), nil
}

func (r *httpRequest) curlString(req *http.Request, p payload) string {
Expand Down Expand Up @@ -384,5 +393,6 @@ func (r *httpRequest) curlString(req *http.Request, p payload) string {
}
}
}

return strings.Join(parts, " ")
}
4 changes: 2 additions & 2 deletions mailgun_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

"github.com/facebookgo/ensure"
"github.com/mailgun/mailgun-go/v4"
"github.com/stretchr/testify/assert"
)

const domain = "valid-mailgun-domain"
Expand All @@ -33,8 +34,7 @@ func TestInvalidBaseAPI(t *testing.T) {

ctx := context.Background()
_, err := mg.GetDomain(ctx, "unknown.domain")
ensure.NotNil(t, err)
ensure.DeepEqual(t, err.Error(), `BaseAPI must end with a /v1, /v2, /v3 or /v4; setBaseAPI("https://host/v3")`)
assert.EqualError(t, err, `APIBase must end with a /v1, /v2, /v3 or /v4; SetAPIBase("https://host/v3")`)
}

func TestValidBaseAPI(t *testing.T) {
Expand Down