diff --git a/v1/src/core/csi_processor.py b/v1/src/core/csi_processor.py index 049bf9d..041c042 100644 --- a/v1/src/core/csi_processor.py +++ b/v1/src/core/csi_processor.py @@ -385,13 +385,69 @@ class CSIProcessor: return correlation_matrix def _extract_doppler_features(self, csi_data: CSIData) -> tuple: - """Extract Doppler and frequency domain features.""" - # Simple Doppler estimation (would use history in real implementation) - doppler_shift = np.random.rand(10) # Placeholder - - # Power spectral density + """Extract Doppler and frequency domain features. + + Doppler shift estimation from CSI phase changes: + - Phase change rate indicates velocity of moving objects + - Frequency analysis reveals movement speed and direction + + The Doppler frequency shift is: f_d = (2 * v * f_c) / c + Where v = velocity, f_c = carrier frequency, c = speed of light + """ + # Power spectral density of amplitude psd = np.abs(scipy.fft.fft(csi_data.amplitude.flatten(), n=128))**2 - + + # Doppler estimation from phase history + if len(self.csi_history) < 2: + # Not enough history, return zeros + doppler_shift = np.zeros(min(csi_data.num_subcarriers, 10)) + return doppler_shift, psd + + # Get phase from current and previous samples + current_phase = csi_data.phase.flatten() + prev_data = self.csi_history[-1] + + # Handle if prev_data is tuple (CSIData, features) or just CSIData + if isinstance(prev_data, tuple): + prev_phase = prev_data[0].phase.flatten() + time_delta = (csi_data.timestamp - prev_data[0].timestamp).total_seconds() + else: + prev_phase = prev_data.phase.flatten() + time_delta = 1.0 / self.sampling_rate # Default to sampling interval + + if time_delta <= 0: + time_delta = 1.0 / self.sampling_rate + + # Ensure same length + min_len = min(len(current_phase), len(prev_phase)) + current_phase = current_phase[:min_len] + prev_phase = prev_phase[:min_len] + + # Calculate phase difference (unwrap to handle wrapping) + phase_diff = np.unwrap(current_phase) - np.unwrap(prev_phase) + + # Phase rate of change (rad/s) + phase_rate = phase_diff / time_delta + + # Convert to Doppler frequency (Hz) + # f_d = (d_phi/dt) / (2 * pi) + doppler_freq = phase_rate / (2 * np.pi) + + # Aggregate Doppler per subcarrier group (reduce to ~10 values) + num_groups = min(10, len(doppler_freq)) + group_size = max(1, len(doppler_freq) // num_groups) + + doppler_shift = np.array([ + np.mean(doppler_freq[i*group_size:(i+1)*group_size]) + for i in range(num_groups) + ]) + + # Apply smoothing to reduce noise + if len(doppler_shift) > 3: + # Simple moving average + kernel = np.ones(3) / 3 + doppler_shift = np.convolve(doppler_shift, kernel, mode='same') + return doppler_shift, psd def _analyze_motion_patterns(self, features: CSIFeatures) -> float: diff --git a/v1/src/core/router_interface.py b/v1/src/core/router_interface.py index 18c1e27..f1c1502 100644 --- a/v1/src/core/router_interface.py +++ b/v1/src/core/router_interface.py @@ -1,15 +1,27 @@ """ -Router interface for WiFi CSI data collection +Router interface for WiFi CSI data collection. + +Supports multiple router types: +- OpenWRT routers with Atheros CSI Tool +- DD-WRT routers with custom CSI extraction +- Custom firmware routers with raw CSI access """ import logging import asyncio +import struct import time -from typing import Dict, List, Optional, Any -from datetime import datetime +from typing import Dict, List, Optional, Any, Tuple +from datetime import datetime, timezone import numpy as np +try: + import asyncssh + HAS_ASYNCSSH = True +except ImportError: + HAS_ASYNCSSH = False + logger = logging.getLogger(__name__) @@ -72,28 +84,80 @@ class RouterInterface: } async def connect(self): - """Connect to the router.""" + """Connect to the router via SSH.""" if self.mock_mode: self.is_connected = True self.logger.info(f"Mock connection established to router {self.router_id}") return - + + if not HAS_ASYNCSSH: + self.logger.warning("asyncssh not available, falling back to mock mode") + self.mock_mode = True + self._initialize_mock_generator() + self.is_connected = True + return + try: self.logger.info(f"Connecting to router {self.router_id} at {self.host}:{self.port}") - - # In a real implementation, this would establish SSH connection - # For now, we'll simulate the connection - await asyncio.sleep(0.1) # Simulate connection delay - + + # Establish SSH connection + self.connection = await asyncssh.connect( + self.host, + port=self.port, + username=self.username, + password=self.password if self.password else None, + known_hosts=None, # Disable host key checking for embedded devices + connect_timeout=10 + ) + + # Verify connection by checking router type + await self._detect_router_type() + self.is_connected = True self.error_count = 0 self.logger.info(f"Connected to router {self.router_id}") - + except Exception as e: self.last_error = str(e) self.error_count += 1 self.logger.error(f"Failed to connect to router {self.router_id}: {e}") raise + + async def _detect_router_type(self): + """Detect router firmware type and CSI capabilities.""" + if not self.connection: + return + + try: + # Check for OpenWRT + result = await self.connection.run('cat /etc/openwrt_release 2>/dev/null || echo ""', check=False) + if 'OpenWrt' in result.stdout: + self.router_type = 'openwrt' + self.logger.info(f"Detected OpenWRT router: {self.router_id}") + return + + # Check for DD-WRT + result = await self.connection.run('nvram get DD_BOARD 2>/dev/null || echo ""', check=False) + if result.stdout.strip(): + self.router_type = 'ddwrt' + self.logger.info(f"Detected DD-WRT router: {self.router_id}") + return + + # Check for Atheros CSI Tool + result = await self.connection.run('which csi_tool 2>/dev/null || echo ""', check=False) + if result.stdout.strip(): + self.csi_tool_path = result.stdout.strip() + self.router_type = 'atheros_csi' + self.logger.info(f"Detected Atheros CSI Tool on router: {self.router_id}") + return + + # Default to generic Linux + self.router_type = 'generic' + self.logger.info(f"Generic Linux router: {self.router_id}") + + except Exception as e: + self.logger.warning(f"Could not detect router type: {e}") + self.router_type = 'unknown' async def disconnect(self): """Disconnect from the router.""" @@ -195,11 +259,244 @@ class RouterInterface: return csi_data async def _collect_real_csi_data(self) -> Optional[np.ndarray]: - """Collect real CSI data from router (placeholder implementation).""" - # This would implement the actual CSI data collection - # For now, return None to indicate no real implementation - self.logger.warning("Real CSI data collection not implemented") - return None + """Collect real CSI data from router via SSH. + + Supports multiple CSI extraction methods: + - Atheros CSI Tool (ath9k/ath10k) + - Custom kernel module reading + - Proc filesystem access + - Raw device file reading + + Returns: + Numpy array of complex CSI values or None on failure + """ + if not self.connection: + self.logger.error("No SSH connection available") + return None + + try: + router_type = getattr(self, 'router_type', 'unknown') + + if router_type == 'atheros_csi': + return await self._collect_atheros_csi() + elif router_type == 'openwrt': + return await self._collect_openwrt_csi() + else: + return await self._collect_generic_csi() + + except Exception as e: + self.logger.error(f"Error collecting CSI data: {e}") + self.error_count += 1 + return None + + async def _collect_atheros_csi(self) -> Optional[np.ndarray]: + """Collect CSI using Atheros CSI Tool.""" + csi_tool = getattr(self, 'csi_tool_path', '/usr/bin/csi_tool') + + try: + # Read single CSI sample + result = await self.connection.run( + f'{csi_tool} -i {self.interface} -c 1 -f /tmp/csi_sample.dat && ' + f'cat /tmp/csi_sample.dat | base64', + check=True, + timeout=5 + ) + + # Decode base64 CSI data + import base64 + csi_bytes = base64.b64decode(result.stdout.strip()) + + return self._parse_atheros_csi_bytes(csi_bytes) + + except Exception as e: + self.logger.error(f"Atheros CSI collection failed: {e}") + return None + + async def _collect_openwrt_csi(self) -> Optional[np.ndarray]: + """Collect CSI from OpenWRT with CSI support.""" + try: + # Try reading from debugfs (common CSI location) + result = await self.connection.run( + f'cat /sys/kernel/debug/ieee80211/phy0/ath9k/csi 2>/dev/null | head -c 4096 | base64', + check=False, + timeout=5 + ) + + if result.returncode == 0 and result.stdout.strip(): + import base64 + csi_bytes = base64.b64decode(result.stdout.strip()) + return self._parse_atheros_csi_bytes(csi_bytes) + + # Try alternate location + result = await self.connection.run( + f'cat /proc/csi 2>/dev/null | head -c 4096 | base64', + check=False, + timeout=5 + ) + + if result.returncode == 0 and result.stdout.strip(): + import base64 + csi_bytes = base64.b64decode(result.stdout.strip()) + return self._parse_generic_csi_bytes(csi_bytes) + + self.logger.warning("No CSI data available from OpenWRT paths") + return None + + except Exception as e: + self.logger.error(f"OpenWRT CSI collection failed: {e}") + return None + + async def _collect_generic_csi(self) -> Optional[np.ndarray]: + """Collect CSI using generic Linux methods.""" + try: + # Try iw command for station info (not real CSI but channel info) + result = await self.connection.run( + f'iw dev {self.interface} survey dump 2>/dev/null || echo ""', + check=False, + timeout=5 + ) + + if result.stdout.strip(): + # Parse survey data for channel metrics + return self._parse_survey_data(result.stdout) + + self.logger.warning("No CSI data available via generic methods") + return None + + except Exception as e: + self.logger.error(f"Generic CSI collection failed: {e}") + return None + + def _parse_atheros_csi_bytes(self, data: bytes) -> Optional[np.ndarray]: + """Parse Atheros CSI Tool binary format. + + Format: + - 4 bytes: magic (0x11111111) + - 8 bytes: timestamp + - 2 bytes: channel + - 1 byte: bandwidth + - 1 byte: num_rx_antennas + - 1 byte: num_tx_antennas + - 1 byte: num_tones + - 2 bytes: RSSI + - Remaining: CSI matrix as int16 I/Q pairs + """ + if len(data) < 20: + return None + + try: + magic = struct.unpack(' Optional[np.ndarray]: + """Parse generic binary CSI format.""" + if len(data) < 8: + return None + + try: + # Assume simple format: int16 I/Q pairs + num_samples = len(data) // 4 + if num_samples == 0: + return None + + # Default to 56 subcarriers (20MHz), adjust antennas + num_tones = min(56, num_samples) + num_antennas = max(1, num_samples // num_tones) + + csi_matrix = np.zeros((num_antennas, num_tones), dtype=complex) + + for i in range(min(num_samples, num_antennas * num_tones)): + offset = i * 4 + if offset + 4 <= len(data): + real, imag = struct.unpack(' Optional[np.ndarray]: + """Parse iw survey dump output to extract channel metrics. + + This isn't true CSI but provides per-channel noise and activity data + that can be used as a fallback. + """ + try: + lines = survey_output.strip().split('\n') + noise_values = [] + busy_values = [] + + for line in lines: + if 'noise:' in line.lower(): + parts = line.split() + for i, p in enumerate(parts): + if p == 'dBm' and i > 0: + try: + noise_values.append(float(parts[i-1])) + except ValueError: + pass + elif 'channel busy time:' in line.lower(): + parts = line.split() + for i, p in enumerate(parts): + if p == 'ms' and i > 0: + try: + busy_values.append(float(parts[i-1])) + except ValueError: + pass + + if noise_values: + # Create pseudo-CSI from noise measurements + num_channels = len(noise_values) + csi_matrix = np.zeros((1, max(56, num_channels)), dtype=complex) + + for i, noise in enumerate(noise_values): + # Convert noise dBm to amplitude (simplified) + amplitude = 10 ** (noise / 20) + phase = 0 if i >= len(busy_values) else busy_values[i] / 1000 * np.pi + csi_matrix[0, i] = amplitude * np.exp(1j * phase) + + return csi_matrix + + return None + + except Exception as e: + self.logger.error(f"Error parsing survey data: {e}") + return None async def check_health(self) -> bool: """Check if the router connection is healthy. diff --git a/v1/src/hardware/csi_extractor.py b/v1/src/hardware/csi_extractor.py index 98f48da..4af4d6e 100644 --- a/v1/src/hardware/csi_extractor.py +++ b/v1/src/hardware/csi_extractor.py @@ -1,9 +1,10 @@ """CSI data extraction from WiFi hardware using Test-Driven Development approach.""" import asyncio +import struct import numpy as np from datetime import datetime, timezone -from typing import Dict, Any, Optional, Callable, Protocol +from typing import Dict, Any, Optional, Callable, Protocol, List, Tuple from dataclasses import dataclass from abc import ABC, abstractmethod import logging @@ -35,128 +36,601 @@ class CSIData: class CSIParser(Protocol): """Protocol for CSI data parsers.""" - + def parse(self, raw_data: bytes) -> CSIData: """Parse raw CSI data into structured format.""" ... class ESP32CSIParser: - """Parser for ESP32 CSI data format.""" - + """Parser for ESP32 CSI data format. + + ESP32 CSI data format (from esp-csi library): + - Header: 'CSI_DATA:' prefix + - Fields: timestamp,rssi,rate,sig_mode,mcs,bandwidth,smoothing, + not_sounding,aggregation,stbc,fec_coding,sgi,noise_floor, + ampdu_cnt,channel,secondary_channel,local_timestamp, + ant,sig_len,rx_state,len,first_word,data[...] + + The actual CSI data is in the 'data' field as complex I/Q values. + """ + + def __init__(self): + """Initialize ESP32 CSI parser with default configuration.""" + self.htltf_subcarriers = 56 # HT-LTF subcarriers for 20MHz + self.antenna_count = 1 # Most ESP32 have 1 antenna + def parse(self, raw_data: bytes) -> CSIData: """Parse ESP32 CSI data format. - + Args: - raw_data: Raw bytes from ESP32 - + raw_data: Raw bytes from ESP32 serial/network + Returns: Parsed CSI data - + Raises: CSIParseError: If data format is invalid """ if not raw_data: raise CSIParseError("Empty data received") - + try: - data_str = raw_data.decode('utf-8') - if not data_str.startswith('CSI_DATA:'): + data_str = raw_data.decode('utf-8').strip() + + # Handle ESP-CSI library format + if data_str.startswith('CSI_DATA,'): + return self._parse_esp_csi_format(data_str) + # Handle simplified format for testing + elif data_str.startswith('CSI_DATA:'): + return self._parse_simple_format(data_str) + else: raise CSIParseError("Invalid ESP32 CSI data format") - - # Parse ESP32 format: CSI_DATA:timestamp,antennas,subcarriers,freq,bw,snr,[amp],[phase] - parts = data_str[9:].split(',') # Remove 'CSI_DATA:' prefix - - timestamp_ms = int(parts[0]) - num_antennas = int(parts[1]) - num_subcarriers = int(parts[2]) - frequency_mhz = float(parts[3]) - bandwidth_mhz = float(parts[4]) - snr = float(parts[5]) - - # Convert to proper units - frequency = frequency_mhz * 1e6 # MHz to Hz - bandwidth = bandwidth_mhz * 1e6 # MHz to Hz - - # Parse amplitude and phase arrays (simplified for now) - # In real implementation, this would parse actual CSI matrix data - amplitude = np.random.rand(num_antennas, num_subcarriers) - phase = np.random.rand(num_antennas, num_subcarriers) - - return CSIData( - timestamp=datetime.fromtimestamp(timestamp_ms / 1000, tz=timezone.utc), - amplitude=amplitude, - phase=phase, - frequency=frequency, - bandwidth=bandwidth, - num_subcarriers=num_subcarriers, - num_antennas=num_antennas, - snr=snr, - metadata={'source': 'esp32', 'raw_length': len(raw_data)} - ) - + + except UnicodeDecodeError: + # Binary format - parse as raw bytes + return self._parse_binary_format(raw_data) except (ValueError, IndexError) as e: raise CSIParseError(f"Failed to parse ESP32 data: {e}") + def _parse_esp_csi_format(self, data_str: str) -> CSIData: + """Parse ESP-CSI library CSV format. + + Format: CSI_DATA,,,,,,,, + ,,,,,, + ,,,,,, + ,,[csi_data...] + """ + parts = data_str.split(',') + + if len(parts) < 22: + raise CSIParseError(f"Incomplete ESP-CSI data: expected >= 22 fields, got {len(parts)}") + + # Extract metadata + mac_addr = parts[1] + rssi = int(parts[2]) + rate = int(parts[3]) + sig_mode = int(parts[4]) + mcs = int(parts[5]) + bandwidth = int(parts[6]) # 0=20MHz, 1=40MHz + channel = int(parts[15]) + timestamp_us = int(parts[17]) + csi_len = int(parts[21]) + + # Parse CSI I/Q data (remaining fields are the CSI values) + csi_raw = [int(x) for x in parts[22:22 + csi_len]] + + # Convert I/Q pairs to complex numbers + # ESP32 CSI format: [I0, Q0, I1, Q1, ...] as signed 8-bit integers + amplitude, phase = self._iq_to_amplitude_phase(csi_raw) + + # Determine frequency from channel + if channel <= 14: + frequency = 2.412e9 + (channel - 1) * 5e6 # 2.4 GHz band + else: + frequency = 5.0e9 + (channel - 36) * 5e6 # 5 GHz band + + bw_hz = 20e6 if bandwidth == 0 else 40e6 + num_subcarriers = len(amplitude) // self.antenna_count + + return CSIData( + timestamp=datetime.fromtimestamp(timestamp_us / 1e6, tz=timezone.utc), + amplitude=amplitude.reshape(self.antenna_count, -1), + phase=phase.reshape(self.antenna_count, -1), + frequency=frequency, + bandwidth=bw_hz, + num_subcarriers=num_subcarriers, + num_antennas=self.antenna_count, + snr=float(rssi + 100), # Approximate SNR from RSSI + metadata={ + 'source': 'esp32', + 'mac': mac_addr, + 'rssi': rssi, + 'mcs': mcs, + 'channel': channel, + 'sig_mode': sig_mode, + } + ) + + def _parse_simple_format(self, data_str: str) -> CSIData: + """Parse simplified CSI format for testing/development. + + Format: CSI_DATA:timestamp,antennas,subcarriers,freq,bw,snr,[amp_values],[phase_values] + """ + content = data_str[9:] # Remove 'CSI_DATA:' prefix + + # Split the main fields and array data + if '[' in content: + main_part, arrays_part = content.split('[', 1) + parts = main_part.rstrip(',').split(',') + + # Parse amplitude and phase arrays + arrays_str = '[' + arrays_part + amp_str, phase_str = self._split_arrays(arrays_str) + amplitude = np.array([float(x) for x in amp_str.strip('[]').split(',')]) + phase = np.array([float(x) for x in phase_str.strip('[]').split(',')]) + else: + parts = content.split(',') + # No array data provided, need to return error or minimal data + raise CSIParseError("No CSI array data in simple format") + + timestamp_ms = int(parts[0]) + num_antennas = int(parts[1]) + num_subcarriers = int(parts[2]) + frequency_mhz = float(parts[3]) + bandwidth_mhz = float(parts[4]) + snr = float(parts[5]) + + # Reshape arrays + expected_size = num_antennas * num_subcarriers + if len(amplitude) != expected_size: + # Interpolate or pad + amplitude = np.interp( + np.linspace(0, 1, expected_size), + np.linspace(0, 1, len(amplitude)), + amplitude + ) + phase = np.interp( + np.linspace(0, 1, expected_size), + np.linspace(0, 1, len(phase)), + phase + ) + + return CSIData( + timestamp=datetime.fromtimestamp(timestamp_ms / 1000, tz=timezone.utc), + amplitude=amplitude.reshape(num_antennas, num_subcarriers), + phase=phase.reshape(num_antennas, num_subcarriers), + frequency=frequency_mhz * 1e6, + bandwidth=bandwidth_mhz * 1e6, + num_subcarriers=num_subcarriers, + num_antennas=num_antennas, + snr=snr, + metadata={'source': 'esp32', 'format': 'simple'} + ) + + def _parse_binary_format(self, raw_data: bytes) -> CSIData: + """Parse binary CSI format from ESP32. + + Binary format (struct packed): + - 4 bytes: timestamp (uint32) + - 1 byte: num_antennas (uint8) + - 1 byte: num_subcarriers (uint8) + - 2 bytes: channel (uint16) + - 4 bytes: frequency (float32) + - 4 bytes: bandwidth (float32) + - 4 bytes: snr (float32) + - Remaining: CSI I/Q data as int8 pairs + """ + if len(raw_data) < 20: + raise CSIParseError("Binary data too short") + + header_fmt = ' expected_size: + amplitude = amplitude[:expected_size] + phase = phase[:expected_size] + + return CSIData( + timestamp=datetime.fromtimestamp(timestamp / 1000, tz=timezone.utc), + amplitude=amplitude.reshape(num_antennas, num_subcarriers), + phase=phase.reshape(num_antennas, num_subcarriers), + frequency=float(freq), + bandwidth=float(bw), + num_subcarriers=num_subcarriers, + num_antennas=num_antennas, + snr=float(snr), + metadata={'source': 'esp32', 'format': 'binary', 'channel': channel} + ) + + def _iq_to_amplitude_phase(self, iq_data: List[int]) -> Tuple[np.ndarray, np.ndarray]: + """Convert I/Q pairs to amplitude and phase. + + Args: + iq_data: List of interleaved I, Q values (signed 8-bit) + + Returns: + Tuple of (amplitude, phase) arrays + """ + if len(iq_data) % 2 != 0: + iq_data = iq_data[:-1] # Trim odd value + + i_vals = np.array(iq_data[0::2], dtype=np.float64) + q_vals = np.array(iq_data[1::2], dtype=np.float64) + + # Calculate amplitude (magnitude) and phase + complex_vals = i_vals + 1j * q_vals + amplitude = np.abs(complex_vals) + phase = np.angle(complex_vals) + + # Normalize amplitude to [0, 1] range + max_amp = np.max(amplitude) + if max_amp > 0: + amplitude = amplitude / max_amp + + return amplitude, phase + + def _split_arrays(self, arrays_str: str) -> Tuple[str, str]: + """Split concatenated array strings.""" + # Find the boundary between two arrays + depth = 0 + split_idx = 0 + for i, c in enumerate(arrays_str): + if c == '[': + depth += 1 + elif c == ']': + depth -= 1 + if depth == 0: + split_idx = i + 1 + break + + amp_str = arrays_str[:split_idx] + phase_str = arrays_str[split_idx:].lstrip(',') + return amp_str, phase_str + class RouterCSIParser: - """Parser for router CSI data format.""" - + """Parser for router CSI data formats (Atheros, Intel, etc.). + + Supports: + - Atheros CSI Tool format (ath9k/ath10k) + - Intel 5300 CSI Tool format + - Nexmon CSI format (Broadcom) + """ + + def __init__(self): + """Initialize router CSI parser.""" + self.default_subcarriers = 56 # 20MHz HT + self.default_antennas = 3 + def parse(self, raw_data: bytes) -> CSIData: """Parse router CSI data format. - + Args: raw_data: Raw bytes from router - + Returns: Parsed CSI data - + Raises: CSIParseError: If data format is invalid """ if not raw_data: raise CSIParseError("Empty data received") - - # Handle different router formats - data_str = raw_data.decode('utf-8') - - if data_str.startswith('ATHEROS_CSI:'): - return self._parse_atheros_format(raw_data) + + # Try to decode as text first + try: + data_str = raw_data.decode('utf-8') + if data_str.startswith('ATHEROS_CSI:'): + return self._parse_atheros_text_format(data_str) + elif data_str.startswith('INTEL_CSI:'): + return self._parse_intel_text_format(data_str) + except UnicodeDecodeError: + pass + + # Binary format detection based on header + if len(raw_data) >= 4: + magic = struct.unpack(' CSIData: + """Parse Atheros CSI text format. + + Format: ATHEROS_CSI:timestamp,rssi,rate,channel,bw,nr,nc,num_tones,[csi_data...] + """ + content = data_str[12:] # Remove 'ATHEROS_CSI:' prefix + parts = content.split(',') + + if len(parts) < 8: + raise CSIParseError("Incomplete Atheros CSI data") + + timestamp = int(parts[0]) + rssi = int(parts[1]) + rate = int(parts[2]) + channel = int(parts[3]) + bandwidth = int(parts[4]) # MHz + nr = int(parts[5]) # Rx antennas + nc = int(parts[6]) # Tx antennas (usually 1 for probe) + num_tones = int(parts[7]) # Subcarriers + + # Parse CSI matrix data + csi_values = [float(x) for x in parts[8:] if x.strip()] + + # CSI data is complex: [real, imag, real, imag, ...] + amplitude, phase = self._parse_complex_csi(csi_values, nr, num_tones) + + # Calculate frequency from channel + if channel <= 14: + frequency = 2.412e9 + (channel - 1) * 5e6 else: - raise CSIParseError("Unknown router CSI format") - - def _parse_atheros_format(self, raw_data: bytes) -> CSIData: - """Parse Atheros CSI format (placeholder implementation).""" - # This would implement actual Atheros CSI parsing - # For now, return mock data for testing + frequency = 5.18e9 + (channel - 36) * 5e6 + return CSIData( - timestamp=datetime.now(timezone.utc), - amplitude=np.random.rand(3, 56), - phase=np.random.rand(3, 56), - frequency=2.4e9, - bandwidth=20e6, - num_subcarriers=56, - num_antennas=3, - snr=12.0, - metadata={'source': 'atheros_router'} + timestamp=datetime.fromtimestamp(timestamp / 1000, tz=timezone.utc), + amplitude=amplitude, + phase=phase, + frequency=frequency, + bandwidth=bandwidth * 1e6, + num_subcarriers=num_tones, + num_antennas=nr, + snr=float(rssi + 95), + metadata={ + 'source': 'atheros_router', + 'rssi': rssi, + 'rate': rate, + 'channel': channel, + 'tx_antennas': nc, + } ) + def _parse_atheros_binary_format(self, raw_data: bytes) -> CSIData: + """Parse Atheros CSI Tool binary format. + + Based on ath9k/ath10k CSI Tool structure: + - 4 bytes: magic (0x11111111) + - 8 bytes: timestamp + - 2 bytes: channel + - 1 byte: bandwidth (0=20MHz, 1=40MHz, 2=80MHz) + - 1 byte: nr (rx antennas) + - 1 byte: nc (tx antennas) + - 1 byte: num_tones + - 2 bytes: rssi + - Remaining: CSI payload (complex int16 per subcarrier per antenna pair) + """ + if len(raw_data) < 20: + raise CSIParseError("Atheros binary data too short") + + header_fmt = ' 0: + amplitude = amplitude / max_amp + + # Calculate frequency + if channel <= 14: + frequency = 2.412e9 + (channel - 1) * 5e6 + else: + frequency = 5.18e9 + (channel - 36) * 5e6 + + bandwidth_hz = [20e6, 40e6, 80e6][bw] if bw < 3 else 20e6 + + return CSIData( + timestamp=datetime.fromtimestamp(timestamp / 1e9, tz=timezone.utc), + amplitude=amplitude, + phase=phase, + frequency=frequency, + bandwidth=bandwidth_hz, + num_subcarriers=num_tones, + num_antennas=nr, + snr=float(rssi), + metadata={ + 'source': 'atheros_router', + 'format': 'binary', + 'channel': channel, + 'tx_antennas': nc, + } + ) + + def _parse_intel_text_format(self, data_str: str) -> CSIData: + """Parse Intel 5300 CSI text format.""" + content = data_str[10:] # Remove 'INTEL_CSI:' prefix + parts = content.split(',') + + if len(parts) < 6: + raise CSIParseError("Incomplete Intel CSI data") + + timestamp = int(parts[0]) + rssi = int(parts[1]) + channel = int(parts[2]) + bandwidth = int(parts[3]) + num_antennas = int(parts[4]) + num_tones = int(parts[5]) + + csi_values = [float(x) for x in parts[6:] if x.strip()] + amplitude, phase = self._parse_complex_csi(csi_values, num_antennas, num_tones) + + frequency = 5.18e9 + (channel - 36) * 5e6 if channel > 14 else 2.412e9 + (channel - 1) * 5e6 + + return CSIData( + timestamp=datetime.fromtimestamp(timestamp / 1000, tz=timezone.utc), + amplitude=amplitude, + phase=phase, + frequency=frequency, + bandwidth=bandwidth * 1e6, + num_subcarriers=num_tones, + num_antennas=num_antennas, + snr=float(rssi + 95), + metadata={'source': 'intel_5300', 'channel': channel} + ) + + def _parse_intel_binary_format(self, raw_data: bytes) -> CSIData: + """Parse Intel 5300 CSI Tool binary format.""" + # Intel format is more complex with BFEE (beamforming feedback) structure + if len(raw_data) < 25: + raise CSIParseError("Intel binary data too short") + + # BFEE header structure + timestamp = struct.unpack(' 0: + amplitude = amplitude / max_amp + + rssi_avg = (rssi_a + rssi_b + rssi_c) / 3 + + return CSIData( + timestamp=datetime.fromtimestamp(timestamp / 1e6, tz=timezone.utc), + amplitude=amplitude, + phase=phase, + frequency=5.32e9, # Default Intel channel + bandwidth=40e6, + num_subcarriers=num_tones, + num_antennas=nr, + snr=float(rssi_avg - noise), + metadata={ + 'source': 'intel_5300', + 'format': 'binary', + 'noise_floor': noise, + 'agc': agc, + } + ) + + def _unpack_intel_csi(self, data: bytes, nr: int, nc: int, num_tones: int) -> np.ndarray: + """Unpack Intel CSI data with bit manipulation.""" + csi = np.zeros((nr, nc, num_tones), dtype=np.complex128) + + # Intel uses packed 10-bit values + bits_per_sample = 10 + samples_needed = nr * nc * num_tones * 2 # real + imag + + # Simple unpacking (actual Intel format is more complex) + idx = 0 + for tone in range(num_tones): + for nc_idx in range(nc): + for nr_idx in range(nr): + if idx + 2 <= len(data): + # Approximate unpacking + real = int.from_bytes(data[idx:idx+1], 'little', signed=True) + imag = int.from_bytes(data[idx+1:idx+2], 'little', signed=True) + csi[nr_idx, nc_idx, tone] = complex(real, imag) + idx += 2 + + return csi + + def _parse_complex_csi( + self, + values: List[float], + num_antennas: int, + num_tones: int + ) -> Tuple[np.ndarray, np.ndarray]: + """Parse complex CSI values from real/imag pairs.""" + expected_len = num_antennas * num_tones * 2 + + if len(values) < expected_len: + # Pad with zeros + values = values + [0.0] * (expected_len - len(values)) + + csi_complex = np.zeros((num_antennas, num_tones), dtype=np.complex128) + + for ant in range(num_antennas): + for tone in range(num_tones): + idx = (ant * num_tones + tone) * 2 + if idx + 1 < len(values): + csi_complex[ant, tone] = complex(values[idx], values[idx + 1]) + + amplitude = np.abs(csi_complex) + phase = np.angle(csi_complex) + + # Normalize + max_amp = np.max(amplitude) + if max_amp > 0: + amplitude = amplitude / max_amp + + return amplitude, phase + class CSIExtractor: """Main CSI data extractor supporting multiple hardware types.""" - + def __init__(self, config: Dict[str, Any], logger: Optional[logging.Logger] = None): """Initialize CSI extractor. - + Args: config: Configuration dictionary logger: Optional logger instance - + Raises: ValueError: If configuration is invalid """ self._validate_config(config) - + self.config = config self.logger = logger or logging.getLogger(__name__) self.hardware_type = config['hardware_type'] @@ -165,49 +639,39 @@ class CSIExtractor: self.timeout = config['timeout'] self.validation_enabled = config.get('validation_enabled', True) self.retry_attempts = config.get('retry_attempts', 3) - + # State management self.is_connected = False self.is_streaming = False - + self._connection = None + # Create appropriate parser if self.hardware_type == 'esp32': self.parser = ESP32CSIParser() - elif self.hardware_type == 'router': + elif self.hardware_type in ('router', 'atheros', 'intel'): self.parser = RouterCSIParser() else: raise ValueError(f"Unsupported hardware type: {self.hardware_type}") - + def _validate_config(self, config: Dict[str, Any]) -> None: - """Validate configuration parameters. - - Args: - config: Configuration to validate - - Raises: - ValueError: If configuration is invalid - """ + """Validate configuration parameters.""" required_fields = ['hardware_type', 'sampling_rate', 'buffer_size', 'timeout'] missing_fields = [field for field in required_fields if field not in config] - + if missing_fields: raise ValueError(f"Missing required configuration: {missing_fields}") - + if config['sampling_rate'] <= 0: raise ValueError("sampling_rate must be positive") - + if config['buffer_size'] <= 0: raise ValueError("buffer_size must be positive") - + if config['timeout'] <= 0: raise ValueError("timeout must be positive") - + async def connect(self) -> bool: - """Establish connection to CSI hardware. - - Returns: - True if connection successful, False otherwise - """ + """Establish connection to CSI hardware.""" try: success = await self._establish_hardware_connection() self.is_connected = success @@ -216,86 +680,64 @@ class CSIExtractor: self.logger.error(f"Failed to connect to hardware: {e}") self.is_connected = False return False - + async def disconnect(self) -> None: """Disconnect from CSI hardware.""" if self.is_connected: await self._close_hardware_connection() self.is_connected = False - + async def extract_csi(self) -> CSIData: - """Extract CSI data from hardware. - - Returns: - Extracted CSI data - - Raises: - CSIParseError: If not connected or extraction fails - """ + """Extract CSI data from hardware.""" if not self.is_connected: raise CSIParseError("Not connected to hardware") - - # Retry mechanism for temporary failures + for attempt in range(self.retry_attempts): try: raw_data = await self._read_raw_data() csi_data = self.parser.parse(raw_data) - + if self.validation_enabled: self.validate_csi_data(csi_data) - + return csi_data - + except ConnectionError as e: if attempt < self.retry_attempts - 1: self.logger.warning(f"Extraction attempt {attempt + 1} failed, retrying: {e}") - await asyncio.sleep(0.1) # Brief delay before retry + await asyncio.sleep(0.1) else: raise CSIParseError(f"Extraction failed after {self.retry_attempts} attempts: {e}") - + def validate_csi_data(self, csi_data: CSIData) -> bool: - """Validate CSI data structure and values. - - Args: - csi_data: CSI data to validate - - Returns: - True if valid - - Raises: - CSIValidationError: If data is invalid - """ + """Validate CSI data structure and values.""" if csi_data.amplitude.size == 0: raise CSIValidationError("Empty amplitude data") - + if csi_data.phase.size == 0: raise CSIValidationError("Empty phase data") - + if csi_data.frequency <= 0: raise CSIValidationError("Invalid frequency") - + if csi_data.bandwidth <= 0: raise CSIValidationError("Invalid bandwidth") - + if csi_data.num_subcarriers <= 0: raise CSIValidationError("Invalid number of subcarriers") - + if csi_data.num_antennas <= 0: raise CSIValidationError("Invalid number of antennas") - - if csi_data.snr < -50 or csi_data.snr > 50: # Reasonable SNR range + + if csi_data.snr < -50 or csi_data.snr > 100: raise CSIValidationError("Invalid SNR value") - + return True - + async def start_streaming(self, callback: Callable[[CSIData], None]) -> None: - """Start streaming CSI data. - - Args: - callback: Function to call with each CSI sample - """ + """Start streaming CSI data.""" self.is_streaming = True - + try: while self.is_streaming: csi_data = await self.extract_csi() @@ -305,22 +747,74 @@ class CSIExtractor: self.logger.error(f"Streaming error: {e}") finally: self.is_streaming = False - + def stop_streaming(self) -> None: """Stop streaming CSI data.""" self.is_streaming = False - + async def _establish_hardware_connection(self) -> bool: - """Establish connection to hardware (to be implemented by subclasses).""" - # Placeholder implementation for testing - return True - + """Establish connection to hardware.""" + connection_config = self.config.get('connection', {}) + + if self.hardware_type == 'esp32': + # Serial or network connection for ESP32 + port = connection_config.get('port', '/dev/ttyUSB0') + baudrate = connection_config.get('baudrate', 115200) + + try: + import serial_asyncio + reader, writer = await serial_asyncio.open_serial_connection( + url=port, baudrate=baudrate + ) + self._connection = (reader, writer) + return True + except ImportError: + self.logger.warning("serial_asyncio not available, using mock connection") + return True + except Exception as e: + self.logger.error(f"Serial connection failed: {e}") + return False + + elif self.hardware_type in ('router', 'atheros', 'intel'): + # Network connection for router + host = connection_config.get('host', '192.168.1.1') + port = connection_config.get('port', 5500) + + try: + reader, writer = await asyncio.open_connection(host, port) + self._connection = (reader, writer) + return True + except Exception as e: + self.logger.error(f"Network connection failed: {e}") + return False + + return False + async def _close_hardware_connection(self) -> None: - """Close hardware connection (to be implemented by subclasses).""" - # Placeholder implementation for testing - pass - + """Close hardware connection.""" + if self._connection: + try: + reader, writer = self._connection + writer.close() + await writer.wait_closed() + except Exception as e: + self.logger.error(f"Error closing connection: {e}") + finally: + self._connection = None + async def _read_raw_data(self) -> bytes: - """Read raw data from hardware (to be implemented by subclasses).""" - # Placeholder implementation for testing - return b"CSI_DATA:1234567890,3,56,2400,20,15.5,[1.0,2.0,3.0],[0.5,1.5,2.5]" \ No newline at end of file + """Read raw data from hardware.""" + if self._connection: + reader, writer = self._connection + try: + # Read until newline or buffer size + data = await asyncio.wait_for( + reader.readline(), + timeout=self.timeout + ) + return data + except asyncio.TimeoutError: + raise ConnectionError("Read timeout") + else: + # Mock data for testing when no real connection + raise ConnectionError("No active connection") \ No newline at end of file diff --git a/v1/src/services/pose_service.py b/v1/src/services/pose_service.py index 55057b3..781f060 100644 --- a/v1/src/services/pose_service.py +++ b/v1/src/services/pose_service.py @@ -265,30 +265,371 @@ class PoseService: self.logger.error(f"Error in pose estimation: {e}") return [] - def _parse_pose_outputs(self, outputs: torch.Tensor) -> List[Dict[str, Any]]: - """Parse neural network outputs into pose detections.""" + def _parse_pose_outputs(self, outputs: Dict[str, torch.Tensor]) -> List[Dict[str, Any]]: + """Parse neural network outputs into pose detections. + + The DensePose model outputs: + - segmentation: (batch, num_parts+1, H, W) - body part segmentation + - uv_coords: (batch, 2, H, W) - UV coordinates for surface mapping + + Returns list of detected persons with keypoints and body parts. + """ + poses = [] + + # Handle different output formats + if isinstance(outputs, torch.Tensor): + # Simple tensor output - use legacy parsing + return self._parse_simple_outputs(outputs) + + # DensePose structured output + segmentation = outputs.get('segmentation') + uv_coords = outputs.get('uv_coords') + + if segmentation is None: + return [] + + batch_size = segmentation.shape[0] + + for batch_idx in range(batch_size): + # Get segmentation for this sample + seg = segmentation[batch_idx] # (num_parts+1, H, W) + + # Find persons by analyzing body part segmentation + # Background is class 0, body parts are 1-24 + body_mask = seg[1:].sum(dim=0) > seg[0] # Any body part vs background + + if not body_mask.any(): + continue + + # Find connected components (persons) + person_regions = self._find_person_regions(body_mask) + + for person_idx, region in enumerate(person_regions): + # Extract keypoints from body part segmentation + keypoints = self._extract_keypoints_from_segmentation(seg, region) + + # Calculate bounding box from region + bbox = self._calculate_bounding_box(region) + + # Calculate confidence from segmentation probabilities + seg_probs = torch.softmax(seg, dim=0) + region_mask = region['mask'] + confidence = float(seg_probs[1:, region_mask].max().item()) + + # Classify activity from pose keypoints + activity = self._classify_activity_from_keypoints(keypoints) + + pose = { + "person_id": person_idx, + "confidence": confidence, + "keypoints": keypoints, + "bounding_box": bbox, + "activity": activity, + "timestamp": datetime.now().isoformat(), + "body_parts": self._extract_body_parts(seg, region) if uv_coords is not None else None + } + + poses.append(pose) + + return poses + + def _parse_simple_outputs(self, outputs: torch.Tensor) -> List[Dict[str, Any]]: + """Parse simple tensor outputs (fallback for non-DensePose models).""" poses = [] - - # This is a simplified parsing - in reality, this would depend on the model architecture - # For now, generate mock poses based on the output shape batch_size = outputs.shape[0] - + for i in range(batch_size): - # Extract pose information (mock implementation) - confidence = float(torch.sigmoid(outputs[i, 0]).item()) if outputs.shape[1] > 0 else 0.5 - + output = outputs[i] + + # Extract confidence from first channel + confidence = float(torch.sigmoid(output[0]).mean().item()) if output.numel() > 0 else 0.0 + + if confidence < 0.1: + continue + + # Try to extract keypoints from output tensor + keypoints = self._extract_keypoints_from_tensor(output) + bbox = self._estimate_bbox_from_keypoints(keypoints) + activity = self._classify_activity_from_keypoints(keypoints) + pose = { "person_id": i, "confidence": confidence, - "keypoints": self._generate_keypoints(), - "bounding_box": self._generate_bounding_box(), - "activity": self._classify_activity(outputs[i] if len(outputs.shape) > 1 else outputs), + "keypoints": keypoints, + "bounding_box": bbox, + "activity": activity, "timestamp": datetime.now().isoformat() } - + poses.append(pose) - + return poses + + def _find_person_regions(self, body_mask: torch.Tensor) -> List[Dict[str, Any]]: + """Find distinct person regions in body mask using connected components.""" + # Convert to numpy for connected component analysis + mask_np = body_mask.cpu().numpy().astype(np.uint8) + + # Simple connected component labeling + from scipy import ndimage + labeled, num_features = ndimage.label(mask_np) + + regions = [] + for label_id in range(1, num_features + 1): + region_mask = labeled == label_id + if region_mask.sum() < 100: # Minimum region size + continue + + # Find bounding coordinates + coords = np.where(region_mask) + regions.append({ + 'mask': torch.from_numpy(region_mask), + 'y_min': int(coords[0].min()), + 'y_max': int(coords[0].max()), + 'x_min': int(coords[1].min()), + 'x_max': int(coords[1].max()), + 'area': int(region_mask.sum()) + }) + + return regions + + def _extract_keypoints_from_segmentation( + self, segmentation: torch.Tensor, region: Dict[str, Any] + ) -> List[Dict[str, Any]]: + """Extract keypoints from body part segmentation.""" + keypoint_names = [ + "nose", "left_eye", "right_eye", "left_ear", "right_ear", + "left_shoulder", "right_shoulder", "left_elbow", "right_elbow", + "left_wrist", "right_wrist", "left_hip", "right_hip", + "left_knee", "right_knee", "left_ankle", "right_ankle" + ] + + # Mapping from body parts to keypoints + # DensePose has 24 body parts, we map to COCO keypoints + part_to_keypoint = { + 14: "nose", # Head -> nose + 10: "left_shoulder", 11: "right_shoulder", + 12: "left_elbow", 13: "right_elbow", + 2: "left_wrist", 3: "right_wrist", # Hands approximate wrists + 7: "left_hip", 6: "right_hip", # Upper legs + 9: "left_knee", 8: "right_knee", # Lower legs + 4: "left_ankle", 5: "right_ankle", # Feet approximate ankles + } + + h, w = segmentation.shape[1], segmentation.shape[2] + keypoints = [] + + # Get softmax probabilities + seg_probs = torch.softmax(segmentation, dim=0) + + for kp_name in keypoint_names: + # Find which body part corresponds to this keypoint + part_idx = None + for part, name in part_to_keypoint.items(): + if name == kp_name: + part_idx = part + break + + if part_idx is not None and part_idx < seg_probs.shape[0]: + # Get probability map for this part within the region + part_prob = seg_probs[part_idx] * region['mask'].float() + + if part_prob.max() > 0.1: + # Find location of maximum probability + max_idx = part_prob.argmax() + y = int(max_idx // w) + x = int(max_idx % w) + + keypoints.append({ + "name": kp_name, + "x": float(x) / w, + "y": float(y) / h, + "confidence": float(part_prob.max().item()) + }) + else: + # Keypoint not visible + keypoints.append({ + "name": kp_name, + "x": 0.0, + "y": 0.0, + "confidence": 0.0 + }) + else: + # Estimate position based on body region + cx = (region['x_min'] + region['x_max']) / 2 / w + cy = (region['y_min'] + region['y_max']) / 2 / h + keypoints.append({ + "name": kp_name, + "x": float(cx), + "y": float(cy), + "confidence": 0.1 + }) + + return keypoints + + def _calculate_bounding_box(self, region: Dict[str, Any]) -> Dict[str, float]: + """Calculate normalized bounding box from region.""" + # Assume region contains mask shape info + mask = region['mask'] + h, w = mask.shape + + return { + "x": float(region['x_min']) / w, + "y": float(region['y_min']) / h, + "width": float(region['x_max'] - region['x_min']) / w, + "height": float(region['y_max'] - region['y_min']) / h + } + + def _extract_body_parts( + self, segmentation: torch.Tensor, region: Dict[str, Any] + ) -> Dict[str, Any]: + """Extract body part information from segmentation.""" + part_names = [ + "background", "torso", "right_hand", "left_hand", "left_foot", "right_foot", + "upper_leg_right", "upper_leg_left", "lower_leg_right", "lower_leg_left", + "upper_arm_left", "upper_arm_right", "lower_arm_left", "lower_arm_right", "head" + ] + + seg_probs = torch.softmax(segmentation, dim=0) + region_mask = region['mask'] + + parts = {} + for i, name in enumerate(part_names): + if i < seg_probs.shape[0]: + part_prob = seg_probs[i] * region_mask.float() + parts[name] = { + "present": bool(part_prob.max() > 0.3), + "confidence": float(part_prob.max().item()), + "coverage": float((part_prob > 0.3).sum().item() / max(1, region_mask.sum().item())) + } + + return parts + + def _extract_keypoints_from_tensor(self, output: torch.Tensor) -> List[Dict[str, Any]]: + """Extract keypoints from a generic output tensor.""" + keypoint_names = [ + "nose", "left_eye", "right_eye", "left_ear", "right_ear", + "left_shoulder", "right_shoulder", "left_elbow", "right_elbow", + "left_wrist", "right_wrist", "left_hip", "right_hip", + "left_knee", "right_knee", "left_ankle", "right_ankle" + ] + + keypoints = [] + + # Try to interpret output as heatmaps + if output.dim() >= 2: + flat = output.flatten() + num_kp = len(keypoint_names) + + # Divide output evenly for each keypoint + chunk_size = len(flat) // num_kp if num_kp > 0 else 1 + + for i, name in enumerate(keypoint_names): + start = i * chunk_size + end = min(start + chunk_size, len(flat)) + + if start < len(flat): + chunk = flat[start:end] + # Find max location in chunk + max_val = chunk.max().item() + max_idx = chunk.argmax().item() + + # Convert to x, y (assume square spatial layout) + side = int(np.sqrt(chunk_size)) + if side > 0: + x = (max_idx % side) / side + y = (max_idx // side) / side + else: + x, y = 0.5, 0.5 + + keypoints.append({ + "name": name, + "x": float(x), + "y": float(y), + "confidence": float(torch.sigmoid(torch.tensor(max_val)).item()) + }) + else: + keypoints.append({ + "name": name, "x": 0.5, "y": 0.5, "confidence": 0.0 + }) + else: + # Fallback + for name in keypoint_names: + keypoints.append({"name": name, "x": 0.5, "y": 0.5, "confidence": 0.1}) + + return keypoints + + def _estimate_bbox_from_keypoints(self, keypoints: List[Dict[str, Any]]) -> Dict[str, float]: + """Estimate bounding box from keypoint positions.""" + valid_kps = [kp for kp in keypoints if kp['confidence'] > 0.1] + + if not valid_kps: + return {"x": 0.3, "y": 0.2, "width": 0.4, "height": 0.6} + + xs = [kp['x'] for kp in valid_kps] + ys = [kp['y'] for kp in valid_kps] + + x_min, x_max = min(xs), max(xs) + y_min, y_max = min(ys), max(ys) + + # Add padding + padding = 0.05 + x_min = max(0, x_min - padding) + y_min = max(0, y_min - padding) + x_max = min(1, x_max + padding) + y_max = min(1, y_max + padding) + + return { + "x": x_min, + "y": y_min, + "width": x_max - x_min, + "height": y_max - y_min + } + + def _classify_activity_from_keypoints(self, keypoints: List[Dict[str, Any]]) -> str: + """Classify activity based on keypoint positions.""" + # Get key body parts + kp_dict = {kp['name']: kp for kp in keypoints} + + # Check if enough keypoints are detected + valid_count = sum(1 for kp in keypoints if kp['confidence'] > 0.3) + if valid_count < 5: + return "unknown" + + # Get relevant keypoints + nose = kp_dict.get('nose', {}) + l_hip = kp_dict.get('left_hip', {}) + r_hip = kp_dict.get('right_hip', {}) + l_ankle = kp_dict.get('left_ankle', {}) + r_ankle = kp_dict.get('right_ankle', {}) + l_shoulder = kp_dict.get('left_shoulder', {}) + r_shoulder = kp_dict.get('right_shoulder', {}) + + # Calculate body metrics + hip_y = (l_hip.get('y', 0.5) + r_hip.get('y', 0.5)) / 2 + ankle_y = (l_ankle.get('y', 0.8) + r_ankle.get('y', 0.8)) / 2 + shoulder_y = (l_shoulder.get('y', 0.3) + r_shoulder.get('y', 0.3)) / 2 + nose_y = nose.get('y', 0.2) + + # Leg spread (horizontal distance between ankles) + leg_spread = abs(l_ankle.get('x', 0.5) - r_ankle.get('x', 0.5)) + + # Vertical compression (how "tall" the pose is) + vertical_span = ankle_y - nose_y if ankle_y > nose_y else 0.6 + + # Classification logic + if vertical_span < 0.3: + # Very compressed vertically - likely lying down + return "lying" + elif vertical_span < 0.45 and hip_y > 0.5: + # Medium compression with low hips - sitting + return "sitting" + elif leg_spread > 0.15: + # Legs apart - likely walking + return "walking" + else: + # Default upright pose + return "standing" def _generate_mock_poses(self) -> List[Dict[str, Any]]: """Generate mock pose data for development."""