From ca4a0a6d15892571f47b7edae7f3cabed4b64cbd Mon Sep 17 00:00:00 2001 From: sukun Date: Fri, 4 Oct 2024 22:03:37 +0530 Subject: [PATCH] add tests for ws and multistream --- p2p/transport/tcpreuse/demultiplex.go | 1 - p2p/transport/tcpreuse/listener_test.go | 171 ++++++++++++++++++------ 2 files changed, 132 insertions(+), 40 deletions(-) diff --git a/p2p/transport/tcpreuse/demultiplex.go b/p2p/transport/tcpreuse/demultiplex.go index 7b69ee35d3..0701e570db 100644 --- a/p2p/transport/tcpreuse/demultiplex.go +++ b/p2p/transport/tcpreuse/demultiplex.go @@ -119,7 +119,6 @@ func ReadSampleFromConn(c net.Conn) (Sample, manet.Conn, error) { if err != nil { return Sample{}, nil, err } - return sc.s, sc, nil } diff --git a/p2p/transport/tcpreuse/listener_test.go b/p2p/transport/tcpreuse/listener_test.go index 8bfe397e0e..db55a055c6 100644 --- a/p2p/transport/tcpreuse/listener_test.go +++ b/p2p/transport/tcpreuse/listener_test.go @@ -10,11 +10,15 @@ import ( "crypto/x509/pkix" "fmt" "math/big" + "net" + "net/http" "testing" "time" + "github.com/gorilla/websocket" ma "github.com/multiformats/go-multiaddr" manet "github.com/multiformats/go-multiaddr/net" + "github.com/multiformats/go-multistream" "github.com/stretchr/testify/require" ) @@ -23,20 +27,11 @@ func selfSignedTLSConfig(t *testing.T) *tls.Config { priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) require.NoError(t, err) - notBefore := time.Now() - notAfter := notBefore.Add(365 * 24 * time.Hour) - - serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128) - serialNumber, err := rand.Int(rand.Reader, serialNumberLimit) - require.NoError(t, err) - certTemplate := x509.Certificate{ - SerialNumber: serialNumber, + SerialNumber: &big.Int{}, Subject: pkix.Name{ Organization: []string{"Test"}, }, - NotBefore: notBefore, - NotAfter: notAfter, KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, BasicConstraintsValid: true, @@ -61,45 +56,143 @@ func getTLSConn(t *testing.T, c manet.Conn) (manet.Conn, error) { return manet.WrapNetConn(tls.Server(c, selfSignedTLSConfig(t))) } +type wsHandler struct{ conns chan *websocket.Conn } + +func (wh wsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + u := websocket.Upgrader{} + c, _ := u.Upgrade(w, r, http.Header{}) + wh.conns <- c +} + func TestListenerSingle(t *testing.T) { listenAddr := ma.StringCast("/ip4/0.0.0.0/tcp/0") + const N = 32 for disableReuseport := range []bool{true, false} { - t.Run(fmt.Sprintf("TLS-reuseport:%v", disableReuseport), func(t *testing.T) { + t.Run(fmt.Sprintf("multistream-reuseport:%v", disableReuseport), func(t *testing.T) { cm := NewConnMgr(false) - l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_TLS) + l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect) require.NoError(t, err) - go func() { - d := tls.Dialer{Config: &tls.Config{InsecureSkipVerify: true}} - for i := 0; i < 100; i++ { - ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) - defer cancel() - conn, err := d.DialContext(ctx, l.Addr().Network(), l.Addr().String()) - if err != nil { - t.Error("failed to dial", err, i) - return - } - buf := make([]byte, 10) - _, err = conn.Write([]byte("hello")) - if err != nil { - t.Error(err) - } - _, err = conn.Read(buf) - if err == nil { - t.Error("expected EOF got nil") - } + d := net.Dialer{} + for i := 0; i < N; i++ { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + conn, err := d.DialContext(ctx, l.Addr().Network(), l.Addr().String()) + if err != nil { + t.Error("failed to dial", err, i) + return + } + lconn := multistream.NewMSSelect(conn, "a") + buf := make([]byte, 10) + _, err = lconn.Write([]byte("")) + if err != nil { + t.Error(err) + } + _, err = lconn.Read(buf) + if err == nil { + t.Error("expected EOF got nil") + } + }() } }() - for i := 0; i < 100; i++ { + for i := 0; i < N; i++ { c, err := l.Accept() require.NoError(t, err) - c, err = getTLSConn(t, c) - require.NoError(t, err) - buf := make([]byte, 10) - n, err := c.Read(buf) - require.NoError(t, err) - require.Equal(t, "hello", string(buf[:n])) - c.Close() + go func() { + cc := multistream.NewMSSelect(c, "a") + buf := make([]byte, 10) + n, err := cc.Read(buf) + require.NoError(t, err) + require.Equal(t, "hello", string(buf[:n])) + c.Close() + }() + } + }) + + t.Run(fmt.Sprintf("WebSocket-reuseport:%v", disableReuseport), func(t *testing.T) { + cm := NewConnMgr(false) + l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP) + require.NoError(t, err) + wh := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)} + go func() { + http.Serve(manet.NetListener(l), wh) + }() + go func() { + d := websocket.Dialer{} + for i := 0; i < N; i++ { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + conn, _, err := d.DialContext(ctx, fmt.Sprintf("ws://%s", l.Addr().String()), http.Header{}) + if err != nil { + t.Error("failed to dial", err, i) + return + } + err = conn.WriteMessage(websocket.TextMessage, []byte("hello")) + if err != nil { + t.Error(err) + } + _, _, err = conn.ReadMessage() + if err == nil { + t.Error("expected EOF got nil") + } + }() + } + }() + for i := 0; i < N; i++ { + c := <-wh.conns + go func() { + msgType, buf, err := c.ReadMessage() + require.NoError(t, err) + require.Equal(t, msgType, websocket.TextMessage) + require.Equal(t, "hello", string(buf)) + c.Close() + }() + } + }) + + t.Run(fmt.Sprintf("WebSocketTLS-reuseport:%v", disableReuseport), func(t *testing.T) { + cm := NewConnMgr(false) + l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_TLS) + require.NoError(t, err) + defer l.Close() + wh := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)} + go func() { + s := http.Server{Handler: wh, TLSConfig: selfSignedTLSConfig(t)} + s.ServeTLS(manet.NetListener(l), "", "") + }() + go func() { + d := websocket.Dialer{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + for i := 0; i < N; i++ { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + conn, _, err := d.DialContext(ctx, fmt.Sprintf("wss://%s", l.Addr().String()), http.Header{}) + if err != nil { + t.Error("failed to dial", err, i) + return + } + err = conn.WriteMessage(websocket.TextMessage, []byte("hello")) + if err != nil { + t.Error(err) + } + _, _, err = conn.ReadMessage() + if err == nil { + t.Error("expected EOF got nil") + } + }() + } + }() + for i := 0; i < N; i++ { + c := <-wh.conns + go func() { + msgType, buf, err := c.ReadMessage() + require.NoError(t, err) + require.Equal(t, msgType, websocket.TextMessage) + require.Equal(t, "hello", string(buf)) + c.Close() + }() } }) }