diff --git a/changelog/@unreleased/pr-351.v2.yml b/changelog/@unreleased/pr-351.v2.yml new file mode 100644 index 00000000..1c4585cd --- /dev/null +++ b/changelog/@unreleased/pr-351.v2.yml @@ -0,0 +1,5 @@ +type: improvement +improvement: + description: Implement round robin scorer as URISelector interface + links: + - https://github.com/palantir/conjure-go-runtime/pull/351 diff --git a/conjure-go-client/httpclient/internal/rr_selector.go b/conjure-go-client/httpclient/internal/rr_selector.go new file mode 100644 index 00000000..67d97134 --- /dev/null +++ b/conjure-go-client/httpclient/internal/rr_selector.go @@ -0,0 +1,79 @@ +// Copyright (c) 2022 Palantir Technologies. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "math/rand" + "net/http" + "sync" + + werror "github.com/palantir/witchcraft-go-error" +) + +type roundRobinSelector struct { + sync.Mutex + source rand.Source + + prevURIs []string + offset int +} + +// NewRoundRobinURISelector returns a URI scorer that uses a round robin algorithm for selecting URIs when scoring +// using a rand.Rand seeded by the nanoClock function. The middleware no-ops on each request. This selector will always +// return one URI. +func NewRoundRobinURISelector(nanoClock func() int64) URISelector { + return &roundRobinSelector{ + source: rand.NewSource(nanoClock()), + prevURIs: []string{}, + } +} + +// Select implements Selector interface +func (s *roundRobinSelector) Select(uris []string, _ http.Header) ([]string, error) { + s.Lock() + defer s.Unlock() + if len(uris) == 0 { + return nil, werror.Error("no valid uris provided to round robin uri-selector") + } + + s.updateURIs(uris) + s.offset = (s.offset + 1) % len(uris) + return []string{uris[s.offset]}, nil +} + +// updateURIs determines whether we need to update the stored prevURIs because the current set of URIs differ from the +// last observed URIs. When the URIs we randomize to get a new offest. +func (s *roundRobinSelector) updateURIs(uris []string) { + reset := false + if len(s.prevURIs) == 0 { + reset = true + } + for i, uri := range s.prevURIs { + if uri != uris[i] { + reset = true + break + } + } + + if reset { + s.prevURIs = uris + // randomize offset on reinit + s.offset = rand.New(s.source).Intn(len(uris)) + } +} + +func (s *roundRobinSelector) RoundTrip(req *http.Request, next http.RoundTripper) (*http.Response, error) { + return next.RoundTrip(req) +} diff --git a/conjure-go-client/httpclient/internal/rr_selector_test.go b/conjure-go-client/httpclient/internal/rr_selector_test.go new file mode 100644 index 00000000..38637999 --- /dev/null +++ b/conjure-go-client/httpclient/internal/rr_selector_test.go @@ -0,0 +1,53 @@ +// Copyright (c) 2022 Palantir Technologies. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRoundRobinSelector_Select(t *testing.T) { + scorer := NewRoundRobinURISelector(func() int64 { return time.Now().UnixNano() }) + + t.Run("round robins across valid connections", func(t *testing.T) { + uris := []string{"uri1", "uri2", "uri3", "uri4", "uri5"} + const iterations = 100 + observed := make(map[string]int, iterations) + for i := 0; i < iterations; i++ { + uri, err := scorer.Select(uris, nil) + assert.NoError(t, err) + assert.Len(t, uri, 1) + observed[uri[0]] = observed[uri[0]] + 1 + } + + occurences := make([]int, 0, len(observed)) + for _, count := range observed { + occurences = append(occurences, count) + } + + for _, v := range occurences { + assert.Equal(t, occurences[0], v) + } + }) + + t.Run("erorrs with empty set of provided uris", func(t *testing.T) { + _, err := scorer.Select([]string{}, nil) + require.Error(t, err) + }) +}