From ce2e962e0a32cc4ef991d3b6e646a31bf580cca4 Mon Sep 17 00:00:00 2001 From: Artem Glazychev Date: Mon, 3 Jul 2023 14:26:04 +0700 Subject: [PATCH] Send REFRESH_REQUESTED if vl3 dnsServerIP is updated Signed-off-by: Artem Glazychev --- pkg/networkservice/chains/nsmgr/vl3_test.go | 109 ++++++++++++++++++ .../dnscontext/vl3dns/metadata.go | 41 +++++++ .../dnscontext/vl3dns/notifier.go | 68 +++++++++++ .../dnscontext/vl3dns/server.go | 92 ++++++++++----- 4 files changed, 279 insertions(+), 31 deletions(-) create mode 100644 pkg/networkservice/connectioncontext/dnscontext/vl3dns/metadata.go create mode 100644 pkg/networkservice/connectioncontext/dnscontext/vl3dns/notifier.go diff --git a/pkg/networkservice/chains/nsmgr/vl3_test.go b/pkg/networkservice/chains/nsmgr/vl3_test.go index 813e60e74..c6130591c 100644 --- a/pkg/networkservice/chains/nsmgr/vl3_test.go +++ b/pkg/networkservice/chains/nsmgr/vl3_test.go @@ -26,6 +26,8 @@ import ( "testing" "time" + "go.uber.org/atomic" + "github.com/edwarnicke/genericsync" "github.com/google/uuid" "github.com/stretchr/testify/require" @@ -38,8 +40,10 @@ import ( "github.com/networkservicemesh/api/pkg/api/registry" "github.com/networkservicemesh/sdk/pkg/networkservice/chains/client" + "github.com/networkservicemesh/sdk/pkg/networkservice/common/upstreamrefresh" "github.com/networkservicemesh/sdk/pkg/networkservice/connectioncontext/dnscontext/vl3dns" "github.com/networkservicemesh/sdk/pkg/networkservice/connectioncontext/ipcontext/vl3" + "github.com/networkservicemesh/sdk/pkg/networkservice/utils/checks/checkconnection" "github.com/networkservicemesh/sdk/pkg/networkservice/utils/checks/checkrequest" "github.com/networkservicemesh/sdk/pkg/tools/clock" "github.com/networkservicemesh/sdk/pkg/tools/dnsutils" @@ -130,6 +134,111 @@ func Test_NSC_ConnectsTo_vl3NSE(t *testing.T) { } } +func Test_NSC_RefreshOnVl3DnsAddressChange(t *testing.T) { + t.Cleanup(func() { goleak.VerifyNone(t) }) + + ctx, cancel := context.WithTimeout(context.Background(), time.Second*15) + defer cancel() + + domain := sandbox.NewBuilder(ctx, t). + SetNodesCount(1). + SetNSMgrProxySupplier(nil). + SetRegistryProxySupplier(nil). + Build() + + nsRegistryClient := domain.NewNSRegistryClient(ctx, sandbox.GenerateTestToken) + + nsReg, err := nsRegistryClient.Register(ctx, defaultRegistryService("vl3")) + require.NoError(t, err) + + nseReg := defaultRegistryEndpoint(nsReg.Name) + + var serverPrefixCh = make(chan *ipam.PrefixResponse, 1) + defer close(serverPrefixCh) + + serverPrefixCh <- &ipam.PrefixResponse{Prefix: "10.0.0.1/24"} + dnsServerIPCh := make(chan net.IP, 1) + + _ = domain.Nodes[0].NewEndpoint( + ctx, + nseReg, + sandbox.GenerateTestToken, + vl3dns.NewServer(ctx, + dnsServerIPCh, + vl3dns.WithDomainSchemes("{{ index .Labels \"podName\" }}.{{ .NetworkService }}."), + vl3dns.WithDNSPort(40053)), + vl3.NewServer(ctx, serverPrefixCh), + ) + + resolver := net.Resolver{ + PreferGo: true, + Dial: func(ctx context.Context, network, address string) (net.Conn, error) { + var dialer net.Dialer + return dialer.DialContext(ctx, network, "127.0.0.1:40053") + }, + } + + type nscInfo struct { + nsc networkservice.NetworkServiceClient + conn *networkservice.Connection + counter *atomic.Int32 + } + var nscInfos []*nscInfo + + for i := 0; i < 10; i++ { + var counter atomic.Int32 + nsc := domain.Nodes[0].NewClient(ctx, sandbox.GenerateTestToken, client.WithAdditionalFunctionality(upstreamrefresh.NewClient(ctx), + checkconnection.NewClient(t, func(t *testing.T, conn *networkservice.Connection) { + if counter.Load() > 0 { + require.Len(t, conn.GetContext().GetDnsContext().GetConfigs(), 1) + require.Len(t, conn.GetContext().GetDnsContext().GetConfigs()[0].DnsServerIps, 1) + } + counter.Inc() + }), + )) + + reqCtx, reqClose := context.WithTimeout(ctx, time.Second*1) + defer reqClose() + + req := defaultRequest(nsReg.Name) + req.Connection.Id = uuid.New().String() + + req.Connection.Labels["podName"] = nscName + fmt.Sprint(i) + + var resp *networkservice.Connection + resp, err = nsc.Request(reqCtx, req) + require.NoError(t, err) + + require.Len(t, resp.GetContext().GetDnsContext().GetConfigs(), 0) + + nscInfos = append(nscInfos, &nscInfo{ + nsc: nsc, + conn: resp, + counter: &counter, + }) + } + + dnsServerIPCh <- net.ParseIP("127.0.0.1") + + for i, n := range nscInfos { + nscInf := n + require.Eventually(t, func() bool { + return nscInf.counter.Load() == 2 + }, timeout, tick) + + reqCtx, reqClose := context.WithTimeout(ctx, time.Second*1) + defer reqClose() + + requireIPv4Lookup(ctx, t, &resolver, nscName+fmt.Sprint(i)+".vl3", fmt.Sprintf("10.0.0.%d", i+1)) + + _, err = nscInf.nsc.Close(reqCtx, nscInf.conn) + require.NoError(t, err) + + _, err = resolver.LookupIP(reqCtx, "ip4", nscName+fmt.Sprint(i)+".vl3") + require.Error(t, err) + } +} + func Test_vl3NSE_ConnectsTo_vl3NSE(t *testing.T) { t.Cleanup(func() { goleak.VerifyNone(t) }) diff --git a/pkg/networkservice/connectioncontext/dnscontext/vl3dns/metadata.go b/pkg/networkservice/connectioncontext/dnscontext/vl3dns/metadata.go new file mode 100644 index 000000000..8387b1855 --- /dev/null +++ b/pkg/networkservice/connectioncontext/dnscontext/vl3dns/metadata.go @@ -0,0 +1,41 @@ +// Copyright (c) 2023 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vl3dns + +import ( + "context" + + "github.com/networkservicemesh/sdk/pkg/networkservice/utils/metadata" +) + +type key struct{} + +// storeNotifierCancelFunc sets the context.CancelFunc stored in per Connection.Id metadata. +func storeNotifierCancelFunc(ctx context.Context, cancel context.CancelFunc) { + metadata.Map(ctx, true).Store(key{}, cancel) +} + +// loadAndDeleteNotifierCancelFunc deletes the context.CancelFunc stored in per Connection.Id metadata, +// returning the previous value if any. The loaded result reports whether the key was present. +func loadAndDeleteNotifierCancelFunc(ctx context.Context) (value context.CancelFunc, ok bool) { + rawValue, ok := metadata.Map(ctx, true).LoadAndDelete(key{}) + if !ok { + return + } + value, ok = rawValue.(context.CancelFunc) + return value, ok +} diff --git a/pkg/networkservice/connectioncontext/dnscontext/vl3dns/notifier.go b/pkg/networkservice/connectioncontext/dnscontext/vl3dns/notifier.go new file mode 100644 index 000000000..cba4c4dfd --- /dev/null +++ b/pkg/networkservice/connectioncontext/dnscontext/vl3dns/notifier.go @@ -0,0 +1,68 @@ +// Copyright (c) 2023 Cisco and/or its affiliates. +// +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vl3dns + +import ( + "github.com/edwarnicke/serialize" +) + +// notifier - notifies all subscribers of an event +type vl3DNSNotifier struct { + executor serialize.Executor + channels map[string]chan struct{} +} + +func newNotifier() *vl3DNSNotifier { + return &vl3DNSNotifier{ + channels: make(map[string]chan struct{}), + } +} + +func (n *vl3DNSNotifier) subscribe(id string) <-chan struct{} { + if n == nil { + return nil + } + var r chan struct{} + <-n.executor.AsyncExec(func() { + n.channels[id] = make(chan struct{}) + r = n.channels[id] + }) + return r +} + +func (n *vl3DNSNotifier) unsubscribe(id string) { + if n == nil { + return + } + <-n.executor.AsyncExec(func() { + if v, ok := n.channels[id]; ok { + close(v) + } + delete(n.channels, id) + }) +} + +func (n *vl3DNSNotifier) notify() { + if n == nil { + return + } + <-n.executor.AsyncExec(func() { + for _, v := range n.channels { + v <- struct{}{} + } + }) +} diff --git a/pkg/networkservice/connectioncontext/dnscontext/vl3dns/server.go b/pkg/networkservice/connectioncontext/dnscontext/vl3dns/server.go index a012f76d7..9d8cad689 100644 --- a/pkg/networkservice/connectioncontext/dnscontext/vl3dns/server.go +++ b/pkg/networkservice/connectioncontext/dnscontext/vl3dns/server.go @@ -31,6 +31,7 @@ import ( "github.com/networkservicemesh/api/pkg/api/networkservice" "github.com/pkg/errors" + "github.com/networkservicemesh/sdk/pkg/networkservice/common/monitor" "github.com/networkservicemesh/sdk/pkg/networkservice/core/next" "github.com/networkservicemesh/sdk/pkg/networkservice/utils/metadata" "github.com/networkservicemesh/sdk/pkg/tools/dnsutils" @@ -41,9 +42,11 @@ import ( "github.com/networkservicemesh/sdk/pkg/tools/dnsutils/noloop" "github.com/networkservicemesh/sdk/pkg/tools/dnsutils/norecursion" "github.com/networkservicemesh/sdk/pkg/tools/ippool" + "github.com/networkservicemesh/sdk/pkg/tools/log" ) type vl3DNSServer struct { + chainCtx context.Context dnsServerRecords genericsync.Map[string, []net.IP] dnsConfigs *genericsync.Map[string, []*networkservice.DNSConfig] domainSchemeTemplates []*template.Template @@ -51,7 +54,7 @@ type vl3DNSServer struct { dnsServer dnsutils.Handler listenAndServeDNS func(ctx context.Context, handler dnsutils.Handler, listenOn string) dnsServerIP atomic.Value - dnsServerIPCh <-chan net.IP + notifier *vl3DNSNotifier } type clientDNSNameKey struct{} @@ -63,10 +66,11 @@ type clientDNSNameKey struct{} // opts configure vl3dns networkservice instance with specific behavior. func NewServer(chainCtx context.Context, dnsServerIPCh <-chan net.IP, opts ...Option) networkservice.NetworkServiceServer { var result = &vl3DNSServer{ + chainCtx: chainCtx, dnsPort: 53, listenAndServeDNS: dnsutils.ListenAndServe, dnsConfigs: new(genericsync.Map[string, []*networkservice.DNSConfig]), - dnsServerIPCh: dnsServerIPCh, + notifier: newNotifier(), } for _, opt := range opts { @@ -95,6 +99,7 @@ func NewServer(chainCtx context.Context, dnsServerIPCh <-chan net.IP, opts ...Op return } result.dnsServerIP.Store(addr) + result.notifier.notify() } } }() @@ -123,37 +128,36 @@ func (n *vl3DNSServer) Request(ctx context.Context, request *networkservice.Netw } } - dnsServerIPStr, err := n.addDNSContext(request.GetConnection(), recordNames) + resp, err := next.Server(ctx).Request(ctx, request) if err != nil { - return nil, err + return resp, err } - resp, err := next.Server(ctx).Request(ctx, request) - if err == nil { - ips := getSrcIPs(resp) - if len(ips) > 0 { - for _, recordName := range recordNames { - n.dnsServerRecords.Store(recordName, ips) - } - - metadata.Map(ctx, false).Store(clientDNSNameKey{}, recordNames) + dnsServerIPStr := n.addDNSContext(ctx, resp, recordNames) + ips := getSrcIPs(resp) + if len(ips) > 0 { + for _, recordName := range recordNames { + n.dnsServerRecords.Store(recordName, ips) } - configs := make([]*networkservice.DNSConfig, 0) - if srcRoutes := resp.GetContext().GetIpContext().GetSrcRoutes(); len(srcRoutes) > 0 { - var lastPrefix = srcRoutes[len(srcRoutes)-1].Prefix - for _, config := range clientsConfigs { - for _, serverIP := range config.DnsServerIps { - if dnsServerIPStr == serverIP { - continue - } - if withinPrefix(serverIP, lastPrefix) { - configs = append(configs, config) - } + + metadata.Map(ctx, false).Store(clientDNSNameKey{}, recordNames) + } + configs := make([]*networkservice.DNSConfig, 0) + if srcRoutes := resp.GetContext().GetIpContext().GetSrcRoutes(); len(srcRoutes) > 0 { + var lastPrefix = srcRoutes[len(srcRoutes)-1].Prefix + for _, config := range clientsConfigs { + for _, serverIP := range config.DnsServerIps { + if dnsServerIPStr == serverIP { + continue + } + if withinPrefix(serverIP, lastPrefix) { + configs = append(configs, config) } } } - n.dnsConfigs.Store(resp.GetId(), configs) } + n.dnsConfigs.Store(resp.GetId(), configs) + return resp, err } @@ -170,7 +174,8 @@ func (n *vl3DNSServer) Close(ctx context.Context, conn *networkservice.Connectio return next.Server(ctx).Close(ctx, conn) } -func (n *vl3DNSServer) addDNSContext(c *networkservice.Connection, dnsRecords []string) (serverIP string, err error) { +func (n *vl3DNSServer) addDNSContext(ctx context.Context, c *networkservice.Connection, dnsRecords []string) string { + var dnsServerIPString string if ip := n.dnsServerIP.Load(); ip != nil { dnsServerIP := ip.(net.IP) @@ -194,12 +199,37 @@ func (n *vl3DNSServer) addDNSContext(c *networkservice.Connection, dnsRecords [] if !dnsutils.ContainsDNSConfig(dnsContext.Configs, configToAdd) { dnsContext.Configs = append(dnsContext.Configs, configToAdd) } - return dnsServerIP.String(), nil - } else if c.GetPath().GetPathSegments()[0].Name == c.GetCurrentPathSegment().Name { - // If it calls itself - this is not an error, but a request to allocate a dns address - return "", nil + dnsServerIPString = dnsServerIP.String() } - return "", errors.New("DNS address is initializing") + + if eventConsumer, ok := monitor.LoadEventConsumer(ctx, metadata.IsClient(n)); ok { + if prevCancel, ok := loadAndDeleteNotifierCancelFunc(ctx); ok { + prevCancel() + } + cancelCtx, cancel := context.WithCancel(context.Background()) + storeNotifierCancelFunc(ctx, cancel) + go n.waitDNSServerIP(cancelCtx, c.Clone(), eventConsumer) + } else { + log.FromContext(ctx).Debug("eventConsumer is not presented") + } + return dnsServerIPString +} + +func (n *vl3DNSServer) waitDNSServerIP(cancelCtx context.Context, c *networkservice.Connection, eventConsumer monitor.EventConsumer) { + ch := n.notifier.subscribe(c.GetId()) + + select { + case <-n.chainCtx.Done(): + case <-cancelCtx.Done(): + case <-ch: + c.State = networkservice.State_REFRESH_REQUESTED + _ = eventConsumer.Send(&networkservice.ConnectionEvent{ + Type: networkservice.ConnectionEventType_UPDATE, + Connections: map[string]*networkservice.Connection{c.GetId(): c}, + }) + } + + n.notifier.unsubscribe(c.GetId()) } func (n *vl3DNSServer) buildSrcDNSRecords(c *networkservice.Connection) ([]string, error) {