Skip to content

Commit

Permalink
Allow firewall marking packets on Linux (#436)
Browse files Browse the repository at this point in the history
  • Loading branch information
Doridian authored Jan 7, 2024
1 parent 0f40387 commit 398a8b8
Show file tree
Hide file tree
Showing 11 changed files with 115 additions and 2 deletions.
1 change: 1 addition & 0 deletions client/cli/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ func reloadConfig(configPtr *string, client *clients.Client) error {
}
client.SetLocalFeature(feat, en)
}
client.FirewallMark = config.FirewallMark
client.SetDefaultGateway = config.Tunnel.SetDefaultGateway
client.ServerURL = dest
client.InterfaceConfig = &config.Interface
Expand Down
2 changes: 2 additions & 0 deletions client/cli/client.example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ interface:
persist: false
component-id: root\tap0901 # Windows only. Defaults: root\tap0901 or tap0901

firewall-mark: 0 # Linux only. Set to positive integer to mark packets with this value in the firewall

scripts:
# These scripts get run as "args... operation subnet interface"
# Pass in an array, first argument is the executable, further arguments
Expand Down
3 changes: 2 additions & 1 deletion client/cli/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@ type Config struct {
Features features.Config `yaml:"features"`
} `yaml:"tunnel"`

Interface iface.InterfaceConfig `yaml:"interface"`
Interface iface.InterfaceConfig `yaml:"interface"`
FirewallMark int `yaml:"firewall-mark"`

Scripts shared.EventConfig `yaml:"scripts"`

Expand Down
1 change: 1 addition & 0 deletions client/clients/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ type Client struct {
ProxyURL *url.URL
ServerURL *url.URL
Headers http.Header
FirewallMark int
SetDefaultGateway bool
SocketConfigurator sockets.SocketConfigurator
InterfaceConfig *iface.InterfaceConfig
Expand Down
8 changes: 8 additions & 0 deletions client/clients/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package clients
import (
"crypto/tls"
"fmt"
"net"
"net/http"
"net/url"
"strings"
Expand All @@ -28,6 +29,13 @@ func (c *Client) GetServerURL() *url.URL {
return c.ServerURL
}

func (c *Client) EnhanceConn(conn net.Conn) error {
if c.FirewallMark <= 0 {
return nil
}
return setFirewallMark(conn, c.FirewallMark)
}

func (c *Client) RegisterDefaultConnectors() {
c.registerConnector(connectors.NewWebSocketConnector())
c.registerConnector(connectors.NewWebTransportConnector())
Expand Down
52 changes: 52 additions & 0 deletions client/clients/connection_linux.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
//go:build linux

package clients

import (
"crypto/tls"
"errors"
"log"
"net"
"syscall"

"golang.org/x/sys/unix"
)

const fwmarkIoctl int = 36 /* unix.SO_MARK */

var ErrUnknownConnType = errors.New("not a known conn type")

func setFirewallMark(conn net.Conn, mark int) error {
var err error
var syscallConn syscall.Conn

switch typedConn := conn.(type) {
case syscall.Conn:
syscallConn = typedConn
case *tls.Conn:
return setFirewallMark(typedConn.NetConn(), mark)
default:
log.Printf("Unknown conn type: %T (%v)", typedConn, typedConn)
err = ErrUnknownConnType
}

if err != nil {
return err
}

var operr error
fd, err := syscallConn.SyscallConn()
if err != nil {
return err
}

err = fd.Control(func(fd uintptr) {
operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
})

if err == nil {
return operr
}

return err
}
11 changes: 11 additions & 0 deletions client/clients/connection_other.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
//go:build !linux

package clients

import (
"net"
)

func setFirewallMark(conn net.Conn, mark int) error {
return nil
}
2 changes: 2 additions & 0 deletions client/connectors/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package connectors

import (
"crypto/tls"
"net"
"net/http"
"net/url"

Expand All @@ -18,4 +19,5 @@ type SocketConnectorConfig interface {
GetTLSConfig() *tls.Config
GetHeaders() http.Header
GetServerURL() *url.URL
EnhanceConn(conn net.Conn) error
}
5 changes: 5 additions & 0 deletions client/connectors/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ func (c *WebSocketConnector) Dial(config SocketConnectorConfig) (adapters.Socket
return nil, err
}

err = config.EnhanceConn(conn)
if err != nil {
return nil, err
}

serializationType := readSerializationType(respHeaders)
return adapters.NewWebSocketAdapter(conn, serializationType, false, reader), nil
}
Expand Down
31 changes: 30 additions & 1 deletion client/connectors/webtransport.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ package connectors

import (
"context"
"crypto/tls"
"errors"
"net"

"github.com/Doridian/wsvpn/shared/sockets/adapters"
"github.com/quic-go/quic-go"
Expand All @@ -19,13 +21,40 @@ func NewWebTransportConnector() *WebTransportConnector {
return &WebTransportConnector{}
}

type quicDialer struct {
transport *quic.Transport
}

func (d *quicDialer) Dial(ctx context.Context, addr string, tlsCfg *tls.Config, cfg *quic.Config) (quic.EarlyConnection, error) {
udpAddr, err := net.ResolveUDPAddr("udp", addr)
if err != nil {
return nil, err
}
return d.transport.DialEarly(ctx, udpAddr, tlsCfg, cfg)
}

func (c *WebTransportConnector) Dial(config SocketConnectorConfig) (adapters.SocketAdapter, error) {
serverURL := *config.GetServerURL()
serverURL.Scheme = "https"

udpConn, err := net.ListenUDP("udp", nil)
if err != nil {
return nil, err
}
err = config.EnhanceConn(udpConn)
if err != nil {
_ = udpConn.Close()
return nil, err
}
quicDialer := &quicDialer{
transport: &quic.Transport{Conn: udpConn},
}

var dialer webtransport.Dialer
if dialer.RoundTripper == nil {
dialer.RoundTripper = &http3.RoundTripper{}
dialer.RoundTripper = &http3.RoundTripper{
Dial: quicDialer.Dial,
}
}
dialer.TLSClientConfig = config.GetTLSConfig()

Expand Down
1 change: 1 addition & 0 deletions tests/bins.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def __init__(self, proj: str) -> None:
self.cfg["client"]["headers"] = {
"X-Test-Header": ["test-header-value"]
}
self.cfg["firewall-mark"] = 1337

self.proc_wait_cond = Condition()
self.is_ready_or_done = False
Expand Down

0 comments on commit 398a8b8

Please sign in to comment.