Skip to content

Commit

Permalink
feat: added semaphore http client (#3)
Browse files Browse the repository at this point in the history
  • Loading branch information
mojixcoder authored Jun 15, 2023
1 parent 8b00782 commit 426a192
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 5 deletions.
26 changes: 23 additions & 3 deletions gosrm.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,17 @@ import (
)

type (
// HTTPClient is the interface that can be used to do HTTP calls.
HTTPClient interface {
Do(req *http.Request) (*http.Response, error)
}

// OSRMClient is the base type with helper methods to call OSRM APIs.
// It only holds the base OSRM URL.
OSRMClient struct {
baseURL *url.URL

client HTTPClient
}

// Request is the OSRM's request structure.
Expand All @@ -29,12 +36,25 @@ type (

// New returns a new OSRM client.
func New(baseURL string) (OSRMClient, error) {
var client OSRMClient

u, err := url.Parse(baseURL)
if err != nil {
return OSRMClient{}, err
return client, err
}

return OSRMClient{baseURL: u}, nil
client.baseURL = u
client.SetHTTPClient(NewHTTPClient(HTTPClientConfig{}))

return client, nil
}

// SetHTTPClient sets the HTTP client that will be used to call OSRM.
func (osrm *OSRMClient) SetHTTPClient(client HTTPClient) {
if client == nil {
panic("http client can't be nil")
}
osrm.client = client
}

// get calls the given URL and parses the response.
Expand All @@ -45,7 +65,7 @@ func (osrm OSRMClient) get(ctx context.Context, url string, out any) error {
}
req = req.WithContext(ctx)

res, err := http.DefaultClient.Do(req)
res, err := osrm.client.Do(req)
if err != nil {
return err
}
Expand Down
22 changes: 20 additions & 2 deletions gosrm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ import (

const invalidURL string = "postgres://user:abc{[email protected]:5432/db?sslmode=require"

func newOSRMClient() OSRMClient {
return OSRMClient{client: NewHTTPClient(HTTPClientConfig{})}
}

func TestNew(t *testing.T) {
testCases := []struct {
name, baseURL string
Expand All @@ -38,8 +42,22 @@ func TestNew(t *testing.T) {
}
}

func TestOSRMClient_SetHTTPClient(t *testing.T) {
osrm := newOSRMClient()

assert.PanicsWithValue(t, "http client can't be nil", func() {
osrm.SetHTTPClient(nil)
})

client := NewHTTPClient(HTTPClientConfig{MaxConcurrency: 10})

osrm.SetHTTPClient(client)

assert.Equal(t, client, osrm.client)
}

func TestOSRMClient_get(t *testing.T) {
osrm := OSRMClient{}
osrm := newOSRMClient()
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("{\"message\": \"Ok\"}"))
}))
Expand All @@ -59,7 +77,7 @@ func TestOSRMClient_get(t *testing.T) {
}

func TestOSRMClient_applyOpts(t *testing.T) {
osrm := OSRMClient{}
osrm := newOSRMClient()
u := url.URL{}

osrm.applyOpts(&u, []Option{
Expand Down
64 changes: 64 additions & 0 deletions semaphore_client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
package gosrm

import "net/http"

type (
// httpClient is the default implementation of HTTPClient interface.
httpClient struct {
client *http.Client
pool chan struct{}
}

// HTTPClientConfig is the config used to customize http client.
HTTPClientConfig struct {
// MaxConcurrency is the max number of concurrent requests.
// If it's 0 then there is no limit.
//
// Defaults to 0.
MaxConcurrency uint

// HTTPClient is the client which will be used to do HTTP calls.
//
// Defaults to http.DefaultClient
HTTPClient *http.Client
}
)

// acquire acquires a spot in the pool.
func (c httpClient) acquire() {
if cap(c.pool) == 0 {
return
}
c.pool <- struct{}{}
}

// release releases a spot from the pool.
func (c httpClient) release() {
if cap(c.pool) == 0 {
return
}
<-c.pool
}

// Do does the HTTP call.
func (c httpClient) Do(req *http.Request) (*http.Response, error) {
c.acquire()
defer c.release()

return c.client.Do(req)
}

// NewHTTPClient returns a new HTTP client.
func NewHTTPClient(cfg HTTPClientConfig) HTTPClient {
var c httpClient

if cfg.HTTPClient != nil {
c.client = cfg.HTTPClient
} else {
c.client = http.DefaultClient
}

c.pool = make(chan struct{}, cfg.MaxConcurrency)

return c
}
74 changes: 74 additions & 0 deletions semaphore_client_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package gosrm

import (
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
)

func TestNewHTTPClient(t *testing.T) {
cfg := HTTPClientConfig{}

c := NewHTTPClient(cfg).(httpClient)
assert.Equal(t, http.DefaultClient, c.client)
assert.Equal(t, 0, cap(c.pool))

cfg.MaxConcurrency = 100

c = NewHTTPClient(cfg).(httpClient)
assert.Equal(t, 100, cap(c.pool))

cfg.HTTPClient = &http.Client{}
c = NewHTTPClient(cfg).(httpClient)
assert.Equal(t, cfg.HTTPClient, c.client)
}

func TestHTTPClient_acquire_and_release(t *testing.T) {
cfg := HTTPClientConfig{}
c := NewHTTPClient(cfg).(httpClient)

for i := 0; i < 5; i++ {
c.acquire()
}

// chan is not used since it's disabled.
assert.Len(t, c.pool, 0)

c.release()
assert.Len(t, c.pool, 0)

cfg.MaxConcurrency = 2
c = NewHTTPClient(cfg).(httpClient)

for i := 0; i < int(cfg.MaxConcurrency); i++ {
c.acquire()
}

// chan is full.
assert.Len(t, c.pool, 2)

for i := cfg.MaxConcurrency; i > 0; i-- {
c.release()
assert.Len(t, c.pool, int(i-1))
}

// chan is empty.
assert.Len(t, c.pool, 0)
}

func TestHTTPClient_Do(t *testing.T) {
client := NewHTTPClient(HTTPClientConfig{MaxConcurrency: 1})

testsrv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(201)
}))

req, err := http.NewRequest("GET", testsrv.URL, nil)
assert.NoError(t, err)

res, err := client.Do(req)
assert.NoError(t, err)
assert.Equal(t, 201, res.StatusCode)
}

0 comments on commit 426a192

Please sign in to comment.