diff --git a/client/cli/base.go b/client/cli/base.go index 588d4dd9..a6ac83ac 100644 --- a/client/cli/base.go +++ b/client/cli/base.go @@ -112,6 +112,7 @@ func reloadConfig(configPtr *string, client *clients.Client) error { } client.SetLocalFeature(feat, en) } + client.FirewallMark = config.FirewallMark client.SetDefaultGateway = config.Tunnel.SetDefaultGateway client.ServerURL = dest client.InterfaceConfig = &config.Interface diff --git a/client/cli/client.example.yml b/client/cli/client.example.yml index 76c3668f..b51f5354 100644 --- a/client/cli/client.example.yml +++ b/client/cli/client.example.yml @@ -12,6 +12,8 @@ interface: persist: false component-id: root\tap0901 # Windows only. Defaults: root\tap0901 or tap0901 +firewall-mark: 0 # Linux only. Set to positive integer to mark packets with this value in the firewall + scripts: # These scripts get run as "args... operation subnet interface" # Pass in an array, first argument is the executable, further arguments diff --git a/client/cli/config.go b/client/cli/config.go index a0308cc8..7f76d8cc 100644 --- a/client/cli/config.go +++ b/client/cli/config.go @@ -23,7 +23,8 @@ type Config struct { Features features.Config `yaml:"features"` } `yaml:"tunnel"` - Interface iface.InterfaceConfig `yaml:"interface"` + Interface iface.InterfaceConfig `yaml:"interface"` + FirewallMark int `yaml:"firewall-mark"` Scripts shared.EventConfig `yaml:"scripts"` diff --git a/client/clients/base.go b/client/clients/base.go index 5811fced..f1801e63 100644 --- a/client/clients/base.go +++ b/client/clients/base.go @@ -23,6 +23,7 @@ type Client struct { ProxyURL *url.URL ServerURL *url.URL Headers http.Header + FirewallMark int SetDefaultGateway bool SocketConfigurator sockets.SocketConfigurator InterfaceConfig *iface.InterfaceConfig diff --git a/client/clients/connection.go b/client/clients/connection.go index ed4e1c3d..e975a1e3 100644 --- a/client/clients/connection.go +++ b/client/clients/connection.go @@ -3,6 +3,7 @@ package clients import ( "crypto/tls" "fmt" + "net" "net/http" "net/url" "strings" @@ -28,6 +29,13 @@ func (c *Client) GetServerURL() *url.URL { return c.ServerURL } +func (c *Client) EnhanceConn(conn net.Conn) error { + if c.FirewallMark <= 0 { + return nil + } + return setFirewallMark(conn, c.FirewallMark) +} + func (c *Client) RegisterDefaultConnectors() { c.registerConnector(connectors.NewWebSocketConnector()) c.registerConnector(connectors.NewWebTransportConnector()) diff --git a/client/clients/connection_linux.go b/client/clients/connection_linux.go new file mode 100644 index 00000000..17b10dc2 --- /dev/null +++ b/client/clients/connection_linux.go @@ -0,0 +1,52 @@ +//go:build linux + +package clients + +import ( + "crypto/tls" + "errors" + "log" + "net" + "syscall" + + "golang.org/x/sys/unix" +) + +const fwmarkIoctl int = 36 /* unix.SO_MARK */ + +var ErrUnknownConnType = errors.New("not a known conn type") + +func setFirewallMark(conn net.Conn, mark int) error { + var err error + var syscallConn syscall.Conn + + switch typedConn := conn.(type) { + case syscall.Conn: + syscallConn = typedConn + case *tls.Conn: + return setFirewallMark(typedConn.NetConn(), mark) + default: + log.Printf("Unknown conn type: %T (%v)", typedConn, typedConn) + err = ErrUnknownConnType + } + + if err != nil { + return err + } + + var operr error + fd, err := syscallConn.SyscallConn() + if err != nil { + return err + } + + err = fd.Control(func(fd uintptr) { + operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark)) + }) + + if err == nil { + return operr + } + + return err +} diff --git a/client/clients/connection_other.go b/client/clients/connection_other.go new file mode 100644 index 00000000..0ae212d2 --- /dev/null +++ b/client/clients/connection_other.go @@ -0,0 +1,11 @@ +//go:build !linux + +package clients + +import ( + "net" +) + +func setFirewallMark(conn net.Conn, mark int) error { + return nil +} diff --git a/client/connectors/base.go b/client/connectors/base.go index 2de739c7..dcbe694a 100644 --- a/client/connectors/base.go +++ b/client/connectors/base.go @@ -2,6 +2,7 @@ package connectors import ( "crypto/tls" + "net" "net/http" "net/url" @@ -18,4 +19,5 @@ type SocketConnectorConfig interface { GetTLSConfig() *tls.Config GetHeaders() http.Header GetServerURL() *url.URL + EnhanceConn(conn net.Conn) error } diff --git a/client/connectors/websocket.go b/client/connectors/websocket.go index cc392ae0..79f74541 100644 --- a/client/connectors/websocket.go +++ b/client/connectors/websocket.go @@ -57,6 +57,11 @@ func (c *WebSocketConnector) Dial(config SocketConnectorConfig) (adapters.Socket return nil, err } + err = config.EnhanceConn(conn) + if err != nil { + return nil, err + } + serializationType := readSerializationType(respHeaders) return adapters.NewWebSocketAdapter(conn, serializationType, false, reader), nil } diff --git a/client/connectors/webtransport.go b/client/connectors/webtransport.go index da87a8e6..c2cc33e8 100644 --- a/client/connectors/webtransport.go +++ b/client/connectors/webtransport.go @@ -2,7 +2,9 @@ package connectors import ( "context" + "crypto/tls" "errors" + "net" "github.com/Doridian/wsvpn/shared/sockets/adapters" "github.com/quic-go/quic-go" @@ -19,13 +21,40 @@ func NewWebTransportConnector() *WebTransportConnector { return &WebTransportConnector{} } +type quicDialer struct { + transport *quic.Transport +} + +func (d *quicDialer) Dial(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) { + udpAddr, err := net.ResolveUDPAddr("udp", addr) + if err != nil { + return nil, err + } + return d.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg) +} + func (c *WebTransportConnector) Dial(config SocketConnectorConfig) (adapters.SocketAdapter, error) { serverURL := *config.GetServerURL() serverURL.Scheme = "https" + udpConn, err := net.ListenUDP("udp", nil) + if err != nil { + return nil, err + } + err = config.EnhanceConn(udpConn) + if err != nil { + _ = udpConn.Close() + return nil, err + } + quicDialer := &quicDialer{ + transport: &quic.Transport{Conn: udpConn}, + } + var dialer webtransport.Dialer if dialer.RoundTripper == nil { - dialer.RoundTripper = &http3.RoundTripper{} + dialer.RoundTripper = &http3.RoundTripper{ + Dial: quicDialer.Dial, + } } dialer.TLSClientConfig = config.GetTLSConfig() diff --git a/tests/bins.py b/tests/bins.py index d2c13e35..3f63e500 100644 --- a/tests/bins.py +++ b/tests/bins.py @@ -78,6 +78,7 @@ def __init__(self, proj: str) -> None: self.cfg["client"]["headers"] = { "X-Test-Header": ["test-header-value"] } + self.cfg["firewall-mark"] = 1337 self.proc_wait_cond = Condition() self.is_ready_or_done = False