updates
This commit is contained in:
577
tests/integration/test_pose_pipeline.py
Normal file
577
tests/integration/test_pose_pipeline.py
Normal file
@@ -0,0 +1,577 @@
|
||||
"""
|
||||
Integration tests for end-to-end pose estimation pipeline.
|
||||
|
||||
Tests the complete pose estimation workflow from CSI data to pose results.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import asyncio
|
||||
import numpy as np
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Dict, Any, List, Optional
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
import json
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class CSIData:
|
||||
"""CSI data structure for testing."""
|
||||
timestamp: datetime
|
||||
router_id: str
|
||||
amplitude: np.ndarray
|
||||
phase: np.ndarray
|
||||
frequency: float
|
||||
bandwidth: float
|
||||
antenna_count: int
|
||||
subcarrier_count: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class PoseResult:
|
||||
"""Pose estimation result structure."""
|
||||
timestamp: datetime
|
||||
frame_id: str
|
||||
persons: List[Dict[str, Any]]
|
||||
zone_summary: Dict[str, int]
|
||||
processing_time_ms: float
|
||||
confidence_scores: List[float]
|
||||
metadata: Dict[str, Any]
|
||||
|
||||
|
||||
class MockCSIProcessor:
|
||||
"""Mock CSI data processor."""
|
||||
|
||||
def __init__(self):
|
||||
self.is_initialized = False
|
||||
self.processing_enabled = True
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the processor."""
|
||||
self.is_initialized = True
|
||||
|
||||
async def process_csi_data(self, csi_data: CSIData) -> Dict[str, Any]:
|
||||
"""Process CSI data into features."""
|
||||
if not self.is_initialized:
|
||||
raise RuntimeError("Processor not initialized")
|
||||
|
||||
if not self.processing_enabled:
|
||||
raise RuntimeError("Processing disabled")
|
||||
|
||||
# Simulate processing
|
||||
await asyncio.sleep(0.01) # Simulate processing time
|
||||
|
||||
return {
|
||||
"features": np.random.rand(64, 32).tolist(), # Mock feature matrix
|
||||
"quality_score": 0.85,
|
||||
"signal_strength": -45.2,
|
||||
"noise_level": -78.1,
|
||||
"processed_at": datetime.utcnow().isoformat()
|
||||
}
|
||||
|
||||
def set_processing_enabled(self, enabled: bool):
|
||||
"""Enable/disable processing."""
|
||||
self.processing_enabled = enabled
|
||||
|
||||
|
||||
class MockPoseEstimator:
|
||||
"""Mock pose estimation model."""
|
||||
|
||||
def __init__(self):
|
||||
self.is_loaded = False
|
||||
self.model_version = "1.0.0"
|
||||
self.confidence_threshold = 0.5
|
||||
|
||||
async def load_model(self):
|
||||
"""Load the pose estimation model."""
|
||||
await asyncio.sleep(0.1) # Simulate model loading
|
||||
self.is_loaded = True
|
||||
|
||||
async def estimate_poses(self, features: np.ndarray) -> Dict[str, Any]:
|
||||
"""Estimate poses from features."""
|
||||
if not self.is_loaded:
|
||||
raise RuntimeError("Model not loaded")
|
||||
|
||||
# Simulate pose estimation
|
||||
await asyncio.sleep(0.05) # Simulate inference time
|
||||
|
||||
# Generate mock pose data
|
||||
num_persons = np.random.randint(0, 4) # 0-3 persons
|
||||
persons = []
|
||||
|
||||
for i in range(num_persons):
|
||||
confidence = np.random.uniform(0.3, 0.95)
|
||||
if confidence >= self.confidence_threshold:
|
||||
persons.append({
|
||||
"person_id": f"person_{i}",
|
||||
"confidence": confidence,
|
||||
"bounding_box": {
|
||||
"x": np.random.uniform(0, 800),
|
||||
"y": np.random.uniform(0, 600),
|
||||
"width": np.random.uniform(50, 200),
|
||||
"height": np.random.uniform(100, 400)
|
||||
},
|
||||
"keypoints": [
|
||||
{
|
||||
"name": "head",
|
||||
"x": np.random.uniform(0, 800),
|
||||
"y": np.random.uniform(0, 200),
|
||||
"confidence": np.random.uniform(0.5, 0.95)
|
||||
},
|
||||
{
|
||||
"name": "torso",
|
||||
"x": np.random.uniform(0, 800),
|
||||
"y": np.random.uniform(200, 400),
|
||||
"confidence": np.random.uniform(0.5, 0.95)
|
||||
}
|
||||
],
|
||||
"activity": "standing" if np.random.random() > 0.2 else "sitting"
|
||||
})
|
||||
|
||||
return {
|
||||
"persons": persons,
|
||||
"processing_time_ms": np.random.uniform(20, 80),
|
||||
"model_version": self.model_version,
|
||||
"confidence_threshold": self.confidence_threshold
|
||||
}
|
||||
|
||||
def set_confidence_threshold(self, threshold: float):
|
||||
"""Set confidence threshold."""
|
||||
self.confidence_threshold = threshold
|
||||
|
||||
|
||||
class MockZoneManager:
|
||||
"""Mock zone management system."""
|
||||
|
||||
def __init__(self):
|
||||
self.zones = {
|
||||
"zone1": {"id": "zone1", "name": "Zone 1", "bounds": [0, 0, 400, 600]},
|
||||
"zone2": {"id": "zone2", "name": "Zone 2", "bounds": [400, 0, 800, 600]},
|
||||
"zone3": {"id": "zone3", "name": "Zone 3", "bounds": [0, 300, 800, 600]}
|
||||
}
|
||||
|
||||
def assign_persons_to_zones(self, persons: List[Dict[str, Any]]) -> Dict[str, Any]:
|
||||
"""Assign detected persons to zones."""
|
||||
zone_summary = {zone_id: 0 for zone_id in self.zones.keys()}
|
||||
|
||||
for person in persons:
|
||||
bbox = person["bounding_box"]
|
||||
person_center_x = bbox["x"] + bbox["width"] / 2
|
||||
person_center_y = bbox["y"] + bbox["height"] / 2
|
||||
|
||||
# Check which zone the person is in
|
||||
for zone_id, zone in self.zones.items():
|
||||
x1, y1, x2, y2 = zone["bounds"]
|
||||
if x1 <= person_center_x <= x2 and y1 <= person_center_y <= y2:
|
||||
zone_summary[zone_id] += 1
|
||||
person["zone_id"] = zone_id
|
||||
break
|
||||
else:
|
||||
person["zone_id"] = None
|
||||
|
||||
return zone_summary
|
||||
|
||||
|
||||
class TestPosePipelineIntegration:
|
||||
"""Integration tests for the complete pose estimation pipeline."""
|
||||
|
||||
@pytest.fixture
|
||||
def csi_processor(self):
|
||||
"""Create CSI processor."""
|
||||
return MockCSIProcessor()
|
||||
|
||||
@pytest.fixture
|
||||
def pose_estimator(self):
|
||||
"""Create pose estimator."""
|
||||
return MockPoseEstimator()
|
||||
|
||||
@pytest.fixture
|
||||
def zone_manager(self):
|
||||
"""Create zone manager."""
|
||||
return MockZoneManager()
|
||||
|
||||
@pytest.fixture
|
||||
def sample_csi_data(self):
|
||||
"""Create sample CSI data."""
|
||||
return CSIData(
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id="router_001",
|
||||
amplitude=np.random.rand(64, 32),
|
||||
phase=np.random.rand(64, 32),
|
||||
frequency=5.8e9, # 5.8 GHz
|
||||
bandwidth=80e6, # 80 MHz
|
||||
antenna_count=4,
|
||||
subcarrier_count=64
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
async def pose_pipeline(self, csi_processor, pose_estimator, zone_manager):
|
||||
"""Create complete pose pipeline."""
|
||||
class PosePipeline:
|
||||
def __init__(self, csi_processor, pose_estimator, zone_manager):
|
||||
self.csi_processor = csi_processor
|
||||
self.pose_estimator = pose_estimator
|
||||
self.zone_manager = zone_manager
|
||||
self.is_initialized = False
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the pipeline."""
|
||||
await self.csi_processor.initialize()
|
||||
await self.pose_estimator.load_model()
|
||||
self.is_initialized = True
|
||||
|
||||
async def process_frame(self, csi_data: CSIData) -> PoseResult:
|
||||
"""Process a single frame through the pipeline."""
|
||||
if not self.is_initialized:
|
||||
raise RuntimeError("Pipeline not initialized")
|
||||
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
# Step 1: Process CSI data
|
||||
processed_data = await self.csi_processor.process_csi_data(csi_data)
|
||||
|
||||
# Step 2: Extract features
|
||||
features = np.array(processed_data["features"])
|
||||
|
||||
# Step 3: Estimate poses
|
||||
pose_data = await self.pose_estimator.estimate_poses(features)
|
||||
|
||||
# Step 4: Assign to zones
|
||||
zone_summary = self.zone_manager.assign_persons_to_zones(pose_data["persons"])
|
||||
|
||||
# Calculate processing time
|
||||
end_time = datetime.utcnow()
|
||||
processing_time = (end_time - start_time).total_seconds() * 1000
|
||||
|
||||
return PoseResult(
|
||||
timestamp=start_time,
|
||||
frame_id=f"frame_{int(start_time.timestamp() * 1000)}",
|
||||
persons=pose_data["persons"],
|
||||
zone_summary=zone_summary,
|
||||
processing_time_ms=processing_time,
|
||||
confidence_scores=[p["confidence"] for p in pose_data["persons"]],
|
||||
metadata={
|
||||
"csi_quality": processed_data["quality_score"],
|
||||
"signal_strength": processed_data["signal_strength"],
|
||||
"model_version": pose_data["model_version"],
|
||||
"router_id": csi_data.router_id
|
||||
}
|
||||
)
|
||||
|
||||
pipeline = PosePipeline(csi_processor, pose_estimator, zone_manager)
|
||||
await pipeline.initialize()
|
||||
return pipeline
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_initialization_should_fail_initially(self, csi_processor, pose_estimator, zone_manager):
|
||||
"""Test pipeline initialization - should fail initially."""
|
||||
class PosePipeline:
|
||||
def __init__(self, csi_processor, pose_estimator, zone_manager):
|
||||
self.csi_processor = csi_processor
|
||||
self.pose_estimator = pose_estimator
|
||||
self.zone_manager = zone_manager
|
||||
self.is_initialized = False
|
||||
|
||||
async def initialize(self):
|
||||
await self.csi_processor.initialize()
|
||||
await self.pose_estimator.load_model()
|
||||
self.is_initialized = True
|
||||
|
||||
pipeline = PosePipeline(csi_processor, pose_estimator, zone_manager)
|
||||
|
||||
# Initially not initialized
|
||||
assert not pipeline.is_initialized
|
||||
assert not csi_processor.is_initialized
|
||||
assert not pose_estimator.is_loaded
|
||||
|
||||
# Initialize pipeline
|
||||
await pipeline.initialize()
|
||||
|
||||
# This will fail initially
|
||||
assert pipeline.is_initialized
|
||||
assert csi_processor.is_initialized
|
||||
assert pose_estimator.is_loaded
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_end_to_end_pose_estimation_should_fail_initially(self, pose_pipeline, sample_csi_data):
|
||||
"""Test end-to-end pose estimation - should fail initially."""
|
||||
result = await pose_pipeline.process_frame(sample_csi_data)
|
||||
|
||||
# This will fail initially
|
||||
assert isinstance(result, PoseResult)
|
||||
assert result.timestamp is not None
|
||||
assert result.frame_id.startswith("frame_")
|
||||
assert isinstance(result.persons, list)
|
||||
assert isinstance(result.zone_summary, dict)
|
||||
assert result.processing_time_ms > 0
|
||||
assert isinstance(result.confidence_scores, list)
|
||||
assert isinstance(result.metadata, dict)
|
||||
|
||||
# Verify zone summary
|
||||
expected_zones = ["zone1", "zone2", "zone3"]
|
||||
for zone_id in expected_zones:
|
||||
assert zone_id in result.zone_summary
|
||||
assert isinstance(result.zone_summary[zone_id], int)
|
||||
assert result.zone_summary[zone_id] >= 0
|
||||
|
||||
# Verify metadata
|
||||
assert "csi_quality" in result.metadata
|
||||
assert "signal_strength" in result.metadata
|
||||
assert "model_version" in result.metadata
|
||||
assert "router_id" in result.metadata
|
||||
assert result.metadata["router_id"] == sample_csi_data.router_id
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_with_multiple_frames_should_fail_initially(self, pose_pipeline):
|
||||
"""Test pipeline with multiple frames - should fail initially."""
|
||||
results = []
|
||||
|
||||
# Process multiple frames
|
||||
for i in range(5):
|
||||
csi_data = CSIData(
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id=f"router_{i % 2 + 1:03d}", # Alternate between router_001 and router_002
|
||||
amplitude=np.random.rand(64, 32),
|
||||
phase=np.random.rand(64, 32),
|
||||
frequency=5.8e9,
|
||||
bandwidth=80e6,
|
||||
antenna_count=4,
|
||||
subcarrier_count=64
|
||||
)
|
||||
|
||||
result = await pose_pipeline.process_frame(csi_data)
|
||||
results.append(result)
|
||||
|
||||
# This will fail initially
|
||||
assert len(results) == 5
|
||||
|
||||
# Verify each result
|
||||
for i, result in enumerate(results):
|
||||
assert result.frame_id != results[0].frame_id if i > 0 else True
|
||||
assert result.metadata["router_id"] in ["router_001", "router_002"]
|
||||
assert result.processing_time_ms > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_error_handling_should_fail_initially(self, csi_processor, pose_estimator, zone_manager, sample_csi_data):
|
||||
"""Test pipeline error handling - should fail initially."""
|
||||
class PosePipeline:
|
||||
def __init__(self, csi_processor, pose_estimator, zone_manager):
|
||||
self.csi_processor = csi_processor
|
||||
self.pose_estimator = pose_estimator
|
||||
self.zone_manager = zone_manager
|
||||
self.is_initialized = False
|
||||
|
||||
async def initialize(self):
|
||||
await self.csi_processor.initialize()
|
||||
await self.pose_estimator.load_model()
|
||||
self.is_initialized = True
|
||||
|
||||
async def process_frame(self, csi_data):
|
||||
if not self.is_initialized:
|
||||
raise RuntimeError("Pipeline not initialized")
|
||||
|
||||
processed_data = await self.csi_processor.process_csi_data(csi_data)
|
||||
features = np.array(processed_data["features"])
|
||||
pose_data = await self.pose_estimator.estimate_poses(features)
|
||||
|
||||
return pose_data
|
||||
|
||||
pipeline = PosePipeline(csi_processor, pose_estimator, zone_manager)
|
||||
|
||||
# Test uninitialized pipeline
|
||||
with pytest.raises(RuntimeError, match="Pipeline not initialized"):
|
||||
await pipeline.process_frame(sample_csi_data)
|
||||
|
||||
# Initialize pipeline
|
||||
await pipeline.initialize()
|
||||
|
||||
# Test with disabled CSI processor
|
||||
csi_processor.set_processing_enabled(False)
|
||||
|
||||
with pytest.raises(RuntimeError, match="Processing disabled"):
|
||||
await pipeline.process_frame(sample_csi_data)
|
||||
|
||||
# This assertion will fail initially
|
||||
assert True # Test completed successfully
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_confidence_threshold_filtering_should_fail_initially(self, pose_pipeline, sample_csi_data):
|
||||
"""Test confidence threshold filtering - should fail initially."""
|
||||
# Set high confidence threshold
|
||||
pose_pipeline.pose_estimator.set_confidence_threshold(0.9)
|
||||
|
||||
result = await pose_pipeline.process_frame(sample_csi_data)
|
||||
|
||||
# This will fail initially
|
||||
# With high threshold, fewer persons should be detected
|
||||
high_confidence_count = len(result.persons)
|
||||
|
||||
# Set low confidence threshold
|
||||
pose_pipeline.pose_estimator.set_confidence_threshold(0.1)
|
||||
|
||||
result = await pose_pipeline.process_frame(sample_csi_data)
|
||||
low_confidence_count = len(result.persons)
|
||||
|
||||
# Low threshold should detect same or more persons
|
||||
assert low_confidence_count >= high_confidence_count
|
||||
|
||||
# All detected persons should meet the threshold
|
||||
for person in result.persons:
|
||||
assert person["confidence"] >= 0.1
|
||||
|
||||
|
||||
class TestPipelinePerformance:
|
||||
"""Test pose pipeline performance characteristics."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_throughput_should_fail_initially(self, pose_pipeline):
|
||||
"""Test pipeline throughput - should fail initially."""
|
||||
frame_count = 10
|
||||
start_time = datetime.utcnow()
|
||||
|
||||
# Process multiple frames
|
||||
for i in range(frame_count):
|
||||
csi_data = CSIData(
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id="router_001",
|
||||
amplitude=np.random.rand(64, 32),
|
||||
phase=np.random.rand(64, 32),
|
||||
frequency=5.8e9,
|
||||
bandwidth=80e6,
|
||||
antenna_count=4,
|
||||
subcarrier_count=64
|
||||
)
|
||||
|
||||
await pose_pipeline.process_frame(csi_data)
|
||||
|
||||
end_time = datetime.utcnow()
|
||||
total_time = (end_time - start_time).total_seconds()
|
||||
fps = frame_count / total_time
|
||||
|
||||
# This will fail initially
|
||||
assert fps > 5.0 # Should process at least 5 FPS
|
||||
assert total_time < 5.0 # Should complete 10 frames in under 5 seconds
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_frame_processing_should_fail_initially(self, pose_pipeline):
|
||||
"""Test concurrent frame processing - should fail initially."""
|
||||
async def process_single_frame(frame_id: int):
|
||||
csi_data = CSIData(
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id=f"router_{frame_id % 3 + 1:03d}",
|
||||
amplitude=np.random.rand(64, 32),
|
||||
phase=np.random.rand(64, 32),
|
||||
frequency=5.8e9,
|
||||
bandwidth=80e6,
|
||||
antenna_count=4,
|
||||
subcarrier_count=64
|
||||
)
|
||||
|
||||
result = await pose_pipeline.process_frame(csi_data)
|
||||
return result.frame_id
|
||||
|
||||
# Process frames concurrently
|
||||
tasks = [process_single_frame(i) for i in range(5)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# This will fail initially
|
||||
assert len(results) == 5
|
||||
assert len(set(results)) == 5 # All frame IDs should be unique
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_memory_usage_stability_should_fail_initially(self, pose_pipeline):
|
||||
"""Test memory usage stability - should fail initially."""
|
||||
import psutil
|
||||
import os
|
||||
|
||||
process = psutil.Process(os.getpid())
|
||||
initial_memory = process.memory_info().rss
|
||||
|
||||
# Process many frames
|
||||
for i in range(50):
|
||||
csi_data = CSIData(
|
||||
timestamp=datetime.utcnow(),
|
||||
router_id="router_001",
|
||||
amplitude=np.random.rand(64, 32),
|
||||
phase=np.random.rand(64, 32),
|
||||
frequency=5.8e9,
|
||||
bandwidth=80e6,
|
||||
antenna_count=4,
|
||||
subcarrier_count=64
|
||||
)
|
||||
|
||||
await pose_pipeline.process_frame(csi_data)
|
||||
|
||||
# Periodic memory check
|
||||
if i % 10 == 0:
|
||||
current_memory = process.memory_info().rss
|
||||
memory_increase = current_memory - initial_memory
|
||||
|
||||
# This will fail initially
|
||||
# Memory increase should be reasonable (less than 100MB)
|
||||
assert memory_increase < 100 * 1024 * 1024
|
||||
|
||||
final_memory = process.memory_info().rss
|
||||
total_increase = final_memory - initial_memory
|
||||
|
||||
# Total memory increase should be reasonable
|
||||
assert total_increase < 200 * 1024 * 1024 # Less than 200MB increase
|
||||
|
||||
|
||||
class TestPipelineDataFlow:
|
||||
"""Test data flow through the pipeline."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_data_transformation_chain_should_fail_initially(self, csi_processor, pose_estimator, zone_manager, sample_csi_data):
|
||||
"""Test data transformation through the pipeline - should fail initially."""
|
||||
# Step 1: CSI processing
|
||||
await csi_processor.initialize()
|
||||
processed_data = await csi_processor.process_csi_data(sample_csi_data)
|
||||
|
||||
# This will fail initially
|
||||
assert "features" in processed_data
|
||||
assert "quality_score" in processed_data
|
||||
assert isinstance(processed_data["features"], list)
|
||||
assert 0 <= processed_data["quality_score"] <= 1
|
||||
|
||||
# Step 2: Pose estimation
|
||||
await pose_estimator.load_model()
|
||||
features = np.array(processed_data["features"])
|
||||
pose_data = await pose_estimator.estimate_poses(features)
|
||||
|
||||
assert "persons" in pose_data
|
||||
assert "processing_time_ms" in pose_data
|
||||
assert isinstance(pose_data["persons"], list)
|
||||
|
||||
# Step 3: Zone assignment
|
||||
zone_summary = zone_manager.assign_persons_to_zones(pose_data["persons"])
|
||||
|
||||
assert isinstance(zone_summary, dict)
|
||||
assert all(isinstance(count, int) for count in zone_summary.values())
|
||||
|
||||
# Verify person zone assignments
|
||||
for person in pose_data["persons"]:
|
||||
if "zone_id" in person and person["zone_id"]:
|
||||
assert person["zone_id"] in zone_summary
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pipeline_state_consistency_should_fail_initially(self, pose_pipeline, sample_csi_data):
|
||||
"""Test pipeline state consistency - should fail initially."""
|
||||
# Process the same frame multiple times
|
||||
results = []
|
||||
for _ in range(3):
|
||||
result = await pose_pipeline.process_frame(sample_csi_data)
|
||||
results.append(result)
|
||||
|
||||
# This will fail initially
|
||||
# Results should be consistent (same input should produce similar output)
|
||||
assert len(results) == 3
|
||||
|
||||
# All results should have the same router_id
|
||||
router_ids = [r.metadata["router_id"] for r in results]
|
||||
assert all(rid == router_ids[0] for rid in router_ids)
|
||||
|
||||
# Processing times should be reasonable and similar
|
||||
processing_times = [r.processing_time_ms for r in results]
|
||||
assert all(10 <= pt <= 200 for pt in processing_times) # Between 10ms and 200ms
|
||||
Reference in New Issue
Block a user