Skip to content

Commit

Permalink
x
Browse files Browse the repository at this point in the history
  • Loading branch information
bassosimone committed Sep 11, 2024
1 parent 7a6255a commit 00af369
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 48 deletions.
7 changes: 5 additions & 2 deletions client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import (
"github.com/neubot/dash/spec"
)

// TODO(bassosimone): we should define the version in a single place

const (
// libraryName is the name of this library
libraryName = "neubot-dash"
Expand Down Expand Up @@ -384,8 +386,8 @@ func (c *Client) collect(

// 5. parse the response body
//
// Implementation note: historically this client did never care
// about saving the response body and we're still doing this
// Implementation note: we are not saving the response and we just
// limit ourselves with checking it's a valid JSON here.
c.Logger.Debugf("dash: body: %s", string(data))
return json.Unmarshal(data, &c.serverResults)
}
Expand Down Expand Up @@ -458,6 +460,7 @@ func (c *Client) StartDownload(ctx context.Context) (<-chan model.ClientResults,

// 1.1: the user manually specified the server -fqdn
case c.FQDN != "":
negotiateURL = &url.URL{}
negotiateURL.Scheme = c.Scheme
negotiateURL.Host = c.FQDN
negotiateURL.Path = spec.NegotiatePath
Expand Down
103 changes: 57 additions & 46 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@ import (
"context"
"errors"
"io"
"io/ioutil"
"net/http"
"net/url"
"strings"
"sync"
"testing"

v2 "github.com/m-lab/locate/api/v2"
"github.com/neubot/dash/model"
)

Expand All @@ -25,7 +26,7 @@ func TestClientNegotiate(t *testing.T) {
client.deps.JSONMarshal = func(v interface{}) ([]byte, error) {
return nil, errors.New("Mocked error")
}
_, err := client.negotiate(context.Background())
_, err := client.negotiate(context.Background(), &url.URL{})
if err == nil {
t.Fatal("Expected an error here")
}
Expand All @@ -38,7 +39,7 @@ func TestClientNegotiate(t *testing.T) {
) (*http.Request, error) {
return nil, errors.New("Mocked error")
}
_, err := client.negotiate(context.Background())
_, err := client.negotiate(context.Background(), &url.URL{})
if err == nil {
t.Fatal("Expected an error here")
}
Expand All @@ -49,7 +50,7 @@ func TestClientNegotiate(t *testing.T) {
client.deps.HTTPClientDo = func(req *http.Request) (*http.Response, error) {
return nil, errors.New("Mocked error")
}
_, err := client.negotiate(context.Background())
_, err := client.negotiate(context.Background(), &url.URL{})
if err == nil {
t.Fatal("Expected an error here")
}
Expand All @@ -60,26 +61,27 @@ func TestClientNegotiate(t *testing.T) {
client.deps.HTTPClientDo = func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 404,
Body: io.NopCloser(bytes.NewReader(nil)),
}, nil
}
_, err := client.negotiate(context.Background())
_, err := client.negotiate(context.Background(), &url.URL{})
if err == nil {
t.Fatal("Expected an error here")
}
})

t.Run("ioutil.ReadAll failure", func(t *testing.T) {
t.Run("io.ReadAll failure", func(t *testing.T) {
client := New(softwareName, softwareVersion)
client.deps.HTTPClientDo = func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader(nil)),
Body: io.NopCloser(bytes.NewReader(nil)),
}, nil
}
client.deps.IOReadAll = func(r io.Reader) ([]byte, error) {
return nil, errors.New("Mocked error")
}
_, err := client.negotiate(context.Background())
_, err := client.negotiate(context.Background(), &url.URL{})
if err == nil {
t.Fatal("Expected an error here")
}
Expand All @@ -90,10 +92,10 @@ func TestClientNegotiate(t *testing.T) {
client.deps.HTTPClientDo = func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader(nil)),
Body: io.NopCloser(bytes.NewReader(nil)),
}, nil
}
_, err := client.negotiate(context.Background())
_, err := client.negotiate(context.Background(), &url.URL{})
if err == nil {
t.Fatal("Expected an error here")
}
Expand All @@ -104,10 +106,10 @@ func TestClientNegotiate(t *testing.T) {
client.deps.HTTPClientDo = func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("{}")),
Body: io.NopCloser(strings.NewReader("{}")),
}, nil
}
_, err := client.negotiate(context.Background())
_, err := client.negotiate(context.Background(), &url.URL{})
if err == nil {
t.Fatal("Expected an error here")
}
Expand All @@ -118,13 +120,13 @@ func TestClientNegotiate(t *testing.T) {
client.deps.HTTPClientDo = func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader(`{
Body: io.NopCloser(strings.NewReader(`{
"Authorization": "0xdeadbeef",
"Unchoked": 1
}`)),
}, nil
}
_, err := client.negotiate(context.Background())
_, err := client.negotiate(context.Background(), &url.URL{})
if err != nil {
t.Fatal(err)
}
Expand All @@ -140,7 +142,7 @@ func TestClientDownload(t *testing.T) {
return nil, errors.New("Mocked error")
}
current := new(model.ClientResults)
err := client.download(context.Background(), "abc", current)
err := client.download(context.Background(), "abc", current, &url.URL{})
if err == nil {
t.Fatal("Expected an error here")
}
Expand All @@ -152,7 +154,7 @@ func TestClientDownload(t *testing.T) {
return nil, errors.New("Mocked error")
}
current := new(model.ClientResults)
err := client.download(context.Background(), "abc", current)
err := client.download(context.Background(), "abc", current, &url.URL{})
if err == nil {
t.Fatal("Expected an error here")
}
Expand All @@ -163,28 +165,29 @@ func TestClientDownload(t *testing.T) {
client.deps.HTTPClientDo = func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 404,
Body: io.NopCloser(bytes.NewReader(nil)),
}, nil
}
current := new(model.ClientResults)
err := client.download(context.Background(), "abc", current)
err := client.download(context.Background(), "abc", current, &url.URL{})
if err == nil {
t.Fatal("Expected an error here")
}
})

t.Run("ioutil.ReadAll failure", func(t *testing.T) {
t.Run("io.ReadAll failure", func(t *testing.T) {
client := New(softwareName, softwareVersion)
client.deps.HTTPClientDo = func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader(nil)),
Body: io.NopCloser(bytes.NewReader(nil)),
}, nil
}
client.deps.IOReadAll = func(r io.Reader) ([]byte, error) {
return nil, errors.New("Mocked error")
}
current := new(model.ClientResults)
err := client.download(context.Background(), "abc", current)
err := client.download(context.Background(), "abc", current, &url.URL{})
if err == nil {
t.Fatal("Expected an error here")
}
Expand All @@ -195,11 +198,11 @@ func TestClientDownload(t *testing.T) {
client.deps.HTTPClientDo = func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader(nil)),
Body: io.NopCloser(bytes.NewReader(nil)),
}, nil
}
current := new(model.ClientResults)
err := client.download(context.Background(), "abc", current)
err := client.download(context.Background(), "abc", current, &url.URL{})
if err != nil {
t.Fatal(err)
}
Expand All @@ -212,7 +215,7 @@ func TestClientCollect(t *testing.T) {
client.deps.JSONMarshal = func(v interface{}) ([]byte, error) {
return nil, errors.New("Mocked error")
}
err := client.collect(context.Background(), "abc")
err := client.collect(context.Background(), "abc", &url.URL{})
if err == nil {
t.Fatal("Expected an error here")
}
Expand All @@ -225,7 +228,7 @@ func TestClientCollect(t *testing.T) {
) (*http.Request, error) {
return nil, errors.New("Mocked error")
}
err := client.collect(context.Background(), "abc")
err := client.collect(context.Background(), "abc", &url.URL{})
if err == nil {
t.Fatal("Expected an error here")
}
Expand All @@ -236,7 +239,7 @@ func TestClientCollect(t *testing.T) {
client.deps.HTTPClientDo = func(req *http.Request) (*http.Response, error) {
return nil, errors.New("Mocked error")
}
err := client.collect(context.Background(), "abc")
err := client.collect(context.Background(), "abc", &url.URL{})
if err == nil {
t.Fatal("Expected an error here")
}
Expand All @@ -247,26 +250,27 @@ func TestClientCollect(t *testing.T) {
client.deps.HTTPClientDo = func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 404,
Body: io.NopCloser(bytes.NewReader(nil)),
}, nil
}
err := client.collect(context.Background(), "abc")
err := client.collect(context.Background(), "abc", &url.URL{})
if err == nil {
t.Fatal("Expected an error here")
}
})

t.Run("ioutil.ReadAll failure", func(t *testing.T) {
t.Run("io.ReadAll failure", func(t *testing.T) {
client := New(softwareName, softwareVersion)
client.deps.HTTPClientDo = func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader(nil)),
Body: io.NopCloser(bytes.NewReader(nil)),
}, nil
}
client.deps.IOReadAll = func(r io.Reader) ([]byte, error) {
return nil, errors.New("Mocked error")
}
err := client.collect(context.Background(), "abc")
err := client.collect(context.Background(), "abc", &url.URL{})
if err == nil {
t.Fatal("Expected an error here")
}
Expand All @@ -277,10 +281,10 @@ func TestClientCollect(t *testing.T) {
client.deps.HTTPClientDo = func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(bytes.NewReader(nil)),
Body: io.NopCloser(bytes.NewReader(nil)),
}, nil
}
err := client.collect(context.Background(), "abc")
err := client.collect(context.Background(), "abc", &url.URL{})
if err == nil {
t.Fatal("Expected an error here")
}
Expand All @@ -291,10 +295,10 @@ func TestClientCollect(t *testing.T) {
client.deps.HTTPClientDo = func(req *http.Request) (*http.Response, error) {
return &http.Response{
StatusCode: 200,
Body: ioutil.NopCloser(strings.NewReader("[]")),
Body: io.NopCloser(strings.NewReader("[]")),
}, nil
}
err := client.collect(context.Background(), "abc")
err := client.collect(context.Background(), "abc", &url.URL{})
if err != nil {
t.Fatal(err)
}
Expand All @@ -305,10 +309,10 @@ func TestClientLoop(t *testing.T) {
t.Run("negotiate failure", func(t *testing.T) {
ch := make(chan model.ClientResults)
client := New(softwareName, softwareVersion)
client.deps.Negotiate = func(ctx context.Context) (model.NegotiateResponse, error) {
client.deps.Negotiate = func(ctx context.Context, negotiateURL *url.URL) (model.NegotiateResponse, error) {
return model.NegotiateResponse{}, errors.New("Mocked error")
}
client.loop(context.Background(), ch)
client.loop(context.Background(), ch, &url.URL{})
if client.err == nil {
t.Fatal("Expected an error here")
}
Expand All @@ -317,15 +321,16 @@ func TestClientLoop(t *testing.T) {
t.Run("download failure", func(t *testing.T) {
ch := make(chan model.ClientResults)
client := New(softwareName, softwareVersion)
client.deps.Negotiate = func(ctx context.Context) (model.NegotiateResponse, error) {
client.deps.Negotiate = func(ctx context.Context, negotiateURL *url.URL) (model.NegotiateResponse, error) {
return model.NegotiateResponse{}, nil
}
client.deps.Download = func(
ctx context.Context, authorization string, current *model.ClientResults,
ctx context.Context, authorization string,
current *model.ClientResults, negotiateURL *url.URL,
) error {
return errors.New("Mocked error")
}
client.loop(context.Background(), ch)
client.loop(context.Background(), ch, &url.URL{})
if client.err == nil {
t.Fatal("Expected an error here")
}
Expand All @@ -334,15 +339,16 @@ func TestClientLoop(t *testing.T) {
t.Run("collect failure", func(t *testing.T) {
ch := make(chan model.ClientResults)
client := New(softwareName, softwareVersion)
client.deps.Negotiate = func(ctx context.Context) (model.NegotiateResponse, error) {
client.deps.Negotiate = func(ctx context.Context, negotiateURL *url.URL) (model.NegotiateResponse, error) {
return model.NegotiateResponse{}, nil
}
client.deps.Download = func(
ctx context.Context, authorization string, current *model.ClientResults,
ctx context.Context, authorization string,
current *model.ClientResults, negotiateURL *url.URL,
) error {
return nil
}
client.deps.Collect = func(ctx context.Context, authorization string) error {
client.deps.Collect = func(ctx context.Context, authorization string, negotiateURL *url.URL) error {
return errors.New("Mocked error")
}
var wg sync.WaitGroup
Expand All @@ -353,20 +359,25 @@ func TestClientLoop(t *testing.T) {
// drain channel
}
}()
client.loop(context.Background(), ch)
client.loop(context.Background(), ch, &url.URL{})
if client.err == nil {
t.Fatal("Expected an error here")
}
wg.Wait() // make sure we really terminate
})
}

type failingLocator struct{}

// Nearest implements locator.
func (f *failingLocator) Nearest(ctx context.Context, service string) ([]v2.Target, error) {
return nil, errors.New("mocked error")
}

func TestClientStartDownload(t *testing.T) {
t.Run("mlabns failure", func(t *testing.T) {
client := New(softwareName, softwareVersion)
client.deps.Locate = func(ctx context.Context) (string, error) {
return "", errors.New("Mocked error")
}
client.deps.Locator = &failingLocator{}
ch, err := client.StartDownload(context.Background())
if err == nil {
t.Fatal("Expected an error here")
Expand All @@ -378,7 +389,7 @@ func TestClientStartDownload(t *testing.T) {

t.Run("common case", func(t *testing.T) {
client := New(softwareName, softwareVersion)
client.deps.Loop = func(ctx context.Context, ch chan<- model.ClientResults) {
client.deps.Loop = func(ctx context.Context, ch chan<- model.ClientResults, negotiateURL *url.URL) {
close(ch)
}
ch, err := client.StartDownload(context.Background())
Expand Down

0 comments on commit 00af369

Please sign in to comment.