Merge commit 'd803bfe2b1fe7f5e219e50ac20d6801a0a58ac75' as 'vendor/ruvector'
This commit is contained in:
112
vendor/ruvector/crates/rvf/rvf-ebpf/bpf/socket_filter.c
vendored
Normal file
112
vendor/ruvector/crates/rvf/rvf-ebpf/bpf/socket_filter.c
vendored
Normal file
@@ -0,0 +1,112 @@
|
||||
// SPDX-License-Identifier: GPL-2.0
|
||||
//
|
||||
// RVF Socket Filter: Port-Based Access Control
|
||||
//
|
||||
// This BPF socket filter enforces a simple port allow-list for RVF
|
||||
// deployments. Only packets destined for explicitly allowed ports are
|
||||
// passed through; everything else is dropped.
|
||||
//
|
||||
// Allowed ports are stored in a BPF hash map so they can be updated at
|
||||
// runtime from userspace without reloading the program.
|
||||
//
|
||||
// Default allowed ports (populated by userspace loader):
|
||||
// - 8080: RVF API / vector query endpoint
|
||||
// - 2222: SSH management access
|
||||
// - 9090: Prometheus metrics scraping
|
||||
// - 6379: Optional Redis sidecar for caching
|
||||
//
|
||||
// Attach point: SO_ATTACH_BPF on a raw socket, or cgroup/skb.
|
||||
|
||||
#include "vmlinux.h"
|
||||
|
||||
/* ── Configuration ───────────────────────────────────────────────── */
|
||||
|
||||
#define MAX_ALLOWED_PORTS 64
|
||||
|
||||
/* ── BPF maps ────────────────────────────────────────────────────── */
|
||||
|
||||
/* Hash map: allowed destination ports. Key = port number, value = 1 */
|
||||
struct {
|
||||
__uint(type, BPF_MAP_TYPE_HASH);
|
||||
__uint(max_entries, MAX_ALLOWED_PORTS);
|
||||
__type(key, __u16);
|
||||
__type(value, __u8);
|
||||
} allowed_ports SEC(".maps");
|
||||
|
||||
/* Per-CPU array: drop/pass counters for observability */
|
||||
struct port_stats {
|
||||
__u64 passed;
|
||||
__u64 dropped;
|
||||
};
|
||||
|
||||
struct {
|
||||
__uint(type, BPF_MAP_TYPE_PERCPU_ARRAY);
|
||||
__uint(max_entries, 1);
|
||||
__type(key, __u32);
|
||||
__type(value, struct port_stats);
|
||||
} stats SEC(".maps");
|
||||
|
||||
/* ── Helpers ─────────────────────────────────────────────────────── */
|
||||
|
||||
static __always_inline void bump_stat(int is_pass)
|
||||
{
|
||||
__u32 zero = 0;
|
||||
struct port_stats *s = bpf_map_lookup_elem(&stats, &zero);
|
||||
if (s) {
|
||||
if (is_pass)
|
||||
s->passed++;
|
||||
else
|
||||
s->dropped++;
|
||||
}
|
||||
}
|
||||
|
||||
/* ── Socket filter entry point ───────────────────────────────────── */
|
||||
|
||||
SEC("socket")
|
||||
int rvf_port_filter(struct __sk_buff *skb)
|
||||
{
|
||||
/* Load the protocol field from the IP header.
|
||||
* For socket filters attached via SO_ATTACH_BPF, skb->data
|
||||
* starts at the IP header (no Ethernet header). */
|
||||
|
||||
__u8 protocol = 0;
|
||||
/* IP protocol field is at byte offset 9 in the IPv4 header */
|
||||
bpf_skb_load_bytes(skb, 9, &protocol, 1);
|
||||
|
||||
__u16 dport = 0;
|
||||
|
||||
if (protocol == IPPROTO_TCP) {
|
||||
/* TCP dest port: IP header (ihl*4) + offset 2 in TCP header */
|
||||
__u8 ihl_byte = 0;
|
||||
bpf_skb_load_bytes(skb, 0, &ihl_byte, 1);
|
||||
__u32 ip_hdr_len = (ihl_byte & 0x0F) * 4;
|
||||
|
||||
__be16 raw_port = 0;
|
||||
bpf_skb_load_bytes(skb, ip_hdr_len + 2, &raw_port, 2);
|
||||
dport = bpf_ntohs(raw_port);
|
||||
} else if (protocol == IPPROTO_UDP) {
|
||||
__u8 ihl_byte = 0;
|
||||
bpf_skb_load_bytes(skb, 0, &ihl_byte, 1);
|
||||
__u32 ip_hdr_len = (ihl_byte & 0x0F) * 4;
|
||||
|
||||
__be16 raw_port = 0;
|
||||
bpf_skb_load_bytes(skb, ip_hdr_len + 2, &raw_port, 2);
|
||||
dport = bpf_ntohs(raw_port);
|
||||
} else {
|
||||
/* Non-TCP/UDP traffic: pass through (e.g. ICMP for health checks) */
|
||||
bump_stat(1);
|
||||
return skb->len;
|
||||
}
|
||||
|
||||
/* Look up the destination port in the allow-list */
|
||||
__u8 *allowed = bpf_map_lookup_elem(&allowed_ports, &dport);
|
||||
if (allowed) {
|
||||
bump_stat(1);
|
||||
return skb->len; /* Pass: return original packet length */
|
||||
}
|
||||
|
||||
bump_stat(0);
|
||||
return 0; /* Drop: returning 0 truncates the packet */
|
||||
}
|
||||
|
||||
char _license[] SEC("license") = "GPL";
|
||||
156
vendor/ruvector/crates/rvf/rvf-ebpf/bpf/tc_query_route.c
vendored
Normal file
156
vendor/ruvector/crates/rvf/rvf-ebpf/bpf/tc_query_route.c
vendored
Normal file
@@ -0,0 +1,156 @@
|
||||
// SPDX-License-Identifier: GPL-2.0
|
||||
//
|
||||
// RVF TC Query Router: Priority-Based Query Classification
|
||||
//
|
||||
// This TC (Traffic Control) classifier inspects incoming UDP packets
|
||||
// destined for the RVF query port and classifies them into priority
|
||||
// tiers based on the query type encoded in the RVF protocol header.
|
||||
//
|
||||
// Classification tiers (set via skb->tc_classid):
|
||||
// TC_H_MAKE(1, 1) = "hot" queries (low-latency, cached vectors)
|
||||
// TC_H_MAKE(1, 2) = "warm" queries (standard priority)
|
||||
// TC_H_MAKE(1, 3) = "cold" queries (batch/bulk, best-effort)
|
||||
//
|
||||
// The query type is determined by inspecting the flags field in the
|
||||
// RVF query header that follows the UDP payload.
|
||||
//
|
||||
// Attach: tc filter add dev <iface> ingress bpf da obj tc_query_route.o
|
||||
|
||||
#include "vmlinux.h"
|
||||
|
||||
/* ── Configuration ───────────────────────────────────────────────── */
|
||||
|
||||
#define RVF_PORT 8080
|
||||
#define RVF_MAGIC 0x52564600 /* "RVF\0" big-endian */
|
||||
|
||||
/* TC classid helpers: major:minor */
|
||||
#define TC_H_MAKE(maj, min) (((maj) << 16) | (min))
|
||||
|
||||
/* Priority classes */
|
||||
#define CLASS_HOT TC_H_MAKE(1, 1)
|
||||
#define CLASS_WARM TC_H_MAKE(1, 2)
|
||||
#define CLASS_COLD TC_H_MAKE(1, 3)
|
||||
|
||||
/* RVF query flag bits (in the flags field of the extended header) */
|
||||
#define RVF_FLAG_HOT_CACHE 0x01 /* Request L0 (BPF map) cache lookup */
|
||||
#define RVF_FLAG_BATCH 0x02 /* Batch query mode */
|
||||
#define RVF_FLAG_PREFETCH 0x04 /* Prefetch hint for warming cache */
|
||||
#define RVF_FLAG_PRIORITY 0x08 /* Caller-requested high priority */
|
||||
|
||||
/* ── RVF query header (same as xdp_distance.c) ──────────────────── */
|
||||
|
||||
struct rvf_query_hdr {
|
||||
__u32 magic; /* RVF_MAGIC */
|
||||
__u16 dimension; /* vector dimension (network byte order) */
|
||||
__u16 k; /* top-k requested */
|
||||
__u64 query_id; /* caller-chosen query identifier */
|
||||
__u32 flags; /* query flags (network byte order) */
|
||||
} __attribute__((packed));
|
||||
|
||||
/* ── BPF maps ────────────────────────────────────────────────────── */
|
||||
|
||||
/* Per-CPU counters for each priority class */
|
||||
struct class_stats {
|
||||
__u64 hot;
|
||||
__u64 warm;
|
||||
__u64 cold;
|
||||
__u64 passthrough;
|
||||
};
|
||||
|
||||
struct {
|
||||
__uint(type, BPF_MAP_TYPE_PERCPU_ARRAY);
|
||||
__uint(max_entries, 1);
|
||||
__type(key, __u32);
|
||||
__type(value, struct class_stats);
|
||||
} tc_stats SEC(".maps");
|
||||
|
||||
/* ── Helpers ─────────────────────────────────────────────────────── */
|
||||
|
||||
static __always_inline void bump_class(int class_idx)
|
||||
{
|
||||
__u32 zero = 0;
|
||||
struct class_stats *s = bpf_map_lookup_elem(&tc_stats, &zero);
|
||||
if (!s)
|
||||
return;
|
||||
|
||||
switch (class_idx) {
|
||||
case 0: s->hot++; break;
|
||||
case 1: s->warm++; break;
|
||||
case 2: s->cold++; break;
|
||||
default: s->passthrough++; break;
|
||||
}
|
||||
}
|
||||
|
||||
/* ── TC classifier entry point ───────────────────────────────────── */
|
||||
|
||||
SEC("tc")
|
||||
int rvf_query_classify(struct __sk_buff *skb)
|
||||
{
|
||||
/* ── Parse IP protocol and header length ─────────────────────── */
|
||||
__u8 ihl_byte = 0;
|
||||
if (bpf_skb_load_bytes(skb, 0, &ihl_byte, 1) < 0)
|
||||
return TC_ACT_OK;
|
||||
|
||||
__u32 ip_hdr_len = (__u32)(ihl_byte & 0x0F) * 4;
|
||||
if (ip_hdr_len < 20)
|
||||
return TC_ACT_OK;
|
||||
|
||||
__u8 protocol = 0;
|
||||
if (bpf_skb_load_bytes(skb, 9, &protocol, 1) < 0)
|
||||
return TC_ACT_OK;
|
||||
|
||||
if (protocol != IPPROTO_UDP) {
|
||||
bump_class(3);
|
||||
return TC_ACT_OK;
|
||||
}
|
||||
|
||||
/* ── Parse UDP destination port ──────────────────────────────── */
|
||||
__be16 raw_dport = 0;
|
||||
if (bpf_skb_load_bytes(skb, ip_hdr_len + 2, &raw_dport, 2) < 0)
|
||||
return TC_ACT_OK;
|
||||
|
||||
__u16 dport = bpf_ntohs(raw_dport);
|
||||
if (dport != RVF_PORT) {
|
||||
bump_class(3);
|
||||
return TC_ACT_OK;
|
||||
}
|
||||
|
||||
/* ── Parse RVF query header (after 8-byte UDP header) ────────── */
|
||||
__u32 rvf_offset = ip_hdr_len + 8; /* IP hdr + UDP hdr */
|
||||
|
||||
struct rvf_query_hdr qhdr;
|
||||
__bpf_memset(&qhdr, 0, sizeof(qhdr));
|
||||
if (bpf_skb_load_bytes(skb, rvf_offset, &qhdr, sizeof(qhdr)) < 0) {
|
||||
bump_class(3);
|
||||
return TC_ACT_OK;
|
||||
}
|
||||
|
||||
if (qhdr.magic != bpf_htonl(RVF_MAGIC)) {
|
||||
bump_class(3);
|
||||
return TC_ACT_OK;
|
||||
}
|
||||
|
||||
/* ── Classify based on flags ─────────────────────────────────── */
|
||||
__u32 flags = bpf_ntohl(qhdr.flags);
|
||||
|
||||
if (flags & RVF_FLAG_PRIORITY || flags & RVF_FLAG_HOT_CACHE) {
|
||||
/* Hot path: low-latency cached query */
|
||||
skb->tc_classid = CLASS_HOT;
|
||||
bump_class(0);
|
||||
return TC_ACT_OK;
|
||||
}
|
||||
|
||||
if (flags & RVF_FLAG_BATCH) {
|
||||
/* Cold path: bulk/batch query, best-effort */
|
||||
skb->tc_classid = CLASS_COLD;
|
||||
bump_class(2);
|
||||
return TC_ACT_OK;
|
||||
}
|
||||
|
||||
/* Default: warm / standard priority */
|
||||
skb->tc_classid = CLASS_WARM;
|
||||
bump_class(1);
|
||||
return TC_ACT_OK;
|
||||
}
|
||||
|
||||
char _license[] SEC("license") = "GPL";
|
||||
243
vendor/ruvector/crates/rvf/rvf-ebpf/bpf/vmlinux.h
vendored
Normal file
243
vendor/ruvector/crates/rvf/rvf-ebpf/bpf/vmlinux.h
vendored
Normal file
@@ -0,0 +1,243 @@
|
||||
/* SPDX-License-Identifier: GPL-2.0 */
|
||||
/* Minimal BPF type stubs for RVF eBPF programs.
|
||||
*
|
||||
* This header provides the essential kernel type definitions so that
|
||||
* BPF C programs can compile without requiring the full kernel headers.
|
||||
* In production, replace this with the vmlinux.h generated by:
|
||||
* bpftool btf dump file /sys/kernel/btf/vmlinux format c
|
||||
*/
|
||||
|
||||
#ifndef __VMLINUX_H__
|
||||
#define __VMLINUX_H__
|
||||
|
||||
/* ── Scalar typedefs ─────────────────────────────────────────────── */
|
||||
|
||||
typedef unsigned char __u8;
|
||||
typedef unsigned short __u16;
|
||||
typedef unsigned int __u32;
|
||||
typedef unsigned long long __u64;
|
||||
typedef signed char __s8;
|
||||
typedef signed short __s16;
|
||||
typedef signed int __s32;
|
||||
typedef signed long long __s64;
|
||||
|
||||
typedef __u16 __be16;
|
||||
typedef __u32 __be32;
|
||||
typedef __u64 __be64;
|
||||
|
||||
typedef __u16 __sum16;
|
||||
|
||||
/* ── Ethernet ────────────────────────────────────────────────────── */
|
||||
|
||||
#define ETH_ALEN 6
|
||||
#define ETH_P_IP 0x0800
|
||||
#define ETH_P_IPV6 0x86DD
|
||||
|
||||
struct ethhdr {
|
||||
unsigned char h_dest[ETH_ALEN];
|
||||
unsigned char h_source[ETH_ALEN];
|
||||
__be16 h_proto;
|
||||
} __attribute__((packed));
|
||||
|
||||
/* ── IPv4 ────────────────────────────────────────────────────────── */
|
||||
|
||||
#define IPPROTO_TCP 6
|
||||
#define IPPROTO_UDP 17
|
||||
|
||||
struct iphdr {
|
||||
#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
|
||||
__u8 ihl:4,
|
||||
version:4;
|
||||
#else
|
||||
__u8 version:4,
|
||||
ihl:4;
|
||||
#endif
|
||||
__u8 tos;
|
||||
__be16 tot_len;
|
||||
__be16 id;
|
||||
__be16 frag_off;
|
||||
__u8 ttl;
|
||||
__u8 protocol;
|
||||
__sum16 check;
|
||||
__be32 saddr;
|
||||
__be32 daddr;
|
||||
} __attribute__((packed));
|
||||
|
||||
/* ── UDP ─────────────────────────────────────────────────────────── */
|
||||
|
||||
struct udphdr {
|
||||
__be16 source;
|
||||
__be16 dest;
|
||||
__be16 len;
|
||||
__sum16 check;
|
||||
} __attribute__((packed));
|
||||
|
||||
/* ── TCP ─────────────────────────────────────────────────────────── */
|
||||
|
||||
struct tcphdr {
|
||||
__be16 source;
|
||||
__be16 dest;
|
||||
__be32 seq;
|
||||
__be32 ack_seq;
|
||||
#if defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
|
||||
__u16 res1:4,
|
||||
doff:4,
|
||||
fin:1,
|
||||
syn:1,
|
||||
rst:1,
|
||||
psh:1,
|
||||
ack:1,
|
||||
urg:1,
|
||||
ece:1,
|
||||
cwr:1;
|
||||
#else
|
||||
__u16 doff:4,
|
||||
res1:4,
|
||||
cwr:1,
|
||||
ece:1,
|
||||
urg:1,
|
||||
ack:1,
|
||||
psh:1,
|
||||
rst:1,
|
||||
syn:1,
|
||||
fin:1;
|
||||
#endif
|
||||
__be16 window;
|
||||
__sum16 check;
|
||||
__be16 urg_ptr;
|
||||
} __attribute__((packed));
|
||||
|
||||
/* ── XDP context ─────────────────────────────────────────────────── */
|
||||
|
||||
struct xdp_md {
|
||||
__u32 data;
|
||||
__u32 data_end;
|
||||
__u32 data_meta;
|
||||
__u32 ingress_ifindex;
|
||||
__u32 rx_queue_index;
|
||||
__u32 egress_ifindex;
|
||||
};
|
||||
|
||||
/* XDP return codes */
|
||||
#define XDP_ABORTED 0
|
||||
#define XDP_DROP 1
|
||||
#define XDP_PASS 2
|
||||
#define XDP_TX 3
|
||||
#define XDP_REDIRECT 4
|
||||
|
||||
/* ── TC (Traffic Control) context ────────────────────────────────── */
|
||||
|
||||
struct __sk_buff {
|
||||
__u32 len;
|
||||
__u32 pkt_type;
|
||||
__u32 mark;
|
||||
__u32 queue_mapping;
|
||||
__u32 protocol;
|
||||
__u32 vlan_present;
|
||||
__u32 vlan_tci;
|
||||
__u32 vlan_proto;
|
||||
__u32 priority;
|
||||
__u32 ingress_ifindex;
|
||||
__u32 ifindex;
|
||||
__u32 tc_index;
|
||||
__u32 cb[5];
|
||||
__u32 hash;
|
||||
__u32 tc_classid;
|
||||
__u32 data;
|
||||
__u32 data_end;
|
||||
__u32 napi_id;
|
||||
__u32 family;
|
||||
__u32 remote_ip4;
|
||||
__u32 local_ip4;
|
||||
__u32 remote_ip6[4];
|
||||
__u32 local_ip6[4];
|
||||
__u32 remote_port;
|
||||
__u32 local_port;
|
||||
__u32 data_meta;
|
||||
};
|
||||
|
||||
/* TC action return codes */
|
||||
#define TC_ACT_UNSPEC (-1)
|
||||
#define TC_ACT_OK 0
|
||||
#define TC_ACT_RECLASSIFY 1
|
||||
#define TC_ACT_SHOT 2
|
||||
#define TC_ACT_PIPE 3
|
||||
#define TC_ACT_STOLEN 4
|
||||
#define TC_ACT_QUEUED 5
|
||||
#define TC_ACT_REPEAT 6
|
||||
#define TC_ACT_REDIRECT 7
|
||||
|
||||
/* ── BPF map type constants ──────────────────────────────────────── */
|
||||
|
||||
#define BPF_MAP_TYPE_HASH 1
|
||||
#define BPF_MAP_TYPE_ARRAY 2
|
||||
#define BPF_MAP_TYPE_PERCPU_ARRAY 6
|
||||
#define BPF_MAP_TYPE_LRU_HASH 9
|
||||
|
||||
/* ── BPF helper function declarations ────────────────────────────── */
|
||||
|
||||
/* SEC / __always_inline macros (if not using bpf/bpf_helpers.h) */
|
||||
#ifndef SEC
|
||||
#define SEC(name) \
|
||||
_Pragma("GCC diagnostic push") \
|
||||
_Pragma("GCC diagnostic ignored \"-Wignored-attributes\"") \
|
||||
__attribute__((section(name), used)) \
|
||||
_Pragma("GCC diagnostic pop")
|
||||
#endif
|
||||
|
||||
#ifndef __always_inline
|
||||
#define __always_inline inline __attribute__((always_inline))
|
||||
#endif
|
||||
|
||||
#ifndef __uint
|
||||
#define __uint(name, val) int (*name)[val]
|
||||
#endif
|
||||
|
||||
#ifndef __type
|
||||
#define __type(name, val) typeof(val) *name
|
||||
#endif
|
||||
|
||||
/* ── BPF helper IDs (from linux/bpf.h) ──────────────────────────── */
|
||||
|
||||
static void *(*bpf_map_lookup_elem)(void *map, const void *key) = (void *) 1;
|
||||
static long (*bpf_map_update_elem)(void *map, const void *key,
|
||||
const void *value, __u64 flags) = (void *) 2;
|
||||
static long (*bpf_map_delete_elem)(void *map, const void *key) = (void *) 3;
|
||||
static __u64 (*bpf_ktime_get_ns)(void) = (void *) 5;
|
||||
static long (*bpf_trace_printk)(const char *fmt, __u32 fmt_size, ...) = (void *) 6;
|
||||
static __u32 (*bpf_get_smp_processor_id)(void) = (void *) 8;
|
||||
static long (*bpf_skb_store_bytes)(struct __sk_buff *skb, __u32 offset,
|
||||
const void *from, __u32 len,
|
||||
__u64 flags) = (void *) 9;
|
||||
static long (*bpf_skb_load_bytes)(const struct __sk_buff *skb, __u32 offset,
|
||||
void *to, __u32 len) = (void *) 26;
|
||||
static __u32 (*bpf_get_prandom_u32)(void) = (void *) 7;
|
||||
|
||||
/* ── Endian helpers ──────────────────────────────────────────────── */
|
||||
|
||||
#ifndef bpf_htons
|
||||
#define bpf_htons(x) __builtin_bswap16(x)
|
||||
#endif
|
||||
|
||||
#ifndef bpf_ntohs
|
||||
#define bpf_ntohs(x) __builtin_bswap16(x)
|
||||
#endif
|
||||
|
||||
#ifndef bpf_htonl
|
||||
#define bpf_htonl(x) __builtin_bswap32(x)
|
||||
#endif
|
||||
|
||||
#ifndef bpf_ntohl
|
||||
#define bpf_ntohl(x) __builtin_bswap32(x)
|
||||
#endif
|
||||
|
||||
/* memcpy/memset for BPF -- must use builtins */
|
||||
#ifndef __bpf_memcpy
|
||||
#define __bpf_memcpy __builtin_memcpy
|
||||
#endif
|
||||
|
||||
#ifndef __bpf_memset
|
||||
#define __bpf_memset __builtin_memset
|
||||
#endif
|
||||
|
||||
#endif /* __VMLINUX_H__ */
|
||||
247
vendor/ruvector/crates/rvf/rvf-ebpf/bpf/xdp_distance.c
vendored
Normal file
247
vendor/ruvector/crates/rvf/rvf-ebpf/bpf/xdp_distance.c
vendored
Normal file
@@ -0,0 +1,247 @@
|
||||
// SPDX-License-Identifier: GPL-2.0
|
||||
//
|
||||
// RVF XDP Vector Distance Computation
|
||||
//
|
||||
// Computes squared L2 distance between a query vector received in a
|
||||
// UDP packet and stored vectors cached in a BPF LRU hash map. Results
|
||||
// are written to a per-CPU array map for lock-free retrieval by
|
||||
// userspace via bpf_map_lookup_elem.
|
||||
//
|
||||
// Wire format of an RVF query packet:
|
||||
// Ethernet | IPv4 | UDP (dst port RVF_PORT) | rvf_query_hdr | f32[dim]
|
||||
//
|
||||
// The program only handles packets destined for RVF_PORT and bearing
|
||||
// the correct magic number. All other traffic is passed through
|
||||
// unchanged (XDP_PASS).
|
||||
|
||||
#include "vmlinux.h"
|
||||
|
||||
#define MAX_DIM 512
|
||||
#define MAX_K 64
|
||||
#define RVF_PORT 8080
|
||||
#define RVF_MAGIC 0x52564600 /* "RVF\0" in big-endian */
|
||||
|
||||
/* ── RVF query packet header (follows UDP) ───────────────────────── */
|
||||
|
||||
struct rvf_query_hdr {
|
||||
__u32 magic; /* RVF_MAGIC */
|
||||
__u16 dimension; /* vector dimension (network byte order) */
|
||||
__u16 k; /* top-k neighbours requested */
|
||||
__u64 query_id; /* caller-chosen query identifier */
|
||||
} __attribute__((packed));
|
||||
|
||||
/* ── Per-query result structure ──────────────────────────────────── */
|
||||
|
||||
struct query_result {
|
||||
__u64 query_id;
|
||||
__u32 count;
|
||||
__u64 ids[MAX_K];
|
||||
__u32 distances[MAX_K]; /* squared L2, fixed-point */
|
||||
};
|
||||
|
||||
/* ── BPF maps ────────────────────────────────────────────────────── */
|
||||
|
||||
/* LRU hash map: caches hot vectors (vector_id -> f32[MAX_DIM]) */
|
||||
struct {
|
||||
__uint(type, BPF_MAP_TYPE_LRU_HASH);
|
||||
__uint(max_entries, 4096);
|
||||
__type(key, __u64);
|
||||
__type(value, __u8[MAX_DIM * 4]);
|
||||
} vector_cache SEC(".maps");
|
||||
|
||||
/* Per-CPU array: one result slot per CPU for lock-free writes */
|
||||
struct {
|
||||
__uint(type, BPF_MAP_TYPE_PERCPU_ARRAY);
|
||||
__uint(max_entries, 1);
|
||||
__type(key, __u32);
|
||||
__type(value, struct query_result);
|
||||
} results SEC(".maps");
|
||||
|
||||
/* Array map: list of cached vector IDs for iteration */
|
||||
struct {
|
||||
__uint(type, BPF_MAP_TYPE_ARRAY);
|
||||
__uint(max_entries, 4096);
|
||||
__type(key, __u32);
|
||||
__type(value, __u64);
|
||||
} vector_ids SEC(".maps");
|
||||
|
||||
/* Array map: single entry holding the count of populated IDs */
|
||||
struct {
|
||||
__uint(type, BPF_MAP_TYPE_ARRAY);
|
||||
__uint(max_entries, 1);
|
||||
__type(key, __u32);
|
||||
__type(value, __u32);
|
||||
} id_count SEC(".maps");
|
||||
|
||||
/* ── Helpers ─────────────────────────────────────────────────────── */
|
||||
|
||||
/*
|
||||
* Compute squared L2 distance between two vectors stored as raw bytes.
|
||||
*
|
||||
* Both `a` and `b` point to dim * 4 bytes of IEEE-754 f32 data.
|
||||
* We reinterpret each 4-byte group as a __u32 and use integer
|
||||
* subtraction as a rough fixed-point proxy -- this is an approximation
|
||||
* suitable for ranking, not exact float arithmetic, because the BPF
|
||||
* verifier does not support floating-point instructions.
|
||||
*/
|
||||
static __always_inline __u64 l2_distance_sq(
|
||||
const __u8 *a, const __u8 *b, __u16 dim)
|
||||
{
|
||||
__u64 sum = 0;
|
||||
__u16 i;
|
||||
|
||||
/* Bounded loop: the verifier requires a compile-time upper bound. */
|
||||
#pragma unroll
|
||||
for (i = 0; i < MAX_DIM; i++) {
|
||||
if (i >= dim)
|
||||
break;
|
||||
|
||||
__u32 va, vb;
|
||||
__builtin_memcpy(&va, a + (__u32)i * 4, 4);
|
||||
__builtin_memcpy(&vb, b + (__u32)i * 4, 4);
|
||||
|
||||
__s32 diff = (__s32)va - (__s32)vb;
|
||||
sum += (__u64)((__s64)diff * (__s64)diff);
|
||||
}
|
||||
return sum;
|
||||
}
|
||||
|
||||
/*
|
||||
* Insert a (distance, id) pair into a max-heap of size k stored in the
|
||||
* result arrays. We keep the worst (largest) distance at index 0 so
|
||||
* eviction is O(1). This is a simplified sift-down for bounded k.
|
||||
*/
|
||||
static __always_inline void heap_insert(
|
||||
struct query_result *res, __u32 k, __u64 vid, __u32 dist)
|
||||
{
|
||||
if (res->count < k) {
|
||||
__u32 idx = res->count;
|
||||
if (idx < MAX_K) {
|
||||
res->ids[idx] = vid;
|
||||
res->distances[idx] = dist;
|
||||
res->count++;
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
/* Find the current worst (max) distance in the heap */
|
||||
__u32 worst_idx = 0;
|
||||
__u32 worst_dist = 0;
|
||||
__u32 i;
|
||||
|
||||
#pragma unroll
|
||||
for (i = 0; i < MAX_K; i++) {
|
||||
if (i >= res->count)
|
||||
break;
|
||||
if (res->distances[i] > worst_dist) {
|
||||
worst_dist = res->distances[i];
|
||||
worst_idx = i;
|
||||
}
|
||||
}
|
||||
|
||||
/* Evict the worst if the new distance is better */
|
||||
if (dist < worst_dist && worst_idx < MAX_K) {
|
||||
res->ids[worst_idx] = vid;
|
||||
res->distances[worst_idx] = dist;
|
||||
}
|
||||
}
|
||||
|
||||
/* ── XDP entry point ─────────────────────────────────────────────── */
|
||||
|
||||
SEC("xdp")
|
||||
int xdp_vector_distance(struct xdp_md *ctx)
|
||||
{
|
||||
void *data = (void *)(__u64)ctx->data;
|
||||
void *data_end = (void *)(__u64)ctx->data_end;
|
||||
|
||||
/* ── Parse Ethernet ──────────────────────────────────────────── */
|
||||
struct ethhdr *eth = data;
|
||||
if ((void *)(eth + 1) > data_end)
|
||||
return XDP_PASS;
|
||||
|
||||
if (eth->h_proto != bpf_htons(ETH_P_IP))
|
||||
return XDP_PASS;
|
||||
|
||||
/* ── Parse IPv4 ──────────────────────────────────────────────── */
|
||||
struct iphdr *iph = (void *)(eth + 1);
|
||||
if ((void *)(iph + 1) > data_end)
|
||||
return XDP_PASS;
|
||||
|
||||
if (iph->protocol != IPPROTO_UDP)
|
||||
return XDP_PASS;
|
||||
|
||||
/* ── Parse UDP ───────────────────────────────────────────────── */
|
||||
struct udphdr *udph = (void *)iph + (iph->ihl * 4);
|
||||
if ((void *)(udph + 1) > data_end)
|
||||
return XDP_PASS;
|
||||
|
||||
if (bpf_ntohs(udph->dest) != RVF_PORT)
|
||||
return XDP_PASS;
|
||||
|
||||
/* ── Parse RVF query header ──────────────────────────────────── */
|
||||
struct rvf_query_hdr *qhdr = (void *)(udph + 1);
|
||||
if ((void *)(qhdr + 1) > data_end)
|
||||
return XDP_PASS;
|
||||
|
||||
if (qhdr->magic != bpf_htonl(RVF_MAGIC))
|
||||
return XDP_PASS;
|
||||
|
||||
__u16 dim = bpf_ntohs(qhdr->dimension);
|
||||
__u16 k = bpf_ntohs(qhdr->k);
|
||||
|
||||
if (dim == 0 || dim > MAX_DIM)
|
||||
return XDP_PASS;
|
||||
if (k == 0 || k > MAX_K)
|
||||
return XDP_PASS;
|
||||
|
||||
/* ── Bounds-check the query vector payload ───────────────────── */
|
||||
__u8 *query_vec = (__u8 *)(qhdr + 1);
|
||||
if ((void *)(query_vec + (__u32)dim * 4) > data_end)
|
||||
return XDP_PASS;
|
||||
|
||||
/* ── Get the result slot for this CPU ────────────────────────── */
|
||||
__u32 zero = 0;
|
||||
struct query_result *result = bpf_map_lookup_elem(&results, &zero);
|
||||
if (!result)
|
||||
return XDP_PASS;
|
||||
|
||||
result->query_id = qhdr->query_id;
|
||||
result->count = 0;
|
||||
|
||||
/* ── Get the number of cached vectors ────────────────────────── */
|
||||
__u32 *cnt_ptr = bpf_map_lookup_elem(&id_count, &zero);
|
||||
__u32 vec_count = cnt_ptr ? *cnt_ptr : 0;
|
||||
if (vec_count > 4096)
|
||||
vec_count = 4096;
|
||||
|
||||
/* ── Scan cached vectors, maintaining a top-k heap ───────────── */
|
||||
__u32 idx;
|
||||
#pragma unroll
|
||||
for (idx = 0; idx < 256; idx++) {
|
||||
if (idx >= vec_count)
|
||||
break;
|
||||
|
||||
__u64 *vid_ptr = bpf_map_lookup_elem(&vector_ids, &idx);
|
||||
if (!vid_ptr)
|
||||
continue;
|
||||
|
||||
__u64 vid = *vid_ptr;
|
||||
__u8 *stored = bpf_map_lookup_elem(&vector_cache, &vid);
|
||||
if (!stored)
|
||||
continue;
|
||||
|
||||
__u64 dist_sq = l2_distance_sq(query_vec, stored, dim);
|
||||
/* Truncate to u32 for storage (upper bits are rarely needed
|
||||
* for ranking among cached vectors). */
|
||||
__u32 dist32 = (dist_sq > 0xFFFFFFFF) ? 0xFFFFFFFF : (__u32)dist_sq;
|
||||
|
||||
heap_insert(result, k, vid, dist32);
|
||||
}
|
||||
|
||||
/* Let the packet continue to userspace for full-index search.
|
||||
* The XDP path only accelerates the L0 cache lookup; userspace
|
||||
* merges the BPF result with the full RVF index result. */
|
||||
return XDP_PASS;
|
||||
}
|
||||
|
||||
char _license[] SEC("license") = "GPL";
|
||||
Reference in New Issue
Block a user