Skip to content

Commit

Permalink
add tests for ws and multistream
Browse files Browse the repository at this point in the history
  • Loading branch information
sukunrt committed Oct 4, 2024
1 parent 8350885 commit ca4a0a6
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 40 deletions.
1 change: 0 additions & 1 deletion p2p/transport/tcpreuse/demultiplex.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ func ReadSampleFromConn(c net.Conn) (Sample, manet.Conn, error) {
if err != nil {
return Sample{}, nil, err
}

return sc.s, sc, nil
}

Expand Down
171 changes: 132 additions & 39 deletions p2p/transport/tcpreuse/listener_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,15 @@ import (
"crypto/x509/pkix"
"fmt"
"math/big"
"net"
"net/http"
"testing"
"time"

"github.com/gorilla/websocket"
ma "github.com/multiformats/go-multiaddr"
manet "github.com/multiformats/go-multiaddr/net"
"github.com/multiformats/go-multistream"
"github.com/stretchr/testify/require"
)

Expand All @@ -23,20 +27,11 @@ func selfSignedTLSConfig(t *testing.T) *tls.Config {
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)

notBefore := time.Now()
notAfter := notBefore.Add(365 * 24 * time.Hour)

serialNumberLimit := new(big.Int).Lsh(big.NewInt(1), 128)
serialNumber, err := rand.Int(rand.Reader, serialNumberLimit)
require.NoError(t, err)

certTemplate := x509.Certificate{
SerialNumber: serialNumber,
SerialNumber: &big.Int{},
Subject: pkix.Name{
Organization: []string{"Test"},
},
NotBefore: notBefore,
NotAfter: notAfter,
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
Expand All @@ -61,45 +56,143 @@ func getTLSConn(t *testing.T, c manet.Conn) (manet.Conn, error) {
return manet.WrapNetConn(tls.Server(c, selfSignedTLSConfig(t)))
}

type wsHandler struct{ conns chan *websocket.Conn }

func (wh wsHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
u := websocket.Upgrader{}
c, _ := u.Upgrade(w, r, http.Header{})
wh.conns <- c
}

func TestListenerSingle(t *testing.T) {
listenAddr := ma.StringCast("/ip4/0.0.0.0/tcp/0")
const N = 32
for disableReuseport := range []bool{true, false} {
t.Run(fmt.Sprintf("TLS-reuseport:%v", disableReuseport), func(t *testing.T) {
t.Run(fmt.Sprintf("multistream-reuseport:%v", disableReuseport), func(t *testing.T) {
cm := NewConnMgr(false)
l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_TLS)
l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_MultistreamSelect)
require.NoError(t, err)

go func() {
d := tls.Dialer{Config: &tls.Config{InsecureSkipVerify: true}}
for i := 0; i < 100; i++ {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
conn, err := d.DialContext(ctx, l.Addr().Network(), l.Addr().String())
if err != nil {
t.Error("failed to dial", err, i)
return
}
buf := make([]byte, 10)
_, err = conn.Write([]byte("hello"))
if err != nil {
t.Error(err)
}
_, err = conn.Read(buf)
if err == nil {
t.Error("expected EOF got nil")
}
d := net.Dialer{}
for i := 0; i < N; i++ {
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
defer cancel()
conn, err := d.DialContext(ctx, l.Addr().Network(), l.Addr().String())
if err != nil {
t.Error("failed to dial", err, i)
return
}
lconn := multistream.NewMSSelect(conn, "a")
buf := make([]byte, 10)
_, err = lconn.Write([]byte(""))
if err != nil {
t.Error(err)
}
_, err = lconn.Read(buf)
if err == nil {
t.Error("expected EOF got nil")
}
}()
}
}()
for i := 0; i < 100; i++ {
for i := 0; i < N; i++ {
c, err := l.Accept()
require.NoError(t, err)
c, err = getTLSConn(t, c)
require.NoError(t, err)
buf := make([]byte, 10)
n, err := c.Read(buf)
require.NoError(t, err)
require.Equal(t, "hello", string(buf[:n]))
c.Close()
go func() {
cc := multistream.NewMSSelect(c, "a")
buf := make([]byte, 10)
n, err := cc.Read(buf)
require.NoError(t, err)
require.Equal(t, "hello", string(buf[:n]))
c.Close()
}()
}
})

t.Run(fmt.Sprintf("WebSocket-reuseport:%v", disableReuseport), func(t *testing.T) {
cm := NewConnMgr(false)
l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_HTTP)
require.NoError(t, err)
wh := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)}
go func() {
http.Serve(manet.NetListener(l), wh)
}()
go func() {
d := websocket.Dialer{}
for i := 0; i < N; i++ {
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
conn, _, err := d.DialContext(ctx, fmt.Sprintf("ws://%s", l.Addr().String()), http.Header{})
if err != nil {
t.Error("failed to dial", err, i)
return
}
err = conn.WriteMessage(websocket.TextMessage, []byte("hello"))
if err != nil {
t.Error(err)
}
_, _, err = conn.ReadMessage()
if err == nil {
t.Error("expected EOF got nil")
}
}()
}
}()
for i := 0; i < N; i++ {
c := <-wh.conns
go func() {
msgType, buf, err := c.ReadMessage()
require.NoError(t, err)
require.Equal(t, msgType, websocket.TextMessage)
require.Equal(t, "hello", string(buf))
c.Close()
}()
}
})

t.Run(fmt.Sprintf("WebSocketTLS-reuseport:%v", disableReuseport), func(t *testing.T) {
cm := NewConnMgr(false)
l, err := cm.DemultiplexedListen(listenAddr, DemultiplexedConnType_TLS)
require.NoError(t, err)
defer l.Close()
wh := wsHandler{conns: make(chan *websocket.Conn, acceptQueueSize)}
go func() {
s := http.Server{Handler: wh, TLSConfig: selfSignedTLSConfig(t)}
s.ServeTLS(manet.NetListener(l), "", "")
}()
go func() {
d := websocket.Dialer{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}}
for i := 0; i < N; i++ {
go func() {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
conn, _, err := d.DialContext(ctx, fmt.Sprintf("wss://%s", l.Addr().String()), http.Header{})
if err != nil {
t.Error("failed to dial", err, i)
return
}
err = conn.WriteMessage(websocket.TextMessage, []byte("hello"))
if err != nil {
t.Error(err)
}
_, _, err = conn.ReadMessage()
if err == nil {
t.Error("expected EOF got nil")
}
}()
}
}()
for i := 0; i < N; i++ {
c := <-wh.conns
go func() {
msgType, buf, err := c.ReadMessage()
require.NoError(t, err)
require.Equal(t, msgType, websocket.TextMessage)
require.Equal(t, "hello", string(buf))
c.Close()
}()
}
})
}
Expand Down

0 comments on commit ca4a0a6

Please sign in to comment.