From 36379126de8eb7f4aade6c8e45f3b8c961a935d0 Mon Sep 17 00:00:00 2001 From: Dmitry Caiman Date: Fri, 28 Feb 2025 21:57:35 +0300 Subject: [PATCH] Added NotificationHandler assignment via SocketConfig. --- sctp.go | 12 +++++++----- sctp_linux.go | 18 ++++++++++-------- sctp_linux_test.go | 42 ++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 57 insertions(+), 15 deletions(-) diff --git a/sctp.go b/sctp.go index add951a..d5a2985 100644 --- a/sctp.go +++ b/sctp.go @@ -636,8 +636,9 @@ func (c *SCTPConn) SetWriteDeadline(t time.Time) error { } type SCTPListener struct { - fd int - m sync.Mutex + fd int + m sync.Mutex + notificationHandler NotificationHandler } func (ln *SCTPListener) Addr() net.Addr { @@ -723,15 +724,16 @@ type SocketConfig struct { // If Control is not nil it is called after the socket is created but before // it is bound or connected. Control func(network, address string, c syscall.RawConn) error - + // NotificationHandler defines actions taken on received notifications when MSG_NOTIFICATION flag is set. + NotificationHandler NotificationHandler // InitMsg is the options to send in the initial SCTP message InitMsg InitMsg } func (cfg *SocketConfig) Listen(net string, laddr *SCTPAddr) (*SCTPListener, error) { - return listenSCTPExtConfig(net, laddr, cfg.InitMsg, cfg.Control) + return listenSCTPExtConfig(net, laddr, cfg.InitMsg, cfg.Control, cfg.NotificationHandler) } func (cfg *SocketConfig) Dial(net string, laddr, raddr *SCTPAddr) (*SCTPConn, error) { - return dialSCTPExtConfig(net, laddr, raddr, cfg.InitMsg, cfg.Control) + return dialSCTPExtConfig(net, laddr, raddr, cfg.InitMsg, cfg.Control, cfg.NotificationHandler) } diff --git a/sctp_linux.go b/sctp_linux.go index 8c36547..10b3645 100644 --- a/sctp_linux.go +++ b/sctp_linux.go @@ -174,11 +174,11 @@ func ListenSCTP(net string, laddr *SCTPAddr) (*SCTPListener, error) { // ListenSCTPExt - start listener on specified address/port with given SCTP options func ListenSCTPExt(network string, laddr *SCTPAddr, options InitMsg) (*SCTPListener, error) { - return listenSCTPExtConfig(network, laddr, options, nil) + return listenSCTPExtConfig(network, laddr, options, nil, nil) } // listenSCTPExtConfig - start listener on specified address/port with given SCTP options and socket configuration -func listenSCTPExtConfig(network string, laddr *SCTPAddr, options InitMsg, control func(network, address string, c syscall.RawConn) error) (*SCTPListener, error) { +func listenSCTPExtConfig(network string, laddr *SCTPAddr, options InitMsg, control func(network, address string, c syscall.RawConn) error, notificationHandler NotificationHandler) (*SCTPListener, error) { af, ipv6only := favoriteAddrFamily(network, laddr, nil, "listen") sock, err := syscall.Socket( af, @@ -232,14 +232,16 @@ func listenSCTPExtConfig(network string, laddr *SCTPAddr, options InitMsg, contr return nil, err } return &SCTPListener{ - fd: sock, - }, nil + fd: sock, + notificationHandler: notificationHandler, + }, + nil } // AcceptSCTP waits for and returns the next SCTP connection to the listener. func (ln *SCTPListener) AcceptSCTP() (*SCTPConn, error) { fd, _, err := syscall.Accept4(ln.fd, 0) - return NewSCTPConn(fd, nil), err + return NewSCTPConn(fd, ln.notificationHandler), err } // Accept waits for and returns the next connection connection to the listener. @@ -259,11 +261,11 @@ func DialSCTP(net string, laddr, raddr *SCTPAddr) (*SCTPConn, error) { // DialSCTPExt - same as DialSCTP but with given SCTP options func DialSCTPExt(network string, laddr, raddr *SCTPAddr, options InitMsg) (*SCTPConn, error) { - return dialSCTPExtConfig(network, laddr, raddr, options, nil) + return dialSCTPExtConfig(network, laddr, raddr, options, nil, nil) } // dialSCTPExtConfig - same as DialSCTP but with given SCTP options and socket configuration -func dialSCTPExtConfig(network string, laddr, raddr *SCTPAddr, options InitMsg, control func(network, address string, c syscall.RawConn) error) (*SCTPConn, error) { +func dialSCTPExtConfig(network string, laddr, raddr *SCTPAddr, options InitMsg, control func(network, address string, c syscall.RawConn) error, notificationHandler NotificationHandler) (*SCTPConn, error) { af, ipv6only := favoriteAddrFamily(network, laddr, raddr, "dial") sock, err := syscall.Socket( af, @@ -315,5 +317,5 @@ func dialSCTPExtConfig(network string, laddr, raddr *SCTPAddr, options InitMsg, if err != nil { return nil, err } - return NewSCTPConn(sock, nil), nil + return NewSCTPConn(sock, notificationHandler), nil } diff --git a/sctp_linux_test.go b/sctp_linux_test.go index 17504bb..aa796bb 100644 --- a/sctp_linux_test.go +++ b/sctp_linux_test.go @@ -19,18 +19,56 @@ package sctp import ( + "errors" "net" "strings" "syscall" "testing" ) +func TestNotificationHandlerAssignmentOnDialing(t *testing.T) { + network := "sctp" + addr := &SCTPAddr{IPAddrs: []net.IPAddr{{IP: net.IPv4(127, 0, 0, 1)}}, Port: 54321} + testErr := errors.New("test error") + notificationHandler := func([]byte) error { return testErr } + + listener, err := ListenSCTP(network, addr) + if err != nil { + t.Fatal(err) + } + conn, err := dialSCTPExtConfig(network, nil, addr, InitMsg{}, nil, notificationHandler) + if err != nil { + t.Fatalf("failed to establish connection due to: %v", err) + } + if conn == nil || conn.notificationHandler(nil) != testErr { + t.Fatalf("notification handler has not been assigned") + } + listener.Close() + conn.Close() +} + +func TestNotificationHandlerAssignmentOnListening(t *testing.T) { + network := "sctp" + addr := &SCTPAddr{IPAddrs: []net.IPAddr{{IP: net.IPv4(127, 0, 0, 1)}}, Port: 54321} + testErr := errors.New("test error") + notificationHandler := func([]byte) error { return testErr } + + listener, err := listenSCTPExtConfig(network, addr, InitMsg{}, nil, notificationHandler) + if err != nil { + t.Fatalf("failed to start listening due to: %v", err) + } + if listener == nil || listener.notificationHandler(nil) != testErr { + t.Fatalf("notification handler has not been assigned") + } + listener.Close() +} + func TestDialUseControlFuncWithoutLocalAddress(t *testing.T) { network := "sctp" raddr := &SCTPAddr{IPAddrs: []net.IPAddr{net.IPAddr{IP: net.IPv4(127, 0, 0, 1)}}} initMsg := InitMsg{} customControlFunc := validationControlFunc(t, network) - conn, err := dialSCTPExtConfig(network, nil, raddr, initMsg, customControlFunc) + conn, err := dialSCTPExtConfig(network, nil, raddr, initMsg, customControlFunc, nil) if err != nil && !strings.Contains(err.Error(), "connection refused") { t.Fatalf("failed to dial connection due to: %v", err) } @@ -41,7 +79,7 @@ func TestListenUseControlFuncWithoutLocalAddress(t *testing.T) { network := "sctp" initMsg := InitMsg{} customControlFunc := validationControlFunc(t, network) - listener, err := listenSCTPExtConfig(network, nil, initMsg, customControlFunc) + listener, err := listenSCTPExtConfig(network, nil, initMsg, customControlFunc, nil) if err != nil { t.Fatalf("failed to start listener: %v", err) }