-
Notifications
You must be signed in to change notification settings - Fork 1
/
conn_udp_singleport.go
108 lines (85 loc) · 2.41 KB
/
conn_udp_singleport.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
// SPDX-FileCopyrightText: 2023 Steffen Vogel <[email protected]>
// SPDX-License-Identifier: Apache-2.0
//go:build linux
package rosenpass
import (
"fmt"
"log/slog"
"net"
ebpfx "cunicu.li/go-rosenpass/internal/ebpf"
netx "cunicu.li/go-rosenpass/internal/net"
)
var _ Conn = (*SinglePortUDPConn)(nil)
type SinglePortUDPConn struct {
listenAddrs []*net.UDPAddr
conns map[string]*netx.RawUDPConn
logger *slog.Logger
}
func NewSinglePortUDPConn(la []*net.UDPAddr) (*SinglePortUDPConn, error) {
return &SinglePortUDPConn{
listenAddrs: la,
conns: map[string]*netx.RawUDPConn{},
logger: slog.Default(),
}, nil
}
func (c *SinglePortUDPConn) Close() error {
for _, conn := range c.conns {
if err := conn.Close(); err != nil {
return err
}
}
return nil
}
func (c *SinglePortUDPConn) Send(pl payload, spkt spk, ep Endpoint) error {
uep, ok := ep.(*UDPEndpoint)
if !ok {
return errInvalidEndpoint
}
addr := (*net.UDPAddr)(uep)
network := networkFromAddr(addr)
// Check if we are on DragonFly or OpenBSD systems
// which require two independent sockets for listening
// on IPv4 and IPv6 simultaneously
conn, ok := c.conns[network]
if !ok {
if conn, ok = c.conns["udp"]; !ok { // Fallback
return fmt.Errorf("failed to find socket with matching address family")
}
}
return sendToConn(conn, addr, pl, spkt)
}
func (c *SinglePortUDPConn) Open() (recvFncs []ReceiveFunc, err error) {
networks := map[string]*net.UDPAddr{}
for _, la := range c.listenAddrs {
if network := networkFromAddr(la); network == "udp" {
networks["udp4"] = la
networks["udp6"] = la
} else {
networks[network] = la
}
}
for network, lAddr := range networks {
conn, err := netx.ListenRawUDP(network, lAddr)
if err != nil {
return nil, fmt.Errorf("failed to listen: %w", err)
}
if err = conn.FilterEBpf(ebpfx.RosenpassFilterEbpf(lAddr.Port)); err != nil {
return nil, fmt.Errorf("failed to apply eBPF filter: %w", err)
}
c.logger.Debug("Started listening", slog.Any("addr", lAddr))
c.conns[network] = conn
recvFncs = append(recvFncs, receiveFromConn(conn))
}
return recvFncs, nil
}
func (c *SinglePortUDPConn) LocalEndpoints() (eps []Endpoint, err error) {
for _, sc := range c.conns {
la := sc.LocalAddr()
lua, ok := la.(*net.UDPAddr)
if !ok {
return nil, fmt.Errorf("invalid address type encountered")
}
eps = append(eps, (*UDPEndpoint)(lua))
}
return eps, nil
}