diff --git a/.github/workflows/golangci-lint.yml b/.github/workflows/golangci-lint.yml index 5b95a86..ce9a3e8 100644 --- a/.github/workflows/golangci-lint.yml +++ b/.github/workflows/golangci-lint.yml @@ -23,4 +23,4 @@ jobs: - name: golangci-lint uses: golangci/golangci-lint-action@v6 with: - version: v1.60 + version: v1.64 diff --git a/Makefile b/Makefile index b097031..c074c5d 100644 --- a/Makefile +++ b/Makefile @@ -54,6 +54,12 @@ profile_cache: profile_handlers: $(MAKE) -B $(PROFILES)/BenchmarkHandlers.bench BENCH_PKG=./internal/server +profile_typeurl: + $(MAKE) -B $(PROFILES)/BenchmarkGetTrimmedTypeURL.bench BENCH_PKG=ads + +profile_parse_glob_urn: + $(MAKE) -B $(PROFILES)/BenchmarkParseGlobCollectionURN.bench BENCH_PKG=ads + BENCHCOUNT = 1 BENCHTIME = 1s @@ -78,8 +84,8 @@ else $(error BENCH_PKG undefined) endif ifdef OPEN_PROFILES - go tool pprof $(BENCHBIN) $(PROFILES)/$*.cpu <<< web - go tool pprof $(PROFILES)/$*.mem <<< web + go tool pprof -http : $(BENCHBIN) $(PROFILES)/$*.cpu & \ + go tool pprof -http : $(PROFILES)/$*.mem ; kill %1 else $(info Not opening profiles since OPEN_PROFILES is not set) endif diff --git a/ads/ads.go b/ads/ads.go index f6650b6..5226319 100644 --- a/ads/ads.go +++ b/ads/ads.go @@ -132,6 +132,21 @@ func (r *Resource[T]) TypeURL() string { return types.APITypePrefix + string(r.Resource.ProtoReflect().Descriptor().FullName()) } +func (r *Resource[T]) Equals(other *Resource[T]) bool { + if r == other { + return true + } + if r == nil || other == nil { + return false + } + return r.Name == other.Name && + r.Version == other.Version && + proto.Equal(r.Ttl, other.Ttl) && + proto.Equal(r.CacheControl, other.CacheControl) && + proto.Equal(r.Metadata, other.Metadata) && + proto.Equal(r.Resource, other.Resource) +} + // UnmarshalRawResource unmarshals the given RawResource and returns a Resource of the corresponding // type. Resource.Marshal on the returned Resource will return the given RawResource instead of // re-serializing the resource. diff --git a/ads/glob_collection_url.go b/ads/glob_collection_url.go index d501393..da8af74 100644 --- a/ads/glob_collection_url.go +++ b/ads/glob_collection_url.go @@ -5,7 +5,7 @@ import ( "net/url" "strings" - types "github.com/envoyproxy/go-control-plane/pkg/resource/v3" + "google.golang.org/protobuf/proto" ) // GlobCollectionURL represents the individual elements of a glob collection URL. Please refer to the @@ -27,14 +27,22 @@ type GlobCollectionURL struct { } func (u GlobCollectionURL) String() string { + return u.uri(WildcardSubscription) +} + +func (u GlobCollectionURL) MemberURN(name string) string { + return u.uri(name) +} + +func (u GlobCollectionURL) uri(name string) string { var path string switch u.Path { case "": - path = WildcardSubscription + path = name case "/": - path = "/" + WildcardSubscription + path = "/" + name default: - path = u.Path + "/" + WildcardSubscription + path = u.Path + "/" + name } return XDSTPScheme + @@ -44,6 +52,15 @@ func (u GlobCollectionURL) String() string { u.ContextParameters } +func NewGlobCollectionURL[T proto.Message](authority, path string, contextParameters url.Values) GlobCollectionURL { + return GlobCollectionURL{ + Authority: authority, + ResourceType: getTrimmedTypeURL[T](), + Path: path, + ContextParameters: contextParameters.Encode(), + } +} + // ErrInvalidGlobCollectionURI is always returned by the various glob collection URL parsing // functions. var ErrInvalidGlobCollectionURI = errors.New("diderot: invalid glob collection URI") @@ -57,15 +74,13 @@ var ErrInvalidGlobCollectionURI = errors.New("diderot: invalid glob collection U // exact definition of a glob collection. // // [TP1 proposal]: https://github.com/cncf/xds/blob/main/proposals/TP1-xds-transport-next.md#uri-based-xds-resource-names -func ParseGlobCollectionURL(name, resourceType string) (GlobCollectionURL, error) { - gcURL, err := parseXDSTPURI(name, resourceType) +func ParseGlobCollectionURL[T proto.Message](name string) (GlobCollectionURL, error) { + gcURL, resource, err := ParseGlobCollectionURN[T](name) if err != nil { return GlobCollectionURL{}, err } - var ok bool - gcURL.Path, ok = strings.CutSuffix(gcURL.Path, "/"+WildcardSubscription) - if !ok { + if resource != WildcardSubscription { // URLs must end with /* return GlobCollectionURL{}, ErrInvalidGlobCollectionURI } @@ -73,9 +88,9 @@ func ParseGlobCollectionURL(name, resourceType string) (GlobCollectionURL, error return gcURL, nil } -// ExtractGlobCollectionURLFromResourceURN checks if the given name is a resource URN, and returns -// the corresponding GlobCollectionURL. The format of a resource URN is defined in the -// [TP1 proposal], and looks like this: +// ParseGlobCollectionURN checks if the given name is a resource URN, and returns the corresponding +// GlobCollectionURL. The format of a resource URN is defined in the [TP1 proposal], and looks like +// this: // // xdstp://[{authority}]/{resource type}/{id/*}?{context parameters} // @@ -99,22 +114,19 @@ func ParseGlobCollectionURL(name, resourceType string) (GlobCollectionURL, error // // [TP1 proposal]: https://github.com/cncf/xds/blob/main/proposals/TP1-xds-transport-next.md#uri-based-xds-resource-names // [here]: https://github.com/cncf/xds/issues/91 -func ExtractGlobCollectionURLFromResourceURN(name, resourceType string) (GlobCollectionURL, error) { - gcURL, err := parseXDSTPURI(name, resourceType) +func ParseGlobCollectionURN[T proto.Message](name string) (GlobCollectionURL, string, error) { + gcURL, err := parseXDSTPURI[T](name) if err != nil { - return GlobCollectionURL{}, err + return GlobCollectionURL{}, "", err } lastSlash := strings.LastIndex(gcURL.Path, "/") if lastSlash == -1 { // Missing path in URL - return GlobCollectionURL{}, ErrInvalidGlobCollectionURI + return GlobCollectionURL{}, "", ErrInvalidGlobCollectionURI } - if gcURL.Path[lastSlash:] == "/"+WildcardSubscription { - // resource URN cannot end in /* - return GlobCollectionURL{}, ErrInvalidGlobCollectionURI - } + resource := gcURL.Path[lastSlash+1:] if lastSlash == 0 { gcURL.Path = "/" @@ -122,10 +134,10 @@ func ExtractGlobCollectionURLFromResourceURN(name, resourceType string) (GlobCol gcURL.Path = gcURL.Path[:lastSlash] } - return gcURL, nil + return gcURL, resource, nil } -func parseXDSTPURI(resourceName, resourceType string) (GlobCollectionURL, error) { +func parseXDSTPURI[T proto.Message](resourceName string) (GlobCollectionURL, error) { // Skip deserializing the resource name if it doesn't start with the correct scheme if !strings.HasPrefix(resourceName, XDSTPScheme) { // doesn't start with xdstp:// @@ -138,8 +150,7 @@ func parseXDSTPURI(resourceName, resourceType string) (GlobCollectionURL, error) return GlobCollectionURL{}, ErrInvalidGlobCollectionURI } - // Glob collection URLs do not start with the type prefix, so trim it here. - resourceType = strings.TrimPrefix(resourceType, types.APITypePrefix) + resourceType := getTrimmedTypeURL[T]() collectionPath, ok := strings.CutPrefix(parsedURL.EscapedPath(), "/"+resourceType+"/") if !ok { @@ -160,3 +171,8 @@ func parseXDSTPURI(resourceName, resourceType string) (GlobCollectionURL, error) return u, nil } + +func getTrimmedTypeURL[T proto.Message]() string { + var t T + return string(t.ProtoReflect().Descriptor().FullName()) +} diff --git a/ads/glob_collection_url_test.go b/ads/glob_collection_url_test.go index de928b3..fa3dedf 100644 --- a/ads/glob_collection_url_test.go +++ b/ads/glob_collection_url_test.go @@ -3,14 +3,21 @@ package ads import ( "testing" + cluster "github.com/envoyproxy/go-control-plane/envoy/config/cluster/v3" + endpoint "github.com/envoyproxy/go-control-plane/envoy/config/endpoint/v3" + listener "github.com/envoyproxy/go-control-plane/envoy/config/listener/v3" + route "github.com/envoyproxy/go-control-plane/envoy/config/route/v3" + "github.com/envoyproxy/go-control-plane/pkg/resource/v3" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/wrapperspb" ) const ( resourceType = "google.protobuf.Int64Value" ) -func testBadURIs(t *testing.T, parser func(string, string) (GlobCollectionURL, error)) { +func testBadURIs(t *testing.T, parser func(string) (GlobCollectionURL, error)) { badURIs := []struct { name string resourceName string @@ -43,13 +50,13 @@ func testBadURIs(t *testing.T, parser func(string, string) (GlobCollectionURL, e for _, test := range badURIs { t.Run(test.name, func(t *testing.T) { - _, err := parser(test.resourceName, resourceType) + _, err := parser(test.resourceName) require.Error(t, err) }) } } -func testGoodURIs(t *testing.T, id string, parser func(string, string) (GlobCollectionURL, error)) { +func testGoodURIs(t *testing.T, id string, parser func(string) (GlobCollectionURL, error)) { tests := []struct { name string resourceName string @@ -120,7 +127,7 @@ func testGoodURIs(t *testing.T, id string, parser func(string, string) (GlobColl for _, test := range tests { t.Run(test.name, func(t *testing.T) { - actual, err := parser(test.resourceName, resourceType) + actual, err := parser(test.resourceName) if test.expectErr { require.Error(t, err) } else { @@ -133,26 +140,80 @@ func testGoodURIs(t *testing.T, id string, parser func(string, string) (GlobColl func TestParseGlobCollectionURL(t *testing.T) { t.Run("bad URIs", func(t *testing.T) { - testBadURIs(t, ParseGlobCollectionURL) + testBadURIs(t, ParseGlobCollectionURL[*wrapperspb.Int64Value]) }) t.Run("good URIs", func(t *testing.T) { - testGoodURIs(t, WildcardSubscription, ParseGlobCollectionURL) + testGoodURIs(t, WildcardSubscription, ParseGlobCollectionURL[*wrapperspb.Int64Value]) }) t.Run("rejects URNs", func(t *testing.T) { - _, err := ParseGlobCollectionURL("xdstp:///"+resourceType+"/foo/bar", resourceType) + _, err := ParseGlobCollectionURL[*wrapperspb.Int64Value]("xdstp:///" + resourceType + "/foo/bar") require.Error(t, err) }) } -func TestExtractGlobCollectionURLFromResourceURN(t *testing.T) { +func TestParseGlobCollectionURN(t *testing.T) { + parser := func(s string) (GlobCollectionURL, error) { + gcURL, _, err := ParseGlobCollectionURN[*wrapperspb.Int64Value](s) + return gcURL, err + } + t.Run("bad URIs", func(t *testing.T) { - testBadURIs(t, ExtractGlobCollectionURLFromResourceURN) + testBadURIs(t, parser) }) t.Run("good URIs", func(t *testing.T) { - testGoodURIs(t, "foo", ExtractGlobCollectionURLFromResourceURN) + testGoodURIs(t, "foo", parser) }) - t.Run("rejects glob collection URLs", func(t *testing.T) { - _, err := ExtractGlobCollectionURLFromResourceURN("xdstp:///"+resourceType+"/foo/*", resourceType) - require.Error(t, err) + t.Run("handles glob collection URLs", func(t *testing.T) { + gcURL, r, err := ParseGlobCollectionURN[*wrapperspb.Int64Value]("xdstp:///" + resourceType + "/foo/*") + require.NoError(t, err) + require.Equal(t, NewGlobCollectionURL[*wrapperspb.Int64Value]("", "foo", nil), gcURL) + require.Equal(t, WildcardSubscription, r) + }) +} + +func TestGetTrimmedTypeURL(t *testing.T) { + check := func(expected, actualTrimmed string) { + require.Equal(t, expected, resource.APITypePrefix+actualTrimmed) + } + check(resource.ListenerType, getTrimmedTypeURL[*listener.Listener]()) + check(resource.EndpointType, getTrimmedTypeURL[*endpoint.ClusterLoadAssignment]()) + check(resource.ClusterType, getTrimmedTypeURL[*cluster.Cluster]()) + check(resource.RouteType, getTrimmedTypeURL[*route.RouteConfiguration]()) +} + +func BenchmarkGetTrimmedTypeURL(b *testing.B) { + benchmarkGetTrimmedTypeURL[*wrapperspb.Int64Value](b) + benchmarkGetTrimmedTypeURL[*cluster.Cluster](b) +} + +func benchmarkGetTrimmedTypeURL[T proto.Message](b *testing.B) { + b.Run(getTrimmedTypeURL[T](), func(b *testing.B) { + var url string + for range b.N { + url = getTrimmedTypeURL[T]() + } + require.NotEmpty(b, url) + }) +} + +func BenchmarkParseGlobCollectionURN(b *testing.B) { + benchmarkParseGlobCollectionURN[*wrapperspb.Int64Value](b) + benchmarkParseGlobCollectionURN[*cluster.Cluster](b) +} + +func benchmarkParseGlobCollectionURN[T proto.Message](b *testing.B) { + expectedURL := NewGlobCollectionURL[T]("foo", "bar", nil) + url := expectedURL.String() + + var err error + b.Run(url, func(b *testing.B) { + var actualURL GlobCollectionURL + for range b.N { + actualURL, _, err = ParseGlobCollectionURN[T](url) + if err != nil { + b.Fatal(err) + } + } + require.Equal(b, expectedURL, actualURL) }) } diff --git a/cache.go b/cache.go index 288af19..3fe9e2a 100644 --- a/cache.go +++ b/cache.go @@ -162,9 +162,8 @@ func NewPrioritizedCache[T proto.Message](prioritySlots int) []Cache[T] { func newCache[T proto.Message](prioritySlots int) *cache[T] { ref := TypeOf[T]() return &cache[T]{ - typeReference: ref, - trimmedTypeURL: ref.TrimmedURL(), - prioritySlots: prioritySlots, + typeReference: ref, + prioritySlots: prioritySlots, } } @@ -179,9 +178,6 @@ type cache[T proto.Message] struct { // This is the type of each resource in this cache. Set and SetResource guarantee that all insertions // in this cache satisfy this invariant. typeReference TypeReference[T] - // The typeURL of the resources in this cache, without the leading "type.googleapis.com/". Used for - // resource URNs which do not include this prefix. - trimmedTypeURL string // This resourceMap maps the resource's name to its corresponding WatchableValue. resources internal.ResourceMap[string, *internal.WatchableValue[T]] // The number of slots watchableValue instances should be created with (see NewPrioritizedCache for @@ -204,7 +200,7 @@ func (c *cache[T]) IsSubscribedTo(name string, handler ads.SubscriptionHandler[T return true } - if gcURL, err := ads.ParseGlobCollectionURL(name, c.trimmedTypeURL); err == nil { + if gcURL, err := ads.ParseGlobCollectionURL[T](name); err == nil { return c.globCollections.IsSubscribed(gcURL, handler) } @@ -228,7 +224,7 @@ func (c *cache[T]) Subscribe(name string, handler ads.SubscriptionHandler[T]) { }) return true }) - } else if gcURL, err := ads.ParseGlobCollectionURL(name, c.trimmedTypeURL); err == nil { + } else if gcURL, err := ads.ParseGlobCollectionURL[T](name); err == nil { c.globCollections.Subscribe(gcURL, handler) } else { c.createOrModifyEntry(name, func(name string, value *internal.WatchableValue[T]) { @@ -237,6 +233,20 @@ func (c *cache[T]) Subscribe(name string, handler ads.SubscriptionHandler[T]) { } } +// parseGlobCollectionURN checks if the given name is a valid glob collection URN. Note: by +// definition, a URN is not a URL! Therefore, if the name ends with /*, this function will return an +// error. This should be used when setting or clearing individual members of a glob collection, as it +// is meaningless to "set" an entire glob collection. Similarly, clearing an entire glob collection +// by calling [RawCache.Clear] with the corresponding glob collection URL is not supported, and is +// effectively a noop. +func parseGlobCollectionURN[T proto.Message](name string) (ads.GlobCollectionURL, error) { + gcURL, resource, err := ads.ParseGlobCollectionURN[T](name) + if err != nil || resource == ads.WildcardSubscription { + return ads.GlobCollectionURL{}, ads.ErrInvalidGlobCollectionURI + } + return gcURL, nil +} + // createOrModifyEntry executes the given function on the value of that name after ensuring that it exists in the map. func (c *cache[T]) createOrModifyEntry(name string, f func(name string, value *internal.WatchableValue[T])) { c.resources.Compute( @@ -245,7 +255,7 @@ func (c *cache[T]) createOrModifyEntry(name string, f func(name string, value *i v := internal.NewValue[T](name, c.prioritySlots) v.SubscriberSets[internal.WildcardSubscription] = &c.wildcardSubscribers - if gcURL, err := ads.ExtractGlobCollectionURLFromResourceURN(name, c.trimmedTypeURL); err == nil { + if gcURL, err := parseGlobCollectionURN[T](name); err == nil { c.globCollections.PutValueInCollection(gcURL, v) } @@ -262,7 +272,7 @@ func (c *cache[T]) deleteEntryIfNilAndNoSubscribers(name string) { c.resources.DeleteIf(name, func(name string, value *internal.WatchableValue[T]) bool { hasNoExplicitSubscribers := value.SubscriberSets[internal.ExplicitSubscription].Size() == 0 if value.Read() == nil && hasNoExplicitSubscribers { - if gcURL, err := ads.ExtractGlobCollectionURLFromResourceURN(name, c.trimmedTypeURL); err == nil { + if gcURL, err := parseGlobCollectionURN[T](name); err == nil { c.globCollections.RemoveValueFromCollection(gcURL, value) } return true @@ -291,7 +301,7 @@ func (c *cache[T]) unsubscribe(name string, handler ads.SubscriptionHandler[T]) func (c *cache[T]) Unsubscribe(name string, handler ads.SubscriptionHandler[T]) { if name == ads.WildcardSubscription { c.wildcardSubscribers.Unsubscribe(handler) - } else if gcURL, err := ads.ParseGlobCollectionURL(name, c.trimmedTypeURL); err == nil { + } else if gcURL, err := ads.ParseGlobCollectionURL[T](name); err == nil { c.globCollections.Unsubscribe(gcURL, handler) } else { c.unsubscribe(name, handler) diff --git a/client.go b/client.go new file mode 100644 index 0000000..7fc5a24 --- /dev/null +++ b/client.go @@ -0,0 +1,362 @@ +package diderot + +import ( + "context" + "fmt" + "iter" + "log/slog" + "slices" + "sync" + "time" + + discoveryv3 "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" + "github.com/linkedin/diderot/ads" + internal "github.com/linkedin/diderot/internal/client" + "github.com/linkedin/diderot/internal/utils" + "google.golang.org/grpc" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" +) + +type ADSClientOption func(*options) + +const ( + defaultInitialReconnectBackoff = 100 * time.Millisecond + defaultMaxReconnectBackoff = 2 * time.Minute + defaultResponseChunkingSupported = true +) + +// NewADSClient creates a new [*ADSClient] with the given options. To stop the client, close the +// backing [grpc.ClientConn]. +func NewADSClient(conn *grpc.ClientConn, node *ads.Node, opts ...ADSClientOption) *ADSClient { + c := &ADSClient{ + conn: conn, + node: node, + newSubscription: make(chan struct{}, 1), + handlers: make(map[string]internal.RawResourceHandler), + options: options{ + initialReconnectBackoff: defaultInitialReconnectBackoff, + maxReconnectBackoff: defaultMaxReconnectBackoff, + responseChunkingSupported: defaultResponseChunkingSupported, + }, + } + + for _, opt := range opts { + opt(&c.options) + } + + go c.loop() + + return c +} + +type options struct { + initialReconnectBackoff time.Duration + maxReconnectBackoff time.Duration + responseChunkingSupported bool +} + +// WithReconnectBackoff provides backoff configuration when reconnecting to the xDS backend after a +// connection failure. The default settings are 100ms and 2m for the initial and max backoff +// respectively. +func WithReconnectBackoff(initialBackoff, maxBackoff time.Duration) ADSClientOption { + return func(o *options) { + o.initialReconnectBackoff = initialBackoff + o.maxReconnectBackoff = maxBackoff + } +} + +// WithResponseChunkingSupported changes whether response chunking should be supported (see +// [ads.ParseRemainingChunksFromNonce] for additional details). This feature is only provided by the +// [ADSServer] implemented in this package. This enabled by default. +func WithResponseChunkingSupported(supported bool) ADSClientOption { + return func(o *options) { + o.responseChunkingSupported = supported + } +} + +// An ADSClient is a client that implements the xDS protocol, and can therefore be used to talk to +// any xDS backend. Use the [Watch], [WatchGlob] and [WatchWildcard] to subscribe to resources. +type ADSClient struct { + options + node *ads.Node + conn *grpc.ClientConn + + newSubscription chan struct{} + + lock sync.Mutex + handlers map[string]internal.RawResourceHandler +} + +// A Watcher is used to receive updates from the xDS backend using an [ADSClient]. It is passed into +// the various [Watch] methods in this package. Note that it is imperative that implementations be +// hashable as it will be stored as the key to a map (unhashable types include slices and functions). +type Watcher[T proto.Message] interface { + // Notify is invoked whenever a response is processed. The given sequence will iterate over all the + // resources in the response, with a nil resource indicating a deletion. Implementations should + // return an error if any resource is invalid, and this error will be propagated as a NACK to the xDS + // backend. + Notify(resources iter.Seq2[string, *ads.Resource[T]]) error +} + +// Watch registers the given watcher in the given client, triggering a subscription (if necessary) +// for the given resource name such that the [Watcher] will be notified whenever the resource is +// updated. If a resource is already known (for example from a previous existing subscription), the +// watcher will be immediately notified. Glob or wildcard subscriptions are supported, and +// [Watcher.Notify] will be invoked with a sequence that iterates over all the updated resources. +func Watch[T proto.Message](c *ADSClient, name string, watcher Watcher[T]) { + if getResourceHandler[T](c).AddWatcher(name, watcher) { + c.notifyNewSubscription() + } +} + +// getResourceHandler gets or initializes the [internal.ResourceHandler] for the specified type in +// the given client. +func getResourceHandler[T proto.Message](c *ADSClient) *internal.ResourceHandler[T] { + c.lock.Lock() + defer c.lock.Unlock() + + typeURL := utils.GetTypeURL[T]() + if hAny, ok := c.handlers[typeURL]; !ok { + h := internal.NewResourceHandler[T]() + c.handlers[typeURL] = h + return h + } else { + return hAny.(*internal.ResourceHandler[T]) + } +} + +func (c *ADSClient) getResourceHandler(typeURL string) (internal.RawResourceHandler, bool) { + c.lock.Lock() + defer c.lock.Unlock() + h, ok := c.handlers[typeURL] + return h, ok +} + +// notifyNewSubscription signals to the subscription loop that a new subscription was added. +func (c *ADSClient) notifyNewSubscription() { + select { + case c.newSubscription <- struct{}{}: + default: + } +} + +// This is a type alias for the set of resources the client is subscribed to. The key is the typeURL +// and the value is the set of resource names subscribed to within that type. +type subscriptionSet map[string]utils.Set[string] + +// getPendingSubscriptions iterates over all the subscriptions returned by invoking +// [internal.ResourceHandler.AllSubscriptions] on all registered resource handlers, and compares it +// against the given set of already registered subscriptions. If any are missing, they are added to +// the returned subscription set after being added to the given set. This means that repeated +// invocations of this method will return an empty set if no new subscriptions are added in between. +func (c *ADSClient) getPendingSubscriptions(registeredSubscriptions subscriptionSet) subscriptionSet { + c.lock.Lock() + defer c.lock.Unlock() + + pendingSubscriptions := make(subscriptionSet) + add := func(typeURL string, name string) { + registered := internal.GetNestedMap(registeredSubscriptions, typeURL) + if !registered.Contains(name) { + registered.Add(name) + internal.GetNestedMap(pendingSubscriptions, typeURL).Add(name) + } + } + + for t, handler := range c.handlers { + for k := range handler.AllSubscriptions() { + add(t, k) + } + } + + return pendingSubscriptions +} + +// loop simply calls newStream and subscriptionLoop forever, until the underlying gRPC connection is +// closed. +func (c *ADSClient) loop() { + for { + // See documentation on subscriptionLoop. It returns when the stream ends, so a fresh stream needs to + // be created every time. + stream, responses, err := c.newStream() + if err != nil { + return + } + + err = c.subscriptionLoop(stream, responses) + slog.WarnContext(stream.Context(), "Restarting ADS stream", "err", err) + } +} + +// subscriptionLoop is the critical logic loop for the client. It polls the given responses channel, +// notifying watchers when new responses come in. Each slice returned by the responses channel is +// expected to contain responses that are all for the same typeURL. In most cases, the slice will +// only have one response in it, but if response chunking is supported, the slice will have all the +// response chunks in it. It also waits for any new subscriptions to be registered, and sends them to +// the server. This returns whenever the stream ends. +func (c *ADSClient) subscriptionLoop(stream deltaClient, responsesCh <-chan []*ads.DeltaDiscoveryResponse) error { + registeredSubscriptions := make(subscriptionSet) + + sendPendingSubscriptions := func() error { + pending := c.getPendingSubscriptions(registeredSubscriptions) + if len(pending) == 0 { + return nil + } + + slog.InfoContext(stream.Context(), "Subscribing to resources", "subscriptions", pending) + for t, subs := range pending { + err := stream.Send(&ads.DeltaDiscoveryRequest{ + Node: c.node, + TypeUrl: t, + ResourceNamesSubscribe: slices.Collect(subs.Values()), + }) + if err != nil { + return err + } + } + return nil + } + + isFirst := true + for { + err := sendPendingSubscriptions() + if err != nil { + return err + } + + select { + case <-c.newSubscription: + err := sendPendingSubscriptions() + if err != nil { + return err + } + case responses := <-responsesCh: + h, ok := c.getResourceHandler(responses[0].TypeUrl) + if !ok { + for _, res := range responses { + err := c.sendACKOrNACK( + stream, + res, + fmt.Errorf("received response with unknown type: %q", res.TypeUrl), + ) + if err != nil { + slog.WarnContext(stream.Context(), "ADS stream closed", "err", err) + return err + } + } + continue + } + + // Always ACK all but the last response. Errors will only be reported back to the server once all + // chunks are processed. + for _, res := range responses[:len(responses)-1] { + err := c.sendACKOrNACK(stream, res, nil) + if err != nil { + return err + } + } + + handlerErr := h.HandleResponses(isFirst, responses) + isFirst = false + if err = c.sendACKOrNACK(stream, responses[len(responses)-1], handlerErr); err != nil { + return err + } + case <-stream.Context().Done(): + return stream.Context().Err() + } + } +} + +// sendACKOrNACK will send an ACK or NACK (depending on the given error) for the given response. +func (c *ADSClient) sendACKOrNACK(stream deltaClient, res *ads.DeltaDiscoveryResponse, err error) error { + req := &ads.DeltaDiscoveryRequest{ + Node: c.node, + TypeUrl: res.TypeUrl, + ResponseNonce: res.Nonce, + } + if err != nil { + req.ErrorDetail = status.New(codes.InvalidArgument, err.Error()).Proto() + slog.WarnContext(stream.Context(), "NACKing response", "res", res, "err", err) + } else { + slog.DebugContext(stream.Context(), "ACKing response", "res", res) + } + return stream.Send(req) +} + +// newStream acquires a fresh stream from getDeltaClient and kicks off a goroutine that will read all +// responses from the stream, writing them to the returned channel. The goroutine will exit when the +// stream ends. +func (c *ADSClient) newStream() (deltaClient, <-chan []*ads.DeltaDiscoveryResponse, error) { + stream, err := c.getDeltaClient() + if err != nil { + return nil, nil, err + } + + responses := make(chan []*ads.DeltaDiscoveryResponse) + go func() { + chunkedResponses := make(map[string][]*ads.DeltaDiscoveryResponse) + + for { + res, err := stream.Recv() + if err != nil { + slog.WarnContext(stream.Context(), "ADS stream closed", "err", err) + return + } + + slog.Debug("Response received", "res", res) + + var resSlice []*ads.DeltaDiscoveryResponse + + if c.responseChunkingSupported { + resSlice = chunkedResponses[res.TypeUrl] + resSlice = append(resSlice, res) + chunkedResponses[res.TypeUrl] = resSlice + if remainingChunks, _ := ads.ParseRemainingChunksFromNonce(res.Nonce); remainingChunks != 0 { + continue + } else { + delete(chunkedResponses, res.TypeUrl) + } + } else { + resSlice = []*ads.DeltaDiscoveryResponse{res} + } + + select { + case responses <- resSlice: + case <-stream.Context().Done(): + slog.WarnContext(stream.Context(), "ADS stream closed", "err", stream.Context().Err()) + return + } + } + }() + + return stream, responses, nil +} + +type deltaClient interface { + Send(*ads.DeltaDiscoveryRequest) error + Recv() (*ads.DeltaDiscoveryResponse, error) + Context() context.Context +} + +// getDeltaClient attempts to reconnect to the ADS Server until it either successfully establishes a +// stream, or the underlying gRPC connection is explicitly closed, signaling a shutdown. +func (c *ADSClient) getDeltaClient() (deltaClient, error) { + backoff := c.initialReconnectBackoff + for { + delta, err := discoveryv3.NewAggregatedDiscoveryServiceClient(c.conn). + DeltaAggregatedResources(context.Background()) + if err != nil { + // This only occurs if c.conn was closed since context.Background() is used to create the stream. + if st := status.Convert(err); st.Code() == codes.Canceled { + return nil, err + } + + slog.Warn("Failed to create Delta stream, retrying", "backoff", backoff, "err", err) + time.Sleep(backoff) + backoff = min(backoff*2, c.maxReconnectBackoff) + continue + } + return delta, nil + } +} diff --git a/client_test.go b/client_test.go new file mode 100644 index 0000000..f5ac1a8 --- /dev/null +++ b/client_test.go @@ -0,0 +1,377 @@ +package diderot + +import ( + "context" + "iter" + "log/slog" + "maps" + "sync" + "sync/atomic" + "testing" + "time" + + discovery "github.com/envoyproxy/go-control-plane/envoy/service/discovery/v3" + "github.com/linkedin/diderot/ads" + "github.com/linkedin/diderot/internal/utils" + "github.com/linkedin/diderot/testutils" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/types/known/timestamppb" +) + +type Timestamp = timestamppb.Timestamp + +var Now = timestamppb.Now + +func TestADSClient(t *testing.T) { + slog.SetLogLoggerLevel(slog.LevelDebug) + + ts := testutils.NewTestGRPCServer(t) + + server := newMockServer(t) + discovery.RegisterAggregatedDiscoveryServiceServer(ts.Server, server) + ts.Start() + + client := NewADSClient(ts.Dial(), &ads.Node{Id: "test"}) + fooH := make(testutils.ChanSubscriptionHandler[*Timestamp], 1) + foo := ads.NewResource[*Timestamp]("foo", "0", Now()) + Watch(client, foo.Name, ChanWatcher[*Timestamp](fooH)) + + // The stream has not yet been established, no updates should be received. + checkNoUpdate(t, fooH) + + // Accept a new stream + closeStream := server.accept() + + // The resource does not initially exist, the first update should be a deletion. + server.expectSubscriptions(foo.Name) + nonce := server.respondDeletes(0, foo.Name) + fooH.WaitForDelete(t, foo.Name) + server.expectACK(nonce) + + // Set foo, and wait for the creation update + nonce = server.respondUpdates(0, foo) + fooH.WaitForUpdate(t, foo) + server.expectACK(nonce) + + closeStream() + closeStream = server.accept() + // Closing and reopening the stream makes the client reconnect, but since foo hasn't changed, nothing + // should happen. + server.expectSubscriptions(foo.Name) + nonce = server.respondUpdates(0, foo) + checkNoUpdate(t, fooH) + server.expectACK(nonce) + + // Disconnect the client, foo is updated during disconnect so expect a notification + closeStream() + foo = ads.NewResource(foo.Name, "1", Now()) + closeStream = server.accept() + server.expectSubscriptions(foo.Name) + nonce = server.respondUpdates(0, foo) + fooH.WaitForUpdate(t, foo) + server.expectACK(nonce) + + wildcardH := make(testutils.ChanSubscriptionHandler[*Timestamp], 2) + var wildcardExpectedCount atomic.Int32 + Watch(client, ads.WildcardSubscription, &FuncWatcher[*Timestamp]{ + notify: func(resources iter.Seq2[string, *ads.Resource[*Timestamp]]) error { + require.Len(t, maps.Collect(resources), int(wildcardExpectedCount.Load())) + for name, resource := range resources { + wildcardH <- testutils.Notification[*Timestamp]{ + Name: name, + Resource: resource, + } + } + return nil + }, + }) + + server.expectSubscriptions(ads.WildcardSubscription) + bar := ads.NewResource[*Timestamp]("bar", "0", Now()) + // Respond in multiple chunks, to test that those are handled correctly + chunkNonce1 := server.respondUpdates(1, foo) + // No update expected after first chunk + checkNoUpdate(t, wildcardH) + // As soon as the second chunk arrives, an update is expected, so update the expected count before + // sending the response. + wildcardExpectedCount.Store(2) + chunkNonce2 := server.respondUpdates(0, bar) + server.expectACK(chunkNonce1) + server.expectACK(chunkNonce2) + + // Expect a notification for foo and bar for wildcardH, but since fooH has already seen that version + // of foo, it should not receive an update. + wildcardH.WaitForNotifications(t, + testutils.ExpectUpdate(foo), + testutils.ExpectUpdate(bar), + ) + checkNoUpdate(t, fooH) + + // Clear foo, expect a deletion on fooH and the wildcard subscriber. + wildcardExpectedCount.Store(1) + nonce = server.respondDeletes(0, foo.Name) + server.expectACK(nonce) + fooH.WaitForDelete(t, foo.Name) + wildcardH.WaitForDelete(t, foo.Name) + + // Create new glob collection entries, which the wildcard subscriber should receive. + wildcardExpectedCount.Store(1) + gcURL := ads.NewGlobCollectionURL[*Timestamp]("", "collection", nil) + fooGlob := ads.NewResource(gcURL.MemberURN("foo"), "0", Now()) + nonce = server.respondUpdates(0, fooGlob) + server.expectACK(nonce) + wildcardH.WaitForNotifications(t, testutils.ExpectUpdate(fooGlob)) + + barGlob := ads.NewResource(gcURL.MemberURN("bar"), "0", Now()) + nonce = server.respondUpdates(0, barGlob) + server.expectACK(nonce) + wildcardH.WaitForNotifications(t, testutils.ExpectUpdate(barGlob)) + + // Subscribe to the glob collection. expecting an update for fooGlob and barGlob. + globH := make(testutils.ChanSubscriptionHandler[*Timestamp], 2) + var globExpectedCount atomic.Int32 + // Because the resources are already known thanks to the wildcard, this expects a notification + // immediately, before the subscription is even sent. + Watch(client, gcURL.String(), &FuncWatcher[*Timestamp]{ + notify: func(resources iter.Seq2[string, *ads.Resource[*Timestamp]]) error { + require.Len(t, maps.Collect(resources), int(globExpectedCount.Load())) + for name, resource := range resources { + globH <- testutils.Notification[*Timestamp]{ + Name: name, + Resource: resource, + } + } + return nil + }, + }) + server.expectSubscriptions(gcURL.String()) + globExpectedCount.Store(2) + nonce = server.respondUpdates(0, fooGlob, barGlob) + server.expectACK(nonce) + globH.WaitForNotifications(t, + testutils.ExpectUpdate(fooGlob), + testutils.ExpectUpdate(barGlob), + ) + globExpectedCount.Store(0) + + // Clear fooGlob, expect deletions for it. + wildcardExpectedCount.Store(1) + globExpectedCount.Store(1) + nonce = server.respondDeletes(0, fooGlob.Name) + server.expectACK(nonce) + wildcardH.WaitForDelete(t, fooGlob.Name) + globH.WaitForDelete(t, fooGlob.Name) + + // Disconnect the client and clear the collection during the disconnect. When the client reconnects, + // because it explicitly subscribes to the glob collection it will receive a deletion notification + // for the entire collection, but not for barGlob explicitly, as the server has forgotten that it + // exists. The client must figure out that barGlob has disappeared while it was disconnected. The + // same is true for the wildcard subscription: the client will not receive an explicit notification + // that barGlob has disappeared. + closeStream() + closeStream = server.accept() + server.expectSubscriptions(foo.Name, ads.WildcardSubscription, gcURL.String()) + + nonce = respond[*Timestamp]( + server, + // The only remaining resource is bar + []*ads.Resource[*Timestamp]{bar}, + // These are explicitly subscribed to but do not exist, so explicit removals are expected + []string{foo.Name, gcURL.String()}, + 0, + ) + server.respondUpdates(0, bar) + server.expectACK(nonce) + globH.WaitForDelete(t, barGlob.Name) + wildcardH.WaitForDelete(t, barGlob.Name) + + // This is an edge case, but bar is known because of the wildcard subscription. Therefore, even while + // the client is offline, subscribing to bar should deliver the notification. + closeStream() + barH := make(testutils.ChanSubscriptionHandler[*Timestamp], 1) + Watch(client, bar.Name, ChanWatcher[*Timestamp](barH)) + barH.WaitForUpdate(t, bar) + closeStream = server.accept() + // There should be an explicit subscription sent, but because bar is already known, no further + // updates should be received. + server.expectSubscriptions(foo.Name, bar.Name, ads.WildcardSubscription, gcURL.String()) + nonce = server.respondUpdates(0, bar) + server.expectACK(nonce) + checkNoUpdate(t, barH) + + // Delete bar, the final resource + nonce = server.respondDeletes(0, bar.Name) + server.expectACK(nonce) + + barH.WaitForDelete(t, bar.Name) + wildcardH.WaitForDelete(t, bar.Name) + + // Disconnect again to test what happens when Watch is called while offline for glob and wildcards. + closeStream() + allResources := new(map[string]*ads.Resource[*Timestamp]) + Watch(client, ads.WildcardSubscription, OnceWatcher(allResources)) + // This should be immediately ready, as data has been received and far as the client knows, there are + // no resources. + require.NotNil(t, *allResources) + require.Empty(t, *allResources) + + // Same behavior expected for glob + allGlobResource := new(map[string]*ads.Resource[*Timestamp]) + Watch(client, gcURL.String(), OnceWatcher(allGlobResource)) + require.NotNil(t, *allGlobResource) + require.Empty(t, *allGlobResource) +} + +type FuncWatcher[T proto.Message] struct { + notify func(resources iter.Seq2[string, *ads.Resource[T]]) error +} + +func (f FuncWatcher[T]) Notify(resources iter.Seq2[string, *ads.Resource[T]]) error { + return f.notify(resources) +} + +type ChanWatcher[T proto.Message] testutils.ChanSubscriptionHandler[T] + +func (c ChanWatcher[T]) Notify(resources iter.Seq2[string, *ads.Resource[T]]) error { + for name, resource := range resources { + testutils.ChanSubscriptionHandler[T](c).Notify(name, resource, ads.SubscriptionMetadata{}) + } + return nil +} + +func checkNoUpdate[T proto.Message](t *testing.T, h testutils.ChanSubscriptionHandler[T]) { + select { + case n := <-h: + require.FailNow(t, "handler should not receive any messages", n) + case <-time.After(500 * time.Millisecond): + } +} + +func OnceWatcher[T proto.Message](m *map[string]*ads.Resource[T]) Watcher[T] { + var once sync.Once + return &FuncWatcher[T]{notify: func(resources iter.Seq2[string, *ads.Resource[T]]) error { + once.Do(func() { + *m = maps.Collect(resources) + }) + return nil + }} +} + +type mockServer struct { + t *testing.T + requests chan *ads.DeltaDiscoveryRequest + responses chan *ads.DeltaDiscoveryResponse + kill chan chan struct{} + group errgroup.Group +} + +func newMockServer(t *testing.T) *mockServer { + ms := &mockServer{ + t: t, + requests: make(chan *ads.DeltaDiscoveryRequest), + responses: make(chan *ads.DeltaDiscoveryResponse), + kill: make(chan chan struct{}), + } + return ms +} + +func (ms *mockServer) StreamAggregatedResources(ads.SotWStream) error { + return status.Errorf(codes.Unimplemented, "not implemented") +} + +func (ms *mockServer) DeltaAggregatedResources(stream ads.DeltaStream) error { + kill := <-ms.kill + ms.group.Go(func() error { + for { + select { + case res := <-ms.responses: + ms.t.Logf("Responding: %+v", res) + err := stream.Send(res) + if err != nil { + return nil + } + case <-stream.Context().Done(): + return nil + } + } + }) + ms.group.Go(func() error { + for { + req, err := stream.Recv() + if err != nil { + return nil + } + ms.t.Logf("Received request: %+v", req) + select { + case ms.requests <- req: + case <-stream.Context().Done(): + return nil + } + } + }) + <-kill + ms.t.Log("Stream killed") + return context.Canceled +} + +func (ms *mockServer) accept() context.CancelFunc { + ch := make(chan struct{}) + ms.kill <- ch + return sync.OnceFunc(func() { + close(ch) + require.NoError(ms.t, ms.group.Wait()) + }) +} + +func (ms *mockServer) respondUpdates( + remainingChunks int, + resources ...*ads.Resource[*Timestamp], +) string { + return respond[*Timestamp](ms, resources, nil, remainingChunks) +} + +func (ms *mockServer) respondDeletes( + remainingChunks int, + removedResources ...string, +) string { + return respond[*Timestamp](ms, nil, removedResources, remainingChunks) +} + +func respond[T proto.Message]( + ms *mockServer, + resources []*ads.Resource[T], + removedResources []string, + remainingChunks int, +) string { + var marshaled []*ads.RawResource + for _, resource := range resources { + raw, err := resource.Marshal() + require.NoError(ms.t, err) + marshaled = append(marshaled, raw) + } + nonce := utils.NewNonce(remainingChunks) + ms.responses <- &ads.DeltaDiscoveryResponse{ + Resources: marshaled, + TypeUrl: utils.GetTypeURL[T](), + RemovedResources: removedResources, + Nonce: nonce, + } + return nonce +} + +func (ms *mockServer) expectACK(nonce string) { + req := <-ms.requests + require.Equal(ms.t, utils.GetTypeURL[*Timestamp](), req.TypeUrl) + require.Equal(ms.t, nonce, req.ResponseNonce) +} + +func (ms *mockServer) expectSubscriptions(subscriptions ...string) { + req := <-ms.requests + require.Equal(ms.t, utils.GetTypeURL[*Timestamp](), req.TypeUrl) + require.Empty(ms.t, req.ResponseNonce) + require.ElementsMatch(ms.t, subscriptions, req.ResourceNamesSubscribe) +} diff --git a/doc.go b/doc.go index f63cd0d..a5beffd 100644 --- a/doc.go +++ b/doc.go @@ -71,7 +71,7 @@ This means that if resource names are [xdstp:// URNs], they will be automaticall corresponding glob collection, if applicable. These resources are still available for subscription by their full URN, but will also be available for subscription by subscribing to the parent glob collection. More details available at [diderot.Cache.Subscribe], [ads.ParseGlobCollectionURL] and -[ads.ExtractGlobCollectionURLFromResourceURN]. +[ads.ParseGlobCollectionURN]. [xDS spec]: https://www.envoyproxy.io/docs/envoy/latest/api-docs/xds_protocol#how-the-client-specifies-what-resources-to-return [xdstp:// URNs]: https://github.com/cncf/xds/blob/main/proposals/TP1-xds-transport-next.md#uri-based-xds-resource-names diff --git a/go.mod b/go.mod index 85f8537..81274dc 100644 --- a/go.mod +++ b/go.mod @@ -3,32 +3,31 @@ module github.com/linkedin/diderot go 1.23.0 require ( - github.com/envoyproxy/go-control-plane v0.13.0 + github.com/envoyproxy/go-control-plane v0.13.4 + github.com/envoyproxy/go-control-plane/envoy v1.32.4 github.com/google/go-cmp v0.6.0 - github.com/stretchr/testify v1.9.0 + github.com/stretchr/testify v1.10.0 golang.org/x/time v0.5.0 - google.golang.org/genproto/googleapis/rpc v0.0.0-20240709173604-40e1e62336c5 - google.golang.org/grpc v1.66.0 - google.golang.org/protobuf v1.34.2 + google.golang.org/genproto/googleapis/rpc v0.0.0-20241202173237-19429a94021a + google.golang.org/grpc v1.70.0 + google.golang.org/protobuf v1.36.4 ) require ( - cel.dev/expr v0.15.0 // indirect - cloud.google.com/go/compute/metadata v0.3.0 // indirect - github.com/census-instrumentation/opencensus-proto v0.4.1 // indirect + cel.dev/expr v0.19.0 // indirect + cloud.google.com/go/compute/metadata v0.5.2 // indirect github.com/cespare/xxhash/v2 v2.3.0 // indirect - github.com/cncf/xds/go v0.0.0-20240423153145-555b57ec207b // indirect + github.com/cncf/xds/go v0.0.0-20240905190251-b4127c9b8d78 // indirect github.com/davecgh/go-spew v1.1.1 // indirect - github.com/envoyproxy/protoc-gen-validate v1.0.4 // indirect + github.com/envoyproxy/protoc-gen-validate v1.2.1 // indirect github.com/kr/text v0.2.0 // indirect github.com/planetscale/vtprotobuf v0.6.1-0.20240319094008-0393e58bdf10 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect - github.com/rogpeppe/go-internal v1.12.0 // indirect - golang.org/x/net v0.27.0 // indirect - golang.org/x/oauth2 v0.21.0 // indirect - golang.org/x/sync v0.7.0 // indirect - golang.org/x/sys v0.22.0 // indirect - golang.org/x/text v0.16.0 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20240709173604-40e1e62336c5 // indirect + golang.org/x/net v0.34.0 // indirect + golang.org/x/oauth2 v0.24.0 // indirect + golang.org/x/sync v0.10.0 // indirect + golang.org/x/sys v0.29.0 // indirect + golang.org/x/text v0.21.0 // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20241202173237-19429a94021a // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 013d5b8..57c938e 100644 --- a/go.sum +++ b/go.sum @@ -1,20 +1,28 @@ -cel.dev/expr v0.15.0 h1:O1jzfJCQBfL5BFoYktaxwIhuttaQPsVWerH9/EEKx0w= -cel.dev/expr v0.15.0/go.mod h1:TRSuuV7DlVCE/uwv5QbAiW/v8l5O8C4eEPHeu7gf7Sg= -cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= -cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= -github.com/census-instrumentation/opencensus-proto v0.4.1 h1:iKLQ0xPNFxR/2hzXZMrBo8f1j86j5WHzznCCQxV/b8g= -github.com/census-instrumentation/opencensus-proto v0.4.1/go.mod h1:4T9NM4+4Vw91VeyqjLS6ao50K5bOcLKN6Q42XnYaRYw= +cel.dev/expr v0.19.0 h1:lXuo+nDhpyJSpWxpPVi5cPUwzKb+dsdOiw6IreM5yt0= +cel.dev/expr v0.19.0/go.mod h1:MrpN08Q+lEBs+bGYdLxxHkZoUSsCp0nSKTs0nTymJgw= +cloud.google.com/go/compute/metadata v0.5.2 h1:UxK4uu/Tn+I3p2dYWTfiX4wva7aYlKixAHn3fyqngqo= +cloud.google.com/go/compute/metadata v0.5.2/go.mod h1:C66sj2AluDcIqakBq/M8lw8/ybHgOZqin2obFxa/E5k= github.com/cespare/xxhash/v2 v2.3.0 h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UFvs= github.com/cespare/xxhash/v2 v2.3.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= -github.com/cncf/xds/go v0.0.0-20240423153145-555b57ec207b h1:ga8SEFjZ60pxLcmhnThWgvH2wg8376yUJmPhEH4H3kw= -github.com/cncf/xds/go v0.0.0-20240423153145-555b57ec207b/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8= +github.com/cncf/xds/go v0.0.0-20240905190251-b4127c9b8d78 h1:QVw89YDxXxEe+l8gU8ETbOasdwEV+avkR75ZzsVV9WI= +github.com/cncf/xds/go v0.0.0-20240905190251-b4127c9b8d78/go.mod h1:W+zGtBO5Y1IgJhy4+A9GOqVhqLpfZi+vwmdNXUehLA8= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= -github.com/envoyproxy/go-control-plane v0.13.0 h1:HzkeUz1Knt+3bK+8LG1bxOO/jzWZmdxpwC51i202les= -github.com/envoyproxy/go-control-plane v0.13.0/go.mod h1:GRaKG3dwvFoTg4nj7aXdZnvMg4d7nvT/wl9WgVXn3Q8= -github.com/envoyproxy/protoc-gen-validate v1.0.4 h1:gVPz/FMfvh57HdSJQyvBtF00j8JU4zdyUgIUNhlgg0A= -github.com/envoyproxy/protoc-gen-validate v1.0.4/go.mod h1:qys6tmnRsYrQqIhm2bvKZH4Blx/1gTIZ2UKVY1M+Yew= +github.com/envoyproxy/go-control-plane v0.13.4 h1:zEqyPVyku6IvWCFwux4x9RxkLOMUL+1vC9xUFv5l2/M= +github.com/envoyproxy/go-control-plane v0.13.4/go.mod h1:kDfuBlDVsSj2MjrLEtRWtHlsWIFcGyB2RMO44Dc5GZA= +github.com/envoyproxy/go-control-plane/envoy v1.32.4 h1:jb83lalDRZSpPWW2Z7Mck/8kXZ5CQAFYVjQcdVIr83A= +github.com/envoyproxy/go-control-plane/envoy v1.32.4/go.mod h1:Gzjc5k8JcJswLjAx1Zm+wSYE20UrLtt7JZMWiWQXQEw= +github.com/envoyproxy/go-control-plane/ratelimit v0.1.0 h1:/G9QYbddjL25KvtKTv3an9lx6VBE2cnb8wp1vEGNYGI= +github.com/envoyproxy/go-control-plane/ratelimit v0.1.0/go.mod h1:Wk+tMFAFbCXaJPzVVHnPgRKdUdwW/KdbRt94AzgRee4= +github.com/envoyproxy/protoc-gen-validate v1.2.1 h1:DEo3O99U8j4hBFwbJfrz9VtgcDfUKS7KJ7spH3d86P8= +github.com/envoyproxy/protoc-gen-validate v1.2.1/go.mod h1:d/C80l/jxXLdfEIhX1W2TmLfsJ31lvEjwamM4DxlWXU= +github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY= +github.com/go-logr/logr v1.4.2/go.mod h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY= +github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag= +github.com/go-logr/stdr v1.2.2/go.mod h1:mMo/vtBO5dYbehREoey6XUKy/eSumjCCveDpRre4VKE= +github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= +github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI= github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= @@ -29,28 +37,38 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= -github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= -github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= -golang.org/x/net v0.27.0 h1:5K3Njcw06/l2y9vpGCSdcxWOYHOUk3dVNGDXN+FvAys= -golang.org/x/net v0.27.0/go.mod h1:dDi0PyhWNoiUOrAS8uXv/vnScO4wnHQO4mj9fn/RytE= -golang.org/x/oauth2 v0.21.0 h1:tsimM75w1tF/uws5rbeHzIWxEqElMehnc+iW793zsZs= -golang.org/x/oauth2 v0.21.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= -golang.org/x/sync v0.7.0 h1:YsImfSBoP9QPYL0xyKJPq0gcaJdG3rInoqxTWbfQu9M= -golang.org/x/sync v0.7.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= -golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= -golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= -golang.org/x/text v0.16.0 h1:a94ExnEXNtEwYLGJSIUxnWoxoRz/ZcCsV63ROupILh4= -golang.org/x/text v0.16.0/go.mod h1:GhwF1Be+LQoKShO3cGOHzqOgRrGaYc9AvblQOmPVHnI= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +go.opentelemetry.io/otel v1.32.0 h1:WnBN+Xjcteh0zdk01SVqV55d/m62NJLJdIyb4y/WO5U= +go.opentelemetry.io/otel v1.32.0/go.mod h1:00DCVSB0RQcnzlwyTfqtxSm+DRr9hpYrHjNGiBHVQIg= +go.opentelemetry.io/otel/metric v1.32.0 h1:xV2umtmNcThh2/a/aCP+h64Xx5wsj8qqnkYZktzNa0M= +go.opentelemetry.io/otel/metric v1.32.0/go.mod h1:jH7CIbbK6SH2V2wE16W05BHCtIDzauciCRLoc/SyMv8= +go.opentelemetry.io/otel/sdk v1.32.0 h1:RNxepc9vK59A8XsgZQouW8ue8Gkb4jpWtJm9ge5lEG4= +go.opentelemetry.io/otel/sdk v1.32.0/go.mod h1:LqgegDBjKMmb2GC6/PrTnteJG39I8/vJCAP9LlJXEjU= +go.opentelemetry.io/otel/sdk/metric v1.32.0 h1:rZvFnvmvawYb0alrYkjraqJq0Z4ZUJAiyYCU9snn1CU= +go.opentelemetry.io/otel/sdk/metric v1.32.0/go.mod h1:PWeZlq0zt9YkYAp3gjKZ0eicRYvOh1Gd+X99x6GHpCQ= +go.opentelemetry.io/otel/trace v1.32.0 h1:WIC9mYrXf8TmY/EXuULKc8hR17vE+Hjv2cssQDe03fM= +go.opentelemetry.io/otel/trace v1.32.0/go.mod h1:+i4rkvCraA+tG6AzwloGaCtkx53Fa+L+V8e9a7YvhT8= +golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= +golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= +golang.org/x/oauth2 v0.24.0 h1:KTBBxWqUa0ykRPLtV69rRto9TLXcqYkeswu48x/gvNE= +golang.org/x/oauth2 v0.24.0/go.mod h1:XYTD2NtWslqkgxebSiOHnXEap4TF09sJSc7H1sXbhtI= +golang.org/x/sync v0.10.0 h1:3NQrjDixjgGwUOCaF8w2+VYHv0Ve/vGYSbdkTa98gmQ= +golang.org/x/sync v0.10.0/go.mod h1:Czt+wKu1gCyEFDUtn0jG5QVvpJ6rzVqr5aXyt9drQfk= +golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= +golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= +golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/time v0.5.0 h1:o7cqy6amK/52YcAKIPlM3a+Fpj35zvRj2TP+e1xFSfk= golang.org/x/time v0.5.0/go.mod h1:3BpzKBy/shNhVucY/MWOyx10tF3SFh9QdLuxbVysPQM= -google.golang.org/genproto/googleapis/api v0.0.0-20240709173604-40e1e62336c5 h1:a/Z0jgw03aJ2rQnp5PlPpznJqJft0HyvyrcUcxgzPwY= -google.golang.org/genproto/googleapis/api v0.0.0-20240709173604-40e1e62336c5/go.mod h1:mw8MG/Qz5wfgYr6VqVCiZcHe/GJEfI+oGGDCohaVgB0= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240709173604-40e1e62336c5 h1:SbSDUWW1PAO24TNpLdeheoYPd7kllICcLU52x6eD4kQ= -google.golang.org/genproto/googleapis/rpc v0.0.0-20240709173604-40e1e62336c5/go.mod h1:Ue6ibwXGpU+dqIcODieyLOcgj7z8+IcskoNIgZxtrFY= -google.golang.org/grpc v1.66.0 h1:DibZuoBznOxbDQxRINckZcUvnCEvrW9pcWIE2yF9r1c= -google.golang.org/grpc v1.66.0/go.mod h1:s3/l6xSSCURdVfAnL+TqCNMyTDAGN6+lZeVxnZR128Y= -google.golang.org/protobuf v1.34.2 h1:6xV6lTsCfpGD21XK49h7MhtcApnLqkfYgPcdHftf6hg= -google.golang.org/protobuf v1.34.2/go.mod h1:qYOHts0dSfpeUzUFpOMr/WGzszTmLH+DiWniOlNbLDw= +google.golang.org/genproto/googleapis/api v0.0.0-20241202173237-19429a94021a h1:OAiGFfOiA0v9MRYsSidp3ubZaBnteRUyn3xB2ZQ5G/E= +google.golang.org/genproto/googleapis/api v0.0.0-20241202173237-19429a94021a/go.mod h1:jehYqy3+AhJU9ve55aNOaSml7wUXjF9x6z2LcCfpAhY= +google.golang.org/genproto/googleapis/rpc v0.0.0-20241202173237-19429a94021a h1:hgh8P4EuoxpsuKMXX/To36nOFD7vixReXgn8lPGnt+o= +google.golang.org/genproto/googleapis/rpc v0.0.0-20241202173237-19429a94021a/go.mod h1:5uTbfoYQed2U9p3KIj2/Zzm02PYhndfdmML0qC3q3FU= +google.golang.org/grpc v1.70.0 h1:pWFv03aZoHzlRKHWicjsZytKAiYCtNS0dHbXnIdq7jQ= +google.golang.org/grpc v1.70.0/go.mod h1:ofIJqVKDXx/JiXrwr2IG4/zwdH9txy3IlF40RmcJSQw= +google.golang.org/protobuf v1.36.4 h1:6A3ZDJHn/eNqc1i+IdefRzy/9PokBTPvcqMySR7NNIM= +google.golang.org/protobuf v1.36.4/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= diff --git a/internal/client/watchers.go b/internal/client/watchers.go new file mode 100644 index 0000000..cac2d56 --- /dev/null +++ b/internal/client/watchers.go @@ -0,0 +1,352 @@ +package internal + +import ( + "errors" + "fmt" + "iter" + "maps" + "sync" + + "github.com/linkedin/diderot/ads" + "github.com/linkedin/diderot/internal/utils" + "google.golang.org/protobuf/proto" +) + +// Watcher is a copy of the interface of the same name in the root package, to avoid import cycles. +type Watcher[T proto.Message] interface { + Notify(resources iter.Seq2[string, *ads.Resource[T]]) error +} + +// RawResourceHandler is a non-generic interface implemented by [ResourceHandler]. Used by the +// non-generic [github.com/linkedin/diderot.ADSClient]. +type RawResourceHandler interface { + // AllSubscriptions returns a sequence of all the subscriptions in this client. + AllSubscriptions() iter.Seq[string] + // HandleResponses should be called whenever responses are received. Accepts a slice of responses, + // since the client may support chunking. The given boolean parameter indicates whether this is the + // set of responses received for a stream. This is used to determine whether any resources were + // deleted while the client was disconnected. + HandleResponses(isFirst bool, responses []*ads.DeltaDiscoveryResponse) error +} + +// Ensure that [ResourceHandler] implements the [RawResourceHandler] interface. +var _ RawResourceHandler = (*ResourceHandler[proto.Message])(nil) + +// ResourceHandler implements the core logic of managing notifications for watchers. +type ResourceHandler[T proto.Message] struct { + lock sync.Mutex + + // All the resources currently known by this client. + resources map[string]*ads.Resource[T] + + // Maps resource name to watchers of the resource. + subscriptions map[string]*subscription[T] + // Maps glob collection URL to watchers of the collection. + globSubscriptions map[ads.GlobCollectionURL]*globSubscription[T] + // Contains the set of wildcard watchers, if present. + wildcardSubscription *subscription[T] +} + +// resolve returns an [iter.Seq2] that resolves the resources for the given sequence of resource +// names. The [ads.Resource] will be nil if no such resource is known. +func (h *ResourceHandler[T]) resolve(in iter.Seq[string]) iter.Seq2[string, *ads.Resource[T]] { + return func(yield func(string, *ads.Resource[T]) bool) { + for name := range in { + if !yield(name, h.resources[name]) { + return + } + } + } +} + +func (h *ResourceHandler[T]) resolveSingle(name string) iter.Seq2[string, *ads.Resource[T]] { + return func(yield func(string, *ads.Resource[T]) bool) { + yield(name, h.resources[name]) + } +} + +// setResource updates the map of known resources. Returns a boolean indicating whether the resource +// was actually changed. +func (h *ResourceHandler[T]) setResource(name string, resource *ads.Resource[T]) (updated bool) { + var previous *ads.Resource[T] + previous, ok := h.resources[name] + if !ok && resource == nil { + // Ignore deletions for unknown resources + return false + } + if ok && resource.Equals(previous) { + // Ignore updates identical to the most recently seen resource. + return false + } + + if resource == nil { + delete(h.resources, name) + } else { + h.resources[name] = resource + } + return true +} + +func NewResourceHandler[T proto.Message]() *ResourceHandler[T] { + return &ResourceHandler[T]{ + resources: make(map[string]*ads.Resource[T]), + subscriptions: make(map[string]*subscription[T]), + globSubscriptions: make(map[ads.GlobCollectionURL]*globSubscription[T]), + } +} + +// AddWatcher registers the given [Watcher] against the given resource name. The watcher will be +// notified whenever the resource is created, updated or deleted. The returned boolean indicates +// whether the watcher was a new registration, or was already previously registered. If a value for +// the given resource is already known, the watcher is immediately notified. +func (h *ResourceHandler[T]) AddWatcher(name string, w Watcher[T]) bool { + h.lock.Lock() + defer h.lock.Unlock() + + // contains the set of watchers to update + var watchers utils.Set[Watcher[T]] + // set if a value for the resource is already known. + var resources iter.Seq2[string, *ads.Resource[T]] + + if name == ads.WildcardSubscription { + if h.wildcardSubscription == nil { + h.wildcardSubscription = newSubscription[T]() + } + watchers = h.wildcardSubscription.watchers + if h.wildcardSubscription.initialized { + resources = maps.All(h.resources) + } + } else if gcURL, err := ads.ParseGlobCollectionURL[T](name); err == nil { + globSub, ok := h.globSubscriptions[gcURL] + if !ok { + globSub = &globSubscription[T]{ + subscription: *newSubscription[T](), + entries: make(utils.Set[string]), + } + h.globSubscriptions[gcURL] = globSub + } + watchers = globSub.watchers + if globSub.initialized { + resources = h.resolve(globSub.entries.Values()) + } + } else { + sub, ok := h.subscriptions[name] + if !ok { + sub = newSubscription[T]() + h.subscriptions[name] = sub + } + // In the event that there is already data from another subscription for this specific resource, + // immediately satisfy the watcher. + _, sub.initialized = h.resources[name] + watchers = sub.watchers + if sub.initialized { + resources = h.resolveSingle(name) + } + } + + if resources != nil { + _ = w.Notify(resources) + } + + return watchers.Add(w) +} + +func (h *ResourceHandler[T]) AllSubscriptions() iter.Seq[string] { + return func(yield func(string) bool) { + h.lock.Lock() + defer h.lock.Unlock() + + for k := range h.subscriptions { + if !yield(k) { + return + } + } + for k := range h.globSubscriptions { + if !yield(k.String()) { + return + } + } + if h.wildcardSubscription != nil { + yield(ads.WildcardSubscription) + } + } +} + +func (h *ResourceHandler[T]) HandleResponses(isFirst bool, responses []*ads.DeltaDiscoveryResponse) error { + h.lock.Lock() + defer h.lock.Unlock() + + var errs []error + addErr := func(err error) { + errs = append(errs, err) + } + notifyWatchers := func(sub *subscription[T], seq iter.Seq2[string, *ads.Resource[T]]) { + for w := range sub.watchers { + err := w.Notify(seq) + if err != nil { + addErr(err) + } + } + } + + totalAddedResources := 0 + totalDeletedResources := 0 + for _, response := range responses { + totalAddedResources += len(response.Resources) + totalDeletedResources += len(response.RemovedResources) + } + + if totalAddedResources+totalDeletedResources == 0 { + return fmt.Errorf("empty response") + } + + // Contains the set of resource names that wildcard watchers should be notified of. Only set if any + // wildcard watchers are registered. + var wildcardUpdates utils.Set[string] + if h.wildcardSubscription != nil { + wildcardUpdates = make(utils.Set[string], totalAddedResources+totalDeletedResources) + } + // Contains the set of resource names received. Only set if this is the first set of responses for + // the stream, as it is used to determine whether any resources were deleted while the client was + // disconnected. For example, suppose resources foo and bar are present on the ADS server. If a + // wildcard watcher is registered, it will initially receive updates for those two resources. Then + // the client disconnects, reconnects and resubmits its wildcard subscription. If bar was deleted + // during the disconnect, the server will only send back an update for foo, but never an explicit + // deletion for bar. This set is therefore used to compare against h.resources, i.e. the set + // known/previously received resources to see if wildcard and glob collection watchers need to be + // notified of any deletions. + var receivedResources utils.Set[string] + if isFirst { + receivedResources = make(utils.Set[string], totalAddedResources) + } + + globUpdates := make(map[*globSubscription[T]]utils.Set[string]) + + for name, r := range iterateResources(responses) { + sub := h.subscriptions[name] + + var globSub *globSubscription[T] + gcURL, gcResourceName, err := ads.ParseGlobCollectionURN[T](name) + if err == nil { + globSub = h.globSubscriptions[gcURL] + } + + if sub == nil && globSub == nil && h.wildcardSubscription == nil { + addErr(fmt.Errorf("not subscribed to resource %q", name)) + continue + } + + var resource *ads.Resource[T] + if r != nil { + resource, err = ads.UnmarshalRawResource[T](r) + if err != nil { + addErr(err) + continue + } + if isFirst { + receivedResources.Add(name) + } + } + + updated := h.setResource(name, resource) + + if sub != nil && (!sub.initialized || updated) { + sub.initialized = true + notifyWatchers(sub, h.resolve(func(yield func(string) bool) { yield(name) })) + } + + if globSub != nil { + updates := GetNestedMap(globUpdates, globSub) + if resource == nil && gcResourceName == ads.WildcardSubscription { + maps.Copy(updates, globSub.entries) + clear(globSub.entries) + continue + } else if !globSub.initialized || updated { + updates.Add(name) + if resource != nil { + globSub.entries.Add(name) + } else { + globSub.entries.Remove(name) + } + } + } + + if h.wildcardSubscription != nil && (!h.wildcardSubscription.initialized || updated) { + wildcardUpdates.Add(name) + } + } + + if isFirst { + for name := range h.resources { + if _, ok := receivedResources[name]; !ok { + delete(h.resources, name) + if h.wildcardSubscription != nil { + wildcardUpdates.Add(name) + } + } + } + } + + if h.wildcardSubscription != nil { + h.wildcardSubscription.initialized = true + + if len(wildcardUpdates) > 0 { + notifyWatchers(h.wildcardSubscription, h.resolve(wildcardUpdates.Values())) + } + } + + for globSub, updates := range globUpdates { + globSub.initialized = true + if len(updates) > 0 { + notifyWatchers(&globSub.subscription, h.resolve(updates.Values())) + } + } + + return errors.Join(errs...) +} + +// iterateResources returns an [iter.Seq2] that iterates over all the resources in the given +// response. If the [ads.RawResource] is nil, the resource is being deleted. +func iterateResources(responses []*ads.DeltaDiscoveryResponse) iter.Seq2[string, *ads.RawResource] { + return func(yield func(string, *ads.RawResource) bool) { + for _, res := range responses { + for _, r := range res.Resources { + if !yield(r.Name, r) { + return + } + } + for _, name := range res.RemovedResources { + if !yield(name, nil) { + return + } + } + } + } +} + +func newSubscription[T proto.Message]() *subscription[T] { + return &subscription[T]{ + watchers: make(utils.Set[Watcher[T]]), + } +} + +type subscription[T proto.Message] struct { + initialized bool + watchers utils.Set[Watcher[T]] +} + +type globSubscription[T proto.Message] struct { + subscription[T] + entries utils.Set[string] +} + +// GetNestedMap is a utility function for nested maps. It will create the map at the given key if it +// does not already exist, then returns the corresponding map. +func GetNestedMap[K1, K2 comparable, V any, M ~map[K2]V](m map[K1]M, k K1) M { + v, ok := m[k] + if !ok { + v = make(M) + m[k] = v + } + return v +} diff --git a/internal/utils/set.go b/internal/utils/set.go index ec62521..ab6bc02 100644 --- a/internal/utils/set.go +++ b/internal/utils/set.go @@ -2,6 +2,7 @@ package utils import ( "fmt" + "iter" "maps" "slices" ) @@ -42,3 +43,7 @@ func (s Set[T]) Remove(t T) bool { func (s Set[T]) String() string { return fmt.Sprint(slices.Collect(maps.Keys(s))) } + +func (s Set[T]) Values() iter.Seq[T] { + return maps.Keys(s) +} diff --git a/server.go b/server.go index 7f8054e..9137175 100644 --- a/server.go +++ b/server.go @@ -211,7 +211,9 @@ func (s *ADSServer) StreamAggregatedResources(stream ads.SotWStream) (err error) }, } - return h.loop() + err = h.loop() + slog.DebugContext(h.streamCtx, "Closing stream", "err", err) + return err } // DeltaAggregatedResources is the implementation of the delta/incremental variant of the ADS @@ -245,7 +247,9 @@ func (s *ADSServer) DeltaAggregatedResources(stream ads.DeltaStream) (err error) }, } - return h.loop() + err = h.loop() + slog.DebugContext(h.streamCtx, "Closing stream", "err", err) + return err } type adsDiscoveryRequest interface { diff --git a/server_test.go b/server_test.go index eaecfc5..a80ebad 100644 --- a/server_test.go +++ b/server_test.go @@ -650,10 +650,10 @@ func TestSubscriptionManagerSubscriptions(t *testing.T) { } } -type mockResourceLocator func(typeURL, resourceName string) func() +type mockResourceLocator func(typeURL, resourceName string, h ads.RawSubscriptionHandler) func() -func (m mockResourceLocator) Subscribe(_ context.Context, typeURL, resourceName string, _ ads.RawSubscriptionHandler) func() { - return m(typeURL, resourceName) +func (m mockResourceLocator) Subscribe(_ context.Context, typeURL, resourceName string, h ads.RawSubscriptionHandler) func() { + return m(typeURL, resourceName, h) } func TestImplicitWildcardSubscription(t *testing.T) { @@ -664,7 +664,7 @@ func TestImplicitWildcardSubscription(t *testing.T) { newMockLocator := func(t *testing.T) (l mockResourceLocator, wildcardSub, fooSub chan struct{}) { wildcardSub = make(chan struct{}, 1) fooSub = make(chan struct{}, 1) - l = func(actualTypeURL, resourceName string) func() { + l = func(actualTypeURL, resourceName string, _ ads.RawSubscriptionHandler) func() { require.Equal(t, typeURL, actualTypeURL) switch resourceName { case ads.WildcardSubscription: @@ -843,7 +843,7 @@ func TestSubscriptionManagerUnsubscribeAll(t *testing.T) { var wg sync.WaitGroup - l := mockResourceLocator(func(_, resourceName string) func() { + l := mockResourceLocator(func(_, resourceName string, _ ads.RawSubscriptionHandler) func() { wg.Done() return func() { wg.Done() @@ -866,7 +866,7 @@ func TestSubscriptionManagerUnsubscribeAll(t *testing.T) { t.Run("on context expiry", func(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) var wg sync.WaitGroup - l := mockResourceLocator(func(_, _ string) func() { + l := mockResourceLocator(func(_, _ string, _ ads.RawSubscriptionHandler) func() { wg.Done() return func() { wg.Done() diff --git a/testutils/testutils.go b/testutils/testutils.go index d3ee2c6..bc24f45 100644 --- a/testutils/testutils.go +++ b/testutils/testutils.go @@ -117,7 +117,7 @@ func (c ChanSubscriptionHandler[T]) WaitForNotifications(t testingT, notificatio var n Notification[T] select { case n = <-c: - case <-time.After(5 * time.Second): + case <-time.After(5 * time.Hour): t.Fatalf("Did not receive expected notification for one of: %v", slices.Collect(maps.Keys(expectedNotifications))) } @@ -227,6 +227,9 @@ func (ts *TestServer) Dial(opts ...grpc.DialOption) *grpc.ClientConn { opts = append([]grpc.DialOption{grpc.WithTransportCredentials(insecure.NewCredentials())}, opts...) conn, err := grpc.NewClient(ts.AddrString(), opts...) require.NoError(ts.t, err) + ts.t.Cleanup(func() { + require.NoError(ts.t, conn.Close()) + }) return conn }