diff --git a/.vscode/settings.json b/.vscode/settings.json index 7a73a41..2dcb4b7 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,2 +1,5 @@ { + "files.associations": { + "bpf_helpers.h": "c" + } } \ No newline at end of file diff --git a/bpf/tc_udp.bpf.c b/bpf/tc_udp.bpf.c new file mode 100644 index 0000000..99fab84 --- /dev/null +++ b/bpf/tc_udp.bpf.c @@ -0,0 +1,217 @@ +// +build ignore + +#include +#include +#include +#include +#include +#include +#include +#include + +char __license[] SEC("license") = "GPL"; + +#ifndef memcpy + #define memcpy(dest, src, n) __builtin_memcpy((dest), (src), (n)) +#endif + +#define MAX_BACKENDS 128 +#define MAX_UDP_LENGTH 1480 + +#define UDP_PAYLOAD_SIZE(x) (unsigned int)(((bpf_htons(x) - sizeof(struct udphdr)) * 8 ) / 4) + +static __always_inline void ip_from_int(__u32 *buf, __be32 ip) { + buf[0] = (ip >> 0 ) & 0xFF; + buf[1] = (ip >> 8 ) & 0xFF; + buf[2] = (ip >> 16 ) & 0xFF; + buf[3] = (ip >> 24 ) & 0xFF; +} + +static __always_inline void bpf_printk_ip(__be32 ip) { + __u32 ip_parts[4]; + ip_from_int((__u32 *)&ip_parts, ip); + bpf_printk("%d.%d.%d.%d", ip_parts[0], ip_parts[1], ip_parts[2], ip_parts[3]); +} + +static __always_inline __u16 csum_fold_helper(__u64 csum) { + int i; +#pragma unroll + for (i = 0; i < 4; i++) + { + if (csum >> 16) + csum = (csum & 0xffff) + (csum >> 16); + } + return ~csum; +} + +static __always_inline __u16 iph_csum(struct iphdr *iph) { + iph->check = 0; + unsigned long long csum = bpf_csum_diff(0, 0, (unsigned int *)iph, sizeof(struct iphdr), 0); + return csum_fold_helper(csum); +} + +// static __always_inline __u16 udp_checksum(struct iphdr *ip, struct udphdr * udp, void * data_end) { +// udp->check = 0; + +// // So we can overflow a bit make this __u32 +// __u32 csum_total = 0; +// __u16 csum; +// __u16 *buf = (void *)udp; + +// csum_total += (__u16)ip->saddr; +// csum_total += (__u16)(ip->saddr >> 16); +// csum_total += (__u16)ip->daddr; +// csum_total += (__u16)(ip->daddr >> 16); +// csum_total += (__u16)(ip->protocol << 8); +// csum_total += udp->len; + +// // The number of nibbles in the UDP header + Payload +// unsigned int udp_packet_nibbles = UDP_PAYLOAD_SIZE(udp->len); + +// // Here we only want to iterate through payload +// // NOT trailing bits +// for (int i = 0; i <= MAX_UDP_LENGTH; i += 2) { +// if (i > udp_packet_nibbles) { +// break; +// } + +// if ((void *)(buf + 1) > data_end) { +// break; +// } +// csum_total += *buf; +// buf++; +// } + +// if ((void *)buf + 1 <= data_end) { +// csum_total += (*(__u8 *)buf); +// } + +// // Add any cksum overflow back into __u16 +// csum = (__u16)csum_total + (__u16)(csum_total >> 16); + +// csum = ~csum; +// return csum; +// } + +struct backend { + __u32 saddr; + __u32 daddr; + __u16 dport; + __u16 ifindex; + // Cksum isn't required for UDP see: + // https://en.wikipedia.org/wiki/User_Datagram_Protocol + __u8 nocksum; + __u8 pad[3]; +}; + +struct vip_key { + __u32 vip; + __u16 port; + __u8 pad[2]; +}; + +struct { + __uint(type, BPF_MAP_TYPE_HASH); + __uint(max_entries, MAX_BACKENDS); + __type(key, struct vip_key); + __type(value, struct backend); +} backends SEC(".maps"); + +SEC("classifier") +int tc_prog_func(struct xdp_md *ctx) { + // --------------------------------------------------------------------------- + // Initialize + // --------------------------------------------------------------------------- + + void *data = (void *)(long)ctx->data; + void *data_end = (void *)(long)ctx->data_end; + + struct ethhdr *eth = data; + if (data + sizeof(struct ethhdr) > data_end) { + bpf_printk("ABORTED: bad ethhdr!"); + return XDP_ABORTED; + } + + if (bpf_ntohs(eth->h_proto) != ETH_P_IP) { + bpf_printk("PASS: not IP protocol!"); + return XDP_PASS; + } + + struct iphdr *ip = data + sizeof(struct ethhdr); + if (data + sizeof(struct ethhdr) + sizeof(struct iphdr) > data_end) { + bpf_printk("ABORTED: bad iphdr!"); + return XDP_ABORTED; + } + + if (ip->protocol != IPPROTO_UDP) + return XDP_PASS; + + struct udphdr *udp = data + sizeof(struct ethhdr) + sizeof(struct iphdr); + if (data + sizeof(struct ethhdr) + sizeof(struct iphdr) + sizeof(struct udphdr) > data_end) { + bpf_printk("ABORTED: bad udphdr!"); + return XDP_ABORTED; + } + + bpf_printk("UDP packet received - daddr:%x, port:%d", ip->daddr, bpf_ntohs(udp->dest)); + + // --------------------------------------------------------------------------- + // Routing + // --------------------------------------------------------------------------- + + struct vip_key key = { + .vip = ip->daddr, + .port = bpf_ntohs(udp->dest) + }; + + struct backend *bk; + bk = bpf_map_lookup_elem(&backends, &key); + if (!bk) { + bpf_printk("no backends for ip %x:%x", key.vip, key.port); + return XDP_PASS; + } + + bpf_printk("got UDP traffic, source address:"); + bpf_printk_ip(ip->saddr); + bpf_printk("destination address:"); + bpf_printk_ip(ip->daddr); + + ip->saddr = bk->saddr; + ip->daddr = bk->daddr; + + bpf_printk("updated saddr to:"); + bpf_printk_ip(ip->saddr); + bpf_printk("updated daddr to:"); + bpf_printk_ip(ip->daddr); + + if (udp->dest != bpf_ntohs(bk->dport)) { + udp->dest = bpf_ntohs(bk->dport); + bpf_printk("updated dport to: %d", bk->dport); + } + +// memcpy(eth->h_source, bk->shwaddr, sizeof(eth->h_source)); +// bpf_printk("new source hwaddr %x:%x:%x:%x:%x:%x", eth->h_source[0], eth->h_source[1], eth->h_source[2], eth->h_source[3], eth->h_source[4], eth->h_source[5]); + +// memcpy(eth->h_dest, bk->dhwaddr, sizeof(eth->h_dest)); +// bpf_printk("new dest hwaddr %x:%x:%x:%x:%x:%x", eth->h_dest[0], eth->h_dest[1], eth->h_dest[2], eth->h_dest[3], eth->h_dest[4], eth->h_dest[5]); + + ip->check = iph_csum(ip); + udp->check = 0; + + if (!bk->nocksum){ + udp->check = udp_checksum(ip, udp, data_end); + } + + bpf_printk("destination interface index %d", bk->ifindex); + + int action = bpf_redirect(bk->ifindex, 0); + + bpf_printk("redirect action: %d", action); + + return action; +} + +// SEC("xdp") +// int bpf_redirect_placeholder(struct xdp_md *ctx) { +// bpf_printk("received a packet on dest interface"); +// return XDP_PASS; +// } \ No newline at end of file diff --git a/bpf/xdp_udp.bpf.c b/bpf/xdp_udp.bpf.c index a513933..20a1c75 100644 --- a/bpf/xdp_udp.bpf.c +++ b/bpf/xdp_udp.bpf.c @@ -18,8 +18,6 @@ char __license[] SEC("license") = "GPL"; #define MAX_BACKENDS 128 #define MAX_UDP_LENGTH 1480 -#define UDP_PAYLOAD_SIZE(x) (unsigned int)(((bpf_htons(x) - sizeof(struct udphdr)) * 8 ) / 4) - static __always_inline void ip_from_int(__u32 *buf, __be32 ip) { buf[0] = (ip >> 0 ) & 0xFF; buf[1] = (ip >> 8 ) & 0xFF; @@ -36,7 +34,7 @@ static __always_inline void bpf_printk_ip(__be32 ip) { static __always_inline __u16 csum_fold_helper(__u64 csum) { int i; #pragma unroll - for (i = 0; i < 4; i++) + for (i = 0; i < 8; i++) { if (csum >> 16) csum = (csum & 0xffff) + (csum >> 16); @@ -50,12 +48,16 @@ static __always_inline __u16 iph_csum(struct iphdr *iph) { return csum_fold_helper(csum); } -static __always_inline __u16 udp_checksum(struct iphdr *ip, struct udphdr * udp, void * data_end) { +static __always_inline __u16 udp_csum_diff(struct udphdr *udp) { udp->check = 0; + unsigned long long csum = bpf_csum_diff(0, 0, (unsigned int *)udp, sizeof(struct udphdr), 0); + return csum_fold_helper(csum); +} +static __always_inline __u16 udp_checksum(struct iphdr *ip, struct udphdr * udp, void * data_end) { // So we can overflow a bit make this __u32 - __u32 csum_total = 0; - __u16 csum; + __u64 csum_total = 0; + __u16 *buf = (void *)udp; csum_total += (__u16)ip->saddr; @@ -65,32 +67,31 @@ static __always_inline __u16 udp_checksum(struct iphdr *ip, struct udphdr * udp, csum_total += (__u16)(ip->protocol << 8); csum_total += udp->len; - // The number of nibbles in the UDP header + Payload - unsigned int udp_packet_nibbles = UDP_PAYLOAD_SIZE(udp->len); + int udp_len = bpf_ntohs(udp->len); + + if (udp_len >= MAX_UDP_LENGTH) { + return 1; + } // Here we only want to iterate through payload // NOT trailing bits - for (int i = 0; i <= MAX_UDP_LENGTH; i += 2) { - if (i > udp_packet_nibbles) { - break; - } - + for (int i = 0; i < udp_len; i += 2) { if ((void *)(buf + 1) > data_end) { break; } - csum_total += *buf; - buf++; - } - if ((void *)buf + 1 <= data_end) { - csum_total += (*(__u8 *)buf); + // Last byte + if (i + 1 == udp_len) { + csum_total += (*(__u8 *)buf); + // Verifier fails without this print statement, I have no Idea why :/ + bpf_printk("Adding last byte %X to csum", (*(__u8 *)buf)); + } else { + csum_total += *buf; + } + buf+=1; } - // Add any cksum overflow back into __u16 - csum = (__u16)csum_total + (__u16)(csum_total >> 16); - - csum = ~csum; - return csum; + return csum_fold_helper(csum_total); } struct backend { @@ -198,10 +199,12 @@ int xdp_prog_func(struct xdp_md *ctx) { bpf_printk("new dest hwaddr %x:%x:%x:%x:%x:%x", eth->h_dest[0], eth->h_dest[1], eth->h_dest[2], eth->h_dest[3], eth->h_dest[4], eth->h_dest[5]); ip->check = iph_csum(ip); + udp->check = 0; - if (!bk->nocksum){ - udp->check = udp_checksum(ip, udp, data_end); + int tmp_check = udp_checksum(ip, udp, data_end); + bpf_printk("Manual Cksum: %X Diff Cksum %X", tmp_check, udp_csum_diff(udp)); + udp->check = tmp_check; } bpf_printk("destination interface index %d", bk->ifindex); diff --git a/userspace-go/bpf_bpfeb.o b/userspace-go/bpf_bpfeb.o index 965a33b..5fc5b77 100644 Binary files a/userspace-go/bpf_bpfeb.o and b/userspace-go/bpf_bpfeb.o differ diff --git a/userspace-go/bpf_bpfel.o b/userspace-go/bpf_bpfel.o index d31f382..301b81a 100644 Binary files a/userspace-go/bpf_bpfel.o and b/userspace-go/bpf_bpfel.o differ diff --git a/userspace-go/userspace-go b/userspace-go/userspace-go index a74ee32..976fddb 100755 Binary files a/userspace-go/userspace-go and b/userspace-go/userspace-go differ diff --git a/userspace-go/xdp_udp.go b/userspace-go/xdp_udp.go index 144b659..e41bbe0 100644 --- a/userspace-go/xdp_udp.go +++ b/userspace-go/xdp_udp.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "encoding/hex" + "errors" "fmt" "log" "net" @@ -32,10 +33,15 @@ func main() { if err != nil { log.Fatalf("lookup network iface %s: %s", ifaceName, err) } - + + var ve *ebpf.VerifierError objs := bpfObjects{} if err := loadBpfObjects(&objs, nil); err != nil { - log.Fatalf("loading objects: %s", err) + if errors.As(err, &ve) { + // Using %+v will print the whole verifier error, not just the last + // few lines. + fmt.Printf("Verifier error: %+v\n", ve) + } } defer objs.Close() @@ -61,20 +67,25 @@ func main() { log.Printf("Press Ctrl-C to exit and remove the program") b := bpfBackend{ + // Hardcoded Src IP (main Nic) Saddr: ip2int("10.8.125.12"), + // Hardcoded Dst IP (container) Daddr: ip2int("192.168.10.2"), + // Hardcoded Dst Port (UDP echo server) Dport: 9875, // Host-Side Veth Mac Shwaddr: hwaddr2bytes("06:56:87:ec:fd:1f"), // Container-Side Veth Mac Dhwaddr: hwaddr2bytes("86:ad:33:29:ff:5e"), - Nocksum: 1, + Nocksum: 0, + // Hardcoded Host side Veth index Ifindex: 8, } key := bpfVipKey{ + // Hardcoded main NIC IP Vip: ip2int("10.8.125.12"), - //Vip: ip2int("192.168.10.1"), + // Hardcoded main NIC port Port: 8888, }