// Micro HNSW - ASIC Hardware Description // Ultra-minimal HNSW accelerator for vector similarity search // // Design specifications: // - Fixed-point arithmetic (Q8.8 format) // - 256 max vectors, 64 dimensions // - 8 neighbors per node, 4 levels // - Pipelined distance computation // - AXI-Lite interface for host communication // // Target: ASIC synthesis with <50K gates `timescale 1ns / 1ps module micro_hnsw #( parameter MAX_VECTORS = 256, parameter MAX_DIMS = 64, parameter MAX_NEIGHBORS = 8, parameter MAX_LEVELS = 4, parameter DATA_WIDTH = 16, // Q8.8 fixed-point parameter ADDR_WIDTH = 8 // log2(MAX_VECTORS) )( input wire clk, input wire rst_n, // Control interface input wire cmd_valid, output reg cmd_ready, input wire [2:0] cmd_op, // 0=NOP, 1=INIT, 2=INSERT, 3=SEARCH input wire [7:0] cmd_dims, input wire [7:0] cmd_k, // Vector data interface input wire vec_valid, output wire vec_ready, input wire [DATA_WIDTH-1:0] vec_data, input wire vec_last, // Result interface output reg result_valid, input wire result_ready, output reg [ADDR_WIDTH-1:0] result_idx, output reg [DATA_WIDTH-1:0] result_dist, output reg result_last, // Status output reg [ADDR_WIDTH-1:0] vector_count ); // ============ Local Parameters ============ localparam STATE_IDLE = 3'd0; localparam STATE_LOAD_VEC = 3'd1; localparam STATE_COMPUTE = 3'd2; localparam STATE_SEARCH = 3'd3; localparam STATE_OUTPUT = 3'd4; // ============ Memories ============ // Vector storage (256 x 64 x 16-bit = 256KB) reg [DATA_WIDTH-1:0] vectors [0:MAX_VECTORS-1][0:MAX_DIMS-1]; // Graph structure - neighbor lists reg [ADDR_WIDTH-1:0] neighbors [0:MAX_VECTORS-1][0:MAX_LEVELS-1][0:MAX_NEIGHBORS-1]; reg [3:0] neighbor_count [0:MAX_VECTORS-1][0:MAX_LEVELS-1]; reg [1:0] node_level [0:MAX_VECTORS-1]; // ============ Registers ============ reg [2:0] state; reg [ADDR_WIDTH-1:0] entry_point; reg [1:0] max_level; reg [7:0] current_dims; // Vector loading reg [DATA_WIDTH-1:0] query_buf [0:MAX_DIMS-1]; reg [DATA_WIDTH-1:0] insert_buf [0:MAX_DIMS-1]; reg [5:0] load_idx; // Search state reg [ADDR_WIDTH-1:0] current_node; reg [1:0] current_level; reg [7:0] current_k; reg [3:0] neighbor_idx; // Candidate buffer (sorted by distance) reg [ADDR_WIDTH-1:0] candidates [0:15]; reg [DATA_WIDTH-1:0] cand_dist [0:15]; reg [3:0] cand_count; // Distance computation reg [31:0] dist_accum; reg [5:0] dist_dim; reg dist_computing; reg [ADDR_WIDTH-1:0] dist_target; // Visited flags (bit vector) reg [MAX_VECTORS-1:0] visited; // ============ Vector Ready ============ assign vec_ready = (state == STATE_LOAD_VEC); // ============ State Machine ============ always @(posedge clk or negedge rst_n) begin if (!rst_n) begin state <= STATE_IDLE; cmd_ready <= 1'b1; result_valid <= 1'b0; vector_count <= 0; entry_point <= 0; max_level <= 0; current_dims <= 32; end else begin case (state) STATE_IDLE: begin result_valid <= 1'b0; if (cmd_valid && cmd_ready) begin cmd_ready <= 1'b0; case (cmd_op) 3'd1: begin // INIT current_dims <= cmd_dims; vector_count <= 0; entry_point <= 0; max_level <= 0; cmd_ready <= 1'b1; end 3'd2: begin // INSERT load_idx <= 0; state <= STATE_LOAD_VEC; end 3'd3: begin // SEARCH load_idx <= 0; current_k <= cmd_k; state <= STATE_LOAD_VEC; end default: cmd_ready <= 1'b1; endcase end end STATE_LOAD_VEC: begin if (vec_valid) begin if (cmd_op == 3'd2) begin insert_buf[load_idx] <= vec_data; end else begin query_buf[load_idx] <= vec_data; end if (vec_last || load_idx == current_dims - 1) begin if (cmd_op == 3'd2) begin state <= STATE_COMPUTE; // Insert processing end else begin state <= STATE_SEARCH; // Search processing end end else begin load_idx <= load_idx + 1; end end end STATE_COMPUTE: begin // Store vector integer i; for (i = 0; i < MAX_DIMS; i = i + 1) begin vectors[vector_count][i] <= insert_buf[i]; end // Generate random level (simplified) node_level[vector_count] <= vector_count[1:0] & 2'b11; // Initialize neighbors for (i = 0; i < MAX_LEVELS; i = i + 1) begin neighbor_count[vector_count][i] <= 0; end // Update entry point for first vector if (vector_count == 0) begin entry_point <= 0; max_level <= 0; end else begin // Simple nearest neighbor connection (level 0 only for minimal design) if (neighbor_count[vector_count][0] < MAX_NEIGHBORS) begin // Connect to entry point neighbors[vector_count][0][0] <= entry_point; neighbor_count[vector_count][0] <= 1; // Bidirectional connection if (neighbor_count[entry_point][0] < MAX_NEIGHBORS) begin neighbors[entry_point][0][neighbor_count[entry_point][0]] <= vector_count; neighbor_count[entry_point][0] <= neighbor_count[entry_point][0] + 1; end end end vector_count <= vector_count + 1; cmd_ready <= 1'b1; state <= STATE_IDLE; end STATE_SEARCH: begin // Initialize search visited <= 0; cand_count <= 0; current_node <= entry_point; current_level <= max_level; // Start distance computation for entry point dist_target <= entry_point; dist_accum <= 0; dist_dim <= 0; dist_computing <= 1'b1; // Simple greedy search (one level) if (!dist_computing && cand_count < current_k) begin // Add current to candidates candidates[cand_count] <= current_node; cand_dist[cand_count] <= dist_accum[DATA_WIDTH-1:0]; cand_count <= cand_count + 1; visited[current_node] <= 1'b1; // Check neighbors if (neighbor_idx < neighbor_count[current_node][0]) begin current_node <= neighbors[current_node][0][neighbor_idx]; neighbor_idx <= neighbor_idx + 1; dist_target <= neighbors[current_node][0][neighbor_idx]; dist_accum <= 0; dist_dim <= 0; dist_computing <= 1'b1; end else begin state <= STATE_OUTPUT; end end end STATE_OUTPUT: begin if (result_ready || !result_valid) begin if (cand_count > 0) begin result_valid <= 1'b1; result_idx <= candidates[0]; result_dist <= cand_dist[0]; result_last <= (cand_count == 1); // Shift candidates integer j; for (j = 0; j < 15; j = j + 1) begin candidates[j] <= candidates[j+1]; cand_dist[j] <= cand_dist[j+1]; end cand_count <= cand_count - 1; end else begin result_valid <= 1'b0; cmd_ready <= 1'b1; state <= STATE_IDLE; end end end endcase end end // ============ Distance Computation Pipeline ============ always @(posedge clk or negedge rst_n) begin if (!rst_n) begin dist_computing <= 1'b0; dist_accum <= 0; end else if (dist_computing) begin if (dist_dim < current_dims) begin // Compute (query - vector)^2 in fixed-point reg signed [DATA_WIDTH:0] diff; reg [31:0] sq; diff = $signed(query_buf[dist_dim]) - $signed(vectors[dist_target][dist_dim]); sq = diff * diff; dist_accum <= dist_accum + sq; dist_dim <= dist_dim + 1; end else begin dist_computing <= 1'b0; end end end endmodule // ============ Distance Unit - Pipelined L2 ============ module distance_unit #( parameter DATA_WIDTH = 16, parameter MAX_DIMS = 64 )( input wire clk, input wire rst_n, input wire start, input wire [5:0] dims, input wire [DATA_WIDTH-1:0] a_data, input wire [DATA_WIDTH-1:0] b_data, output reg [31:0] distance, output reg done ); reg [5:0] dim_idx; reg [31:0] accum; reg computing; always @(posedge clk or negedge rst_n) begin if (!rst_n) begin done <= 1'b0; computing <= 1'b0; accum <= 0; end else begin if (start && !computing) begin computing <= 1'b1; dim_idx <= 0; accum <= 0; done <= 1'b0; end else if (computing) begin if (dim_idx < dims) begin // Compute squared difference reg signed [DATA_WIDTH:0] diff; diff = $signed(a_data) - $signed(b_data); accum <= accum + (diff * diff); dim_idx <= dim_idx + 1; end else begin distance <= accum; done <= 1'b1; computing <= 1'b0; end end else begin done <= 1'b0; end end end endmodule // ============ Priority Queue for Candidates ============ module priority_queue #( parameter DEPTH = 16, parameter IDX_WIDTH = 8, parameter DIST_WIDTH = 16 )( input wire clk, input wire rst_n, input wire clear, // Insert interface input wire insert_valid, output wire insert_ready, input wire [IDX_WIDTH-1:0] insert_idx, input wire [DIST_WIDTH-1:0] insert_dist, // Pop interface (returns min distance) input wire pop_valid, output reg pop_ready, output reg [IDX_WIDTH-1:0] pop_idx, output reg [DIST_WIDTH-1:0] pop_dist, // Status output reg [4:0] count, output wire empty, output wire full ); reg [IDX_WIDTH-1:0] indices [0:DEPTH-1]; reg [DIST_WIDTH-1:0] distances [0:DEPTH-1]; assign empty = (count == 0); assign full = (count == DEPTH); assign insert_ready = !full; integer i; always @(posedge clk or negedge rst_n) begin if (!rst_n || clear) begin count <= 0; pop_ready <= 1'b0; end else begin // Insert operation (sorted insert) if (insert_valid && !full) begin // Find insertion position reg [4:0] pos; pos = count; for (i = count - 1; i >= 0; i = i - 1) begin if (insert_dist < distances[i]) begin indices[i+1] <= indices[i]; distances[i+1] <= distances[i]; pos = i; end end indices[pos] <= insert_idx; distances[pos] <= insert_dist; count <= count + 1; end // Pop operation if (pop_valid && !empty) begin pop_idx <= indices[0]; pop_dist <= distances[0]; pop_ready <= 1'b1; // Shift elements for (i = 0; i < DEPTH - 1; i = i + 1) begin indices[i] <= indices[i+1]; distances[i] <= distances[i+1]; end count <= count - 1; end else begin pop_ready <= 1'b0; end end end endmodule // ============ AXI-Lite Wrapper ============ module micro_hnsw_axi #( parameter C_S_AXI_DATA_WIDTH = 32, parameter C_S_AXI_ADDR_WIDTH = 8 )( // AXI-Lite interface input wire S_AXI_ACLK, input wire S_AXI_ARESETN, // Write address channel input wire [C_S_AXI_ADDR_WIDTH-1:0] S_AXI_AWADDR, input wire S_AXI_AWVALID, output wire S_AXI_AWREADY, // Write data channel input wire [C_S_AXI_DATA_WIDTH-1:0] S_AXI_WDATA, input wire [(C_S_AXI_DATA_WIDTH/8)-1:0] S_AXI_WSTRB, input wire S_AXI_WVALID, output wire S_AXI_WREADY, // Write response channel output wire [1:0] S_AXI_BRESP, output wire S_AXI_BVALID, input wire S_AXI_BREADY, // Read address channel input wire [C_S_AXI_ADDR_WIDTH-1:0] S_AXI_ARADDR, input wire S_AXI_ARVALID, output wire S_AXI_ARREADY, // Read data channel output wire [C_S_AXI_DATA_WIDTH-1:0] S_AXI_RDATA, output wire [1:0] S_AXI_RRESP, output wire S_AXI_RVALID, input wire S_AXI_RREADY ); // Register map: // 0x00: Control (W) - [2:0] cmd_op, [15:8] dims, [23:16] k // 0x04: Status (R) - [0] ready, [15:8] vector_count // 0x08: Vector Data (W) - write vector data // 0x0C: Result (R) - [7:0] idx, [23:8] distance, [31] last // Internal signals wire cmd_valid, cmd_ready; reg [2:0] cmd_op; reg [7:0] cmd_dims, cmd_k; wire vec_valid, vec_ready; reg [15:0] vec_data; reg vec_last; wire result_valid, result_ready; wire [7:0] result_idx; wire [15:0] result_dist; wire result_last; wire [7:0] vector_count; // Instantiate core micro_hnsw core ( .clk(S_AXI_ACLK), .rst_n(S_AXI_ARESETN), .cmd_valid(cmd_valid), .cmd_ready(cmd_ready), .cmd_op(cmd_op), .cmd_dims(cmd_dims), .cmd_k(cmd_k), .vec_valid(vec_valid), .vec_ready(vec_ready), .vec_data(vec_data), .vec_last(vec_last), .result_valid(result_valid), .result_ready(result_ready), .result_idx(result_idx), .result_dist(result_dist), .result_last(result_last), .vector_count(vector_count) ); // AXI-Lite state machine (simplified) reg aw_ready, w_ready, ar_ready; reg [1:0] b_resp; reg b_valid, r_valid; reg [C_S_AXI_DATA_WIDTH-1:0] r_data; assign S_AXI_AWREADY = aw_ready; assign S_AXI_WREADY = w_ready; assign S_AXI_BRESP = b_resp; assign S_AXI_BVALID = b_valid; assign S_AXI_ARREADY = ar_ready; assign S_AXI_RDATA = r_data; assign S_AXI_RRESP = 2'b00; assign S_AXI_RVALID = r_valid; assign cmd_valid = S_AXI_WVALID && (S_AXI_AWADDR == 8'h00); assign vec_valid = S_AXI_WVALID && (S_AXI_AWADDR == 8'h08); assign result_ready = S_AXI_RREADY && (S_AXI_ARADDR == 8'h0C); always @(posedge S_AXI_ACLK or negedge S_AXI_ARESETN) begin if (!S_AXI_ARESETN) begin aw_ready <= 1'b1; w_ready <= 1'b1; ar_ready <= 1'b1; b_valid <= 1'b0; r_valid <= 1'b0; end else begin // Write handling if (S_AXI_AWVALID && S_AXI_WVALID && aw_ready && w_ready) begin case (S_AXI_AWADDR) 8'h00: begin cmd_op <= S_AXI_WDATA[2:0]; cmd_dims <= S_AXI_WDATA[15:8]; cmd_k <= S_AXI_WDATA[23:16]; end 8'h08: begin vec_data <= S_AXI_WDATA[15:0]; vec_last <= S_AXI_WDATA[31]; end endcase b_valid <= 1'b1; end if (S_AXI_BREADY && b_valid) begin b_valid <= 1'b0; end // Read handling if (S_AXI_ARVALID && ar_ready) begin case (S_AXI_ARADDR) 8'h04: r_data <= {16'b0, vector_count, 7'b0, cmd_ready}; 8'h0C: r_data <= {result_last, 7'b0, result_dist, result_idx}; default: r_data <= 32'b0; endcase r_valid <= 1'b1; end if (S_AXI_RREADY && r_valid) begin r_valid <= 1'b0; end end end endmodule