-
Notifications
You must be signed in to change notification settings - Fork 0
/
rakelimit.go
110 lines (92 loc) · 2.61 KB
/
rakelimit.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
package socklimit
import (
"errors"
"fmt"
"syscall"
"github.com/cilium/ebpf"
"github.com/cilium/ebpf/asm"
"golang.org/x/sys/unix"
)
//go:generate go run github.com/cilium/ebpf/cmd/bpf2go -cc clang-14 rake ./src/rakelimit.c -- -I./include -nostdinc -Os
// Limiter holds an instance of a ratelimiter that can be applied on a socket
type Limiter struct {
domain int
program *ebpf.Program
bpfObjects *rakeObjects
}
// New creates a new Rakelimit instance based on the specified ppsLimit
func New(conn syscall.Conn, ppsLimit uint32) (*Limiter, error) {
// set ratelimit
spec, err := loadRake()
if err != nil {
return nil, fmt.Errorf("get elf spec: %v", err)
}
if err := rewriteConstant(spec, "LIMIT", uint64(ppsLimit)); err != nil {
return nil, err
}
var objs rakeObjects
if err := spec.LoadAndAssign(&objs, nil); err != nil {
return nil, fmt.Errorf("load BPF: %v", err)
}
raw, err := conn.SyscallConn()
if err != nil {
return nil, fmt.Errorf("raw conn: %s", err)
}
var opErr error
var domain int
var prog *ebpf.Program
if err := raw.Control(func(s uintptr) {
domain, opErr = unix.GetsockoptInt(int(s), unix.SOL_SOCKET, unix.SO_DOMAIN)
if opErr != nil {
opErr = fmt.Errorf("can't retrieve domain: %s", opErr)
return
}
switch domain {
case unix.AF_INET:
prog = objs.FilterIpv4
case unix.AF_INET6:
prog = objs.FilterIpv6
default:
opErr = fmt.Errorf("unsupported socket domain: %d", domain)
return
}
opErr = unix.SetsockoptInt(int(s), unix.SOL_SOCKET, unix.SO_ATTACH_BPF, prog.FD())
if errors.Is(opErr, unix.ENOMEM) {
opErr = fmt.Errorf("attach filter: net.core.optmem_max might be too low: %s", opErr)
return
}
if opErr != nil {
opErr = fmt.Errorf("attach filter: %s", opErr)
}
}); err != nil {
return nil, fmt.Errorf("can't access fd: %s", err)
}
if opErr != nil {
return nil, opErr
}
return &Limiter{domain, prog, &objs}, nil
}
// Close cleans up resources occupied and should be called when finished using the structure
func (rl *Limiter) Close() error {
return rl.bpfObjects.Close()
}
func rewriteConstant(spec *ebpf.CollectionSpec, symbol string, value uint64) error {
rewritten := false
for name, prog := range spec.Programs {
for i := range prog.Instructions {
ins := &prog.Instructions[i]
if ins.Reference() != symbol {
continue
}
if !ins.IsConstantLoad(asm.DWord) {
return fmt.Errorf("program %s: instruction %d: not a dword-sized constant load: %s", name, i, ins)
}
ins.Constant = int64(value)
rewritten = true
}
}
if !rewritten {
return fmt.Errorf("symbol %s is not referenced", symbol)
}
return nil
}