Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add XDP redirect example #5

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 215 additions & 0 deletions bpf/xdp_udp.bpf.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
// +build ignore

#include <linux/bpf_common.h>
#include <linux/if_ether.h>
#include <linux/in.h>
#include <linux/ip.h>
#include <linux/udp.h>
#include <linux/bpf.h>
#include <bpf_helpers.h>
#include <bpf_endian.h>

char __license[] SEC("license") = "GPL";

#ifndef memcpy
#define memcpy(dest, src, n) __builtin_memcpy((dest), (src), (n))
#endif

#define MAX_BACKENDS 128
#define MAX_UDP_LENGTH 1480

#define UDP_PAYLOAD_SIZE(x) (unsigned int)(((bpf_htons(x) - sizeof(struct udphdr)) * 8 ) / 4)
astoycos marked this conversation as resolved.
Show resolved Hide resolved

static __always_inline void ip_from_int(__u32 *buf, __be32 ip) {
buf[0] = (ip >> 0 ) & 0xFF;
buf[1] = (ip >> 8 ) & 0xFF;
buf[2] = (ip >> 16 ) & 0xFF;
buf[3] = (ip >> 24 ) & 0xFF;
}

static __always_inline void bpf_printk_ip(__be32 ip) {
__u32 ip_parts[4];
ip_from_int((__u32 *)&ip_parts, ip);
bpf_printk("%d.%d.%d.%d", ip_parts[0], ip_parts[1], ip_parts[2], ip_parts[3]);
}

static __always_inline __u16 csum_fold_helper(__u64 csum) {
int i;
#pragma unroll
for (i = 0; i < 4; i++)
{
if (csum >> 16)
csum = (csum & 0xffff) + (csum >> 16);
}
return ~csum;
}

static __always_inline __u16 iph_csum(struct iphdr *iph) {
iph->check = 0;
unsigned long long csum = bpf_csum_diff(0, 0, (unsigned int *)iph, sizeof(struct iphdr), 0);
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be possible to use bpf_csum_diff to just recompute the checksum for the changed fields, i.e. just the src and dest ip pseudo headers in the UDP csum. Need to see if we can find an example of a bpf program that does this.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed I had tried that in the past with little luck which is why I did it so manually here, but it would definitely be much easier

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I tried this

static __always_inline __u16 udp_csum_diff(struct udphdr *udp) {
    udp->check = 0;
    unsigned long long csum = bpf_csum_diff(0, 0, (unsigned int *)udp, sizeof(struct udphdr), 0);
    return csum_fold_helper(csum);
}

But it doesn't return the right values (I know the Manual Cksum calc is working)

[016] dNs3. 354704.109793: bpf_trace_printk: Manual Cksum: C8A7 Diff Cksum 8B14

return csum_fold_helper(csum);
}

static __always_inline __u16 udp_checksum(struct iphdr *ip, struct udphdr * udp, void * data_end) {
udp->check = 0;

// So we can overflow a bit make this __u32
__u32 csum_total = 0;
__u16 *buf = (void *)udp;

csum_total += (__u16)ip->saddr;
csum_total += (__u16)(ip->saddr >> 16);
csum_total += (__u16)ip->daddr;
csum_total += (__u16)(ip->daddr >> 16);
csum_total += (__u16)(ip->protocol << 8);
csum_total += udp->len;

// The number of nibbles in the UDP header + Payload
unsigned int udp_packet_nibbles = UDP_PAYLOAD_SIZE(udp->len);

// Here we only want to iterate through payload
// NOT trailing bits
for (int i = 0; i <= MAX_UDP_LENGTH; i += 2) {
if (i > udp_packet_nibbles) {
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this test will ever be true because nibbles is 2 * number of bytes. The if statement on line 78 will end the loop.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I'm still getting confused here... maybe it's because i'm using the term nibble incorrectly

So the issue is, data_end is a PTR which includes the payload AND any trailing bits, while I forced udp_packet_nibbles to be the number of "digits" in the raw hex payload...

For example in the following packet

13:23:15.756911 06:56:87:ec:fd:1f > 86:ad:33:29:ff:5e, ethertype IPv4 (0x0800), length 60: (tos 0x0, ttl 57, id 20891, offset 0, flags [DF], proto UDP (17), length 33)
    10.8.125.12.58980 > 192.168.10.2.sapv1: [bad udp cksum 0xd301 -> 0xaf43!] UDP, length 5
        0x0000:  86ad 3329 ff5e 0656 87ec fd1f 0800 4500
        0x0010:  0021 519b 4000 3911 9e72 0a08 7d0c c0a8
        0x0020:  0a02 e664 2693 000d d301 7465 7374 0a00
        0x0030:  0000 0000 d2f2 935d 0000 0000

The UDP header + Data is

e664 2693 000d d301 7465 7374 0a00
0000 0000 d2f2 935d 0000 0000

so

    for (int i = 0; i <= MAX_UDP_LENGTH; i += 2) {
        if ((void *)(buf + 1) > data_end) {
            break;
        }
        csum_total += *buf;
        buf++;
    }

would iterate through all of that ^^ (i.e increment buff and read 13 times == 26 bytes of data)

Here packet length is 000d(13 bytes) so I know that I want to exit after incrementing buff 6.5 times ignoring those last 6.5 iterations (13 bytes of data)

The confusing part here is that i is a nibble index (1 hex digit (4 bytes)) (maybe) here, so that we can represent those 6 16 bit iterations and 1 8 bit iteration

I guess I'm confused trying to calculate 6 from the packet data

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A nibble is half a byte, i.e. 4 bits, 0..15 (what I think you are calling a hex digit). It's not useful to think about nibbles or try and do anything in terms of nibbles, except if you are writing your own code to convert numbers to hex strings.

__u16 *buf = (void *)udp;

This declares buf to be a pointer that will point to __u16 width values. buf++ will increment it by 2, to point at the next __u16 width value:

__u16 *buf = 0xf00;
printf("%x\n", buf);
printf("%x\n", buf + 1);
f00
f02

This means that if ((void *)(buf + 1) > data_end) would trigger too early because buf + 1 is actually +2 and would exit even if there were 2 bytes to be read.

It's safer to do all the arithmetic in bytes since the payload is a byte stream and the udp->len is in bytes.

Copy link
Author

@astoycos astoycos Nov 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated this to be much simpler

break;
}

if ((void *)(buf + 1) > data_end) {
break;
}
csum_total += *buf;
buf++;
}

if ((void *)buf + 1 <= data_end) {
astoycos marked this conversation as resolved.
Show resolved Hide resolved
csum_total += (*(__u8 *)buf);
}

return csum_fold_helper(csum_total);
}

struct backend {
__u32 saddr;
__u32 daddr;
__u16 dport;
__u8 shwaddr[6];
__u8 dhwaddr[6];
__u16 ifindex;
// Cksum isn't required for UDP see:
// https://en.wikipedia.org/wiki/User_Datagram_Protocol
__u8 nocksum;
__u8 pad[3];
};


struct vip_key {
__u32 vip;
__u16 port;
__u8 pad[2];
};

struct {
__uint(type, BPF_MAP_TYPE_HASH);
__uint(max_entries, MAX_BACKENDS);
__type(key, struct vip_key);
__type(value, struct backend);
} backends SEC(".maps");

SEC("xdp")
int xdp_prog_func(struct xdp_md *ctx) {
// ---------------------------------------------------------------------------
// Initialize
// ---------------------------------------------------------------------------

void *data = (void *)(long)ctx->data;
void *data_end = (void *)(long)ctx->data_end;

struct ethhdr *eth = data;
if (data + sizeof(struct ethhdr) > data_end) {
bpf_printk("ABORTED: bad ethhdr!");
return XDP_ABORTED;
}

if (bpf_ntohs(eth->h_proto) != ETH_P_IP) {
bpf_printk("PASS: not IP protocol!");
return XDP_PASS;
}

struct iphdr *ip = data + sizeof(struct ethhdr);
if (data + sizeof(struct ethhdr) + sizeof(struct iphdr) > data_end) {
bpf_printk("ABORTED: bad iphdr!");
return XDP_ABORTED;
}

if (ip->protocol != IPPROTO_UDP)
return XDP_PASS;

struct udphdr *udp = data + sizeof(struct ethhdr) + sizeof(struct iphdr);
if (data + sizeof(struct ethhdr) + sizeof(struct iphdr) + sizeof(struct udphdr) > data_end) {
bpf_printk("ABORTED: bad udphdr!");
return XDP_ABORTED;
}

bpf_printk("UDP packet received - daddr:%x, port:%d", ip->daddr, bpf_ntohs(udp->dest));

// ---------------------------------------------------------------------------
// Routing
// ---------------------------------------------------------------------------

struct vip_key key = {
.vip = ip->daddr,
.port = bpf_ntohs(udp->dest)
};

struct backend *bk;
bk = bpf_map_lookup_elem(&backends, &key);
if (!bk) {
bpf_printk("no backends for ip %x:%x", key.vip, key.port);
return XDP_PASS;
}

bpf_printk("got UDP traffic, source address:");
bpf_printk_ip(ip->saddr);
bpf_printk("destination address:");
bpf_printk_ip(ip->daddr);

ip->saddr = bk->saddr;
ip->daddr = bk->daddr;

bpf_printk("updated saddr to:");
bpf_printk_ip(ip->saddr);
bpf_printk("updated daddr to:");
bpf_printk_ip(ip->daddr);

if (udp->dest != bpf_ntohs(bk->dport)) {
udp->dest = bpf_ntohs(bk->dport);
bpf_printk("updated dport to: %d", bk->dport);
}

memcpy(eth->h_source, bk->shwaddr, sizeof(eth->h_source));
bpf_printk("new source hwaddr %x:%x:%x:%x:%x:%x", eth->h_source[0], eth->h_source[1], eth->h_source[2], eth->h_source[3], eth->h_source[4], eth->h_source[5]);

memcpy(eth->h_dest, bk->dhwaddr, sizeof(eth->h_dest));
bpf_printk("new dest hwaddr %x:%x:%x:%x:%x:%x", eth->h_dest[0], eth->h_dest[1], eth->h_dest[2], eth->h_dest[3], eth->h_dest[4], eth->h_dest[5]);

ip->check = iph_csum(ip);
udp->check = 0;

if (!bk->nocksum){
udp->check = udp_checksum(ip, udp, data_end);
}

bpf_printk("destination interface index %d", bk->ifindex);

int action = bpf_redirect(bk->ifindex, 0);

bpf_printk("redirect action: %d", action);

return action;
}

SEC("xdp")
int bpf_redirect_placeholder(struct xdp_md *ctx) {
bpf_printk("received a packet on dest interface");
return XDP_PASS;
}
File renamed without changes.
27 changes: 27 additions & 0 deletions userspace-go/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@


TAG ?= latest

LIBBPF ?= ../libbpf/src
CLANG ?= clang
CFLAGS := -O2 -g -Wall -Werror -Wno-unused-value -Wno-pointer-sign -Wcompare-distinct-pointer-types -I$(LIBBPF) $(CFLAGS)

all: build

.PHONY:
clean:
rm -f bpf_bpfeb.go
rm -f bpf_bpfeb.o
rm -f bpf_bpfel.go
rm -f bpf_bpfel.o
rm -f blixt-dataplane

.PHONY: generate
generate: export BPF_CLANG := $(CLANG)
generate: export BPF_CFLAGS := $(CFLAGS)
generate:
go generate ./...

.PHONY: build
build: generate
go build -o blixt-dataplane
Loading