From f119439a45deba7a6add66859b0df6f5f8f71f26 Mon Sep 17 00:00:00 2001 From: Ulrich Hornung Date: Sun, 11 Aug 2024 19:40:44 +0200 Subject: [PATCH 01/33] performance impr.: avoid repeated allocation of "lastTick" on heap --- timeout.go | 14 ++++++++------ timeout_test.go | 8 ++++---- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/timeout.go b/timeout.go index c1b4c398b..26beb25b5 100644 --- a/timeout.go +++ b/timeout.go @@ -16,7 +16,8 @@ type TimerWheel[T any] struct { wheelLen int // Last time we ticked, since we are lazy ticking - lastTick *time.Time + lastTickValid bool + lastTick time.Time // Durations of a tick and the entire wheel tickDuration time.Duration @@ -168,13 +169,14 @@ func (tw *TimerWheel[T]) findWheel(timeout time.Duration) (i int) { // Advance will move the wheel forward by the appropriate number of ticks for the provided time and all items // passed over will be moved to the expired list. Calling Purge is necessary to remove them entirely. -func (tw *TimerWheel[T]) Advance(now time.Time) { - if tw.lastTick == nil { - tw.lastTick = &now +func (tw *TimerWheel[T]) Advance(now1 time.Time) { + if !tw.lastTickValid { + tw.lastTick = now1 + tw.lastTickValid = true } // We want to round down - ticks := int(now.Sub(*tw.lastTick) / tw.tickDuration) + ticks := int(now1.Sub(tw.lastTick) / tw.tickDuration) adv := ticks if ticks > tw.wheelLen { ticks = tw.wheelLen @@ -203,7 +205,7 @@ func (tw *TimerWheel[T]) Advance(now time.Time) { // Advance the tick based on duration to avoid losing some accuracy newTick := tw.lastTick.Add(tw.tickDuration * time.Duration(adv)) - tw.lastTick = &newTick + tw.lastTick = newTick } func (lw *LockingTimerWheel[T]) Add(v T, timeout time.Duration) *TimeoutItem[T] { diff --git a/timeout_test.go b/timeout_test.go index 4c6364ef5..616d83f74 100644 --- a/timeout_test.go +++ b/timeout_test.go @@ -14,7 +14,7 @@ func TestNewTimerWheel(t *testing.T) { tw := NewTimerWheel[firewall.Packet](time.Second, time.Second*10) assert.Equal(t, 12, tw.wheelLen) assert.Equal(t, 0, tw.current) - assert.Nil(t, tw.lastTick) + assert.Equal(t, false, tw.lastTickValid) assert.Equal(t, time.Second*1, tw.tickDuration) assert.Equal(t, time.Second*10, tw.wheelDuration) assert.Len(t, tw.wheel, 12) @@ -110,9 +110,9 @@ func TestTimerWheel_Add(t *testing.T) { func TestTimerWheel_Purge(t *testing.T) { // First advance should set the lastTick and do nothing else tw := NewTimerWheel[firewall.Packet](time.Second, time.Second*10) - assert.Nil(t, tw.lastTick) + assert.Equal(t, false, tw.lastTickValid) tw.Advance(time.Now()) - assert.NotNil(t, tw.lastTick) + assert.Equal(t, true, tw.lastTickValid) assert.Equal(t, 0, tw.current) fps := []firewall.Packet{ @@ -128,7 +128,7 @@ func TestTimerWheel_Purge(t *testing.T) { tw.Add(fps[3], time.Second*2) ta := time.Now().Add(time.Second * 3) - lastTick := *tw.lastTick + lastTick := tw.lastTick tw.Advance(ta) assert.Equal(t, 3, tw.current) assert.True(t, tw.lastTick.After(lastTick)) From 6eaf41865c091361df9ab77516290aaa80e47c41 Mon Sep 17 00:00:00 2001 From: Ulrich Hornung Date: Mon, 5 Aug 2024 23:34:54 +0200 Subject: [PATCH 02/33] support unsafe_routes for use with port_forwarding and user_tun --- overlay/user.go | 42 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 39 insertions(+), 3 deletions(-) diff --git a/overlay/user.go b/overlay/user.go index 1bb4ef5f7..5e5180930 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -3,16 +3,42 @@ package overlay import ( "io" "net/netip" + "sync/atomic" + "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" ) func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) { - return NewUserDevice(tunCidr) + d, err := NewUserDevice(tunCidr) + if err != nil { + return nil, err + } + + _, routes, err := getAllRoutesFromConfig(c, tunCidr, true) + if err != nil { + return nil, err + } + + routeTree, err := makeRouteTree(l, routes, true) + if err != nil { + return nil, err + } + + newDefaultMTU := c.GetInt("tun.mtu", DefaultMTU) + for i, r := range routes { + if r.MTU == 0 { + routes[i].MTU = newDefaultMTU + } + } + + d.routeTree.Store(routeTree) + + return d, nil } -func NewUserDevice(tunCidr netip.Prefix) (Device, error) { +func NewUserDevice(tunCidr netip.Prefix) (*UserDevice, error) { // these pipes guarantee each write/read will match 1:1 or, ow := io.Pipe() ir, iw := io.Pipe() @@ -33,6 +59,8 @@ type UserDevice struct { inboundReader *io.PipeReader inboundWriter *io.PipeWriter + + routeTree atomic.Pointer[bart.Table[netip.Addr]] } func (d *UserDevice) Activate() error { @@ -40,7 +68,15 @@ func (d *UserDevice) Activate() error { } func (d *UserDevice) Cidr() netip.Prefix { return d.tunCidr } func (d *UserDevice) Name() string { return "faketun0" } -func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { return ip } +func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { + ptr := d.routeTree.Load() + if ptr != nil { + r, _ := d.routeTree.Load().Lookup(ip) + return r + } else { + return ip + } +} func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { return d, nil } From 2388e2c93a69ff367b67632bce8e04ccb53c09db Mon Sep 17 00:00:00 2001 From: Ulrich Hornung Date: Sun, 11 Aug 2024 19:44:20 +0200 Subject: [PATCH 03/33] performance 3: avoid repeated allocation of []byte{1} on heap --- connection_manager.go | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/connection_manager.go b/connection_manager.go index d2e861647..9a2d310d4 100644 --- a/connection_manager.go +++ b/connection_manager.go @@ -26,6 +26,11 @@ const ( sendTestPacket trafficDecision = 6 ) +// The data written into this variable is never used. +// Its there to avoid a fresh dynamic memory allocation of 1 byte +// for each time its used. +var BYTE_SLICE_ONE []byte = []byte{1} + type connectionManager struct { in map[uint32]struct{} inLock *sync.RWMutex @@ -463,12 +468,12 @@ func (n *connectionManager) sendPunch(hostinfo *HostInfo) { if n.punchy.GetTargetEverything() { hostinfo.remotes.ForEach(n.hostMap.GetPreferredRanges(), func(addr netip.AddrPort, preferred bool) { n.metricsTxPunchy.Inc(1) - n.intf.outside.WriteTo([]byte{1}, addr) + n.intf.outside.WriteTo(BYTE_SLICE_ONE, addr) }) } else if hostinfo.remote.IsValid() { n.metricsTxPunchy.Inc(1) - n.intf.outside.WriteTo([]byte{1}, hostinfo.remote) + n.intf.outside.WriteTo(BYTE_SLICE_ONE, hostinfo.remote) } } From aa415264ba33affc13a44e59fa6d88ff1e7557ec Mon Sep 17 00:00:00 2001 From: Ulrich Hornung Date: Sun, 11 Aug 2024 19:49:36 +0200 Subject: [PATCH 04/33] performance4: use buffer.view based channels instead of pipe --- overlay/user.go | 52 +++++++++++++++++++++++++++++++--------------- service/service.go | 28 ++++++------------------- 2 files changed, 41 insertions(+), 39 deletions(-) diff --git a/overlay/user.go b/overlay/user.go index 5e5180930..90329f2cb 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -8,6 +8,7 @@ import ( "github.com/gaissmai/bart" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" + "gvisor.dev/gvisor/pkg/buffer" ) func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix, routines int) (Device, error) { @@ -40,25 +41,19 @@ func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix func NewUserDevice(tunCidr netip.Prefix) (*UserDevice, error) { // these pipes guarantee each write/read will match 1:1 - or, ow := io.Pipe() - ir, iw := io.Pipe() return &UserDevice{ - tunCidr: tunCidr, - outboundReader: or, - outboundWriter: ow, - inboundReader: ir, - inboundWriter: iw, + tunCidr: tunCidr, + outboundChannel: make(chan *buffer.View, 16), + inboundChannel: make(chan *buffer.View, 16), }, nil } type UserDevice struct { tunCidr netip.Prefix - outboundReader *io.PipeReader - outboundWriter *io.PipeWriter + outboundChannel chan *buffer.View + inboundChannel chan *buffer.View - inboundReader *io.PipeReader - inboundWriter *io.PipeWriter routeTree atomic.Pointer[bart.Table[netip.Addr]] } @@ -81,18 +76,41 @@ func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) { return d, nil } -func (d *UserDevice) Pipe() (*io.PipeReader, *io.PipeWriter) { - return d.inboundReader, d.outboundWriter +func (d *UserDevice) Pipe() (<-chan *buffer.View, chan<- *buffer.View) { + return d.inboundChannel, d.outboundChannel } func (d *UserDevice) Read(p []byte) (n int, err error) { - return d.outboundReader.Read(p) + view, ok := <-d.outboundChannel + if !ok { + return 0, io.EOF + } + return view.Read(p) +} +func (d *UserDevice) WriteTo(w io.Writer) (n int64, err error) { + view, ok := <-d.outboundChannel + if !ok { + return 0, io.EOF + } + return view.WriteTo(w) } + func (d *UserDevice) Write(p []byte) (n int, err error) { - return d.inboundWriter.Write(p) + view := buffer.NewViewWithData(p) + d.inboundChannel <- view + return view.Size(), nil } +func (d *UserDevice) ReadFrom(r io.Reader) (n int64, err error) { + view := buffer.NewViewSize(2048) + n, err = view.ReadFrom(r) + if n > 0 { + d.inboundChannel <- view + } + return +} + func (d *UserDevice) Close() error { - d.inboundWriter.Close() - d.outboundWriter.Close() + close(d.inboundChannel) + close(d.outboundChannel) return nil } diff --git a/service/service.go b/service/service.go index 50c1d4a11..c1091a272 100644 --- a/service/service.go +++ b/service/service.go @@ -106,31 +106,19 @@ func New(config *config.C) (*Service, error) { tcpFwd := tcp.NewForwarder(s.ipstack, tcpReceiveBufferSize, maxInFlightConnectionAttempts, s.tcpHandler) s.ipstack.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpFwd.HandlePacket) - reader, writer := device.Pipe() - - go func() { - <-ctx.Done() - reader.Close() - writer.Close() - }() + nebula_tun_reader, nebula_tun_writer := device.Pipe() // create Goroutines to forward packets between Nebula and Gvisor eg.Go(func() error { - buf := make([]byte, header.IPv4MaximumHeaderSize+header.IPv4MaximumPayloadSize) for { - // this will read exactly one packet - n, err := reader.Read(buf) - if err != nil { - return err + view, ok := <-nebula_tun_reader + if !ok { + return nil } packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Payload: buffer.MakeWithData(bytes.Clone(buf[:n])), + Payload: buffer.MakeWithView(view), }) linkEP.InjectInbound(header.IPv4ProtocolNumber, packetBuf) - - if err := ctx.Err(); err != nil { - return err - } } }) eg.Go(func() error { @@ -142,11 +130,7 @@ func New(config *config.C) (*Service, error) { } continue } - bufView := packet.ToView() - if _, err := bufView.WriteTo(writer); err != nil { - return err - } - bufView.Release() + nebula_tun_writer <- packet.ToView() } }) From c245a309666be91afc603cb6b3a8ad706df7fbfc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jens=20=C3=85kerlund?= Date: Sun, 21 Jul 2024 16:45:07 +0200 Subject: [PATCH 05/33] Add script for speedtesting port forwarding --- e2e/forwarding/.gitignore | 3 +++ e2e/forwarding/README.md | 27 ++++++++++++++++++++ e2e/forwarding/a_config.yml | 27 ++++++++++++++++++++ e2e/forwarding/b_config.yml | 27 ++++++++++++++++++++ e2e/forwarding/generate_certificates.sh | 3 +++ e2e/forwarding/speedtest.sh | 34 +++++++++++++++++++++++++ 6 files changed, 121 insertions(+) create mode 100644 e2e/forwarding/.gitignore create mode 100644 e2e/forwarding/README.md create mode 100644 e2e/forwarding/a_config.yml create mode 100644 e2e/forwarding/b_config.yml create mode 100755 e2e/forwarding/generate_certificates.sh create mode 100755 e2e/forwarding/speedtest.sh diff --git a/e2e/forwarding/.gitignore b/e2e/forwarding/.gitignore new file mode 100644 index 000000000..ccdfcbd6d --- /dev/null +++ b/e2e/forwarding/.gitignore @@ -0,0 +1,3 @@ +*.out +*.crt +*.key diff --git a/e2e/forwarding/README.md b/e2e/forwarding/README.md new file mode 100644 index 000000000..3b6d39b63 --- /dev/null +++ b/e2e/forwarding/README.md @@ -0,0 +1,27 @@ +# Userspace port forwarding +A simple speedtest for userspace port forwarding that can run without root access. + +## A side +Nebula running at port 10000, forwarding inbound TCP connections on port 5201 to 127.0.0.1:15001. + +## B side +Nebula running at port 10001, forwarding outbound TCP connections from 127.0.0.1:15002 to port 5201 of the A side. + +## Speedtest + + ┌──────────────────────┐:10001 :10002┌──────────────────────┐ + │ Nebula A side ├─────────────────┤ Nebula B side │ + │ │ │ │ + │ 192.168.100.1 │ TCP 5201 │ 192.168.100.2 │ + │ ┌───────────┼─────────────────┼──────────┐ │ + │ │ ├─────────────────┤ │ │ + └──────────▼───────────┘ └──────────▲───────────┘ + │ │ 127.0.0.1:15002 + │ │ + ┌──────────▼───────────┐ ┌──────────┴───────────┐ + │ │ │ │ + │ │ │ │ + │ iperf3 -s -p 15001 │ │ iperf3 -c -p 15001 │ + │ │ │ │ + │ │ │ │ + └──────────────────────┘ └──────────────────────┘ diff --git a/e2e/forwarding/a_config.yml b/e2e/forwarding/a_config.yml new file mode 100644 index 000000000..8469a34fe --- /dev/null +++ b/e2e/forwarding/a_config.yml @@ -0,0 +1,27 @@ +pki: + ca: ca.crt + cert: a.crt + key: a.key + +static_host_map: + "192.168.100.2": ["127.0.0.1:10002"] + +listen: + host: 127.0.0.1 + port: 10001 + +forwarding: + inbound: + - listen: ":5201" + dial: "127.0.0.1:15001" + proto: tcp + +tun: + disabled: true + mtu: 1300 + +firewall: + inbound: + - port: 5201 + proto: tcp + host: any diff --git a/e2e/forwarding/b_config.yml b/e2e/forwarding/b_config.yml new file mode 100644 index 000000000..b0707ddc9 --- /dev/null +++ b/e2e/forwarding/b_config.yml @@ -0,0 +1,27 @@ +pki: + ca: ca.crt + cert: b.crt + key: b.key + +static_host_map: + "192.168.100.1": ["127.0.0.1:10001"] + +listen: + host: 127.0.0.1 + port: 10002 + +forwarding: + outbound: + - listen: "127.0.0.1:15002" + dial: "192.168.100.1:5201" + proto: tcp + +tun: + disabled: true + mtu: 1300 + +firewall: + outbound: + - port: 5201 + proto: tcp + host: any diff --git a/e2e/forwarding/generate_certificates.sh b/e2e/forwarding/generate_certificates.sh new file mode 100755 index 000000000..5052a8925 --- /dev/null +++ b/e2e/forwarding/generate_certificates.sh @@ -0,0 +1,3 @@ +../../nebula-cert ca -name "E2E test CA" +../../nebula-cert sign -name "A" -ip "192.168.100.1/24" -out-crt a.crt -out-key a.key +../../nebula-cert sign -name "B" -ip "192.168.100.2/24" -out-crt b.crt -out-key b.key diff --git a/e2e/forwarding/speedtest.sh b/e2e/forwarding/speedtest.sh new file mode 100755 index 000000000..762ce4eb6 --- /dev/null +++ b/e2e/forwarding/speedtest.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +cd "$(dirname "$0")" + +if ! test -f ca.key; then + echo "Generating new test certificates" + ./generate_certificates.sh +fi + + +../../nebula -config "$(pwd)/a_config.yml" &>a.out & +A_PID=$! +../../nebula -config "$(pwd)/b_config.yml" &>b.out & +B_PID=$! + +iperf3 -s -p 15001 & +IPERF_SERVER_PID=$! + +sleep 1 +iperf3 -c 127.0.0.1 -p 15002 -P 10 + +# Cleanup +kill $IPERF_SERVER_PID $A_PID $B_PID + +echo "##########################################" +echo "A side logs:" +echo "##########################################" +cat a.out + +echo "##########################################" +echo "B side logs:" +echo "##########################################" +cat b.out +rm a.out b.out From ed937ab4a97547976d63d8b1f9d1795f1cb952c5 Mon Sep 17 00:00:00 2001 From: Ulrich Hornung Date: Sun, 11 Aug 2024 19:52:28 +0200 Subject: [PATCH 06/33] Revert "Add script for speedtesting port forwarding" This reverts commit c245a309666be91afc603cb6b3a8ad706df7fbfc. --- e2e/forwarding/.gitignore | 3 --- e2e/forwarding/README.md | 27 -------------------- e2e/forwarding/a_config.yml | 27 -------------------- e2e/forwarding/b_config.yml | 27 -------------------- e2e/forwarding/generate_certificates.sh | 3 --- e2e/forwarding/speedtest.sh | 34 ------------------------- 6 files changed, 121 deletions(-) delete mode 100644 e2e/forwarding/.gitignore delete mode 100644 e2e/forwarding/README.md delete mode 100644 e2e/forwarding/a_config.yml delete mode 100644 e2e/forwarding/b_config.yml delete mode 100755 e2e/forwarding/generate_certificates.sh delete mode 100755 e2e/forwarding/speedtest.sh diff --git a/e2e/forwarding/.gitignore b/e2e/forwarding/.gitignore deleted file mode 100644 index ccdfcbd6d..000000000 --- a/e2e/forwarding/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -*.out -*.crt -*.key diff --git a/e2e/forwarding/README.md b/e2e/forwarding/README.md deleted file mode 100644 index 3b6d39b63..000000000 --- a/e2e/forwarding/README.md +++ /dev/null @@ -1,27 +0,0 @@ -# Userspace port forwarding -A simple speedtest for userspace port forwarding that can run without root access. - -## A side -Nebula running at port 10000, forwarding inbound TCP connections on port 5201 to 127.0.0.1:15001. - -## B side -Nebula running at port 10001, forwarding outbound TCP connections from 127.0.0.1:15002 to port 5201 of the A side. - -## Speedtest - - ┌──────────────────────┐:10001 :10002┌──────────────────────┐ - │ Nebula A side ├─────────────────┤ Nebula B side │ - │ │ │ │ - │ 192.168.100.1 │ TCP 5201 │ 192.168.100.2 │ - │ ┌───────────┼─────────────────┼──────────┐ │ - │ │ ├─────────────────┤ │ │ - └──────────▼───────────┘ └──────────▲───────────┘ - │ │ 127.0.0.1:15002 - │ │ - ┌──────────▼───────────┐ ┌──────────┴───────────┐ - │ │ │ │ - │ │ │ │ - │ iperf3 -s -p 15001 │ │ iperf3 -c -p 15001 │ - │ │ │ │ - │ │ │ │ - └──────────────────────┘ └──────────────────────┘ diff --git a/e2e/forwarding/a_config.yml b/e2e/forwarding/a_config.yml deleted file mode 100644 index 8469a34fe..000000000 --- a/e2e/forwarding/a_config.yml +++ /dev/null @@ -1,27 +0,0 @@ -pki: - ca: ca.crt - cert: a.crt - key: a.key - -static_host_map: - "192.168.100.2": ["127.0.0.1:10002"] - -listen: - host: 127.0.0.1 - port: 10001 - -forwarding: - inbound: - - listen: ":5201" - dial: "127.0.0.1:15001" - proto: tcp - -tun: - disabled: true - mtu: 1300 - -firewall: - inbound: - - port: 5201 - proto: tcp - host: any diff --git a/e2e/forwarding/b_config.yml b/e2e/forwarding/b_config.yml deleted file mode 100644 index b0707ddc9..000000000 --- a/e2e/forwarding/b_config.yml +++ /dev/null @@ -1,27 +0,0 @@ -pki: - ca: ca.crt - cert: b.crt - key: b.key - -static_host_map: - "192.168.100.1": ["127.0.0.1:10001"] - -listen: - host: 127.0.0.1 - port: 10002 - -forwarding: - outbound: - - listen: "127.0.0.1:15002" - dial: "192.168.100.1:5201" - proto: tcp - -tun: - disabled: true - mtu: 1300 - -firewall: - outbound: - - port: 5201 - proto: tcp - host: any diff --git a/e2e/forwarding/generate_certificates.sh b/e2e/forwarding/generate_certificates.sh deleted file mode 100755 index 5052a8925..000000000 --- a/e2e/forwarding/generate_certificates.sh +++ /dev/null @@ -1,3 +0,0 @@ -../../nebula-cert ca -name "E2E test CA" -../../nebula-cert sign -name "A" -ip "192.168.100.1/24" -out-crt a.crt -out-key a.key -../../nebula-cert sign -name "B" -ip "192.168.100.2/24" -out-crt b.crt -out-key b.key diff --git a/e2e/forwarding/speedtest.sh b/e2e/forwarding/speedtest.sh deleted file mode 100755 index 762ce4eb6..000000000 --- a/e2e/forwarding/speedtest.sh +++ /dev/null @@ -1,34 +0,0 @@ -#!/bin/bash - -cd "$(dirname "$0")" - -if ! test -f ca.key; then - echo "Generating new test certificates" - ./generate_certificates.sh -fi - - -../../nebula -config "$(pwd)/a_config.yml" &>a.out & -A_PID=$! -../../nebula -config "$(pwd)/b_config.yml" &>b.out & -B_PID=$! - -iperf3 -s -p 15001 & -IPERF_SERVER_PID=$! - -sleep 1 -iperf3 -c 127.0.0.1 -p 15002 -P 10 - -# Cleanup -kill $IPERF_SERVER_PID $A_PID $B_PID - -echo "##########################################" -echo "A side logs:" -echo "##########################################" -cat a.out - -echo "##########################################" -echo "B side logs:" -echo "##########################################" -cat b.out -rm a.out b.out From 48745ebd501a7a3588570dec6e009dc6e9fd52f1 Mon Sep 17 00:00:00 2001 From: Ulrich Hornung Date: Sun, 11 Aug 2024 19:54:39 +0200 Subject: [PATCH 07/33] modified and extended speedtest --- e2e/forwarding/.gitignore | 3 +++ e2e/forwarding/README.md | 27 +++++++++++++++++++ e2e/forwarding/a_config.yml | 35 ++++++++++++++++++++++++ e2e/forwarding/b_config.yml | 34 +++++++++++++++++++++++ e2e/forwarding/generate_certificates.sh | 5 ++++ e2e/forwarding/speedtest.sh | 36 +++++++++++++++++++++++++ e2e/forwarding/speedtest_udp.sh | 5 ++++ 7 files changed, 145 insertions(+) create mode 100644 e2e/forwarding/.gitignore create mode 100644 e2e/forwarding/README.md create mode 100644 e2e/forwarding/a_config.yml create mode 100644 e2e/forwarding/b_config.yml create mode 100755 e2e/forwarding/generate_certificates.sh create mode 100755 e2e/forwarding/speedtest.sh create mode 100755 e2e/forwarding/speedtest_udp.sh diff --git a/e2e/forwarding/.gitignore b/e2e/forwarding/.gitignore new file mode 100644 index 000000000..ccdfcbd6d --- /dev/null +++ b/e2e/forwarding/.gitignore @@ -0,0 +1,3 @@ +*.out +*.crt +*.key diff --git a/e2e/forwarding/README.md b/e2e/forwarding/README.md new file mode 100644 index 000000000..3b6d39b63 --- /dev/null +++ b/e2e/forwarding/README.md @@ -0,0 +1,27 @@ +# Userspace port forwarding +A simple speedtest for userspace port forwarding that can run without root access. + +## A side +Nebula running at port 10000, forwarding inbound TCP connections on port 5201 to 127.0.0.1:15001. + +## B side +Nebula running at port 10001, forwarding outbound TCP connections from 127.0.0.1:15002 to port 5201 of the A side. + +## Speedtest + + ┌──────────────────────┐:10001 :10002┌──────────────────────┐ + │ Nebula A side ├─────────────────┤ Nebula B side │ + │ │ │ │ + │ 192.168.100.1 │ TCP 5201 │ 192.168.100.2 │ + │ ┌───────────┼─────────────────┼──────────┐ │ + │ │ ├─────────────────┤ │ │ + └──────────▼───────────┘ └──────────▲───────────┘ + │ │ 127.0.0.1:15002 + │ │ + ┌──────────▼───────────┐ ┌──────────┴───────────┐ + │ │ │ │ + │ │ │ │ + │ iperf3 -s -p 15001 │ │ iperf3 -c -p 15001 │ + │ │ │ │ + │ │ │ │ + └──────────────────────┘ └──────────────────────┘ diff --git a/e2e/forwarding/a_config.yml b/e2e/forwarding/a_config.yml new file mode 100644 index 000000000..01b8e120e --- /dev/null +++ b/e2e/forwarding/a_config.yml @@ -0,0 +1,35 @@ +pki: + ca: ca.crt + cert: a.crt + key: a.key + +static_host_map: + "192.168.100.2": ["127.0.0.1:10002"] + +logging: + level: info + +listen: + host: 127.0.0.1 + port: 10001 + +port_forwarding: + enable_without_rules: true + inbound: + - listen_port: 5201 + dial_address: "127.0.0.1:15001" + protocols: [tcp, udp] + +tun: + disabled: true + mtu: 1300 + +firewall: + outbound: + - port: any + proto: udp + host: any + inbound: + - port: 5201 + proto: any + host: any diff --git a/e2e/forwarding/b_config.yml b/e2e/forwarding/b_config.yml new file mode 100644 index 000000000..b7e486cbc --- /dev/null +++ b/e2e/forwarding/b_config.yml @@ -0,0 +1,34 @@ +pki: + ca: ca.crt + cert: b.crt + key: b.key + +static_host_map: + "192.168.100.1": ["127.0.0.1:10001"] + +logging: + level: info + +listen: + host: 127.0.0.1 + port: 10002 + +port_forwarding: + enable_without_rules: true + outbound: + - listen_address: "127.0.0.1:15002" + dial_address: "192.168.100.1:5201" + protocols: [tcp, udp] + +tun: + disabled: true + mtu: 1300 + +firewall: + outbound: + - port: any + proto: udp + host: any + - port: 5201 + proto: any + host: any diff --git a/e2e/forwarding/generate_certificates.sh b/e2e/forwarding/generate_certificates.sh new file mode 100755 index 000000000..4df299c4e --- /dev/null +++ b/e2e/forwarding/generate_certificates.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +../../nebula-cert ca -name "E2E test CA" +../../nebula-cert sign -name "A" -ip "192.168.100.1/24" -out-crt a.crt -out-key a.key +../../nebula-cert sign -name "B" -ip "192.168.100.2/24" -out-crt b.crt -out-key b.key diff --git a/e2e/forwarding/speedtest.sh b/e2e/forwarding/speedtest.sh new file mode 100755 index 000000000..9d166de99 --- /dev/null +++ b/e2e/forwarding/speedtest.sh @@ -0,0 +1,36 @@ +#!/bin/bash + +cd "$(dirname "$0")" + +if ! test -f ca.key; then + echo "Generating new test certificates" + ./generate_certificates.sh +fi + +../../nebula -config "$(pwd)/a_config.yml" &>a.out & +A_PID=$! +../../nebula -config "$(pwd)/b_config.yml" &>b.out & +B_PID=$! + +iperf3 -s -p 15001 & +IPERF_SERVER_PID=$! + +sleep 1 +iperf3 -c 127.0.0.1 -p 15002 -P 10 "$@" + +# Cleanup +kill $IPERF_SERVER_PID $A_PID $B_PID + +# wait for shutdown logs are written to files +sleep 1 + +echo "##########################################" +echo "A side logs:" +echo "##########################################" +cat a.out + +echo "##########################################" +echo "B side logs:" +echo "##########################################" +cat b.out +rm a.out b.out diff --git a/e2e/forwarding/speedtest_udp.sh b/e2e/forwarding/speedtest_udp.sh new file mode 100755 index 000000000..cb42d05e8 --- /dev/null +++ b/e2e/forwarding/speedtest_udp.sh @@ -0,0 +1,5 @@ +#!/bin/bash + +cd "$(dirname "$0")" + +./speedtest.sh --udp --bidir --bitrate=100MiB "$@" From c51638345b7904f60afb15928f5e397df5959259 Mon Sep 17 00:00:00 2001 From: Ulrich Hornung Date: Sun, 11 Aug 2024 19:56:27 +0200 Subject: [PATCH 08/33] add automated functional tests for port forwarding --- port_forwarder/port_forwarder_tcp_test.go | 288 ++++++++++++++++++++++ port_forwarder/port_forwarder_udp_test.go | 132 ++++++++++ service/service_test.go | 90 +------ service/service_testhelpers.go | 101 ++++++++ 4 files changed, 522 insertions(+), 89 deletions(-) create mode 100644 port_forwarder/port_forwarder_tcp_test.go create mode 100644 port_forwarder/port_forwarder_udp_test.go create mode 100644 service/service_testhelpers.go diff --git a/port_forwarder/port_forwarder_tcp_test.go b/port_forwarder/port_forwarder_tcp_test.go new file mode 100644 index 000000000..2f86c37a4 --- /dev/null +++ b/port_forwarder/port_forwarder_tcp_test.go @@ -0,0 +1,288 @@ +package port_forwarder + +import ( + "net" + "testing" + + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/service" + "github.com/stretchr/testify/assert" +) + +func doTestTcpCommunication( + t *testing.T, + msg string, + senderConn net.Conn, + receiverConn net.Conn, +) { + data_sent := []byte(msg) + n, err := senderConn.Write(data_sent) + assert.Nil(t, err) + assert.Equal(t, n, len(data_sent)) + + buf := make([]byte, 100) + n, err = receiverConn.Read(buf) + assert.Nil(t, err) + assert.Equal(t, n, len(data_sent)) + assert.Equal(t, data_sent, buf[:n]) +} + +func doTestTcpCommunicationFail( + t *testing.T, + msg string, + senderConn net.Conn, + receiverConn net.Conn, +) { + data_sent := []byte(msg) + n, err := senderConn.Write(data_sent) + if err != nil { + return + } + assert.Nil(t, err) + assert.Equal(t, n, len(data_sent)) + + buf := make([]byte, 100) + _, err = receiverConn.Read(buf) + assert.NotNil(t, err) +} + +func TestTcpInOut2Clients(t *testing.T) { + l := logrus.New() + server, client := service.CreateTwoConnectedServices(4247) + defer client.Close() + defer server.Close() + + server_pf, err := createPortForwarderFromConfigString(l, server, ` +port_forwarding: + inbound: + - listen_port: 4495 + dial_address: 127.0.0.1:5595 + protocols: [tcp] +`) + assert.Nil(t, err) + + assert.Len(t, server_pf.portForwardings, 1) + + client_pf, err := createPortForwarderFromConfigString(l, client, ` +port_forwarding: + outbound: + - listen_address: 127.0.0.1:3395 + dial_address: 10.0.0.1:4495 + protocols: [tcp] +`) + assert.Nil(t, err) + + assert.Len(t, client_pf.portForwardings, 1) + + client_conn_addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:3395") + assert.Nil(t, err) + server_conn_addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:5595") + assert.Nil(t, err) + + server_listen_conn, err := net.ListenTCP("tcp", server_conn_addr) + assert.Nil(t, err) + client1_conn, err := net.DialTCP("tcp", nil, client_conn_addr) + assert.Nil(t, err) + client1_server_side_conn, err := server_listen_conn.Accept() + assert.Nil(t, err) + client2_conn, err := net.DialTCP("tcp", nil, client_conn_addr) + assert.Nil(t, err) + client2_server_side_conn, err := server_listen_conn.Accept() + assert.Nil(t, err) + + doTestTcpCommunication(t, "Hello from client 1 side!", + client1_conn, client1_server_side_conn) + doTestTcpCommunication(t, "Hello from client two side!", + client2_conn, client2_server_side_conn) + + doTestTcpCommunication(t, "Hello from server first side!", + client1_server_side_conn, client1_conn) + doTestTcpCommunication(t, "Hello from server second side!", + client2_server_side_conn, client2_conn) + doTestTcpCommunication(t, "Hello from server third side!", + client1_server_side_conn, client1_conn) + + doTestTcpCommunication(t, "Hello from client two side AGAIN!", + client2_conn, client2_server_side_conn) + +} + +func TestTcpInOut1ClientConfigReload(t *testing.T) { + l := logrus.New() + server, client := service.CreateTwoConnectedServices(4246) + defer client.Close() + defer server.Close() + + server_pf, err := createPortForwarderFromConfigString(l, server, ` +port_forwarding: + inbound: + - listen_port: 4497 + dial_address: 127.0.0.1:5597 + protocols: [tcp] +`) + assert.Nil(t, err) + + assert.Len(t, server_pf.portForwardings, 1) + + client_pf, err := createPortForwarderFromConfigString(l, client, ` +port_forwarding: + outbound: + - listen_address: 127.0.0.1:3397 + dial_address: 10.0.0.1:4497 + protocols: [tcp] +`) + assert.Nil(t, err) + + assert.Len(t, client_pf.portForwardings, 1) + + client_conn_addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:3397") + assert.Nil(t, err) + server_conn_addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:5597") + assert.Nil(t, err) + + server_listen_conn, err := net.ListenTCP("tcp", server_conn_addr) + assert.Nil(t, err) + client1_conn, err := net.DialTCP("tcp", nil, client_conn_addr) + assert.Nil(t, err) + client1_server_side_conn, err := server_listen_conn.Accept() + assert.Nil(t, err) + + doTestTcpCommunication(t, "Hello from client 1 side!", + client1_conn, client1_server_side_conn) + + doTestTcpCommunication(t, "Hello from server first side!", + client1_server_side_conn, client1_conn) + doTestTcpCommunication(t, "Hello from server third side!", + client1_server_side_conn, client1_conn) + + doTestTcpCommunication(t, "Hello from client one side AGAIN!", + client1_conn, client1_server_side_conn) + + new_server_fwd_list, err := loadPortFwdConfigFromString(l, ` +port_forwarding: + inbound: + - listen_port: 4496 + dial_address: 127.0.0.1:5596 + protocols: [tcp] +`) + assert.Nil(t, err) + + assert.Len(t, server_pf.portForwardings, 1) + + new_client_fwd_list, err := loadPortFwdConfigFromString(l, ` +port_forwarding: + outbound: + - listen_address: 127.0.0.1:3396 + dial_address: 10.0.0.1:4496 + protocols: [tcp] +`) + assert.Nil(t, err) + + err = client_pf.ApplyChangesByNewFwdList(new_client_fwd_list) + assert.Nil(t, err) + + doTestTcpCommunicationFail(t, "Hello from client 1 side!", + client1_conn, client1_server_side_conn) + + doTestTcpCommunicationFail(t, "Hello from server first side!", + client1_server_side_conn, client1_conn) + + err = server_pf.ApplyChangesByNewFwdList(new_server_fwd_list) + assert.Nil(t, err) + + doTestTcpCommunicationFail(t, "Hello from client 1 side!", + client1_conn, client1_server_side_conn) + + doTestTcpCommunicationFail(t, "Hello from server first side!", + client1_server_side_conn, client1_conn) +} + +func TestTcpInOut1ClientConfigReload_inverseCloseOrder(t *testing.T) { + l := logrus.New() + server, client := service.CreateTwoConnectedServices(4245) + defer client.Close() + defer server.Close() + + server_pf, err := createPortForwarderFromConfigString(l, server, ` +port_forwarding: + inbound: + - listen_port: 4499 + dial_address: 127.0.0.1:5599 + protocols: [tcp] +`) + assert.Nil(t, err) + + assert.Len(t, server_pf.portForwardings, 1) + + client_pf, err := createPortForwarderFromConfigString(l, client, ` +port_forwarding: + outbound: + - listen_address: 127.0.0.1:3399 + dial_address: 10.0.0.1:4499 + protocols: [tcp] +`) + assert.Nil(t, err) + + assert.Len(t, client_pf.portForwardings, 1) + + client_conn_addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:3399") + assert.Nil(t, err) + server_conn_addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:5599") + assert.Nil(t, err) + + server_listen_conn, err := net.ListenTCP("tcp", server_conn_addr) + assert.Nil(t, err) + client1_conn, err := net.DialTCP("tcp", nil, client_conn_addr) + assert.Nil(t, err) + client1_server_side_conn, err := server_listen_conn.Accept() + assert.Nil(t, err) + + doTestTcpCommunication(t, "Hello from client 1 side!", + client1_conn, client1_server_side_conn) + + doTestTcpCommunication(t, "Hello from server first side!", + client1_server_side_conn, client1_conn) + doTestTcpCommunication(t, "Hello from server third side!", + client1_server_side_conn, client1_conn) + + doTestTcpCommunication(t, "Hello from client one side AGAIN!", + client1_conn, client1_server_side_conn) + + new_server_fwd_list, err := loadPortFwdConfigFromString(l, ` +port_forwarding: + inbound: + - listen_port: 4498 + dial_address: 127.0.0.1:5598 + protocols: [tcp] +`) + assert.Nil(t, err) + + assert.Len(t, server_pf.portForwardings, 1) + + new_client_fwd_list, err := loadPortFwdConfigFromString(l, ` +port_forwarding: + outbound: + - listen_address: 127.0.0.1:3398 + dial_address: 10.0.0.1:4498 + protocols: [tcp] +`) + assert.Nil(t, err) + + err = server_pf.ApplyChangesByNewFwdList(new_server_fwd_list) + assert.Nil(t, err) + + doTestTcpCommunicationFail(t, "Hello from client 1 side!", + client1_conn, client1_server_side_conn) + + doTestTcpCommunicationFail(t, "Hello from server first side!", + client1_server_side_conn, client1_conn) + + err = client_pf.ApplyChangesByNewFwdList(new_client_fwd_list) + assert.Nil(t, err) + + doTestTcpCommunicationFail(t, "Hello from client 1 side!", + client1_conn, client1_server_side_conn) + + doTestTcpCommunicationFail(t, "Hello from server first side!", + client1_server_side_conn, client1_conn) +} diff --git a/port_forwarder/port_forwarder_udp_test.go b/port_forwarder/port_forwarder_udp_test.go new file mode 100644 index 000000000..27ea5800d --- /dev/null +++ b/port_forwarder/port_forwarder_udp_test.go @@ -0,0 +1,132 @@ +package port_forwarder + +import ( + "net" + "testing" + + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/service" + "github.com/stretchr/testify/assert" +) + +func loadPortFwdConfigFromString(l *logrus.Logger, configStr string) (*PortForwardingList, error) { + c := config.NewC(l) + err := c.LoadString(configStr) + if err != nil { + return nil, err + } + + fwd_list := NewPortForwardingList() + err = ParseConfig(l, c, fwd_list) + if err != nil { + return nil, err + } + + return &fwd_list, nil +} + +func createPortForwarderFromConfigString(l *logrus.Logger, srv *service.Service, configStr string) (*PortForwardingService, error) { + + fwd_list, err := loadPortFwdConfigFromString(l, configStr) + if err != nil { + return nil, err + } + + pf, err := ConstructFromInitialFwdList(srv, l, fwd_list) + if err != nil { + return nil, err + } + + err = pf.Activate() + if err != nil { + return nil, err + } + + return pf, nil +} + +func doTestUdpCommunication( + t *testing.T, + msg string, + senderConn *net.UDPConn, + toAddr net.Addr, + receiverConn *net.UDPConn, +) (senderAddr net.Addr) { + data_sent := []byte(msg) + var n int + var err error + if toAddr != nil { + n, err = senderConn.WriteTo(data_sent, toAddr) + } else { + n, err = senderConn.Write(data_sent) + } + assert.Nil(t, err) + assert.Equal(t, n, len(data_sent)) + + buf := make([]byte, 100) + n, senderAddr, err = receiverConn.ReadFrom(buf) + assert.Nil(t, err) + assert.Equal(t, n, len(data_sent)) + assert.Equal(t, data_sent, buf[:n]) + return +} + +func TestUdpInOut2Clients(t *testing.T) { + l := logrus.New() + server, client := service.CreateTwoConnectedServices(4244) + defer client.Close() + defer server.Close() + + server_pf, err := createPortForwarderFromConfigString(l, server, ` +port_forwarding: + inbound: + - listen_port: 4499 + dial_address: 127.0.0.1:5599 + protocols: [udp] +`) + assert.Nil(t, err) + + assert.Len(t, server_pf.portForwardings, 1) + + client_pf, err := createPortForwarderFromConfigString(l, client, ` +port_forwarding: + outbound: + - listen_address: 127.0.0.1:3399 + dial_address: 10.0.0.1:4499 + protocols: [udp] +`) + assert.Nil(t, err) + + assert.Len(t, client_pf.portForwardings, 1) + + client_conn_addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:3399") + assert.Nil(t, err) + server_conn_addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:5599") + assert.Nil(t, err) + + server_listen_conn, err := net.ListenUDP("udp", server_conn_addr) + assert.Nil(t, err) + client1_conn, err := net.DialUDP("udp", nil, client_conn_addr) + assert.Nil(t, err) + client2_conn, err := net.DialUDP("udp", nil, client_conn_addr) + assert.Nil(t, err) + + client1_addr := doTestUdpCommunication(t, "Hello from client 1 side!", + client1_conn, nil, server_listen_conn) + assert.NotNil(t, client1_addr) + client2_addr := doTestUdpCommunication(t, "Hello from client two side!", + client2_conn, nil, server_listen_conn) + assert.NotNil(t, client2_addr) + + doTestUdpCommunication(t, "Hello from server first side!", + server_listen_conn, client1_addr, client1_conn) + doTestUdpCommunication(t, "Hello from server second side!", + server_listen_conn, client2_addr, client2_conn) + doTestUdpCommunication(t, "Hello from server third side!", + server_listen_conn, client1_addr, client1_conn) + + doTestUdpCommunication(t, "Hello from client two side AGAIN!", + client2_conn, nil, server_listen_conn) + +} diff --git a/service/service_test.go b/service/service_test.go index 31762090d..b9098c34f 100644 --- a/service/service_test.go +++ b/service/service_test.go @@ -4,101 +4,13 @@ import ( "bytes" "context" "errors" - "net/netip" "testing" - "time" - "dario.cat/mergo" - "github.com/slackhq/nebula/cert" - "github.com/slackhq/nebula/config" - "github.com/slackhq/nebula/e2e" "golang.org/x/sync/errgroup" - "gopkg.in/yaml.v2" ) -type m map[string]interface{} - -func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service { - _, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), netip.PrefixFrom(udpIp, 24), nil, []string{}) - caB, err := caCrt.MarshalToPEM() - if err != nil { - panic(err) - } - - mc := m{ - "pki": m{ - "ca": string(caB), - "cert": string(myPEM), - "key": string(myPrivKey), - }, - //"tun": m{"disabled": true}, - "firewall": m{ - "outbound": []m{{ - "proto": "any", - "port": "any", - "host": "any", - }}, - "inbound": []m{{ - "proto": "any", - "port": "any", - "host": "any", - }}, - }, - "timers": m{ - "pending_deletion_interval": 2, - "connection_alive_interval": 2, - }, - "handshakes": m{ - "try_interval": "200ms", - }, - } - - if overrides != nil { - err = mergo.Merge(&overrides, mc, mergo.WithAppendSlice) - if err != nil { - panic(err) - } - mc = overrides - } - - cb, err := yaml.Marshal(mc) - if err != nil { - panic(err) - } - - var c config.C - if err := c.LoadString(string(cb)); err != nil { - panic(err) - } - - s, err := New(&c) - if err != nil { - panic(err) - } - return s -} - func TestService(t *testing.T) { - ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) - a := newSimpleService(ca, caKey, "a", netip.MustParseAddr("10.0.0.1"), m{ - "static_host_map": m{}, - "lighthouse": m{ - "am_lighthouse": true, - }, - "listen": m{ - "host": "0.0.0.0", - "port": 4243, - }, - }) - b := newSimpleService(ca, caKey, "b", netip.MustParseAddr("10.0.0.2"), m{ - "static_host_map": m{ - "10.0.0.1": []string{"localhost:4243"}, - }, - "lighthouse": m{ - "hosts": []string{"10.0.0.1"}, - "interval": 1, - }, - }) + a, b := CreateTwoConnectedServices(4243) ln, err := a.Listen("tcp", ":1234") if err != nil { diff --git a/service/service_testhelpers.go b/service/service_testhelpers.go new file mode 100644 index 000000000..28661865b --- /dev/null +++ b/service/service_testhelpers.go @@ -0,0 +1,101 @@ +package service + +import ( + "fmt" + "net/netip" + "time" + + "dario.cat/mergo" + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/cert" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/e2e" + "gopkg.in/yaml.v2" +) + +type m map[string]interface{} + +func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service { + _, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), netip.PrefixFrom(udpIp, 24), nil, []string{}) + caB, err := caCrt.MarshalToPEM() + if err != nil { + panic(err) + } + + mc := m{ + "pki": m{ + "ca": string(caB), + "cert": string(myPEM), + "key": string(myPrivKey), + }, + //"tun": m{"disabled": true}, + "firewall": m{ + "outbound": []m{{ + "proto": "any", + "port": "any", + "host": "any", + }}, + "inbound": []m{{ + "proto": "any", + "port": "any", + "host": "any", + }}, + }, + "timers": m{ + "pending_deletion_interval": 2, + "connection_alive_interval": 2, + }, + "handshakes": m{ + "try_interval": "200ms", + }, + } + + if overrides != nil { + err = mergo.Merge(&overrides, mc, mergo.WithAppendSlice) + if err != nil { + panic(err) + } + mc = overrides + } + + cb, err := yaml.Marshal(mc) + if err != nil { + panic(err) + } + + var c config.C + if err := c.LoadString(string(cb)); err != nil { + panic(err) + } + + l := logrus.New() + s, err := New(&c, l) + if err != nil { + panic(err) + } + return s +} + +func CreateTwoConnectedServices(port int) (*Service, *Service) { + ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + a := newSimpleService(ca, caKey, "a", netip.MustParseAddr("10.0.0.1"), m{ + "static_host_map": m{}, + "lighthouse": m{ + "am_lighthouse": true, + }, + "listen": m{ + "host": "0.0.0.0", + "port": port, + }, + }) + b := newSimpleService(ca, caKey, "b", netip.MustParseAddr("10.0.0.2"), m{ + "static_host_map": m{ + "10.0.0.1": []string{fmt.Sprintf("localhost:%d", port)}, + }, + "lighthouse": m{ + "hosts": []string{"10.0.0.1"}, + "interval": 1, + }, + }) + return a, b +} From 067410f2fbba2b3328ed1c0bf3bbaaeb4626d82e Mon Sep 17 00:00:00 2001 From: Ulrich Hornung Date: Sun, 11 Aug 2024 20:00:28 +0200 Subject: [PATCH 09/33] TCP/UDP port fwd. for disabled tun with config reload support - reload config supported - killing existing live connections for removed fwd. configs --- cmd/nebula/main.go | 74 ++++- examples/config.yml | 29 +- examples/go_service/main.go | 4 +- interface.go | 2 +- overlay/user.go | 1 - port_forwarder/builder.go | 233 ++++++++++++++++ port_forwarder/config.go | 27 ++ port_forwarder/config_test.go | 275 +++++++++++++++++++ port_forwarder/fwd_tcp.go | 225 +++++++++++++++ port_forwarder/fwd_udp.go | 302 +++++++++++++++++++++ port_forwarder/lockfree_timeout_counter.go | 33 +++ port_forwarder/port_forwarding_service.go | 54 ++++ service/service.go | 33 ++- 13 files changed, 1277 insertions(+), 15 deletions(-) create mode 100644 port_forwarder/builder.go create mode 100644 port_forwarder/config.go create mode 100644 port_forwarder/config_test.go create mode 100644 port_forwarder/fwd_tcp.go create mode 100644 port_forwarder/fwd_udp.go create mode 100644 port_forwarder/lockfree_timeout_counter.go create mode 100644 port_forwarder/port_forwarding_service.go diff --git a/cmd/nebula/main.go b/cmd/nebula/main.go index 5cf0a028a..f3e6a35b8 100644 --- a/cmd/nebula/main.go +++ b/cmd/nebula/main.go @@ -1,13 +1,20 @@ package main import ( + "context" + "errors" "flag" "fmt" + "io" "os" + "os/signal" + "syscall" "github.com/sirupsen/logrus" "github.com/slackhq/nebula" "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/port_forwarder" + "github.com/slackhq/nebula/service" "github.com/slackhq/nebula/util" ) @@ -52,16 +59,67 @@ func main() { os.Exit(1) } - ctrl, err := nebula.Main(c, *configTest, Build, l, nil) - if err != nil { - util.LogWithContextIfNeeded("Failed to start", err, l) - os.Exit(1) + fwd_list := port_forwarder.NewPortForwardingList() + disabled_tun := c.GetBool("tun.disabled", false) + activate_service_anyway := c.GetBool("port_forwarding.enable_without_rules", false) + if disabled_tun { + port_forwarder.ParseConfig(l, c, fwd_list) } - if !*configTest { - ctrl.Start() - notifyReady(l) - ctrl.ShutdownBlock() + if !*configTest && disabled_tun && (activate_service_anyway || !fwd_list.IsEmpty()) { + l.Infof("Configuring user-tun instead of disabled-tun as port forwarding is configured") + + service, err := service.New(c, l) + if err != nil { + util.LogWithContextIfNeeded("Failed to create service", err, l) + os.Exit(1) + } + + // initialize port forwarding: + pf_service, err := port_forwarder.ConstructFromInitialFwdList(service, l, &fwd_list) + if err != nil { + util.LogWithContextIfNeeded("Failed to start", err, l) + os.Exit(1) + } + + c.RegisterReloadCallback(func(c *config.C) { + pf_service.ReloadConfigAndApplyChanges(c) + }) + + pf_service.Activate() + + // wait for termination request + signalChannel := make(chan os.Signal, 1) + signal.Notify(signalChannel, syscall.SIGINT, syscall.SIGTERM) + fmt.Println("Running, press ctrl+c to shutdown...") + <-signalChannel + + // shutdown: + service.Close() + if err := service.Wait(); err != nil { + if errors.Is(err, os.ErrClosed) || + errors.Is(err, io.EOF) || + errors.Is(err, context.Canceled) { + l.Debugf("Stop of user-tun service returned: %v", err) + } else { + util.LogWithContextIfNeeded("Unclean stop", err, l) + } + } + + } else { + + l.Info("Configuring for disabled or kernel tun. no port forwarding provided") + ctrl, err := nebula.Main(c, *configTest, Build, l, nil) + if err != nil { + util.LogWithContextIfNeeded("Failed to start", err, l) + os.Exit(1) + } + + if !*configTest { + ctrl.Start() + notifyReady(l) + ctrl.ShutdownBlock() + } } os.Exit(0) diff --git a/examples/config.yml b/examples/config.yml index c74ffc68f..62c800573 100644 --- a/examples/config.yml +++ b/examples/config.yml @@ -206,7 +206,11 @@ relay: # Configure the private interface. Note: addr is baked into the nebula certificate tun: - # When tun is disabled, a lighthouse can be started without a local tun interface (and therefore without root) + # When tun is disabled, a feature limited Nebula can be started without root privileges. + # In this limited mode, Nebula can + # - run a lighthouse node + # - offer access from and to the nebula network via port forwarding + # - respond to ping requests disabled: false # Name of the device. If not set, a default will be chosen by the OS. # For macOS: if set, must be in the form `utun[0-9]+`. @@ -368,3 +372,26 @@ firewall: proto: tcp group: remote_client local_cidr: 192.168.100.1/24 + +# By using port port forwarding (port tunnels) its possible to establish connections +# from/into the nebula-network without using a tun/tap device and thus without requiring root access +# on the host. Port forwarding is only supported when setting "tun.disabled" is set to true. +# In this case, if a user-tun instead of a real one is instantiated to allow the port forwarding. +# IMPORTANT: For incoming tunnels, don't forget to also open the firewall for the relevant ports. +port_forwarding: + # Forces activation of the user tun, even when there is no rule specified. + # This can be useful, when rules are planned to be added later by reload. + # Reload config can't consider a change on tun-type. + enable_without_rules: false + outbound: + # format of listen- and dial-address: : + #- listen_address: 127.0.0.1:3399 + # dial_address: 192.168.100.92:4499 + # format of protocols lists (yml-list): [tcp], [udp], [tcp, udp] + # protocols: [tcp, udp] + inbound: + # format of dial_address: : + #- listen_port: 5599 + # dial_address: 127.0.0.1:5599 + # format of protocols lists (yml-list): [tcp], [udp], [tcp, udp] + # protocols: [tcp, udp] diff --git a/examples/go_service/main.go b/examples/go_service/main.go index f46273acf..cd07b7fb2 100644 --- a/examples/go_service/main.go +++ b/examples/go_service/main.go @@ -5,6 +5,7 @@ import ( "fmt" "log" + "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/service" ) @@ -58,7 +59,8 @@ pki: if err := config.LoadString(configStr); err != nil { return err } - service, err := service.New(&config) + l := logrus.New() + service, err := service.New(&config, l) if err != nil { return err } diff --git a/interface.go b/interface.go index f2519076c..26c25340d 100644 --- a/interface.go +++ b/interface.go @@ -301,7 +301,7 @@ func (f *Interface) listenIn(reader io.ReadWriteCloser, i int) { for { n, err := reader.Read(packet) if err != nil { - if errors.Is(err, os.ErrClosed) && f.closed.Load() { + if (errors.Is(err, os.ErrClosed) && f.closed.Load()) || errors.Is(err, io.EOF) { return } diff --git a/overlay/user.go b/overlay/user.go index 90329f2cb..b938d4762 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -54,7 +54,6 @@ type UserDevice struct { outboundChannel chan *buffer.View inboundChannel chan *buffer.View - routeTree atomic.Pointer[bart.Table[netip.Addr]] } diff --git a/port_forwarder/builder.go b/port_forwarder/builder.go new file mode 100644 index 000000000..906069e9a --- /dev/null +++ b/port_forwarder/builder.go @@ -0,0 +1,233 @@ +package port_forwarder + +import ( + "fmt" + "io" + "strconv" + + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/config" + "github.com/slackhq/nebula/service" +) + +func ymlGetStringOfNode(node interface{}) string { + return fmt.Sprintf("%v", node) +} + +func ymlMapGetStringEntry(k string, m map[interface{}]interface{}) string { + v, ok := m[k] + if !ok { + return "" + } + return fmt.Sprintf("%v", v) +} + +type ymlListNode = []interface{} +type ymlMapNode = map[interface{}]interface{} +type configFactoryFn = func(yml_node ymlMapNode) error +type configFactoryFnMap = map[string]configFactoryFn + +type builderData struct { + l *logrus.Logger + target ConfigList + factories map[string]configFactoryFnMap +} + +func ParseConfig( + l *logrus.Logger, + c *config.C, + target ConfigList, +) error { + builder := builderData{ + l: l, + target: target, + factories: map[string]configFactoryFnMap{}, + } + + in := configFactoryFnMap{} + in["udp"] = func(yml_node ymlMapNode) error { + return builder.convertToForwardConfigIncoming(l, yml_node, false) + } + in["tcp"] = func(yml_node ymlMapNode) error { + return builder.convertToForwardConfigIncoming(l, yml_node, true) + } + builder.factories["inbound"] = in + + out := configFactoryFnMap{} + out["udp"] = func(yml_node ymlMapNode) error { + return builder.convertToForwardConfigOutgoing(l, yml_node, false) + } + out["tcp"] = func(yml_node ymlMapNode) error { + return builder.convertToForwardConfigOutgoing(l, yml_node, true) + } + builder.factories["outbound"] = out + + for _, direction := range [...]string{"inbound", "outbound"} { + cfg_fwds := c.Get("port_forwarding." + direction) + if cfg_fwds == nil { + continue + } + + cfg_fwds_list, ok := cfg_fwds.(ymlListNode) + if !ok { + return fmt.Errorf("yml node \"port_forwarding.%s\" needs to be a list", direction) + } + + for fwd_idx, node := range cfg_fwds_list { + node_map, ok := node.(ymlMapNode) + if !ok { + return fmt.Errorf("child yml node of \"port_forwarding.%s\" needs to be a map", direction) + } + + protocols, ok := node_map["protocols"] + if !ok { + l.Infof("child yml node of \"port_forwarding.%s\" should have a child \"protocols\"", direction) + continue + } + + protocols_list, ok := protocols.(ymlListNode) + if !ok { + return fmt.Errorf("child yml node of \"port_forwarding.%s\" needs to have a child \"protocols\" that is a yml list", direction) + } + + for _, proto := range protocols_list { + proto_str := ymlGetStringOfNode(proto) + factoryFn, ok := builder.factories[direction][proto_str] + if !ok { + return fmt.Errorf("child yml node of \"port_forwarding.%s.%d.protocols\" doesn't support: %s", direction, fwd_idx, proto_str) + } + + factoryFn(node_map) + } + } + } + + return nil +} + +func ConstructFromInitialFwdList( + tunService *service.Service, + l *logrus.Logger, + fwd_list *PortForwardingList, +) (*PortForwardingService, error) { + + pfService := &PortForwardingService{ + l: l, + tunService: tunService, + configPortForwardings: fwd_list.configPortForwardings, + portForwardings: make(map[string]io.Closer), + } + + return pfService, nil +} + +func NewPortForwardingList() PortForwardingList { + return PortForwardingList{ + configPortForwardings: map[string]ForwardConfig{}, + } +} + +type PortForwardingList struct { + configPortForwardings map[string]ForwardConfig +} + +func (pfl PortForwardingList) AddConfig(cfg ForwardConfig) { + pfl.configPortForwardings[cfg.ConfigDescriptor()] = cfg +} + +func (pfl PortForwardingList) IsEmpty() bool { + return len(pfl.configPortForwardings) == 0 +} + +func (s *PortForwardingService) ReloadConfigAndApplyChanges( + c *config.C, +) error { + + s.l.Infof("reloading port forwarding configuration...") + + pflNew := NewPortForwardingList() + + err := ParseConfig(s.l, c, pflNew) + if err != nil { + return err + } + + return s.ApplyChangesByNewFwdList(&pflNew) +} + +func (s *PortForwardingService) ApplyChangesByNewFwdList( + pflNew *PortForwardingList, +) error { + + to_be_closed := []string{} + for old := range s.configPortForwardings { + _, corresponding_new_exists := pflNew.configPortForwardings[old] + if !corresponding_new_exists { + to_be_closed = append(to_be_closed, old) + } + } + + s.CloseSelective(to_be_closed) + + to_be_added := map[string]ForwardConfig{} + for new, cfg := range pflNew.configPortForwardings { + _, corresponding_old_exists := s.configPortForwardings[new] + if !corresponding_old_exists { + to_be_added[cfg.ConfigDescriptor()] = cfg + } + } + + s.ActivateNew(to_be_added) + + return nil +} + +func (builder *builderData) convertToForwardConfigOutgoing( + _ *logrus.Logger, + m ymlMapNode, + isTcp bool, +) error { + fwd_port := ForwardConfigOutgoing{ + localListen: ymlMapGetStringEntry("listen_address", m), + remoteConnect: ymlMapGetStringEntry("dial_address", m), + } + + var cfg ForwardConfig + if isTcp { + cfg = ForwardConfigOutgoingTcp{fwd_port} + } else { + cfg = ForwardConfigOutgoingUdp{fwd_port} + } + + builder.target.AddConfig(cfg) + + return nil +} + +func (builder *builderData) convertToForwardConfigIncoming( + _ *logrus.Logger, + m ymlMapNode, + isTcp bool, +) error { + + v, err := strconv.ParseUint(ymlMapGetStringEntry("listen_port", m), 10, 32) + if err != nil { + return err + } + + fwd_port := ForwardConfigIncoming{ + port: uint32(v), + forwardLocalAddress: ymlMapGetStringEntry("dial_address", m), + } + + var cfg ForwardConfig + if isTcp { + cfg = ForwardConfigIncomingTcp{fwd_port} + } else { + cfg = ForwardConfigIncomingUdp{fwd_port} + } + + builder.target.AddConfig(cfg) + + return nil +} diff --git a/port_forwarder/config.go b/port_forwarder/config.go new file mode 100644 index 000000000..6eeb07aa4 --- /dev/null +++ b/port_forwarder/config.go @@ -0,0 +1,27 @@ +package port_forwarder + +import ( + "io" + + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/service" +) + +type ForwardConfig interface { + SetupPortForwarding(tunService *service.Service, l *logrus.Logger) (io.Closer, error) + ConfigDescriptor() string +} + +type ConfigList interface { + AddConfig(cfg ForwardConfig) +} + +type ForwardConfigOutgoing struct { + localListen string + remoteConnect string +} + +type ForwardConfigIncoming struct { + port uint32 + forwardLocalAddress string +} diff --git a/port_forwarder/config_test.go b/port_forwarder/config_test.go new file mode 100644 index 000000000..944f5a16c --- /dev/null +++ b/port_forwarder/config_test.go @@ -0,0 +1,275 @@ +package port_forwarder + +import ( + "testing" + + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/config" + "github.com/stretchr/testify/assert" +) + +func TestEmptyConfig(t *testing.T) { + l := logrus.New() + c := config.NewC(l) + err := c.LoadString("bla:") + assert.Nil(t, err) + + fwd_list := NewPortForwardingList() + err = ParseConfig(l, c, fwd_list) + assert.Nil(t, err) + + assert.Len(t, fwd_list.configPortForwardings, 0) + assert.True(t, fwd_list.IsEmpty()) +} + +func TestConfigWithNoProtocols(t *testing.T) { + l := logrus.New() + c := config.NewC(l) + err := c.LoadString(` +port_forwarding: + outbound: + - listen_address: 127.0.0.1:3399 + dial_address: 192.168.100.92:4499 + protocols: [] + inbound: + - listen_port: 5599 + dial_address: 127.0.0.1:5599 + protocols: [] +`) + assert.Nil(t, err) + + fwd_list := NewPortForwardingList() + err = ParseConfig(l, c, fwd_list) + assert.Nil(t, err) + + assert.Len(t, fwd_list.configPortForwardings, 0) + assert.True(t, fwd_list.IsEmpty()) +} + +func TestConfigWithNoProtocols_commentedProtos(t *testing.T) { + l := logrus.New() + c := config.NewC(l) + err := c.LoadString(` +port_forwarding: + outbound: + - listen_address: 127.0.0.1:3399 + dial_address: 192.168.100.92:4499 + # protocols: [tcp, udp] + inbound: + - listen_port: 5599 + dial_address: 127.0.0.1:5599 + # protocols: [tc, udp] +`) + assert.Nil(t, err) + + fwd_list := NewPortForwardingList() + err = ParseConfig(l, c, fwd_list) + assert.Nil(t, err) + + assert.Len(t, fwd_list.configPortForwardings, 0) + assert.True(t, fwd_list.IsEmpty()) +} + +func TestConfigWithNoProtocols_missing_in_out(t *testing.T) { + l := logrus.New() + c := config.NewC(l) + err := c.LoadString(` +port_forwarding: +`) + assert.Nil(t, err) + + fwd_list := NewPortForwardingList() + err = ParseConfig(l, c, fwd_list) + assert.Nil(t, err) + + assert.Len(t, fwd_list.configPortForwardings, 0) + assert.True(t, fwd_list.IsEmpty()) +} + +func TestConfigWithTcpIn(t *testing.T) { + l := logrus.New() + c := config.NewC(l) + err := c.LoadString(` +port_forwarding: + outbound: + - listen_address: 127.0.0.1:3399 + dial_address: 192.168.100.92:4499 + protocols: [] + inbound: + - listen_port: 5580 + dial_address: 127.0.0.1:5599 + protocols: [tcp] +`) + assert.Nil(t, err) + + fwd_list := NewPortForwardingList() + err = ParseConfig(l, c, fwd_list) + assert.Nil(t, err) + + assert.Len(t, fwd_list.configPortForwardings, 1) + assert.False(t, fwd_list.IsEmpty()) + + fwd1 := fwd_list.configPortForwardings["inbound.tcp.5580.127.0.0.1:5599"].(ForwardConfigIncomingTcp) + assert.NotNil(t, fwd1) + assert.Equal(t, fwd1.forwardLocalAddress, "127.0.0.1:5599") + assert.Equal(t, int(fwd1.port), 5580) +} + +func TestConfigWithTcpOut(t *testing.T) { + l := logrus.New() + c := config.NewC(l) + err := c.LoadString(` +port_forwarding: + outbound: + - listen_address: 127.0.0.1:3399 + dial_address: 192.168.100.92:4499 + protocols: [tcp] + inbound: + - listen_port: 5580 + dial_address: 127.0.0.1:5599 + protocols: [] +`) + assert.Nil(t, err) + + fwd_list := NewPortForwardingList() + err = ParseConfig(l, c, fwd_list) + assert.Nil(t, err) + + assert.Len(t, fwd_list.configPortForwardings, 1) + assert.False(t, fwd_list.IsEmpty()) + + fwd1 := fwd_list.configPortForwardings["outbound.tcp.127.0.0.1:3399.192.168.100.92:4499"].(ForwardConfigOutgoingTcp) + assert.NotNil(t, fwd1) + assert.Equal(t, fwd1.localListen, "127.0.0.1:3399") + assert.Equal(t, fwd1.remoteConnect, "192.168.100.92:4499") +} + +func TestConfigWithUdpIn(t *testing.T) { + l := logrus.New() + c := config.NewC(l) + err := c.LoadString(` +port_forwarding: + outbound: + - listen_address: 127.0.0.1:3399 + dial_address: 192.168.100.92:4499 + protocols: [] + inbound: + - listen_port: 5580 + dial_address: 127.0.0.1:5599 + protocols: [udp] +`) + assert.Nil(t, err) + + fwd_list := NewPortForwardingList() + err = ParseConfig(l, c, fwd_list) + assert.Nil(t, err) + + assert.Len(t, fwd_list.configPortForwardings, 1) + assert.False(t, fwd_list.IsEmpty()) + + fwd1 := fwd_list.configPortForwardings["inbound.udp.5580.127.0.0.1:5599"].(ForwardConfigIncomingUdp) + assert.NotNil(t, fwd1) + assert.Equal(t, fwd1.forwardLocalAddress, "127.0.0.1:5599") + assert.Equal(t, int(fwd1.port), 5580) +} + +func TestConfigWithUdpOut(t *testing.T) { + l := logrus.New() + c := config.NewC(l) + err := c.LoadString(` +port_forwarding: + outbound: + - listen_address: 127.0.0.1:3399 + dial_address: 192.168.100.92:4499 + protocols: [udp] + inbound: + - listen_port: 5580 + dial_address: 127.0.0.1:5599 + protocols: [] +`) + assert.Nil(t, err) + + fwd_list := NewPortForwardingList() + err = ParseConfig(l, c, fwd_list) + assert.Nil(t, err) + + assert.Len(t, fwd_list.configPortForwardings, 1) + assert.False(t, fwd_list.IsEmpty()) + + fwd1 := fwd_list.configPortForwardings["outbound.udp.127.0.0.1:3399.192.168.100.92:4499"].(ForwardConfigOutgoingUdp) + assert.NotNil(t, fwd1) + assert.Equal(t, fwd1.localListen, "127.0.0.1:3399") + assert.Equal(t, fwd1.remoteConnect, "192.168.100.92:4499") +} + +func TestConfigWithMultipleMixed(t *testing.T) { + l := logrus.New() + c := config.NewC(l) + err := c.LoadString(` +port_forwarding: + outbound: + - listen_address: 127.0.0.1:3399 + dial_address: 192.168.100.92:4499 + protocols: [udp, tcp] + - listen_address: 127.0.0.1:3399 + dial_address: 192.168.100.92:5499 + protocols: [tcp] + inbound: + - listen_port: 5580 + dial_address: 127.0.0.1:5599 + protocols: [tcp, udp] + - listen_port: 5570 + dial_address: 127.0.0.1:5555 + protocols: [udp] +`) + assert.Nil(t, err) + + fwd_list := NewPortForwardingList() + err = ParseConfig(l, c, fwd_list) + assert.Nil(t, err) + + assert.Len(t, fwd_list.configPortForwardings, 6) + assert.False(t, fwd_list.IsEmpty()) + + assert.NotNil(t, fwd_list.configPortForwardings["outbound.udp.127.0.0.1:3399.192.168.100.92:4499"].(ForwardConfigOutgoingUdp)) + assert.NotNil(t, fwd_list.configPortForwardings["outbound.tcp.127.0.0.1:3399.192.168.100.92:4499"].(ForwardConfigOutgoingTcp)) + assert.NotNil(t, fwd_list.configPortForwardings["outbound.tcp.127.0.0.1:3399.192.168.100.92:5499"].(ForwardConfigOutgoingTcp)) + assert.NotNil(t, fwd_list.configPortForwardings["inbound.tcp.5580.127.0.0.1:5599"].(ForwardConfigIncomingTcp)) + assert.NotNil(t, fwd_list.configPortForwardings["inbound.udp.5580.127.0.0.1:5599"].(ForwardConfigIncomingUdp)) + assert.NotNil(t, fwd_list.configPortForwardings["inbound.udp.5570.127.0.0.1:5555"].(ForwardConfigIncomingUdp)) +} + +func TestConfigWithOverlappingRulesNoDuplicatesInResult(t *testing.T) { + l := logrus.New() + c := config.NewC(l) + err := c.LoadString(` +port_forwarding: + outbound: + - listen_address: 127.0.0.1:3399 + dial_address: 192.168.100.92:4499 + protocols: [udp, tcp, udp] + - listen_address: 127.0.0.1:3399 + dial_address: 192.168.100.92:4499 + protocols: [tcp] + inbound: + - listen_port: 5580 + dial_address: 127.0.0.1:5599 + protocols: [tcp, udp] + - listen_port: 5580 + dial_address: 127.0.0.1:5599 + protocols: [udp, udp] +`) + assert.Nil(t, err) + + fwd_list := NewPortForwardingList() + err = ParseConfig(l, c, fwd_list) + assert.Nil(t, err) + + assert.Len(t, fwd_list.configPortForwardings, 4) + assert.False(t, fwd_list.IsEmpty()) + + assert.NotNil(t, fwd_list.configPortForwardings["outbound.udp.127.0.0.1:3399.192.168.100.92:4499"].(ForwardConfigOutgoingUdp)) + assert.NotNil(t, fwd_list.configPortForwardings["outbound.tcp.127.0.0.1:3399.192.168.100.92:4499"].(ForwardConfigOutgoingTcp)) + assert.NotNil(t, fwd_list.configPortForwardings["inbound.tcp.5580.127.0.0.1:5599"].(ForwardConfigIncomingTcp)) + assert.NotNil(t, fwd_list.configPortForwardings["inbound.udp.5580.127.0.0.1:5599"].(ForwardConfigIncomingUdp)) +} diff --git a/port_forwarder/fwd_tcp.go b/port_forwarder/fwd_tcp.go new file mode 100644 index 000000000..c345876ef --- /dev/null +++ b/port_forwarder/fwd_tcp.go @@ -0,0 +1,225 @@ +package port_forwarder + +import ( + "context" + "fmt" + "io" + "net" + "time" + + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/service" + "golang.org/x/sync/errgroup" +) + +type ForwardConfigOutgoingTcp struct { + ForwardConfigOutgoing +} + +func (cfg ForwardConfigOutgoingTcp) ConfigDescriptor() string { + return fmt.Sprintf("outbound.tcp.%s.%s", cfg.localListen, cfg.remoteConnect) +} + +type ForwardConfigIncomingTcp struct { + ForwardConfigIncoming +} + +func (cfg ForwardConfigIncomingTcp) ConfigDescriptor() string { + return fmt.Sprintf("inbound.tcp.%d.%s", cfg.port, cfg.forwardLocalAddress) +} + +type PortForwardingCommonTcp struct { + ctx context.Context + l *logrus.Logger + tunService *service.Service + localListenConnection net.Listener +} + +func (fwd PortForwardingCommonTcp) Close() error { + fwd.localListenConnection.Close() + return nil +} + +type PortForwardingOutgoingTcp struct { + PortForwardingCommonTcp + cfg ForwardConfigOutgoingTcp +} + +func (cf ForwardConfigOutgoingTcp) SetupPortForwarding( + tunService *service.Service, + l *logrus.Logger, +) (io.Closer, error) { + localTcpListenAddr, err := net.ResolveTCPAddr("tcp", cf.localListen) + if err != nil { + return nil, err + } + localListenPort, err := net.ListenTCP("tcp", localTcpListenAddr) + if err != nil { + return nil, err + } + + l.Infof("TCP port forwarding to '%v': listening on local TCP addr: '%v'", + cf.remoteConnect, localTcpListenAddr) + + ctx, cancel := context.WithCancel(context.Background()) + + portForwarding := &PortForwardingOutgoingTcp{ + PortForwardingCommonTcp: PortForwardingCommonTcp{ + ctx: ctx, + l: l, + tunService: tunService, + localListenConnection: localListenPort, + }, + cfg: cf, + } + + go func() { + defer cancel() + portForwarding.acceptOnLocalListenPort_generic(portForwarding.handleClientConnectionWithErrorReturn) + }() + + return portForwarding, nil +} + +func (pt *PortForwardingCommonTcp) acceptOnLocalListenPort_generic( + handleClientConnectionWithErrorReturn func(localConnection net.Conn) error, +) error { + for { + pt.l.Debug("listening on local TCP port ...") + connection, err := pt.localListenConnection.Accept() + if err != nil { + fmt.Println(err) + return err + } + + pt.l.Debugf("accept TCP connect from local TCP port: %v", connection.RemoteAddr()) + + go func() { + defer connection.Close() + <-pt.ctx.Done() + }() + + go func() { + err := handleClientConnectionWithErrorReturn(connection) + if err != nil { + pt.l.Debugf("Closed TCP client connection %s. Err: %+v", + connection.LocalAddr().String(), err) + } + }() + } +} + +func (pt *PortForwardingOutgoingTcp) handleClientConnectionWithErrorReturn(localConnection net.Conn) error { + + remoteConnection, err := pt.tunService.DialContext(context.Background(), "tcp", pt.cfg.remoteConnect) + if err != nil { + return err + } + return handleTcpClientConnectionPair_generic(pt.l, localConnection, remoteConnection) +} + +func handleTcpClientConnectionPair_generic(l *logrus.Logger, connA, connB net.Conn) error { + + dataTransferHandler := func(from, to net.Conn) error { + + name := fmt.Sprintf("%s -> %s", from.LocalAddr().String(), to.LocalAddr().String()) + + defer from.Close() + defer to.Close() + + // no write/read timeout + to.SetDeadline(time.Time{}) + from.SetDeadline(time.Time{}) + megabyte := (1 << 20) + buf := make([]byte, 1*megabyte) + if false { + // this variant seems to be slightly slower on the local speed-test. 1.60GiB/s vs. 1.70GiB/s + n, err := io.CopyBuffer(to, from, buf) + l.WithError(err). + WithField("payloadSize", n). + WithField("from", from.RemoteAddr()). + WithField("to", to.RemoteAddr()). + WithField("localFrom", from.LocalAddr()). + WithField("localTo", to.LocalAddr()). + Debug("stopped data forwarding") + return err + } else { + for { + rn, r_err := from.Read(buf) + l.Tracef("%s read(%d), err: %v", name, rn, r_err) + for i := 0; i < rn; { + wn, w_err := to.Write(buf[i:rn]) + if w_err != nil { + l.Debugf("%s writing(%d) to to-connection failed: %v", name, rn, w_err) + return w_err + } + i += wn + } + if r_err != nil { + l.Debugf("%s reading(%d) from from-connection failed: %v", name, rn, r_err) + return r_err + } + } + } + } + + errGroup := errgroup.Group{} + + errGroup.Go(func() error { return dataTransferHandler(connA, connB) }) + errGroup.Go(func() error { return dataTransferHandler(connB, connA) }) + + return errGroup.Wait() +} + +type PortForwardingIncomingTcp struct { + PortForwardingCommonTcp + cfg ForwardConfigIncomingTcp +} + +func (cf ForwardConfigIncomingTcp) SetupPortForwarding( + tunService *service.Service, + l *logrus.Logger, +) (io.Closer, error) { + + localListenPort, err := tunService.Listen("tcp", fmt.Sprintf(":%d", cf.port)) + if err != nil { + return nil, err + } + + l.Infof("TCP port forwarding to '%v': listening on local, outside TCP addr: ':%d'", + cf.forwardLocalAddress, cf.port) + + ctx, cancel := context.WithCancel(context.Background()) + + portForwarding := &PortForwardingIncomingTcp{ + PortForwardingCommonTcp: PortForwardingCommonTcp{ + ctx: ctx, + l: l, + tunService: tunService, + localListenConnection: localListenPort, + }, + cfg: cf, + } + + go func() { + defer cancel() + portForwarding.acceptOnLocalListenPort_generic(portForwarding.handleClientConnectionWithErrorReturn) + }() + + return portForwarding, nil +} + +func (pt *PortForwardingIncomingTcp) handleClientConnectionWithErrorReturn(outsideConnection net.Conn) error { + + fwdAddr, err := net.ResolveTCPAddr("tcp", pt.cfg.forwardLocalAddress) + if err != nil { + return err + } + + localConnection, err := net.DialTCP("tcp", nil, fwdAddr) + if err != nil { + return err + } + + return handleTcpClientConnectionPair_generic(pt.l, outsideConnection, localConnection) +} diff --git a/port_forwarder/fwd_udp.go b/port_forwarder/fwd_udp.go new file mode 100644 index 000000000..789ae78b6 --- /dev/null +++ b/port_forwarder/fwd_udp.go @@ -0,0 +1,302 @@ +package port_forwarder + +import ( + "errors" + "fmt" + "io" + "net" + "time" + + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/service" + "gvisor.dev/gvisor/pkg/tcpip/adapters/gonet" +) + +type ForwardConfigOutgoingUdp struct { + ForwardConfigOutgoing +} + +func (cfg ForwardConfigOutgoingUdp) ConfigDescriptor() string { + return fmt.Sprintf("outbound.udp.%s.%s", cfg.localListen, cfg.remoteConnect) +} + +type ForwardConfigIncomingUdp struct { + ForwardConfigIncoming +} + +func (cfg ForwardConfigIncomingUdp) ConfigDescriptor() string { + return fmt.Sprintf("inbound.udp.%d.%s", cfg.port, cfg.forwardLocalAddress) +} + +// use UDP timeout of 300 seconds according to +// https://support.goto.com/connect/help/what-are-the-recommended-nat-keep-alive-settings +var UDP_CONNECTION_TIMEOUT_SECONDS uint32 = 300 + +type udpConnInterface interface { + WriteTo(b []byte, addr net.Addr) (int, error) + Write(b []byte) (int, error) + ReadFrom(b []byte) (int, net.Addr, error) +} + +func handleUdpDestinationPortResponseReading[destConn net.Conn, srcConn udpConnInterface]( + l *logrus.Logger, + loggingFields logrus.Fields, + closedConnections *chan string, + sourceAddr net.Addr, + destConnection *TimedConnection[destConn], + localListenConnection srcConn, +) error { + // net.Conn is thread-safe according to: https://pkg.go.dev/net#Conn + // no need for remoteConnection to protect by mutex + + defer func() { (*closedConnections) <- sourceAddr.String() }() + + l.WithFields(loggingFields).Debug("begin reading responses ...") + buf := make([]byte, 2*(1<<16)) + for { + destConnection.connection.SetDeadline(time.Now().Add(time.Second * 10)) + l.WithFields(loggingFields).Trace("response read ...") + n, err := destConnection.connection.Read(buf) + if n == 0 { + if netErr, ok := err.(net.Error); ok && netErr.Timeout() { + l.WithFields(loggingFields).Debug("response read - timeout tick") + if destConnection.timeout_counter.Increment(10) { + l.WithFields(loggingFields).Debug("response read - closed due to timeout") + return nil + } + continue + } else { + l.WithFields(loggingFields).WithError(err).Debugf("response read - close due to error") + return err + } + } + + destConnection.timeout_counter.Reset() + l.WithFields(loggingFields). + WithField("payloadSize", n). + Debug("response forward") + n, err = localListenConnection.WriteTo(buf[:n], sourceAddr) + if n == 0 && (err != nil) { + l.WithFields(loggingFields).WithError(err).Debugf("response forward - write error") + return err + } + } +} + +func handleClosedConnections[C any]( + l *logrus.Logger, + closedConnections *chan string, + portReaders *map[string]bool, + remoteConnections *map[string]*TimedConnection[C], +) { +cleanup: + for { + select { + case closedOne := <-(*closedConnections): + l.Debugf("closing connection to %s", closedOne) + delete(*remoteConnections, closedOne) + delete(*portReaders, closedOne) + default: + break cleanup + } + } +} + +type PortForwardingCommonUdp struct { + l *logrus.Logger + tunService *service.Service + // net.Conn is thread-safe according to: https://pkg.go.dev/net#Conn + // no need for localListenConnection to protect by mutex + localListenConnection io.Closer +} + +func (fwd PortForwardingCommonUdp) Close() error { + fwd.localListenConnection.Close() + return nil +} + +type PortForwardingOutgoingUdp struct { + PortForwardingCommonUdp + cfg ForwardConfigOutgoingUdp +} + +func (cfg ForwardConfigOutgoingUdp) SetupPortForwarding( + tunService *service.Service, + l *logrus.Logger, +) (io.Closer, error) { + localUdpListenAddr, err := net.ResolveUDPAddr("udp", cfg.localListen) + if err != nil { + return nil, err + } + + localListenConnection, err := net.ListenUDP("udp", localUdpListenAddr) + if err != nil { + return nil, err + } + + l.Infof("UDP port forwarding to '%v': listening on local UDP addr: '%v'", + cfg.remoteConnect, localUdpListenAddr) + + portForwarding := &PortForwardingOutgoingUdp{ + PortForwardingCommonUdp: PortForwardingCommonUdp{ + l: l, + tunService: tunService, + localListenConnection: localListenConnection, + }, + cfg: cfg, + } + + logPrefix := logrus.Fields{ + "a": "UDP fwd out", + "listen": localListenConnection.LocalAddr(), + "dial": cfg.remoteConnect, + } + + go func() { + err := listenLocalPort_generic( + l, + logPrefix, + localListenConnection, + func(address string) (*gonet.UDPConn, error) { + return tunService.DialUDP(address) + }, + cfg.remoteConnect, + ) + if err != nil { + l.WithFields(logPrefix).WithError(err). + Error("listening stopped with error") + } + }() + + return portForwarding, nil +} + +func listenLocalPort_generic[destConn net.Conn]( + l *logrus.Logger, + loggingFields logrus.Fields, + localListenConnection udpConnInterface, + dial func(address string) (destConn, error), + remoteConnect string, +) error { + dialConnResponseReaders := make(map[string]bool) + dialConnections := make(map[string]*TimedConnection[destConn]) + closedConnections := make(chan string) + + l.WithFields(loggingFields).Debug("start listening ...") + var buf [512 * 1024]byte + for { + handleClosedConnections(l, &closedConnections, &dialConnResponseReaders, &dialConnections) + + l.WithFields(loggingFields).Trace("reading data ...") + n, localSourceAddr, err := localListenConnection.ReadFrom(buf[0:]) + if err != nil { + if errors.Is(err, io.EOF) { + return nil + } + l.WithFields(loggingFields).Error("listen for data failed. stop.") + return err + } + + l.WithFields(loggingFields). + WithField("source", localSourceAddr). + WithField("payloadSize", n). + Trace("read data") + + dialConnection, ok := dialConnections[localSourceAddr.String()] + if !ok { + newDialConn, err := dial(remoteConnect) + if err != nil { + l.WithFields(loggingFields).WithError(err).Error("dialing dial address failed") + continue + } + dialConnection = &TimedConnection[destConn]{ + connection: newDialConn, + timeout_counter: NewTimeoutCounter(UDP_CONNECTION_TIMEOUT_SECONDS), + } + dialConnections[localSourceAddr.String()] = dialConnection + } + + l.WithFields(loggingFields). + WithField("source", localSourceAddr). + WithField("dialSource", dialConnection.connection.LocalAddr()). + WithField("payloadSize", n). + Debug("forward") + + dialConnection.timeout_counter.Reset() + dialConnection.connection.Write(buf[:n]) + + _, ok = dialConnResponseReaders[localSourceAddr.String()] + if !ok { + loggingFieldsRsp := logrus.Fields{ + "source": localSourceAddr, + "dialSource": dialConnection.connection.LocalAddr(), + } + for k, v := range loggingFields { + loggingFieldsRsp[k] = v + } + dialConnResponseReaders[localSourceAddr.String()] = true + go func() error { + return handleUdpDestinationPortResponseReading( + l, loggingFieldsRsp, &closedConnections, localSourceAddr, + dialConnection, localListenConnection) + }() + } + } +} + +type PortForwardingIncomingUdp struct { + PortForwardingCommonUdp + cfg ForwardConfigIncomingUdp +} + +func (cfg ForwardConfigIncomingUdp) SetupPortForwarding( + tunService *service.Service, + l *logrus.Logger, +) (io.Closer, error) { + + conn, err := tunService.ListenUDP(fmt.Sprintf(":%d", cfg.port)) + if err != nil { + return nil, err + } + + l.Infof("UDP port forwarding to '%v': listening on outside UDP addr: ':%d'", + cfg.forwardLocalAddress, cfg.port) + + logPrefix := logrus.Fields{ + "a": "UDP fwd in", + "listenPort": cfg.port, + "dial": cfg.forwardLocalAddress, + } + + forwarding := &PortForwardingIncomingUdp{ + PortForwardingCommonUdp: PortForwardingCommonUdp{ + l: l, + tunService: tunService, + localListenConnection: conn, + }, + cfg: cfg, + } + + go func() { + err := listenLocalPort_generic( + l, + logPrefix, + conn, + func(address string) (*net.UDPConn, error) { + fwdAddr, err := net.ResolveUDPAddr("udp", cfg.forwardLocalAddress) + if err != nil { + l.WithFields(logPrefix).Error("resolve of dial address failed") + return nil, err + } + return net.DialUDP("udp", nil, fwdAddr) + }, + cfg.forwardLocalAddress, + ) + if err != nil { + l.WithFields(logPrefix).WithError(err). + Error("listening stopped with error") + } + }() + + return forwarding, nil +} diff --git a/port_forwarder/lockfree_timeout_counter.go b/port_forwarder/lockfree_timeout_counter.go new file mode 100644 index 000000000..79d30d167 --- /dev/null +++ b/port_forwarder/lockfree_timeout_counter.go @@ -0,0 +1,33 @@ +package port_forwarder + +import "sync/atomic" + +type TimeoutCounter struct { + counter atomic.Uint32 + threshold uint32 +} + +func NewTimeoutCounter(threshold uint32) TimeoutCounter { + return TimeoutCounter{ + counter: atomic.Uint32{}, + threshold: threshold, + } +} + +func (tc *TimeoutCounter) Increment(step uint32) bool { + tc.counter.Add(step) + return tc.IsTimeout() +} + +func (tc *TimeoutCounter) Reset() { + tc.counter.Store(0) +} + +func (tc *TimeoutCounter) IsTimeout() bool { + return tc.counter.Load() > tc.threshold +} + +type TimedConnection[C any] struct { + connection C + timeout_counter TimeoutCounter +} diff --git a/port_forwarder/port_forwarding_service.go b/port_forwarder/port_forwarding_service.go new file mode 100644 index 000000000..8bad84e4d --- /dev/null +++ b/port_forwarder/port_forwarding_service.go @@ -0,0 +1,54 @@ +package port_forwarder + +import ( + "io" + + "github.com/sirupsen/logrus" + "github.com/slackhq/nebula/service" +) + +type PortForwardingService struct { + l *logrus.Logger + tunService *service.Service + + configPortForwardings map[string]ForwardConfig + portForwardings map[string]io.Closer +} + +func (t *PortForwardingService) AddConfig(cfg ForwardConfig) { + t.configPortForwardings[cfg.ConfigDescriptor()] = cfg +} + +func (t *PortForwardingService) Activate() error { + return t.ActivateNew(t.configPortForwardings) +} + +func (t *PortForwardingService) ActivateNew(newForwards map[string]ForwardConfig) error { + + for descriptor, config := range newForwards { + fwd_instance, err := config.SetupPortForwarding(t.tunService, t.l) + if err == nil { + t.configPortForwardings[config.ConfigDescriptor()] = config + t.portForwardings[config.ConfigDescriptor()] = fwd_instance + } else { + t.l.Errorf("failed to setup port forwarding #%s: %s", descriptor, config.ConfigDescriptor()) + } + } + + return nil +} + +func (t *PortForwardingService) CloseSelective(descriptors []string) error { + + for _, descriptor := range descriptors { + delete(t.configPortForwardings, descriptor) + pf, ok := t.portForwardings[descriptor] + if ok { + t.l.Infof("closing port forwarding: %s", descriptor) + pf.Close() + delete(t.portForwardings, descriptor) + } + } + + return nil +} diff --git a/service/service.go b/service/service.go index c1091a272..6969404c1 100644 --- a/service/service.go +++ b/service/service.go @@ -45,8 +45,7 @@ type Service struct { } } -func New(config *config.C) (*Service, error) { - logger := logrus.New() +func New(config *config.C, logger *logrus.Logger) (*Service, error) { logger.Out = os.Stdout control, err := nebula.Main(config, false, "custom-app", logger, overlay.NewUserDeviceFromConfig) @@ -138,7 +137,7 @@ func New(config *config.C) (*Service, error) { } // DialContext dials the provided address. Currently only TCP is supported. -func (s *Service) DialContext(ctx context.Context, network, address string) (net.Conn, error) { +func (s *Service) DialContext(ctx context.Context, network, address string) (*gonet.TCPConn, error) { if network != "tcp" && network != "tcp4" { return nil, errors.New("only tcp is supported") } @@ -157,6 +156,21 @@ func (s *Service) DialContext(ctx context.Context, network, address string) (net return gonet.DialContextTCP(ctx, s.ipstack, fullAddr, ipv4.ProtocolNumber) } +func (s *Service) DialUDP(address string) (*gonet.UDPConn, error) { + addr, err := net.ResolveUDPAddr("udp", address) + if err != nil { + return nil, err + } + + fullAddr := tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.AddrFromSlice(addr.IP), + Port: uint16(addr.Port), + } + + return gonet.DialUDP(s.ipstack, nil, &fullAddr, ipv4.ProtocolNumber) +} + // Listen listens on the provided address. Currently only TCP with wildcard // addresses are supported. func (s *Service) Listen(network, address string) (net.Listener, error) { @@ -196,6 +210,19 @@ func (s *Service) Listen(network, address string) (net.Listener, error) { return l, nil } +func (s *Service) ListenUDP(address string) (*gonet.UDPConn, error) { + addr, err := net.ResolveUDPAddr("udp", address) + if err != nil { + return nil, err + } + return gonet.DialUDP(s.ipstack, &tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.AddrFromSlice(addr.IP), + Port: uint16(addr.Port), + LinkAddr: "", + }, nil, ipv4.ProtocolNumber) +} + func (s *Service) Wait() error { return s.eg.Wait() } From ea77cded33f77313100d4ae095462a47b03d3c27 Mon Sep 17 00:00:00 2001 From: Ulrich Hornung Date: Sun, 11 Aug 2024 20:05:15 +0200 Subject: [PATCH 10/33] fix fmt --- overlay/user.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/overlay/user.go b/overlay/user.go index b938d4762..d34c1d96a 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -60,8 +60,8 @@ type UserDevice struct { func (d *UserDevice) Activate() error { return nil } -func (d *UserDevice) Cidr() netip.Prefix { return d.tunCidr } -func (d *UserDevice) Name() string { return "faketun0" } +func (d *UserDevice) Cidr() netip.Prefix { return d.tunCidr } +func (d *UserDevice) Name() string { return "faketun0" } func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { ptr := d.routeTree.Load() if ptr != nil { From ac016c9d60a9f51434fb413cb197fe5bc2453f3a Mon Sep 17 00:00:00 2001 From: Ulrich Hornung Date: Mon, 12 Aug 2024 12:02:08 +0200 Subject: [PATCH 11/33] try to fix instability of the service level tests --- service/service_testhelpers.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/service/service_testhelpers.go b/service/service_testhelpers.go index 28661865b..4b8d913cb 100644 --- a/service/service_testhelpers.go +++ b/service/service_testhelpers.go @@ -77,7 +77,10 @@ func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, } func CreateTwoConnectedServices(port int) (*Service, *Service) { - ca, _, caKey, _ := e2e.NewTestCaCert(time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{}) + ca, _, caKey, _ := e2e.NewTestCaCert( + time.Now().Add(-5*time.Minute), // ensure that there is no issue due to rounding + time.Now().Add(30*time.Minute), // ensure that the certificate is valid for at least the time ot the test execution + nil, nil, []string{}) a := newSimpleService(ca, caKey, "a", netip.MustParseAddr("10.0.0.1"), m{ "static_host_map": m{}, "lighthouse": m{ From a678cff41f551cbeb9526e8454758568d70b23f1 Mon Sep 17 00:00:00 2001 From: Ulrich Hornung Date: Mon, 12 Aug 2024 19:55:21 +0200 Subject: [PATCH 12/33] lets see if randomization of the port helps --- service/service_testhelpers.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/service/service_testhelpers.go b/service/service_testhelpers.go index 4b8d913cb..7384de4d4 100644 --- a/service/service_testhelpers.go +++ b/service/service_testhelpers.go @@ -2,6 +2,7 @@ package service import ( "fmt" + "math/rand" "net/netip" "time" @@ -77,6 +78,7 @@ func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, } func CreateTwoConnectedServices(port int) (*Service, *Service) { + port += 100 * (rand.Int() % 10) ca, _, caKey, _ := e2e.NewTestCaCert( time.Now().Add(-5*time.Minute), // ensure that there is no issue due to rounding time.Now().Add(30*time.Minute), // ensure that the certificate is valid for at least the time ot the test execution From 9d60a1b4bbb73cf94beecc2762ee40da6449de48 Mon Sep 17 00:00:00 2001 From: Ulrich Hornung Date: Mon, 12 Aug 2024 20:04:30 +0200 Subject: [PATCH 13/33] avoid panic due to writing to closed channel --- overlay/user.go | 2 +- service/service.go | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/overlay/user.go b/overlay/user.go index d34c1d96a..282c7852e 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -110,6 +110,6 @@ func (d *UserDevice) ReadFrom(r io.Reader) (n int64, err error) { func (d *UserDevice) Close() error { close(d.inboundChannel) - close(d.outboundChannel) + // close(d.outboundChannel) outbound channel needs to be closed from writer side return nil } diff --git a/service/service.go b/service/service.go index 6969404c1..faf1bdfa4 100644 --- a/service/service.go +++ b/service/service.go @@ -121,6 +121,7 @@ func New(config *config.C, logger *logrus.Logger) (*Service, error) { } }) eg.Go(func() error { + defer close(nebula_tun_writer) for { packet := linkEP.ReadContext(ctx) if packet == nil { From bd553623a149c2090fe029a60f28ad52f9dd85f5 Mon Sep 17 00:00:00 2001 From: Ulrich Hornung Date: Tue, 13 Aug 2024 23:27:00 +0200 Subject: [PATCH 14/33] extend validity time range to avoid race conditions in CI --- service/service_testhelpers.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/service/service_testhelpers.go b/service/service_testhelpers.go index 7384de4d4..2d4504f76 100644 --- a/service/service_testhelpers.go +++ b/service/service_testhelpers.go @@ -17,7 +17,10 @@ import ( type m map[string]interface{} func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service { - _, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", time.Now(), time.Now().Add(5*time.Minute), netip.PrefixFrom(udpIp, 24), nil, []string{}) + _, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", + time.Now().Add(-5*time.Minute), + time.Now().Add(30*time.Minute), + netip.PrefixFrom(udpIp, 24), nil, []string{}) caB, err := caCrt.MarshalToPEM() if err != nil { panic(err) From a9b0b1d6b4084c28ec79c954ec333d829e990a07 Mon Sep 17 00:00:00 2001 From: Ulrich Hornung Date: Wed, 14 Aug 2024 10:28:33 +0200 Subject: [PATCH 15/33] ensure that node certs lifetime doesn't outlife ca cert lifetime --- service/service_testhelpers.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/service/service_testhelpers.go b/service/service_testhelpers.go index 2d4504f76..0af404b45 100644 --- a/service/service_testhelpers.go +++ b/service/service_testhelpers.go @@ -18,7 +18,7 @@ type m map[string]interface{} func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service { _, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", - time.Now().Add(-5*time.Minute), + time.Now().Add(-3*time.Minute), time.Now().Add(30*time.Minute), netip.PrefixFrom(udpIp, 24), nil, []string{}) caB, err := caCrt.MarshalToPEM() @@ -83,8 +83,8 @@ func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, func CreateTwoConnectedServices(port int) (*Service, *Service) { port += 100 * (rand.Int() % 10) ca, _, caKey, _ := e2e.NewTestCaCert( - time.Now().Add(-5*time.Minute), // ensure that there is no issue due to rounding - time.Now().Add(30*time.Minute), // ensure that the certificate is valid for at least the time ot the test execution + time.Now().Add(-9*time.Minute), // ensure that there is no issue due to rounding + time.Now().Add(40*time.Minute), // ensure that the certificate is valid for at least the time ot the test execution nil, nil, []string{}) a := newSimpleService(ca, caKey, "a", netip.MustParseAddr("10.0.0.1"), m{ "static_host_map": m{}, From d22fe212fa8e3551fa4876958bfccf55a56d0408 Mon Sep 17 00:00:00 2001 From: Ulrich Hornung Date: Wed, 14 Aug 2024 19:59:18 +0200 Subject: [PATCH 16/33] consider injected name for certificate configuration --- service/service_testhelpers.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/service/service_testhelpers.go b/service/service_testhelpers.go index 0af404b45..c0cbef057 100644 --- a/service/service_testhelpers.go +++ b/service/service_testhelpers.go @@ -17,7 +17,7 @@ import ( type m map[string]interface{} func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service { - _, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, "a", + _, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, name, time.Now().Add(-3*time.Minute), time.Now().Add(30*time.Minute), netip.PrefixFrom(udpIp, 24), nil, []string{}) From 670b3abb4f046079d39483d62c698c4dfe56626d Mon Sep 17 00:00:00 2001 From: Ulrich Hornung Date: Wed, 14 Aug 2024 20:17:08 +0200 Subject: [PATCH 17/33] try to unique node names in tests --- service/service_testhelpers.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/service/service_testhelpers.go b/service/service_testhelpers.go index c0cbef057..8c4c8e001 100644 --- a/service/service_testhelpers.go +++ b/service/service_testhelpers.go @@ -86,7 +86,7 @@ func CreateTwoConnectedServices(port int) (*Service, *Service) { time.Now().Add(-9*time.Minute), // ensure that there is no issue due to rounding time.Now().Add(40*time.Minute), // ensure that the certificate is valid for at least the time ot the test execution nil, nil, []string{}) - a := newSimpleService(ca, caKey, "a", netip.MustParseAddr("10.0.0.1"), m{ + a := newSimpleService(ca, caKey, fmt.Sprintf("a_port_%d", port), netip.MustParseAddr("10.0.0.1"), m{ "static_host_map": m{}, "lighthouse": m{ "am_lighthouse": true, @@ -96,7 +96,7 @@ func CreateTwoConnectedServices(port int) (*Service, *Service) { "port": port, }, }) - b := newSimpleService(ca, caKey, "b", netip.MustParseAddr("10.0.0.2"), m{ + b := newSimpleService(ca, caKey, fmt.Sprintf("b_port_%d", port), netip.MustParseAddr("10.0.0.2"), m{ "static_host_map": m{ "10.0.0.1": []string{fmt.Sprintf("localhost:%d", port)}, }, From 6d03850d8dee8577d4e77080ad47ef38f71f317b Mon Sep 17 00:00:00 2001 From: Ulrich Hornung Date: Thu, 15 Aug 2024 09:52:22 +0200 Subject: [PATCH 18/33] add name of the test to the certificate --- port_forwarder/port_forwarder_tcp_test.go | 6 +++--- port_forwarder/port_forwarder_udp_test.go | 2 +- service/service_test.go | 2 +- service/service_testhelpers.go | 7 ++++--- 4 files changed, 9 insertions(+), 8 deletions(-) diff --git a/port_forwarder/port_forwarder_tcp_test.go b/port_forwarder/port_forwarder_tcp_test.go index 2f86c37a4..742b1e3fd 100644 --- a/port_forwarder/port_forwarder_tcp_test.go +++ b/port_forwarder/port_forwarder_tcp_test.go @@ -48,7 +48,7 @@ func doTestTcpCommunicationFail( func TestTcpInOut2Clients(t *testing.T) { l := logrus.New() - server, client := service.CreateTwoConnectedServices(4247) + server, client := service.CreateTwoConnectedServices(t, 4247) defer client.Close() defer server.Close() @@ -109,7 +109,7 @@ port_forwarding: func TestTcpInOut1ClientConfigReload(t *testing.T) { l := logrus.New() - server, client := service.CreateTwoConnectedServices(4246) + server, client := service.CreateTwoConnectedServices(t, 4246) defer client.Close() defer server.Close() @@ -199,7 +199,7 @@ port_forwarding: func TestTcpInOut1ClientConfigReload_inverseCloseOrder(t *testing.T) { l := logrus.New() - server, client := service.CreateTwoConnectedServices(4245) + server, client := service.CreateTwoConnectedServices(t, 4245) defer client.Close() defer server.Close() diff --git a/port_forwarder/port_forwarder_udp_test.go b/port_forwarder/port_forwarder_udp_test.go index 27ea5800d..c0bf9b505 100644 --- a/port_forwarder/port_forwarder_udp_test.go +++ b/port_forwarder/port_forwarder_udp_test.go @@ -74,7 +74,7 @@ func doTestUdpCommunication( func TestUdpInOut2Clients(t *testing.T) { l := logrus.New() - server, client := service.CreateTwoConnectedServices(4244) + server, client := service.CreateTwoConnectedServices(t, 4244) defer client.Close() defer server.Close() diff --git a/service/service_test.go b/service/service_test.go index b9098c34f..fb488f0f6 100644 --- a/service/service_test.go +++ b/service/service_test.go @@ -10,7 +10,7 @@ import ( ) func TestService(t *testing.T) { - a, b := CreateTwoConnectedServices(4243) + a, b := CreateTwoConnectedServices(t, 4243) ln, err := a.Listen("tcp", ":1234") if err != nil { diff --git a/service/service_testhelpers.go b/service/service_testhelpers.go index 8c4c8e001..5715c2dff 100644 --- a/service/service_testhelpers.go +++ b/service/service_testhelpers.go @@ -4,6 +4,7 @@ import ( "fmt" "math/rand" "net/netip" + "testing" "time" "dario.cat/mergo" @@ -80,13 +81,13 @@ func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, return s } -func CreateTwoConnectedServices(port int) (*Service, *Service) { +func CreateTwoConnectedServices(t *testing.T, port int) (*Service, *Service) { port += 100 * (rand.Int() % 10) ca, _, caKey, _ := e2e.NewTestCaCert( time.Now().Add(-9*time.Minute), // ensure that there is no issue due to rounding time.Now().Add(40*time.Minute), // ensure that the certificate is valid for at least the time ot the test execution nil, nil, []string{}) - a := newSimpleService(ca, caKey, fmt.Sprintf("a_port_%d", port), netip.MustParseAddr("10.0.0.1"), m{ + a := newSimpleService(ca, caKey, fmt.Sprintf("a_port_%d_test_name_%s", port, t.Name()), netip.MustParseAddr("10.0.0.1"), m{ "static_host_map": m{}, "lighthouse": m{ "am_lighthouse": true, @@ -96,7 +97,7 @@ func CreateTwoConnectedServices(port int) (*Service, *Service) { "port": port, }, }) - b := newSimpleService(ca, caKey, fmt.Sprintf("b_port_%d", port), netip.MustParseAddr("10.0.0.2"), m{ + b := newSimpleService(ca, caKey, fmt.Sprintf("b_port_%d_test_name_%s", port, t.Name()), netip.MustParseAddr("10.0.0.2"), m{ "static_host_map": m{ "10.0.0.1": []string{fmt.Sprintf("localhost:%d", port)}, }, From 3fd775b156881f086d6699aea014cf6f40cb2ba4 Mon Sep 17 00:00:00 2001 From: Ulrich Hornung Date: Thu, 15 Aug 2024 10:16:14 +0200 Subject: [PATCH 19/33] add logging prefix to differentiate the output in a test with 2 services --- examples/go_service/main.go | 2 ++ service/service.go | 3 --- service/service_testhelpers.go | 17 +++++++++++++++++ 3 files changed, 19 insertions(+), 3 deletions(-) diff --git a/examples/go_service/main.go b/examples/go_service/main.go index cd07b7fb2..80d5a9893 100644 --- a/examples/go_service/main.go +++ b/examples/go_service/main.go @@ -4,6 +4,7 @@ import ( "bufio" "fmt" "log" + "os" "github.com/sirupsen/logrus" "github.com/slackhq/nebula/config" @@ -60,6 +61,7 @@ pki: return err } l := logrus.New() + l.Out = os.Stdout service, err := service.New(&config, l) if err != nil { return err diff --git a/service/service.go b/service/service.go index faf1bdfa4..1d5f069fd 100644 --- a/service/service.go +++ b/service/service.go @@ -8,7 +8,6 @@ import ( "log" "math" "net" - "os" "strings" "sync" @@ -46,8 +45,6 @@ type Service struct { } func New(config *config.C, logger *logrus.Logger) (*Service, error) { - logger.Out = os.Stdout - control, err := nebula.Main(config, false, "custom-app", logger, overlay.NewUserDeviceFromConfig) if err != nil { return nil, err diff --git a/service/service_testhelpers.go b/service/service_testhelpers.go index 5715c2dff..39b4d2603 100644 --- a/service/service_testhelpers.go +++ b/service/service_testhelpers.go @@ -2,6 +2,7 @@ package service import ( "fmt" + "io" "math/rand" "net/netip" "testing" @@ -17,6 +18,16 @@ import ( type m map[string]interface{} +type LogOutputWithPrefix struct { + prefix string + out io.Writer +} + +func (o LogOutputWithPrefix) Write(p []byte) (n int, err error) { + fmt.Fprintf(o.out, "[%s] ", o.prefix) + return o.out.Write(p) +} + func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service { _, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, name, time.Now().Add(-3*time.Minute), @@ -74,6 +85,12 @@ func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, } l := logrus.New() + prefixWriter := LogOutputWithPrefix{ + prefix: name, + out: l.Out, + } + l.SetOutput(prefixWriter) + s, err := New(&c, l) if err != nil { panic(err) From 99b11b33de678a2522fdcebfbd8b39111f1f70ad Mon Sep 17 00:00:00 2001 From: Ulrich Hornung Date: Fri, 16 Aug 2024 10:11:58 +0200 Subject: [PATCH 20/33] tests: check for clean service shutdown and improve logging --- cmd/nebula/main.go | 14 +------- port_forwarder/port_forwarder_tcp_test.go | 42 ++++++++++------------- port_forwarder/port_forwarder_udp_test.go | 11 +++--- service/service.go | 22 ++++++++++++ service/service_test.go | 2 +- service/service_testhelpers.go | 12 +++---- 6 files changed, 54 insertions(+), 49 deletions(-) diff --git a/cmd/nebula/main.go b/cmd/nebula/main.go index f3e6a35b8..e0d26a28d 100644 --- a/cmd/nebula/main.go +++ b/cmd/nebula/main.go @@ -1,11 +1,8 @@ package main import ( - "context" - "errors" "flag" "fmt" - "io" "os" "os/signal" "syscall" @@ -95,16 +92,7 @@ func main() { <-signalChannel // shutdown: - service.Close() - if err := service.Wait(); err != nil { - if errors.Is(err, os.ErrClosed) || - errors.Is(err, io.EOF) || - errors.Is(err, context.Canceled) { - l.Debugf("Stop of user-tun service returned: %v", err) - } else { - util.LogWithContextIfNeeded("Unclean stop", err, l) - } - } + service.CloseAndWait() } else { diff --git a/port_forwarder/port_forwarder_tcp_test.go b/port_forwarder/port_forwarder_tcp_test.go index 742b1e3fd..a0167fe88 100644 --- a/port_forwarder/port_forwarder_tcp_test.go +++ b/port_forwarder/port_forwarder_tcp_test.go @@ -4,7 +4,6 @@ import ( "net" "testing" - "github.com/sirupsen/logrus" "github.com/slackhq/nebula/service" "github.com/stretchr/testify/assert" ) @@ -47,12 +46,11 @@ func doTestTcpCommunicationFail( } func TestTcpInOut2Clients(t *testing.T) { - l := logrus.New() - server, client := service.CreateTwoConnectedServices(t, 4247) - defer client.Close() - defer server.Close() + server, sl, client, cl := service.CreateTwoConnectedServices(t, 4247) + defer assert.NoError(t, client.CloseAndWait()) + defer assert.NoError(t, server.CloseAndWait()) - server_pf, err := createPortForwarderFromConfigString(l, server, ` + server_pf, err := createPortForwarderFromConfigString(sl, server, ` port_forwarding: inbound: - listen_port: 4495 @@ -63,7 +61,7 @@ port_forwarding: assert.Len(t, server_pf.portForwardings, 1) - client_pf, err := createPortForwarderFromConfigString(l, client, ` + client_pf, err := createPortForwarderFromConfigString(cl, client, ` port_forwarding: outbound: - listen_address: 127.0.0.1:3395 @@ -108,12 +106,11 @@ port_forwarding: } func TestTcpInOut1ClientConfigReload(t *testing.T) { - l := logrus.New() - server, client := service.CreateTwoConnectedServices(t, 4246) - defer client.Close() - defer server.Close() + server, sl, client, cl := service.CreateTwoConnectedServices(t, 4246) + defer assert.NoError(t, client.CloseAndWait()) + defer assert.NoError(t, server.CloseAndWait()) - server_pf, err := createPortForwarderFromConfigString(l, server, ` + server_pf, err := createPortForwarderFromConfigString(sl, server, ` port_forwarding: inbound: - listen_port: 4497 @@ -124,7 +121,7 @@ port_forwarding: assert.Len(t, server_pf.portForwardings, 1) - client_pf, err := createPortForwarderFromConfigString(l, client, ` + client_pf, err := createPortForwarderFromConfigString(cl, client, ` port_forwarding: outbound: - listen_address: 127.0.0.1:3397 @@ -158,7 +155,7 @@ port_forwarding: doTestTcpCommunication(t, "Hello from client one side AGAIN!", client1_conn, client1_server_side_conn) - new_server_fwd_list, err := loadPortFwdConfigFromString(l, ` + new_server_fwd_list, err := loadPortFwdConfigFromString(sl, ` port_forwarding: inbound: - listen_port: 4496 @@ -169,7 +166,7 @@ port_forwarding: assert.Len(t, server_pf.portForwardings, 1) - new_client_fwd_list, err := loadPortFwdConfigFromString(l, ` + new_client_fwd_list, err := loadPortFwdConfigFromString(cl, ` port_forwarding: outbound: - listen_address: 127.0.0.1:3396 @@ -198,12 +195,11 @@ port_forwarding: } func TestTcpInOut1ClientConfigReload_inverseCloseOrder(t *testing.T) { - l := logrus.New() - server, client := service.CreateTwoConnectedServices(t, 4245) - defer client.Close() - defer server.Close() + server, sl, client, cl := service.CreateTwoConnectedServices(t, 4245) + defer assert.NoError(t, client.CloseAndWait()) + defer assert.NoError(t, server.CloseAndWait()) - server_pf, err := createPortForwarderFromConfigString(l, server, ` + server_pf, err := createPortForwarderFromConfigString(sl, server, ` port_forwarding: inbound: - listen_port: 4499 @@ -214,7 +210,7 @@ port_forwarding: assert.Len(t, server_pf.portForwardings, 1) - client_pf, err := createPortForwarderFromConfigString(l, client, ` + client_pf, err := createPortForwarderFromConfigString(cl, client, ` port_forwarding: outbound: - listen_address: 127.0.0.1:3399 @@ -248,7 +244,7 @@ port_forwarding: doTestTcpCommunication(t, "Hello from client one side AGAIN!", client1_conn, client1_server_side_conn) - new_server_fwd_list, err := loadPortFwdConfigFromString(l, ` + new_server_fwd_list, err := loadPortFwdConfigFromString(sl, ` port_forwarding: inbound: - listen_port: 4498 @@ -259,7 +255,7 @@ port_forwarding: assert.Len(t, server_pf.portForwardings, 1) - new_client_fwd_list, err := loadPortFwdConfigFromString(l, ` + new_client_fwd_list, err := loadPortFwdConfigFromString(cl, ` port_forwarding: outbound: - listen_address: 127.0.0.1:3398 diff --git a/port_forwarder/port_forwarder_udp_test.go b/port_forwarder/port_forwarder_udp_test.go index c0bf9b505..4cb638c92 100644 --- a/port_forwarder/port_forwarder_udp_test.go +++ b/port_forwarder/port_forwarder_udp_test.go @@ -73,12 +73,11 @@ func doTestUdpCommunication( } func TestUdpInOut2Clients(t *testing.T) { - l := logrus.New() - server, client := service.CreateTwoConnectedServices(t, 4244) - defer client.Close() - defer server.Close() + server, sl, client, cl := service.CreateTwoConnectedServices(t, 4244) + defer assert.NoError(t, client.CloseAndWait()) + defer assert.NoError(t, server.CloseAndWait()) - server_pf, err := createPortForwarderFromConfigString(l, server, ` + server_pf, err := createPortForwarderFromConfigString(sl, server, ` port_forwarding: inbound: - listen_port: 4499 @@ -89,7 +88,7 @@ port_forwarding: assert.Len(t, server_pf.portForwardings, 1) - client_pf, err := createPortForwarderFromConfigString(l, client, ` + client_pf, err := createPortForwarderFromConfigString(cl, client, ` port_forwarding: outbound: - listen_address: 127.0.0.1:3399 diff --git a/service/service.go b/service/service.go index 1d5f069fd..55cffa5c0 100644 --- a/service/service.go +++ b/service/service.go @@ -5,9 +5,11 @@ import ( "context" "errors" "fmt" + "io" "log" "math" "net" + "os" "strings" "sync" @@ -15,6 +17,7 @@ import ( "github.com/slackhq/nebula" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/overlay" + "github.com/slackhq/nebula/util" "golang.org/x/sync/errgroup" "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/tcpip" @@ -33,6 +36,7 @@ import ( const nicID = 1 type Service struct { + l *logrus.Logger eg *errgroup.Group control *nebula.Control ipstack *stack.Stack @@ -54,6 +58,7 @@ func New(config *config.C, logger *logrus.Logger) (*Service, error) { ctx := control.Context() eg, ctx := errgroup.WithContext(ctx) s := Service{ + l: logger, eg: eg, control: control, } @@ -230,6 +235,23 @@ func (s *Service) Close() error { return nil } +func (s *Service) CloseAndWait() error { + s.Close() + if err := s.Wait(); err != nil { + if errors.Is(err, os.ErrClosed) || + errors.Is(err, io.EOF) || + errors.Is(err, context.Canceled) { + s.l.Debugf("Stop of nebula service returned: %v", err) + return nil + } else { + util.LogWithContextIfNeeded("Unclean stop", err, s.l) + return err + } + } + + return nil +} + func (s *Service) tcpHandler(r *tcp.ForwarderRequest) { endpointID := r.ID() diff --git a/service/service_test.go b/service/service_test.go index fb488f0f6..f466cd513 100644 --- a/service/service_test.go +++ b/service/service_test.go @@ -10,7 +10,7 @@ import ( ) func TestService(t *testing.T) { - a, b := CreateTwoConnectedServices(t, 4243) + a, _, b, _ := CreateTwoConnectedServices(t, 4243) ln, err := a.Listen("tcp", ":1234") if err != nil { diff --git a/service/service_testhelpers.go b/service/service_testhelpers.go index 39b4d2603..62173ef6b 100644 --- a/service/service_testhelpers.go +++ b/service/service_testhelpers.go @@ -28,7 +28,7 @@ func (o LogOutputWithPrefix) Write(p []byte) (n int, err error) { return o.out.Write(p) } -func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp netip.Addr, overrides m) *Service { +func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp netip.Addr, overrides m) (*Service, *logrus.Logger) { _, _, myPrivKey, myPEM := e2e.NewTestCert(caCrt, caKey, name, time.Now().Add(-3*time.Minute), time.Now().Add(30*time.Minute), @@ -95,16 +95,16 @@ func newSimpleService(caCrt *cert.NebulaCertificate, caKey []byte, name string, if err != nil { panic(err) } - return s + return s, l } -func CreateTwoConnectedServices(t *testing.T, port int) (*Service, *Service) { +func CreateTwoConnectedServices(t *testing.T, port int) (*Service, *logrus.Logger, *Service, *logrus.Logger) { port += 100 * (rand.Int() % 10) ca, _, caKey, _ := e2e.NewTestCaCert( time.Now().Add(-9*time.Minute), // ensure that there is no issue due to rounding time.Now().Add(40*time.Minute), // ensure that the certificate is valid for at least the time ot the test execution nil, nil, []string{}) - a := newSimpleService(ca, caKey, fmt.Sprintf("a_port_%d_test_name_%s", port, t.Name()), netip.MustParseAddr("10.0.0.1"), m{ + a, al := newSimpleService(ca, caKey, fmt.Sprintf("a_port_%d_test_name_%s", port, t.Name()), netip.MustParseAddr("10.0.0.1"), m{ "static_host_map": m{}, "lighthouse": m{ "am_lighthouse": true, @@ -114,7 +114,7 @@ func CreateTwoConnectedServices(t *testing.T, port int) (*Service, *Service) { "port": port, }, }) - b := newSimpleService(ca, caKey, fmt.Sprintf("b_port_%d_test_name_%s", port, t.Name()), netip.MustParseAddr("10.0.0.2"), m{ + b, bl := newSimpleService(ca, caKey, fmt.Sprintf("b_port_%d_test_name_%s", port, t.Name()), netip.MustParseAddr("10.0.0.2"), m{ "static_host_map": m{ "10.0.0.1": []string{fmt.Sprintf("localhost:%d", port)}, }, @@ -123,5 +123,5 @@ func CreateTwoConnectedServices(t *testing.T, port int) (*Service, *Service) { "interval": 1, }, }) - return a, b + return a, al, b, bl } From ca4383278e457b206fca9c3ed365512376305050 Mon Sep 17 00:00:00 2001 From: Ulrich Hornung Date: Sat, 17 Aug 2024 21:50:36 +0200 Subject: [PATCH 21/33] try to make it more stable by using channels and waitgroups --- port_forwarder/fwd_tcp.go | 15 ++ port_forwarder/fwd_udp.go | 270 +++++++++++++-------- port_forwarder/lockfree_timeout_counter.go | 33 --- port_forwarder/port_forwarder_tcp_test.go | 59 +++-- port_forwarder/port_forwarder_udp_test.go | 12 +- port_forwarder/port_forwarding_service.go | 11 + service/service_test.go | 8 - service/service_testhelpers.go | 7 + 8 files changed, 252 insertions(+), 163 deletions(-) delete mode 100644 port_forwarder/lockfree_timeout_counter.go diff --git a/port_forwarder/fwd_tcp.go b/port_forwarder/fwd_tcp.go index c345876ef..184031ffc 100644 --- a/port_forwarder/fwd_tcp.go +++ b/port_forwarder/fwd_tcp.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "net" + "sync" "time" "github.com/sirupsen/logrus" @@ -30,6 +31,7 @@ func (cfg ForwardConfigIncomingTcp) ConfigDescriptor() string { type PortForwardingCommonTcp struct { ctx context.Context + wg *sync.WaitGroup l *logrus.Logger tunService *service.Service localListenConnection net.Listener @@ -37,6 +39,7 @@ type PortForwardingCommonTcp struct { func (fwd PortForwardingCommonTcp) Close() error { fwd.localListenConnection.Close() + fwd.wg.Wait() return nil } @@ -62,10 +65,12 @@ func (cf ForwardConfigOutgoingTcp) SetupPortForwarding( cf.remoteConnect, localTcpListenAddr) ctx, cancel := context.WithCancel(context.Background()) + wg := &sync.WaitGroup{} portForwarding := &PortForwardingOutgoingTcp{ PortForwardingCommonTcp: PortForwardingCommonTcp{ ctx: ctx, + wg: wg, l: l, tunService: tunService, localListenConnection: localListenPort, @@ -73,7 +78,9 @@ func (cf ForwardConfigOutgoingTcp) SetupPortForwarding( cfg: cf, } + wg.Add(1) go func() { + defer wg.Done() defer cancel() portForwarding.acceptOnLocalListenPort_generic(portForwarding.handleClientConnectionWithErrorReturn) }() @@ -94,12 +101,16 @@ func (pt *PortForwardingCommonTcp) acceptOnLocalListenPort_generic( pt.l.Debugf("accept TCP connect from local TCP port: %v", connection.RemoteAddr()) + pt.wg.Add(1) go func() { + defer pt.wg.Done() defer connection.Close() <-pt.ctx.Done() }() + pt.wg.Add(1) go func() { + defer pt.wg.Done() err := handleClientConnectionWithErrorReturn(connection) if err != nil { pt.l.Debugf("Closed TCP client connection %s. Err: %+v", @@ -190,10 +201,12 @@ func (cf ForwardConfigIncomingTcp) SetupPortForwarding( cf.forwardLocalAddress, cf.port) ctx, cancel := context.WithCancel(context.Background()) + wg := &sync.WaitGroup{} portForwarding := &PortForwardingIncomingTcp{ PortForwardingCommonTcp: PortForwardingCommonTcp{ ctx: ctx, + wg: wg, l: l, tunService: tunService, localListenConnection: localListenPort, @@ -201,7 +214,9 @@ func (cf ForwardConfigIncomingTcp) SetupPortForwarding( cfg: cf, } + wg.Add(1) go func() { + defer wg.Done() defer cancel() portForwarding.acceptOnLocalListenPort_generic(portForwarding.handleClientConnectionWithErrorReturn) }() diff --git a/port_forwarder/fwd_udp.go b/port_forwarder/fwd_udp.go index 789ae78b6..4eeb0c07b 100644 --- a/port_forwarder/fwd_udp.go +++ b/port_forwarder/fwd_udp.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "net" + "sync" "time" "github.com/sirupsen/logrus" @@ -33,17 +34,29 @@ func (cfg ForwardConfigIncomingUdp) ConfigDescriptor() string { var UDP_CONNECTION_TIMEOUT_SECONDS uint32 = 300 type udpConnInterface interface { + io.Closer WriteTo(b []byte, addr net.Addr) (int, error) Write(b []byte) (int, error) ReadFrom(b []byte) (int, net.Addr, error) + LocalAddr() net.Addr } -func handleUdpDestinationPortResponseReading[destConn net.Conn, srcConn udpConnInterface]( +func resetTimer(t *time.Timer, d time.Duration) { + if !t.Stop() { + select { + case <-t.C: + default: + } + } + t.Reset(d) +} + +func handleUdpDestinationPortResponseReading[destConn udpConnInterface, srcConn udpConnInterface]( l *logrus.Logger, loggingFields logrus.Fields, closedConnections *chan string, sourceAddr net.Addr, - destConnection *TimedConnection[destConn], + destConnection destConn, localListenConnection srcConn, ) error { // net.Conn is thread-safe according to: https://pkg.go.dev/net#Conn @@ -52,57 +65,41 @@ func handleUdpDestinationPortResponseReading[destConn net.Conn, srcConn udpConnI defer func() { (*closedConnections) <- sourceAddr.String() }() l.WithFields(loggingFields).Debug("begin reading responses ...") - buf := make([]byte, 2*(1<<16)) - for { - destConnection.connection.SetDeadline(time.Now().Add(time.Second * 10)) - l.WithFields(loggingFields).Trace("response read ...") - n, err := destConnection.connection.Read(buf) - if n == 0 { - if netErr, ok := err.(net.Error); ok && netErr.Timeout() { - l.WithFields(loggingFields).Debug("response read - timeout tick") - if destConnection.timeout_counter.Increment(10) { - l.WithFields(loggingFields).Debug("response read - closed due to timeout") - return nil - } - continue - } else { - l.WithFields(loggingFields).WithError(err).Debugf("response read - close due to error") - return err - } - } + wg := &sync.WaitGroup{} + defer wg.Wait() - destConnection.timeout_counter.Reset() - l.WithFields(loggingFields). - WithField("payloadSize", n). - Debug("response forward") - n, err = localListenConnection.WriteTo(buf[:n], sourceAddr) - if n == 0 && (err != nil) { - l.WithFields(loggingFields).WithError(err).Debugf("response forward - write error") - return err - } - } -} + timeout := time.Second * time.Duration(UDP_CONNECTION_TIMEOUT_SECONDS) + timer := time.NewTimer(timeout) -func handleClosedConnections[C any]( - l *logrus.Logger, - closedConnections *chan string, - portReaders *map[string]bool, - remoteConnections *map[string]*TimedConnection[C], -) { -cleanup: + rr := newUdpPortReader(wg, l, loggingFields, destConnection) + defer close(rr.receivedDataDone) for { select { - case closedOne := <-(*closedConnections): - l.Debugf("closing connection to %s", closedOne) - delete(*remoteConnections, closedOne) - delete(*portReaders, closedOne) - default: - break cleanup + case <-timer.C: + destConnection.Close() + l.WithFields(loggingFields).Debug("response read - closed due to timeout") + return nil + case data, ok := <-rr.receivedData: + if !ok { + return nil + } + resetTimer(timer, timeout) + + l.WithFields(loggingFields). + WithField("payloadSize", data.n). + Debug("response forward") + n, err := localListenConnection.WriteTo(rr.buf[:data.n], sourceAddr) + rr.receivedDataDone <- 1 + if (n == 0) && (err != nil) { + l.WithFields(loggingFields).WithError(err).Debugf("response forward - write error") + return err + } } } } type PortForwardingCommonUdp struct { + wg *sync.WaitGroup l *logrus.Logger tunService *service.Service // net.Conn is thread-safe according to: https://pkg.go.dev/net#Conn @@ -112,6 +109,7 @@ type PortForwardingCommonUdp struct { func (fwd PortForwardingCommonUdp) Close() error { fwd.localListenConnection.Close() + fwd.wg.Wait() return nil } @@ -137,8 +135,11 @@ func (cfg ForwardConfigOutgoingUdp) SetupPortForwarding( l.Infof("UDP port forwarding to '%v': listening on local UDP addr: '%v'", cfg.remoteConnect, localUdpListenAddr) + wg := &sync.WaitGroup{} + portForwarding := &PortForwardingOutgoingUdp{ PortForwardingCommonUdp: PortForwardingCommonUdp{ + wg: wg, l: l, tunService: tunService, localListenConnection: localListenConnection, @@ -152,8 +153,11 @@ func (cfg ForwardConfigOutgoingUdp) SetupPortForwarding( "dial": cfg.remoteConnect, } + wg.Add(1) go func() { + defer wg.Done() err := listenLocalPort_generic( + wg, l, logPrefix, localListenConnection, @@ -171,7 +175,67 @@ func (cfg ForwardConfigOutgoingUdp) SetupPortForwarding( return portForwarding, nil } -func listenLocalPort_generic[destConn net.Conn]( +type readData struct { + n int + addr net.Addr +} + +type readerRoutine struct { + buf []byte + receivedData chan readData + receivedDataDone chan int +} + +func newUdpPortReader( + wg *sync.WaitGroup, + l *logrus.Logger, + loggingFields logrus.Fields, + conn udpConnInterface, +) *readerRoutine { + r := &readerRoutine{ + buf: make([]byte, 512*1024), + receivedData: make(chan readData), + receivedDataDone: make(chan int, 1), + } + r.receivedDataDone <- 1 + + wg.Add(1) + go func() { + defer wg.Done() + defer close(r.receivedData) + l.WithFields(loggingFields). + WithField("addr", conn.LocalAddr()). + Debug("start listening") + for { + _, ok := <-r.receivedDataDone + if !ok { + return + } + l.WithFields(loggingFields). + WithField("addr", conn.LocalAddr()). + Trace("reading data ...") + n, addr, err := conn.ReadFrom(r.buf[0:]) + if err != nil { + if errors.Is(err, io.EOF) { + return + } + l.WithFields(loggingFields). + WithField("addr", conn.LocalAddr()). + WithError(err).Error("listen for data failed. stop.") + return + } + r.receivedData <- readData{ + n: n, + addr: addr, + } + } + }() + + return r +} + +func listenLocalPort_generic[destConn udpConnInterface]( + wg *sync.WaitGroup, l *logrus.Logger, loggingFields logrus.Fields, localListenConnection udpConnInterface, @@ -179,67 +243,71 @@ func listenLocalPort_generic[destConn net.Conn]( remoteConnect string, ) error { dialConnResponseReaders := make(map[string]bool) - dialConnections := make(map[string]*TimedConnection[destConn]) - closedConnections := make(chan string) + dialConnections := make(map[string]destConn) + closedConnections := make(chan string, 5) + mr := newUdpPortReader(wg, l, loggingFields, localListenConnection) + defer close(mr.receivedDataDone) + + defer func() { + // close and wait for remaining connections + for _, connection := range dialConnections { + connection.Close() + } + for range dialConnResponseReaders { + <-closedConnections + } + }() - l.WithFields(loggingFields).Debug("start listening ...") - var buf [512 * 1024]byte for { - handleClosedConnections(l, &closedConnections, &dialConnResponseReaders, &dialConnections) - - l.WithFields(loggingFields).Trace("reading data ...") - n, localSourceAddr, err := localListenConnection.ReadFrom(buf[0:]) - if err != nil { - if errors.Is(err, io.EOF) { + select { + case closedOne := <-closedConnections: + l.Debugf("closing connection to %s", closedOne) + delete(dialConnections, closedOne) + delete(dialConnResponseReaders, closedOne) + case data, ok := <-mr.receivedData: + if !ok { return nil } - l.WithFields(loggingFields).Error("listen for data failed. stop.") - return err - } + l.WithFields(loggingFields). + WithField("source", data.addr). + WithField("payloadSize", data.n). + Trace("read data") + dialConnection, ok := dialConnections[data.addr.String()] + if !ok { + newConnection, err := dial(remoteConnect) + if err != nil { + l.WithFields(loggingFields).WithError(err).Error("dialing dial address failed") + continue + } + dialConnections[data.addr.String()] = newConnection + dialConnection = newConnection + } - l.WithFields(loggingFields). - WithField("source", localSourceAddr). - WithField("payloadSize", n). - Trace("read data") + l.WithFields(loggingFields). + WithField("source", data.addr). + WithField("dialSource", dialConnection.LocalAddr()). + WithField("payloadSize", data.n). + Debug("forward") - dialConnection, ok := dialConnections[localSourceAddr.String()] - if !ok { - newDialConn, err := dial(remoteConnect) - if err != nil { - l.WithFields(loggingFields).WithError(err).Error("dialing dial address failed") - continue - } - dialConnection = &TimedConnection[destConn]{ - connection: newDialConn, - timeout_counter: NewTimeoutCounter(UDP_CONNECTION_TIMEOUT_SECONDS), - } - dialConnections[localSourceAddr.String()] = dialConnection - } + dialConnection.Write(mr.buf[:data.n]) + mr.receivedDataDone <- 1 - l.WithFields(loggingFields). - WithField("source", localSourceAddr). - WithField("dialSource", dialConnection.connection.LocalAddr()). - WithField("payloadSize", n). - Debug("forward") - - dialConnection.timeout_counter.Reset() - dialConnection.connection.Write(buf[:n]) - - _, ok = dialConnResponseReaders[localSourceAddr.String()] - if !ok { - loggingFieldsRsp := logrus.Fields{ - "source": localSourceAddr, - "dialSource": dialConnection.connection.LocalAddr(), - } - for k, v := range loggingFields { - loggingFieldsRsp[k] = v + _, ok = dialConnResponseReaders[data.addr.String()] + if !ok { + loggingFieldsRsp := logrus.Fields{ + "source": data.addr, + "dialSource": dialConnection.LocalAddr(), + } + for k, v := range loggingFields { + loggingFieldsRsp[k] = v + } + dialConnResponseReaders[data.addr.String()] = true + go func() error { + return handleUdpDestinationPortResponseReading( + l, loggingFieldsRsp, &closedConnections, data.addr, + dialConnection, localListenConnection) + }() } - dialConnResponseReaders[localSourceAddr.String()] = true - go func() error { - return handleUdpDestinationPortResponseReading( - l, loggingFieldsRsp, &closedConnections, localSourceAddr, - dialConnection, localListenConnection) - }() } } } @@ -268,8 +336,11 @@ func (cfg ForwardConfigIncomingUdp) SetupPortForwarding( "dial": cfg.forwardLocalAddress, } + wg := &sync.WaitGroup{} + forwarding := &PortForwardingIncomingUdp{ PortForwardingCommonUdp: PortForwardingCommonUdp{ + wg: wg, l: l, tunService: tunService, localListenConnection: conn, @@ -277,8 +348,11 @@ func (cfg ForwardConfigIncomingUdp) SetupPortForwarding( cfg: cfg, } + wg.Add(1) go func() { + defer wg.Done() err := listenLocalPort_generic( + wg, l, logPrefix, conn, diff --git a/port_forwarder/lockfree_timeout_counter.go b/port_forwarder/lockfree_timeout_counter.go deleted file mode 100644 index 79d30d167..000000000 --- a/port_forwarder/lockfree_timeout_counter.go +++ /dev/null @@ -1,33 +0,0 @@ -package port_forwarder - -import "sync/atomic" - -type TimeoutCounter struct { - counter atomic.Uint32 - threshold uint32 -} - -func NewTimeoutCounter(threshold uint32) TimeoutCounter { - return TimeoutCounter{ - counter: atomic.Uint32{}, - threshold: threshold, - } -} - -func (tc *TimeoutCounter) Increment(step uint32) bool { - tc.counter.Add(step) - return tc.IsTimeout() -} - -func (tc *TimeoutCounter) Reset() { - tc.counter.Store(0) -} - -func (tc *TimeoutCounter) IsTimeout() bool { - return tc.counter.Load() > tc.threshold -} - -type TimedConnection[C any] struct { - connection C - timeout_counter TimeoutCounter -} diff --git a/port_forwarder/port_forwarder_tcp_test.go b/port_forwarder/port_forwarder_tcp_test.go index a0167fe88..b16d3db8f 100644 --- a/port_forwarder/port_forwarder_tcp_test.go +++ b/port_forwarder/port_forwarder_tcp_test.go @@ -3,6 +3,7 @@ package port_forwarder import ( "net" "testing" + "time" "github.com/slackhq/nebula/service" "github.com/stretchr/testify/assert" @@ -45,12 +46,23 @@ func doTestTcpCommunicationFail( assert.NotNil(t, err) } +func tcpListenerNAccept(t *testing.T, listener *net.TCPListener, n int) <-chan net.Conn { + c := make(chan net.Conn, 1) + go func() { + defer close(c) + for range n { + conn, err := listener.Accept() + assert.Nil(t, err) + c <- conn + } + }() + return c +} + func TestTcpInOut2Clients(t *testing.T) { server, sl, client, cl := service.CreateTwoConnectedServices(t, 4247) - defer assert.NoError(t, client.CloseAndWait()) - defer assert.NoError(t, server.CloseAndWait()) - server_pf, err := createPortForwarderFromConfigString(sl, server, ` + server_pf, err := createPortForwarderFromConfigString(t, sl, server, ` port_forwarding: inbound: - listen_port: 4495 @@ -61,7 +73,7 @@ port_forwarding: assert.Len(t, server_pf.portForwardings, 1) - client_pf, err := createPortForwarderFromConfigString(cl, client, ` + client_pf, err := createPortForwarderFromConfigString(t, cl, client, ` port_forwarding: outbound: - listen_address: 127.0.0.1:3395 @@ -79,13 +91,15 @@ port_forwarding: server_listen_conn, err := net.ListenTCP("tcp", server_conn_addr) assert.Nil(t, err) + server_listen_conn_accepts := tcpListenerNAccept(t, server_listen_conn, 2) + client1_conn, err := net.DialTCP("tcp", nil, client_conn_addr) assert.Nil(t, err) - client1_server_side_conn, err := server_listen_conn.Accept() - assert.Nil(t, err) + client1_server_side_conn := <-server_listen_conn_accepts + client2_conn, err := net.DialTCP("tcp", nil, client_conn_addr) assert.Nil(t, err) - client2_server_side_conn, err := server_listen_conn.Accept() + client2_server_side_conn := <-server_listen_conn_accepts assert.Nil(t, err) doTestTcpCommunication(t, "Hello from client 1 side!", @@ -107,10 +121,8 @@ port_forwarding: func TestTcpInOut1ClientConfigReload(t *testing.T) { server, sl, client, cl := service.CreateTwoConnectedServices(t, 4246) - defer assert.NoError(t, client.CloseAndWait()) - defer assert.NoError(t, server.CloseAndWait()) - server_pf, err := createPortForwarderFromConfigString(sl, server, ` + server_pf, err := createPortForwarderFromConfigString(t, sl, server, ` port_forwarding: inbound: - listen_port: 4497 @@ -121,7 +133,9 @@ port_forwarding: assert.Len(t, server_pf.portForwardings, 1) - client_pf, err := createPortForwarderFromConfigString(cl, client, ` + time.Sleep(100 * time.Millisecond) + + client_pf, err := createPortForwarderFromConfigString(t, cl, client, ` port_forwarding: outbound: - listen_address: 127.0.0.1:3397 @@ -132,6 +146,8 @@ port_forwarding: assert.Len(t, client_pf.portForwardings, 1) + time.Sleep(100 * time.Millisecond) + client_conn_addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:3397") assert.Nil(t, err) server_conn_addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:5597") @@ -139,10 +155,15 @@ port_forwarding: server_listen_conn, err := net.ListenTCP("tcp", server_conn_addr) assert.Nil(t, err) + + server_listen_conn_accepts := tcpListenerNAccept(t, server_listen_conn, 1) + + time.Sleep(100 * time.Millisecond) + client1_conn, err := net.DialTCP("tcp", nil, client_conn_addr) assert.Nil(t, err) - client1_server_side_conn, err := server_listen_conn.Accept() - assert.Nil(t, err) + + client1_server_side_conn := <-server_listen_conn_accepts doTestTcpCommunication(t, "Hello from client 1 side!", client1_conn, client1_server_side_conn) @@ -196,10 +217,8 @@ port_forwarding: func TestTcpInOut1ClientConfigReload_inverseCloseOrder(t *testing.T) { server, sl, client, cl := service.CreateTwoConnectedServices(t, 4245) - defer assert.NoError(t, client.CloseAndWait()) - defer assert.NoError(t, server.CloseAndWait()) - server_pf, err := createPortForwarderFromConfigString(sl, server, ` + server_pf, err := createPortForwarderFromConfigString(t, sl, server, ` port_forwarding: inbound: - listen_port: 4499 @@ -210,7 +229,7 @@ port_forwarding: assert.Len(t, server_pf.portForwardings, 1) - client_pf, err := createPortForwarderFromConfigString(cl, client, ` + client_pf, err := createPortForwarderFromConfigString(t, cl, client, ` port_forwarding: outbound: - listen_address: 127.0.0.1:3399 @@ -228,10 +247,12 @@ port_forwarding: server_listen_conn, err := net.ListenTCP("tcp", server_conn_addr) assert.Nil(t, err) + server_listen_conn_accepts := tcpListenerNAccept(t, server_listen_conn, 1) + client1_conn, err := net.DialTCP("tcp", nil, client_conn_addr) assert.Nil(t, err) - client1_server_side_conn, err := server_listen_conn.Accept() - assert.Nil(t, err) + + client1_server_side_conn := <-server_listen_conn_accepts doTestTcpCommunication(t, "Hello from client 1 side!", client1_conn, client1_server_side_conn) diff --git a/port_forwarder/port_forwarder_udp_test.go b/port_forwarder/port_forwarder_udp_test.go index 4cb638c92..45bb293d3 100644 --- a/port_forwarder/port_forwarder_udp_test.go +++ b/port_forwarder/port_forwarder_udp_test.go @@ -26,7 +26,7 @@ func loadPortFwdConfigFromString(l *logrus.Logger, configStr string) (*PortForwa return &fwd_list, nil } -func createPortForwarderFromConfigString(l *logrus.Logger, srv *service.Service, configStr string) (*PortForwardingService, error) { +func createPortForwarderFromConfigString(t *testing.T, l *logrus.Logger, srv *service.Service, configStr string) (*PortForwardingService, error) { fwd_list, err := loadPortFwdConfigFromString(l, configStr) if err != nil { @@ -43,6 +43,10 @@ func createPortForwarderFromConfigString(l *logrus.Logger, srv *service.Service, return nil, err } + t.Cleanup(func() { + pf.CloseAll() + }) + return pf, nil } @@ -74,10 +78,8 @@ func doTestUdpCommunication( func TestUdpInOut2Clients(t *testing.T) { server, sl, client, cl := service.CreateTwoConnectedServices(t, 4244) - defer assert.NoError(t, client.CloseAndWait()) - defer assert.NoError(t, server.CloseAndWait()) - server_pf, err := createPortForwarderFromConfigString(sl, server, ` + server_pf, err := createPortForwarderFromConfigString(t, sl, server, ` port_forwarding: inbound: - listen_port: 4499 @@ -88,7 +90,7 @@ port_forwarding: assert.Len(t, server_pf.portForwardings, 1) - client_pf, err := createPortForwarderFromConfigString(cl, client, ` + client_pf, err := createPortForwarderFromConfigString(t, cl, client, ` port_forwarding: outbound: - listen_address: 127.0.0.1:3399 diff --git a/port_forwarder/port_forwarding_service.go b/port_forwarder/port_forwarding_service.go index 8bad84e4d..ed2910d72 100644 --- a/port_forwarder/port_forwarding_service.go +++ b/port_forwarder/port_forwarding_service.go @@ -52,3 +52,14 @@ func (t *PortForwardingService) CloseSelective(descriptors []string) error { return nil } + +func (t *PortForwardingService) CloseAll() error { + + for descriptor, pf := range t.portForwardings { + t.l.Infof("closing port forwarding: %s", descriptor) + pf.Close() + delete(t.portForwardings, descriptor) + } + + return nil +} diff --git a/service/service_test.go b/service/service_test.go index f466cd513..e1c1d4cc4 100644 --- a/service/service_test.go +++ b/service/service_test.go @@ -62,12 +62,4 @@ func TestService(t *testing.T) { if !bytes.Equal(data, []byte("server msg")) { t.Fatal("got invalid message from client") } - - if err := c.Close(); err != nil { - t.Fatal(err) - } - - if err := eg.Wait(); err != nil { - t.Fatal(err) - } } diff --git a/service/service_testhelpers.go b/service/service_testhelpers.go index 62173ef6b..c77c1c574 100644 --- a/service/service_testhelpers.go +++ b/service/service_testhelpers.go @@ -13,6 +13,7 @@ import ( "github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/e2e" + "github.com/stretchr/testify/assert" "gopkg.in/yaml.v2" ) @@ -114,6 +115,9 @@ func CreateTwoConnectedServices(t *testing.T, port int) (*Service, *logrus.Logge "port": port, }, }) + t.Cleanup(func() { + assert.NoError(t, a.CloseAndWait()) + }) b, bl := newSimpleService(ca, caKey, fmt.Sprintf("b_port_%d_test_name_%s", port, t.Name()), netip.MustParseAddr("10.0.0.2"), m{ "static_host_map": m{ "10.0.0.1": []string{fmt.Sprintf("localhost:%d", port)}, @@ -123,5 +127,8 @@ func CreateTwoConnectedServices(t *testing.T, port int) (*Service, *logrus.Logge "interval": 1, }, }) + t.Cleanup(func() { + assert.NoError(t, b.CloseAndWait()) + }) return a, al, b, bl } From 12a0dd8a390ff76d3623a6846ada7ad7813dfb0c Mon Sep 17 00:00:00 2001 From: Ulrich Hornung Date: Sat, 17 Aug 2024 21:51:06 +0200 Subject: [PATCH 22/33] improve stopping logic for UserDevice --- overlay/user.go | 5 +++-- service/service.go | 17 +++++++++++------ 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/overlay/user.go b/overlay/user.go index 282c7852e..f8a64fcaf 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -109,7 +109,8 @@ func (d *UserDevice) ReadFrom(r io.Reader) (n int64, err error) { } func (d *UserDevice) Close() error { - close(d.inboundChannel) - // close(d.outboundChannel) outbound channel needs to be closed from writer side + // There is nothing to be done for the UserDevice. + // It doesn't start any goroutines on its own. + // It doesn't manage any resources that needs closing. return nil } diff --git a/service/service.go b/service/service.go index 55cffa5c0..1e79dd081 100644 --- a/service/service.go +++ b/service/service.go @@ -111,15 +111,20 @@ func New(config *config.C, logger *logrus.Logger) (*Service, error) { // create Goroutines to forward packets between Nebula and Gvisor eg.Go(func() error { + defer linkEP.Close() for { - view, ok := <-nebula_tun_reader - if !ok { + select { + case <-ctx.Done(): return nil + case view, ok := <-nebula_tun_reader: + if !ok { + return nil + } + packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Payload: buffer.MakeWithView(view), + }) + linkEP.InjectInbound(header.IPv4ProtocolNumber, packetBuf) } - packetBuf := stack.NewPacketBuffer(stack.PacketBufferOptions{ - Payload: buffer.MakeWithView(view), - }) - linkEP.InjectInbound(header.IPv4ProtocolNumber, packetBuf) } }) eg.Go(func() error { From b1ea9f53da46e12ded74c788189e337180f8cc86 Mon Sep 17 00:00:00 2001 From: Ulrich Hornung Date: Mon, 19 Aug 2024 22:01:38 +0200 Subject: [PATCH 23/33] improving test code to get stability improved - still fails with stress --- port_forwarder/port_forwarder_tcp_test.go | 98 ++++++++++++++++++----- port_forwarder/port_forwarder_udp_test.go | 55 ++++++++++--- 2 files changed, 118 insertions(+), 35 deletions(-) diff --git a/port_forwarder/port_forwarder_tcp_test.go b/port_forwarder/port_forwarder_tcp_test.go index b16d3db8f..ef9014600 100644 --- a/port_forwarder/port_forwarder_tcp_test.go +++ b/port_forwarder/port_forwarder_tcp_test.go @@ -1,6 +1,7 @@ package port_forwarder import ( + "fmt" "net" "testing" "time" @@ -9,19 +10,53 @@ import ( "github.com/stretchr/testify/assert" ) +func startReadToChannel(receiverConn net.Conn) <-chan []byte { + rcv_chan := make(chan []byte, 10) + r := make(chan bool, 1) + go func() { + defer close(rcv_chan) + r <- true + for { + buf := make([]byte, 100) + n, err := receiverConn.Read(buf) + if err != nil { + break + } + rcv_chan <- buf[0:n] + } + }() + <-r + time.Sleep(50 * time.Millisecond) + return rcv_chan +} + func doTestTcpCommunication( t *testing.T, msg string, senderConn net.Conn, - receiverConn net.Conn, + receiverConn <-chan []byte, ) { + var n int = 0 + var err error = nil data_sent := []byte(msg) - n, err := senderConn.Write(data_sent) - assert.Nil(t, err) - assert.Equal(t, n, len(data_sent)) - - buf := make([]byte, 100) - n, err = receiverConn.Read(buf) + var buf []byte = nil + for { + fmt.Println("sending ...") + t.Log("sending ...") + n, err = senderConn.Write(data_sent) + assert.Nil(t, err) + assert.Equal(t, n, len(data_sent)) + + fmt.Println("receiving ...") + t.Log("receiving ...") + var ok bool = false + buf, ok = <-receiverConn + if ok { + break + } + } + fmt.Println("DONE") + t.Log("DONE") assert.Nil(t, err) assert.Equal(t, n, len(data_sent)) assert.Equal(t, data_sent, buf[:n]) @@ -48,14 +83,20 @@ func doTestTcpCommunicationFail( func tcpListenerNAccept(t *testing.T, listener *net.TCPListener, n int) <-chan net.Conn { c := make(chan net.Conn, 1) + r := make(chan bool, 1) go func() { defer close(c) + r <- true for range n { conn, err := listener.Accept() assert.Nil(t, err) c <- conn } }() + + <-r + time.Sleep(50 * time.Millisecond) + return c } @@ -95,27 +136,30 @@ port_forwarding: client1_conn, err := net.DialTCP("tcp", nil, client_conn_addr) assert.Nil(t, err) + client1_rcv_chan := startReadToChannel(client1_conn) client1_server_side_conn := <-server_listen_conn_accepts + client1_server_side_rcv_chan := startReadToChannel(client1_server_side_conn) client2_conn, err := net.DialTCP("tcp", nil, client_conn_addr) assert.Nil(t, err) + client2_rcv_chan := startReadToChannel(client2_conn) client2_server_side_conn := <-server_listen_conn_accepts - assert.Nil(t, err) + client2_server_side_rcv_chan := startReadToChannel(client2_server_side_conn) doTestTcpCommunication(t, "Hello from client 1 side!", - client1_conn, client1_server_side_conn) + client1_conn, client1_server_side_rcv_chan) doTestTcpCommunication(t, "Hello from client two side!", - client2_conn, client2_server_side_conn) + client2_conn, client2_server_side_rcv_chan) doTestTcpCommunication(t, "Hello from server first side!", - client1_server_side_conn, client1_conn) + client1_server_side_conn, client1_rcv_chan) doTestTcpCommunication(t, "Hello from server second side!", - client2_server_side_conn, client2_conn) + client2_server_side_conn, client2_rcv_chan) doTestTcpCommunication(t, "Hello from server third side!", - client1_server_side_conn, client1_conn) + client1_server_side_conn, client1_rcv_chan) doTestTcpCommunication(t, "Hello from client two side AGAIN!", - client2_conn, client2_server_side_conn) + client2_conn, client2_server_side_rcv_chan) } @@ -155,6 +199,7 @@ port_forwarding: server_listen_conn, err := net.ListenTCP("tcp", server_conn_addr) assert.Nil(t, err) + defer server_listen_conn.Close() server_listen_conn_accepts := tcpListenerNAccept(t, server_listen_conn, 1) @@ -162,19 +207,23 @@ port_forwarding: client1_conn, err := net.DialTCP("tcp", nil, client_conn_addr) assert.Nil(t, err) + defer client1_conn.Close() + client1_rcv_chan := startReadToChannel(client1_conn) client1_server_side_conn := <-server_listen_conn_accepts + defer client1_server_side_conn.Close() + client1_server_side_rcv_chan := startReadToChannel(client1_server_side_conn) doTestTcpCommunication(t, "Hello from client 1 side!", - client1_conn, client1_server_side_conn) + client1_conn, client1_server_side_rcv_chan) doTestTcpCommunication(t, "Hello from server first side!", - client1_server_side_conn, client1_conn) + client1_server_side_conn, client1_rcv_chan) doTestTcpCommunication(t, "Hello from server third side!", - client1_server_side_conn, client1_conn) + client1_server_side_conn, client1_rcv_chan) doTestTcpCommunication(t, "Hello from client one side AGAIN!", - client1_conn, client1_server_side_conn) + client1_conn, client1_server_side_rcv_chan) new_server_fwd_list, err := loadPortFwdConfigFromString(sl, ` port_forwarding: @@ -247,23 +296,28 @@ port_forwarding: server_listen_conn, err := net.ListenTCP("tcp", server_conn_addr) assert.Nil(t, err) + defer server_listen_conn.Close() server_listen_conn_accepts := tcpListenerNAccept(t, server_listen_conn, 1) client1_conn, err := net.DialTCP("tcp", nil, client_conn_addr) assert.Nil(t, err) + defer client1_conn.Close() + client1_rcv_chan := startReadToChannel(client1_conn) client1_server_side_conn := <-server_listen_conn_accepts + defer client1_server_side_conn.Close() + client1_server_side_rcv_chan := startReadToChannel(client1_server_side_conn) doTestTcpCommunication(t, "Hello from client 1 side!", - client1_conn, client1_server_side_conn) + client1_conn, client1_server_side_rcv_chan) doTestTcpCommunication(t, "Hello from server first side!", - client1_server_side_conn, client1_conn) + client1_server_side_conn, client1_rcv_chan) doTestTcpCommunication(t, "Hello from server third side!", - client1_server_side_conn, client1_conn) + client1_server_side_conn, client1_rcv_chan) doTestTcpCommunication(t, "Hello from client one side AGAIN!", - client1_conn, client1_server_side_conn) + client1_conn, client1_server_side_rcv_chan) new_server_fwd_list, err := loadPortFwdConfigFromString(sl, ` port_forwarding: diff --git a/port_forwarder/port_forwarder_udp_test.go b/port_forwarder/port_forwarder_udp_test.go index 45bb293d3..d2c75606b 100644 --- a/port_forwarder/port_forwarder_udp_test.go +++ b/port_forwarder/port_forwarder_udp_test.go @@ -55,8 +55,8 @@ func doTestUdpCommunication( msg string, senderConn *net.UDPConn, toAddr net.Addr, - receiverConn *net.UDPConn, -) (senderAddr net.Addr) { + receiverConn <-chan Pair[[]byte, net.Addr], +) net.Addr { data_sent := []byte(msg) var n int var err error @@ -68,12 +68,33 @@ func doTestUdpCommunication( assert.Nil(t, err) assert.Equal(t, n, len(data_sent)) - buf := make([]byte, 100) - n, senderAddr, err = receiverConn.ReadFrom(buf) + pair := <-receiverConn assert.Nil(t, err) - assert.Equal(t, n, len(data_sent)) - assert.Equal(t, data_sent, buf[:n]) - return + assert.Equal(t, data_sent, pair.a) + return pair.b +} + +type Pair[A any, B any] struct { + a A + b B +} + +func readUdpConnectionToChannel(conn *net.UDPConn) <-chan Pair[[]byte, net.Addr] { + rcv_chan := make(chan Pair[[]byte, net.Addr]) + + go func() { + defer close(rcv_chan) + for { + buf := make([]byte, 100) + n, addr, err := conn.ReadFrom(buf) + if err != nil { + return + } + rcv_chan <- Pair[[]byte, net.Addr]{buf[0:n], addr} + } + }() + + return rcv_chan } func TestUdpInOut2Clients(t *testing.T) { @@ -108,26 +129,34 @@ port_forwarding: server_listen_conn, err := net.ListenUDP("udp", server_conn_addr) assert.Nil(t, err) + defer server_listen_conn.Close() + server_listen_rcv_chan := readUdpConnectionToChannel(server_listen_conn) + client1_conn, err := net.DialUDP("udp", nil, client_conn_addr) assert.Nil(t, err) + defer client1_conn.Close() + client1_rcv_chan := readUdpConnectionToChannel(client1_conn) + client2_conn, err := net.DialUDP("udp", nil, client_conn_addr) assert.Nil(t, err) + defer client2_conn.Close() + client2_rcv_chan := readUdpConnectionToChannel(client2_conn) client1_addr := doTestUdpCommunication(t, "Hello from client 1 side!", - client1_conn, nil, server_listen_conn) + client1_conn, nil, server_listen_rcv_chan) assert.NotNil(t, client1_addr) client2_addr := doTestUdpCommunication(t, "Hello from client two side!", - client2_conn, nil, server_listen_conn) + client2_conn, nil, server_listen_rcv_chan) assert.NotNil(t, client2_addr) doTestUdpCommunication(t, "Hello from server first side!", - server_listen_conn, client1_addr, client1_conn) + server_listen_conn, client1_addr, client1_rcv_chan) doTestUdpCommunication(t, "Hello from server second side!", - server_listen_conn, client2_addr, client2_conn) + server_listen_conn, client2_addr, client2_rcv_chan) doTestUdpCommunication(t, "Hello from server third side!", - server_listen_conn, client1_addr, client1_conn) + server_listen_conn, client1_addr, client1_rcv_chan) doTestUdpCommunication(t, "Hello from client two side AGAIN!", - client2_conn, nil, server_listen_conn) + client2_conn, nil, server_listen_rcv_chan) } From 84d1a268d506248d705bab9d7dd228c177624343 Mon Sep 17 00:00:00 2001 From: Ulrich Hornung Date: Sat, 14 Sep 2024 22:40:10 +0200 Subject: [PATCH 24/33] fix issue with survival of nebula service from previous testrun --- udp/udp_linux.go | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/udp/udp_linux.go b/udp/udp_linux.go index 2eee76ee2..ac9cba796 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -22,10 +22,11 @@ import ( //TODO: make it support reload as best you can! type StdConn struct { - sysFd int - isV4 bool - l *logrus.Logger - batch int + sysFd int + closed bool + isV4 bool + l *logrus.Logger + batch int } func maybeIPV4(ip net.IP) (net.IP, bool) { @@ -142,6 +143,11 @@ func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew return } + if u.closed { + u.l.Debug("flag for closing connection is set, exiting read loop") + return + } + //metric.Update(int64(n)) for i := 0; i < n; i++ { if u.isV4 { @@ -315,6 +321,10 @@ func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { func (u *StdConn) Close() error { //TODO: this will not interrupt the read loop + if u.closed { + return nil + } + u.closed = true return syscall.Close(u.sysFd) } From ad2dbdc6cd62ed5062170eba04483e37b8463b00 Mon Sep 17 00:00:00 2001 From: Ulrich Hornung Date: Sun, 15 Sep 2024 12:39:38 +0200 Subject: [PATCH 25/33] require instead of assert; fix missing close connectons in one test --- port_forwarder/port_forwarder_tcp_test.go | 68 ++++++++++++----------- port_forwarder/port_forwarder_udp_test.go | 19 ++++--- 2 files changed, 47 insertions(+), 40 deletions(-) diff --git a/port_forwarder/port_forwarder_tcp_test.go b/port_forwarder/port_forwarder_tcp_test.go index ef9014600..9a883b8a5 100644 --- a/port_forwarder/port_forwarder_tcp_test.go +++ b/port_forwarder/port_forwarder_tcp_test.go @@ -8,6 +8,7 @@ import ( "github.com/slackhq/nebula/service" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func startReadToChannel(receiverConn net.Conn) <-chan []byte { @@ -44,7 +45,7 @@ func doTestTcpCommunication( fmt.Println("sending ...") t.Log("sending ...") n, err = senderConn.Write(data_sent) - assert.Nil(t, err) + require.Nil(t, err) assert.Equal(t, n, len(data_sent)) fmt.Println("receiving ...") @@ -57,7 +58,7 @@ func doTestTcpCommunication( } fmt.Println("DONE") t.Log("DONE") - assert.Nil(t, err) + require.Nil(t, err) assert.Equal(t, n, len(data_sent)) assert.Equal(t, data_sent, buf[:n]) } @@ -73,7 +74,7 @@ func doTestTcpCommunicationFail( if err != nil { return } - assert.Nil(t, err) + require.Nil(t, err) assert.Equal(t, n, len(data_sent)) buf := make([]byte, 100) @@ -89,7 +90,7 @@ func tcpListenerNAccept(t *testing.T, listener *net.TCPListener, n int) <-chan n r <- true for range n { conn, err := listener.Accept() - assert.Nil(t, err) + require.Nil(t, err) c <- conn } }() @@ -110,7 +111,7 @@ port_forwarding: dial_address: 127.0.0.1:5595 protocols: [tcp] `) - assert.Nil(t, err) + require.Nil(t, err) assert.Len(t, server_pf.portForwardings, 1) @@ -121,27 +122,32 @@ port_forwarding: dial_address: 10.0.0.1:4495 protocols: [tcp] `) - assert.Nil(t, err) + require.Nil(t, err) assert.Len(t, client_pf.portForwardings, 1) client_conn_addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:3395") - assert.Nil(t, err) + require.Nil(t, err) server_conn_addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:5595") - assert.Nil(t, err) + require.Nil(t, err) server_listen_conn, err := net.ListenTCP("tcp", server_conn_addr) - assert.Nil(t, err) + require.Nil(t, err) + defer server_listen_conn.Close() server_listen_conn_accepts := tcpListenerNAccept(t, server_listen_conn, 2) client1_conn, err := net.DialTCP("tcp", nil, client_conn_addr) - assert.Nil(t, err) + require.Nil(t, err) + defer client1_conn.Close() + client1_rcv_chan := startReadToChannel(client1_conn) client1_server_side_conn := <-server_listen_conn_accepts client1_server_side_rcv_chan := startReadToChannel(client1_server_side_conn) client2_conn, err := net.DialTCP("tcp", nil, client_conn_addr) - assert.Nil(t, err) + require.Nil(t, err) + defer client2_conn.Close() + client2_rcv_chan := startReadToChannel(client2_conn) client2_server_side_conn := <-server_listen_conn_accepts client2_server_side_rcv_chan := startReadToChannel(client2_server_side_conn) @@ -173,7 +179,7 @@ port_forwarding: dial_address: 127.0.0.1:5597 protocols: [tcp] `) - assert.Nil(t, err) + require.Nil(t, err) assert.Len(t, server_pf.portForwardings, 1) @@ -186,19 +192,19 @@ port_forwarding: dial_address: 10.0.0.1:4497 protocols: [tcp] `) - assert.Nil(t, err) + require.Nil(t, err) assert.Len(t, client_pf.portForwardings, 1) time.Sleep(100 * time.Millisecond) client_conn_addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:3397") - assert.Nil(t, err) + require.Nil(t, err) server_conn_addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:5597") - assert.Nil(t, err) + require.Nil(t, err) server_listen_conn, err := net.ListenTCP("tcp", server_conn_addr) - assert.Nil(t, err) + require.Nil(t, err) defer server_listen_conn.Close() server_listen_conn_accepts := tcpListenerNAccept(t, server_listen_conn, 1) @@ -206,7 +212,7 @@ port_forwarding: time.Sleep(100 * time.Millisecond) client1_conn, err := net.DialTCP("tcp", nil, client_conn_addr) - assert.Nil(t, err) + require.Nil(t, err) defer client1_conn.Close() client1_rcv_chan := startReadToChannel(client1_conn) @@ -232,7 +238,7 @@ port_forwarding: dial_address: 127.0.0.1:5596 protocols: [tcp] `) - assert.Nil(t, err) + require.Nil(t, err) assert.Len(t, server_pf.portForwardings, 1) @@ -243,10 +249,10 @@ port_forwarding: dial_address: 10.0.0.1:4496 protocols: [tcp] `) - assert.Nil(t, err) + require.Nil(t, err) err = client_pf.ApplyChangesByNewFwdList(new_client_fwd_list) - assert.Nil(t, err) + require.Nil(t, err) doTestTcpCommunicationFail(t, "Hello from client 1 side!", client1_conn, client1_server_side_conn) @@ -255,7 +261,7 @@ port_forwarding: client1_server_side_conn, client1_conn) err = server_pf.ApplyChangesByNewFwdList(new_server_fwd_list) - assert.Nil(t, err) + require.Nil(t, err) doTestTcpCommunicationFail(t, "Hello from client 1 side!", client1_conn, client1_server_side_conn) @@ -274,7 +280,7 @@ port_forwarding: dial_address: 127.0.0.1:5599 protocols: [tcp] `) - assert.Nil(t, err) + require.Nil(t, err) assert.Len(t, server_pf.portForwardings, 1) @@ -285,22 +291,22 @@ port_forwarding: dial_address: 10.0.0.1:4499 protocols: [tcp] `) - assert.Nil(t, err) + require.Nil(t, err) assert.Len(t, client_pf.portForwardings, 1) client_conn_addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:3399") - assert.Nil(t, err) + require.Nil(t, err) server_conn_addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:5599") - assert.Nil(t, err) + require.Nil(t, err) server_listen_conn, err := net.ListenTCP("tcp", server_conn_addr) - assert.Nil(t, err) + require.Nil(t, err) defer server_listen_conn.Close() server_listen_conn_accepts := tcpListenerNAccept(t, server_listen_conn, 1) client1_conn, err := net.DialTCP("tcp", nil, client_conn_addr) - assert.Nil(t, err) + require.Nil(t, err) defer client1_conn.Close() client1_rcv_chan := startReadToChannel(client1_conn) @@ -326,7 +332,7 @@ port_forwarding: dial_address: 127.0.0.1:5598 protocols: [tcp] `) - assert.Nil(t, err) + require.Nil(t, err) assert.Len(t, server_pf.portForwardings, 1) @@ -337,10 +343,10 @@ port_forwarding: dial_address: 10.0.0.1:4498 protocols: [tcp] `) - assert.Nil(t, err) + require.Nil(t, err) err = server_pf.ApplyChangesByNewFwdList(new_server_fwd_list) - assert.Nil(t, err) + require.Nil(t, err) doTestTcpCommunicationFail(t, "Hello from client 1 side!", client1_conn, client1_server_side_conn) @@ -349,7 +355,7 @@ port_forwarding: client1_server_side_conn, client1_conn) err = client_pf.ApplyChangesByNewFwdList(new_client_fwd_list) - assert.Nil(t, err) + require.Nil(t, err) doTestTcpCommunicationFail(t, "Hello from client 1 side!", client1_conn, client1_server_side_conn) diff --git a/port_forwarder/port_forwarder_udp_test.go b/port_forwarder/port_forwarder_udp_test.go index d2c75606b..4a629a3ea 100644 --- a/port_forwarder/port_forwarder_udp_test.go +++ b/port_forwarder/port_forwarder_udp_test.go @@ -8,6 +8,7 @@ import ( "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/service" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func loadPortFwdConfigFromString(l *logrus.Logger, configStr string) (*PortForwardingList, error) { @@ -65,11 +66,11 @@ func doTestUdpCommunication( } else { n, err = senderConn.Write(data_sent) } - assert.Nil(t, err) + require.Nil(t, err) assert.Equal(t, n, len(data_sent)) pair := <-receiverConn - assert.Nil(t, err) + require.Nil(t, err) assert.Equal(t, data_sent, pair.a) return pair.b } @@ -107,7 +108,7 @@ port_forwarding: dial_address: 127.0.0.1:5599 protocols: [udp] `) - assert.Nil(t, err) + require.Nil(t, err) assert.Len(t, server_pf.portForwardings, 1) @@ -118,27 +119,27 @@ port_forwarding: dial_address: 10.0.0.1:4499 protocols: [udp] `) - assert.Nil(t, err) + require.Nil(t, err) assert.Len(t, client_pf.portForwardings, 1) client_conn_addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:3399") - assert.Nil(t, err) + require.Nil(t, err) server_conn_addr, err := net.ResolveUDPAddr("udp", "127.0.0.1:5599") - assert.Nil(t, err) + require.Nil(t, err) server_listen_conn, err := net.ListenUDP("udp", server_conn_addr) - assert.Nil(t, err) + require.Nil(t, err) defer server_listen_conn.Close() server_listen_rcv_chan := readUdpConnectionToChannel(server_listen_conn) client1_conn, err := net.DialUDP("udp", nil, client_conn_addr) - assert.Nil(t, err) + require.Nil(t, err) defer client1_conn.Close() client1_rcv_chan := readUdpConnectionToChannel(client1_conn) client2_conn, err := net.DialUDP("udp", nil, client_conn_addr) - assert.Nil(t, err) + require.Nil(t, err) defer client2_conn.Close() client2_rcv_chan := readUdpConnectionToChannel(client2_conn) From fa7d1204ad33cbf2cbf9d878445391eff9ee32e4 Mon Sep 17 00:00:00 2001 From: Ulrich Hornung Date: Sun, 15 Sep 2024 12:51:47 +0200 Subject: [PATCH 26/33] service: fix missing destruction of ipstack --- service/service.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/service/service.go b/service/service.go index aacb59591..469b55e31 100644 --- a/service/service.go +++ b/service/service.go @@ -257,7 +257,11 @@ func (s *Service) ListenUDP(address string) (*gonet.UDPConn, error) { } func (s *Service) Wait() error { - return s.eg.Wait() + err := s.eg.Wait() + + s.ipstack.Destroy() + + return err } func (s *Service) Close() error { From e6bcba24ee8a8a670a80597ac290ad1e99413eca Mon Sep 17 00:00:00 2001 From: Ulrich Hornung Date: Mon, 16 Sep 2024 11:25:08 +0200 Subject: [PATCH 27/33] add comment to "unsafe_routes" initialisation --- overlay/user.go | 1 + 1 file changed, 1 insertion(+) diff --git a/overlay/user.go b/overlay/user.go index f8a64fcaf..9ab1430c6 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -34,6 +34,7 @@ func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, tunCidr netip.Prefix } } + // this is needed to enable the "unsafe_routes" feature in combination with port forwarding. d.routeTree.Store(routeTree) return d, nil From f908c1020add51e96a3ea4298149c5732a0509fd Mon Sep 17 00:00:00 2001 From: Ulrich Hornung Date: Mon, 16 Sep 2024 11:29:16 +0200 Subject: [PATCH 28/33] add comment to performance improvement in user-tun --- overlay/user.go | 1 + 1 file changed, 1 insertion(+) diff --git a/overlay/user.go b/overlay/user.go index 9ab1430c6..632ab1817 100644 --- a/overlay/user.go +++ b/overlay/user.go @@ -52,6 +52,7 @@ func NewUserDevice(tunCidr netip.Prefix) (*UserDevice, error) { type UserDevice struct { tunCidr netip.Prefix + // using channel of *buffer.View significantly improves performance outboundChannel chan *buffer.View inboundChannel chan *buffer.View From ba8a03717e41ed9ecc3cf215d0f721fd792238ad Mon Sep 17 00:00:00 2001 From: Ulrich Hornung Date: Mon, 16 Sep 2024 11:41:20 +0200 Subject: [PATCH 29/33] add missing error handling when calling fwd factorie functions --- port_forwarder/builder.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/port_forwarder/builder.go b/port_forwarder/builder.go index 906069e9a..1bd997e82 100644 --- a/port_forwarder/builder.go +++ b/port_forwarder/builder.go @@ -97,7 +97,10 @@ func ParseConfig( return fmt.Errorf("child yml node of \"port_forwarding.%s.%d.protocols\" doesn't support: %s", direction, fwd_idx, proto_str) } - factoryFn(node_map) + err := factoryFn(node_map) + if err != nil { + return fmt.Errorf("child yml node of \"port_forwarding.%s.%d.protocols\" with proto %s - failed to instantiate forwarder: %v", direction, fwd_idx, proto_str, err) + } } } } From ba7880a563490eaa59325904d633603a72feca59 Mon Sep 17 00:00:00 2001 From: Ulrich Hornung Date: Mon, 16 Sep 2024 21:27:49 +0200 Subject: [PATCH 30/33] remove all sleeps in tests - no need for them --- port_forwarder/port_forwarder_tcp_test.go | 9 --------- 1 file changed, 9 deletions(-) diff --git a/port_forwarder/port_forwarder_tcp_test.go b/port_forwarder/port_forwarder_tcp_test.go index 9a883b8a5..ca090d0a1 100644 --- a/port_forwarder/port_forwarder_tcp_test.go +++ b/port_forwarder/port_forwarder_tcp_test.go @@ -4,7 +4,6 @@ import ( "fmt" "net" "testing" - "time" "github.com/slackhq/nebula/service" "github.com/stretchr/testify/assert" @@ -27,7 +26,6 @@ func startReadToChannel(receiverConn net.Conn) <-chan []byte { } }() <-r - time.Sleep(50 * time.Millisecond) return rcv_chan } @@ -96,7 +94,6 @@ func tcpListenerNAccept(t *testing.T, listener *net.TCPListener, n int) <-chan n }() <-r - time.Sleep(50 * time.Millisecond) return c } @@ -183,8 +180,6 @@ port_forwarding: assert.Len(t, server_pf.portForwardings, 1) - time.Sleep(100 * time.Millisecond) - client_pf, err := createPortForwarderFromConfigString(t, cl, client, ` port_forwarding: outbound: @@ -196,8 +191,6 @@ port_forwarding: assert.Len(t, client_pf.portForwardings, 1) - time.Sleep(100 * time.Millisecond) - client_conn_addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:3397") require.Nil(t, err) server_conn_addr, err := net.ResolveTCPAddr("tcp", "127.0.0.1:5597") @@ -209,8 +202,6 @@ port_forwarding: server_listen_conn_accepts := tcpListenerNAccept(t, server_listen_conn, 1) - time.Sleep(100 * time.Millisecond) - client1_conn, err := net.DialTCP("tcp", nil, client_conn_addr) require.Nil(t, err) defer client1_conn.Close() From 39d8332f146920c535d42883560b90d64f1e79f0 Mon Sep 17 00:00:00 2001 From: Ulrich Hornung Date: Mon, 16 Sep 2024 22:59:17 +0200 Subject: [PATCH 31/33] nitpick: use atomic bool instead of bool for usynced thread access --- udp/udp_linux.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/udp/udp_linux.go b/udp/udp_linux.go index ac9cba796..6f06c0424 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -8,6 +8,7 @@ import ( "fmt" "net" "net/netip" + "sync/atomic" "syscall" "unsafe" @@ -23,7 +24,7 @@ import ( type StdConn struct { sysFd int - closed bool + closed atomic.Bool isV4 bool l *logrus.Logger batch int From cd510b3e70b460c7a8381de9685e7fb4e165448d Mon Sep 17 00:00:00 2001 From: Ulrich Hornung Date: Mon, 16 Sep 2024 23:05:21 +0200 Subject: [PATCH 32/33] fix race condition where "CloseAndZero" is executed while still used --- udp/udp_rio_windows.go | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/udp/udp_rio_windows.go b/udp/udp_rio_windows.go index ee7e1e002..eec96c91f 100644 --- a/udp/udp_rio_windows.go +++ b/udp/udp_rio_windows.go @@ -14,6 +14,7 @@ import ( "sync" "sync/atomic" "syscall" + "time" "unsafe" "github.com/sirupsen/logrus" @@ -178,10 +179,11 @@ func (u *RIOConn) receive(buf []byte) (int, windows.RawSockaddrInet6, error) { retry: count = 0 for tries := 0; count == 0 && tries < receiveSpins; tries++ { + if !u.isOpen.Load() { // might have changed since first check before the mutex lock + return 0, windows.RawSockaddrInet6{}, net.ErrClosed + } + if tries > 0 { - if !u.isOpen.Load() { - return 0, windows.RawSockaddrInet6{}, net.ErrClosed - } procyield(1) } @@ -247,6 +249,10 @@ func (u *RIOConn) WriteTo(buf []byte, ip netip.AddrPort) error { u.tx.mu.Lock() defer u.tx.mu.Unlock() + if !u.isOpen.Load() { // might have changed since first check before the mutex lock + return net.ErrClosed + } + count := winrio.DequeueCompletion(u.tx.cq, u.results[:]) if count == 0 && u.tx.isFull { err := winrio.Notify(u.tx.cq) @@ -323,6 +329,14 @@ func (u *RIOConn) Close() error { windows.PostQueuedCompletionStatus(u.rx.iocp, 0, 0, nil) windows.PostQueuedCompletionStatus(u.tx.iocp, 0, 0, nil) + u.rx.mu.Lock() // for waiting till active reader is done + time.Sleep(time.Millisecond * 0) // avoid warning about empty critical section + u.rx.mu.Unlock() + + u.tx.mu.Lock() // for waiting till active writer is done + time.Sleep(time.Millisecond * 0) // avoid warning about empty critical section + u.tx.mu.Unlock() + u.rx.CloseAndZero() u.tx.CloseAndZero() if u.sock != 0 { From c2e4dd94ebc17a295f4b909398f04d3c5dfadf51 Mon Sep 17 00:00:00 2001 From: Ulrich Hornung Date: Tue, 17 Sep 2024 21:34:44 +0200 Subject: [PATCH 33/33] fix the closing of the linux udp reading loop using Shutdown --- udp/udp_linux.go | 37 ++++++++++++++++++++++++++++++++----- 1 file changed, 32 insertions(+), 5 deletions(-) diff --git a/udp/udp_linux.go b/udp/udp_linux.go index 6f06c0424..ae44e582d 100644 --- a/udp/udp_linux.go +++ b/udp/udp_linux.go @@ -8,6 +8,7 @@ import ( "fmt" "net" "net/netip" + "sync" "sync/atomic" "syscall" "unsafe" @@ -25,6 +26,7 @@ import ( type StdConn struct { sysFd int closed atomic.Bool + wg *sync.WaitGroup isV4 bool l *logrus.Logger batch int @@ -81,7 +83,14 @@ func NewListener(l *logrus.Logger, ip netip.Addr, port int, multi bool, batch in //v, err := unix.GetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_INCOMING_CPU) //l.Println(v, err) - return &StdConn{sysFd: fd, isV4: ip.Is4(), l: l, batch: batch}, err + return &StdConn{ + sysFd: fd, + closed: atomic.Bool{}, + wg: &sync.WaitGroup{}, + isV4: ip.Is4(), + l: l, + batch: batch, + }, err } func (u *StdConn) Rebind() error { @@ -123,6 +132,15 @@ func (u *StdConn) LocalAddr() (netip.AddrPort, error) { } func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firewall.ConntrackCacheTicker, q int) { + + u.wg.Add(1) + defer func() { + u.wg.Done() + }() + if u.closed.Load() { + return + } + plaintext := make([]byte, MTU) h := &header.H{} fwPacket := &firewall.Packet{} @@ -144,7 +162,7 @@ func (u *StdConn) ListenOut(r EncReader, lhf LightHouseHandlerFunc, cache *firew return } - if u.closed { + if u.closed.Load() { u.l.Debug("flag for closing connection is set, exiting read loop") return } @@ -321,11 +339,20 @@ func (u *StdConn) getMemInfo(meminfo *[unix.SK_MEMINFO_VARS]uint32) error { } func (u *StdConn) Close() error { - //TODO: this will not interrupt the read loop - if u.closed { + if !u.closed.CompareAndSwap(false, true) { + // already closed by e.g. other thread return nil } - u.closed = true + err := syscall.Shutdown(u.sysFd, syscall.SHUT_RDWR) + if err != nil { + errno, ok := err.(syscall.Errno) + // connection might have been terminated by remote before + wasDisconnected := ok && (errno == syscall.ENOTCONN) + if !wasDisconnected { + panic(fmt.Sprintf("error while shutdown of UDP socket: %v", err)) + } + } + u.wg.Wait() return syscall.Close(u.sysFd) }