Files
wifi-densepose/vendor/ruvector/crates/micro-hnsw-wasm/verilog/micro_hnsw.v

556 lines
18 KiB
Verilog

// Micro HNSW - ASIC Hardware Description
// Ultra-minimal HNSW accelerator for vector similarity search
//
// Design specifications:
// - Fixed-point arithmetic (Q8.8 format)
// - 256 max vectors, 64 dimensions
// - 8 neighbors per node, 4 levels
// - Pipelined distance computation
// - AXI-Lite interface for host communication
//
// Target: ASIC synthesis with <50K gates
`timescale 1ns / 1ps
module micro_hnsw #(
parameter MAX_VECTORS = 256,
parameter MAX_DIMS = 64,
parameter MAX_NEIGHBORS = 8,
parameter MAX_LEVELS = 4,
parameter DATA_WIDTH = 16, // Q8.8 fixed-point
parameter ADDR_WIDTH = 8 // log2(MAX_VECTORS)
)(
input wire clk,
input wire rst_n,
// Control interface
input wire cmd_valid,
output reg cmd_ready,
input wire [2:0] cmd_op, // 0=NOP, 1=INIT, 2=INSERT, 3=SEARCH
input wire [7:0] cmd_dims,
input wire [7:0] cmd_k,
// Vector data interface
input wire vec_valid,
output wire vec_ready,
input wire [DATA_WIDTH-1:0] vec_data,
input wire vec_last,
// Result interface
output reg result_valid,
input wire result_ready,
output reg [ADDR_WIDTH-1:0] result_idx,
output reg [DATA_WIDTH-1:0] result_dist,
output reg result_last,
// Status
output reg [ADDR_WIDTH-1:0] vector_count
);
// ============ Local Parameters ============
localparam STATE_IDLE = 3'd0;
localparam STATE_LOAD_VEC = 3'd1;
localparam STATE_COMPUTE = 3'd2;
localparam STATE_SEARCH = 3'd3;
localparam STATE_OUTPUT = 3'd4;
// ============ Memories ============
// Vector storage (256 x 64 x 16-bit = 256KB)
reg [DATA_WIDTH-1:0] vectors [0:MAX_VECTORS-1][0:MAX_DIMS-1];
// Graph structure - neighbor lists
reg [ADDR_WIDTH-1:0] neighbors [0:MAX_VECTORS-1][0:MAX_LEVELS-1][0:MAX_NEIGHBORS-1];
reg [3:0] neighbor_count [0:MAX_VECTORS-1][0:MAX_LEVELS-1];
reg [1:0] node_level [0:MAX_VECTORS-1];
// ============ Registers ============
reg [2:0] state;
reg [ADDR_WIDTH-1:0] entry_point;
reg [1:0] max_level;
reg [7:0] current_dims;
// Vector loading
reg [DATA_WIDTH-1:0] query_buf [0:MAX_DIMS-1];
reg [DATA_WIDTH-1:0] insert_buf [0:MAX_DIMS-1];
reg [5:0] load_idx;
// Search state
reg [ADDR_WIDTH-1:0] current_node;
reg [1:0] current_level;
reg [7:0] current_k;
reg [3:0] neighbor_idx;
// Candidate buffer (sorted by distance)
reg [ADDR_WIDTH-1:0] candidates [0:15];
reg [DATA_WIDTH-1:0] cand_dist [0:15];
reg [3:0] cand_count;
// Distance computation
reg [31:0] dist_accum;
reg [5:0] dist_dim;
reg dist_computing;
reg [ADDR_WIDTH-1:0] dist_target;
// Visited flags (bit vector)
reg [MAX_VECTORS-1:0] visited;
// ============ Vector Ready ============
assign vec_ready = (state == STATE_LOAD_VEC);
// ============ State Machine ============
always @(posedge clk or negedge rst_n) begin
if (!rst_n) begin
state <= STATE_IDLE;
cmd_ready <= 1'b1;
result_valid <= 1'b0;
vector_count <= 0;
entry_point <= 0;
max_level <= 0;
current_dims <= 32;
end else begin
case (state)
STATE_IDLE: begin
result_valid <= 1'b0;
if (cmd_valid && cmd_ready) begin
cmd_ready <= 1'b0;
case (cmd_op)
3'd1: begin // INIT
current_dims <= cmd_dims;
vector_count <= 0;
entry_point <= 0;
max_level <= 0;
cmd_ready <= 1'b1;
end
3'd2: begin // INSERT
load_idx <= 0;
state <= STATE_LOAD_VEC;
end
3'd3: begin // SEARCH
load_idx <= 0;
current_k <= cmd_k;
state <= STATE_LOAD_VEC;
end
default: cmd_ready <= 1'b1;
endcase
end
end
STATE_LOAD_VEC: begin
if (vec_valid) begin
if (cmd_op == 3'd2) begin
insert_buf[load_idx] <= vec_data;
end else begin
query_buf[load_idx] <= vec_data;
end
if (vec_last || load_idx == current_dims - 1) begin
if (cmd_op == 3'd2) begin
state <= STATE_COMPUTE; // Insert processing
end else begin
state <= STATE_SEARCH; // Search processing
end
end else begin
load_idx <= load_idx + 1;
end
end
end
STATE_COMPUTE: begin
// Store vector
integer i;
for (i = 0; i < MAX_DIMS; i = i + 1) begin
vectors[vector_count][i] <= insert_buf[i];
end
// Generate random level (simplified)
node_level[vector_count] <= vector_count[1:0] & 2'b11;
// Initialize neighbors
for (i = 0; i < MAX_LEVELS; i = i + 1) begin
neighbor_count[vector_count][i] <= 0;
end
// Update entry point for first vector
if (vector_count == 0) begin
entry_point <= 0;
max_level <= 0;
end else begin
// Simple nearest neighbor connection (level 0 only for minimal design)
if (neighbor_count[vector_count][0] < MAX_NEIGHBORS) begin
// Connect to entry point
neighbors[vector_count][0][0] <= entry_point;
neighbor_count[vector_count][0] <= 1;
// Bidirectional connection
if (neighbor_count[entry_point][0] < MAX_NEIGHBORS) begin
neighbors[entry_point][0][neighbor_count[entry_point][0]] <= vector_count;
neighbor_count[entry_point][0] <= neighbor_count[entry_point][0] + 1;
end
end
end
vector_count <= vector_count + 1;
cmd_ready <= 1'b1;
state <= STATE_IDLE;
end
STATE_SEARCH: begin
// Initialize search
visited <= 0;
cand_count <= 0;
current_node <= entry_point;
current_level <= max_level;
// Start distance computation for entry point
dist_target <= entry_point;
dist_accum <= 0;
dist_dim <= 0;
dist_computing <= 1'b1;
// Simple greedy search (one level)
if (!dist_computing && cand_count < current_k) begin
// Add current to candidates
candidates[cand_count] <= current_node;
cand_dist[cand_count] <= dist_accum[DATA_WIDTH-1:0];
cand_count <= cand_count + 1;
visited[current_node] <= 1'b1;
// Check neighbors
if (neighbor_idx < neighbor_count[current_node][0]) begin
current_node <= neighbors[current_node][0][neighbor_idx];
neighbor_idx <= neighbor_idx + 1;
dist_target <= neighbors[current_node][0][neighbor_idx];
dist_accum <= 0;
dist_dim <= 0;
dist_computing <= 1'b1;
end else begin
state <= STATE_OUTPUT;
end
end
end
STATE_OUTPUT: begin
if (result_ready || !result_valid) begin
if (cand_count > 0) begin
result_valid <= 1'b1;
result_idx <= candidates[0];
result_dist <= cand_dist[0];
result_last <= (cand_count == 1);
// Shift candidates
integer j;
for (j = 0; j < 15; j = j + 1) begin
candidates[j] <= candidates[j+1];
cand_dist[j] <= cand_dist[j+1];
end
cand_count <= cand_count - 1;
end else begin
result_valid <= 1'b0;
cmd_ready <= 1'b1;
state <= STATE_IDLE;
end
end
end
endcase
end
end
// ============ Distance Computation Pipeline ============
always @(posedge clk or negedge rst_n) begin
if (!rst_n) begin
dist_computing <= 1'b0;
dist_accum <= 0;
end else if (dist_computing) begin
if (dist_dim < current_dims) begin
// Compute (query - vector)^2 in fixed-point
reg signed [DATA_WIDTH:0] diff;
reg [31:0] sq;
diff = $signed(query_buf[dist_dim]) - $signed(vectors[dist_target][dist_dim]);
sq = diff * diff;
dist_accum <= dist_accum + sq;
dist_dim <= dist_dim + 1;
end else begin
dist_computing <= 1'b0;
end
end
end
endmodule
// ============ Distance Unit - Pipelined L2 ============
module distance_unit #(
parameter DATA_WIDTH = 16,
parameter MAX_DIMS = 64
)(
input wire clk,
input wire rst_n,
input wire start,
input wire [5:0] dims,
input wire [DATA_WIDTH-1:0] a_data,
input wire [DATA_WIDTH-1:0] b_data,
output reg [31:0] distance,
output reg done
);
reg [5:0] dim_idx;
reg [31:0] accum;
reg computing;
always @(posedge clk or negedge rst_n) begin
if (!rst_n) begin
done <= 1'b0;
computing <= 1'b0;
accum <= 0;
end else begin
if (start && !computing) begin
computing <= 1'b1;
dim_idx <= 0;
accum <= 0;
done <= 1'b0;
end else if (computing) begin
if (dim_idx < dims) begin
// Compute squared difference
reg signed [DATA_WIDTH:0] diff;
diff = $signed(a_data) - $signed(b_data);
accum <= accum + (diff * diff);
dim_idx <= dim_idx + 1;
end else begin
distance <= accum;
done <= 1'b1;
computing <= 1'b0;
end
end else begin
done <= 1'b0;
end
end
end
endmodule
// ============ Priority Queue for Candidates ============
module priority_queue #(
parameter DEPTH = 16,
parameter IDX_WIDTH = 8,
parameter DIST_WIDTH = 16
)(
input wire clk,
input wire rst_n,
input wire clear,
// Insert interface
input wire insert_valid,
output wire insert_ready,
input wire [IDX_WIDTH-1:0] insert_idx,
input wire [DIST_WIDTH-1:0] insert_dist,
// Pop interface (returns min distance)
input wire pop_valid,
output reg pop_ready,
output reg [IDX_WIDTH-1:0] pop_idx,
output reg [DIST_WIDTH-1:0] pop_dist,
// Status
output reg [4:0] count,
output wire empty,
output wire full
);
reg [IDX_WIDTH-1:0] indices [0:DEPTH-1];
reg [DIST_WIDTH-1:0] distances [0:DEPTH-1];
assign empty = (count == 0);
assign full = (count == DEPTH);
assign insert_ready = !full;
integer i;
always @(posedge clk or negedge rst_n) begin
if (!rst_n || clear) begin
count <= 0;
pop_ready <= 1'b0;
end else begin
// Insert operation (sorted insert)
if (insert_valid && !full) begin
// Find insertion position
reg [4:0] pos;
pos = count;
for (i = count - 1; i >= 0; i = i - 1) begin
if (insert_dist < distances[i]) begin
indices[i+1] <= indices[i];
distances[i+1] <= distances[i];
pos = i;
end
end
indices[pos] <= insert_idx;
distances[pos] <= insert_dist;
count <= count + 1;
end
// Pop operation
if (pop_valid && !empty) begin
pop_idx <= indices[0];
pop_dist <= distances[0];
pop_ready <= 1'b1;
// Shift elements
for (i = 0; i < DEPTH - 1; i = i + 1) begin
indices[i] <= indices[i+1];
distances[i] <= distances[i+1];
end
count <= count - 1;
end else begin
pop_ready <= 1'b0;
end
end
end
endmodule
// ============ AXI-Lite Wrapper ============
module micro_hnsw_axi #(
parameter C_S_AXI_DATA_WIDTH = 32,
parameter C_S_AXI_ADDR_WIDTH = 8
)(
// AXI-Lite interface
input wire S_AXI_ACLK,
input wire S_AXI_ARESETN,
// Write address channel
input wire [C_S_AXI_ADDR_WIDTH-1:0] S_AXI_AWADDR,
input wire S_AXI_AWVALID,
output wire S_AXI_AWREADY,
// Write data channel
input wire [C_S_AXI_DATA_WIDTH-1:0] S_AXI_WDATA,
input wire [(C_S_AXI_DATA_WIDTH/8)-1:0] S_AXI_WSTRB,
input wire S_AXI_WVALID,
output wire S_AXI_WREADY,
// Write response channel
output wire [1:0] S_AXI_BRESP,
output wire S_AXI_BVALID,
input wire S_AXI_BREADY,
// Read address channel
input wire [C_S_AXI_ADDR_WIDTH-1:0] S_AXI_ARADDR,
input wire S_AXI_ARVALID,
output wire S_AXI_ARREADY,
// Read data channel
output wire [C_S_AXI_DATA_WIDTH-1:0] S_AXI_RDATA,
output wire [1:0] S_AXI_RRESP,
output wire S_AXI_RVALID,
input wire S_AXI_RREADY
);
// Register map:
// 0x00: Control (W) - [2:0] cmd_op, [15:8] dims, [23:16] k
// 0x04: Status (R) - [0] ready, [15:8] vector_count
// 0x08: Vector Data (W) - write vector data
// 0x0C: Result (R) - [7:0] idx, [23:8] distance, [31] last
// Internal signals
wire cmd_valid, cmd_ready;
reg [2:0] cmd_op;
reg [7:0] cmd_dims, cmd_k;
wire vec_valid, vec_ready;
reg [15:0] vec_data;
reg vec_last;
wire result_valid, result_ready;
wire [7:0] result_idx;
wire [15:0] result_dist;
wire result_last;
wire [7:0] vector_count;
// Instantiate core
micro_hnsw core (
.clk(S_AXI_ACLK),
.rst_n(S_AXI_ARESETN),
.cmd_valid(cmd_valid),
.cmd_ready(cmd_ready),
.cmd_op(cmd_op),
.cmd_dims(cmd_dims),
.cmd_k(cmd_k),
.vec_valid(vec_valid),
.vec_ready(vec_ready),
.vec_data(vec_data),
.vec_last(vec_last),
.result_valid(result_valid),
.result_ready(result_ready),
.result_idx(result_idx),
.result_dist(result_dist),
.result_last(result_last),
.vector_count(vector_count)
);
// AXI-Lite state machine (simplified)
reg aw_ready, w_ready, ar_ready;
reg [1:0] b_resp;
reg b_valid, r_valid;
reg [C_S_AXI_DATA_WIDTH-1:0] r_data;
assign S_AXI_AWREADY = aw_ready;
assign S_AXI_WREADY = w_ready;
assign S_AXI_BRESP = b_resp;
assign S_AXI_BVALID = b_valid;
assign S_AXI_ARREADY = ar_ready;
assign S_AXI_RDATA = r_data;
assign S_AXI_RRESP = 2'b00;
assign S_AXI_RVALID = r_valid;
assign cmd_valid = S_AXI_WVALID && (S_AXI_AWADDR == 8'h00);
assign vec_valid = S_AXI_WVALID && (S_AXI_AWADDR == 8'h08);
assign result_ready = S_AXI_RREADY && (S_AXI_ARADDR == 8'h0C);
always @(posedge S_AXI_ACLK or negedge S_AXI_ARESETN) begin
if (!S_AXI_ARESETN) begin
aw_ready <= 1'b1;
w_ready <= 1'b1;
ar_ready <= 1'b1;
b_valid <= 1'b0;
r_valid <= 1'b0;
end else begin
// Write handling
if (S_AXI_AWVALID && S_AXI_WVALID && aw_ready && w_ready) begin
case (S_AXI_AWADDR)
8'h00: begin
cmd_op <= S_AXI_WDATA[2:0];
cmd_dims <= S_AXI_WDATA[15:8];
cmd_k <= S_AXI_WDATA[23:16];
end
8'h08: begin
vec_data <= S_AXI_WDATA[15:0];
vec_last <= S_AXI_WDATA[31];
end
endcase
b_valid <= 1'b1;
end
if (S_AXI_BREADY && b_valid) begin
b_valid <= 1'b0;
end
// Read handling
if (S_AXI_ARVALID && ar_ready) begin
case (S_AXI_ARADDR)
8'h04: r_data <= {16'b0, vector_count, 7'b0, cmd_ready};
8'h0C: r_data <= {result_last, 7'b0, result_dist, result_idx};
default: r_data <= 32'b0;
endcase
r_valid <= 1'b1;
end
if (S_AXI_RREADY && r_valid) begin
r_valid <= 1'b0;
end
end
end
endmodule