"""Security tests for event framework integration.""" import unittest import time from unittest.mock import patch, MagicMock from security.encrypt import AES256Cipher from events.core import EventSystem class TestEventSecurity(unittest.TestCase): """Security-specific event framework tests.""" def setUp(self): self.cipher = AES256Cipher() self.system = EventSystem(MagicMock()) self.original_key = self.cipher.key def test_key_rotation(self): """Test event handling during key rotation.""" # Initial key works event1 = {'type': 'rotate', 'data': 'secret1'} self.system.publish(event1) # Rotate key new_key = AES256Cipher.generate_key() self.cipher.rotate_key(new_key) # New key works event2 = {'type': 'rotate', 'data': 'secret2'} self.system.publish(event2) # Verify both events processed time.sleep(0.1) self.assertEqual(len(self.system.get_processed_events()), 2) def test_invalid_key_handling(self): """Test handling of events with invalid keys.""" with patch('security.encrypt.AES256Cipher.decrypt') as mock_decrypt: mock_decrypt.side_effect = ValueError("Invalid key") error_count = 0 def error_handler(event): nonlocal error_count error_count += 1 self.system.subscribe('invalid', error_handler) self.system.publish({'type': 'invalid', 'data': 'bad'}) time.sleep(0.1) self.assertEqual(error_count, 1) def test_tampered_event_detection(self): """Test detection of tampered event payloads.""" with patch('security.encrypt.AES256Cipher.verify_tag') as mock_verify: mock_verify.return_value = False tampered_count = 0 def tamper_handler(event): nonlocal tampered_count tampered_count += 1 self.system.subscribe('tampered', tamper_handler) self.system.publish({'type': 'tampered', 'data': 'changed'}) time.sleep(0.1) self.assertEqual(tampered_count, 1) def test_security_performance(self): """Test security operation performance.""" start_time = time.time() for i in range(100): self.system.publish({'type': 'perf', 'data': str(i)}) duration = time.time() - start_time stats = self.system.get_performance_stats() self.assertLess(duration, 1.0) # 100 events in <1s self.assertEqual(stats['total_events'], 100) self.assertLess(stats['avg_security_latency'], 0.01) def test_critical_path_coverage(self): """Verify 100% coverage of security critical paths.""" # Test all security-sensitive event types test_cases = [ ('auth', {'user': 'admin', 'action': 'login'}), ('permission', {'resource': 'db', 'access': 'write'}), ('audit', {'action': 'delete', 'target': 'record123'}) ] results = [] def handler(event): results.append(event['type']) self.system.subscribe('*', handler) for event_type, payload in test_cases: self.system.publish({'type': event_type, **payload}) time.sleep(0.1) self.assertEqual(sorted(results), ['auth', 'audit', 'permission']) def test_key_rotation_edge_cases(self): """Test edge cases during key rotation.""" # Test rapid key rotation for i in range(5): new_key = AES256Cipher.generate_key() self.cipher.rotate_key(new_key) event = {'type': 'rotate', 'data': f'secret{i}'} self.system.publish(event) time.sleep(0.2) self.assertEqual(len(self.system.get_processed_events()), 5) def test_tampered_event_types(self): """Test detection of various tampered event types.""" tamper_types = ['auth', 'config', 'data', 'system'] tampered_count = 0 def tamper_handler(event): nonlocal tampered_count tampered_count += 1 self.system.subscribe('*', tamper_handler) with patch('security.encrypt.AES256Cipher.verify_tag') as mock_verify: mock_verify.return_value = False for event_type in tamper_types: self.system.publish({'type': event_type, 'data': 'tampered'}) time.sleep(0.1) self.assertEqual(tampered_count, len(tamper_types)) def test_negative_security_operations(self): """Test negative cases for security operations.""" # Test invalid key format with self.assertRaises(ValueError): self.cipher.rotate_key('invalid-key-format') # Test empty event handling with self.assertRaises(ValueError): self.system.publish(None) # Test invalid event structure with self.assertRaises(ValueError): self.system.publish({'invalid': 'structure'}) def test_malformed_encryption_headers(self): """Test handling of events with malformed encryption headers.""" with patch('security.encrypt.AES256Cipher.decrypt') as mock_decrypt: mock_decrypt.side_effect = ValueError("Invalid header") error_count = 0 def error_handler(event): nonlocal error_count error_count += 1 self.system.subscribe('malformed', error_handler) self.system.publish({'type': 'malformed', 'data': 'bad_header'}) time.sleep(0.1) self.assertEqual(error_count, 1) def test_partial_message_corruption(self): """Test detection of partially corrupted messages.""" with patch('security.encrypt.AES256Cipher.decrypt') as mock_decrypt: # Return partial data mock_decrypt.return_value = {'type': 'partial', 'data': 'corrupt'} corrupt_count = 0 def corrupt_handler(event): nonlocal corrupt_count if len(event.get('data', '')) < 10: # Simulate truncated data corrupt_count += 1 self.system.subscribe('partial', corrupt_handler) self.system.publish({'type': 'partial', 'data': 'full_message'}) time.sleep(0.1) self.assertEqual(corrupt_count, 1) def test_replay_attack_detection(self): """Test detection of replayed events.""" event_id = '12345' event = {'type': 'replay', 'id': event_id, 'data': 'original'} # First publish should succeed self.system.publish(event) time.sleep(0.1) # Replay should be detected replay_count = 0 def replay_handler(e): nonlocal replay_count if e.get('replay_detected'): replay_count += 1 self.system.subscribe('replay', replay_handler) self.system.publish(event) time.sleep(0.1) self.assertEqual(replay_count, 1) def test_timing_side_channels(self): """Test for timing side channels in security operations.""" test_cases = [ ('valid', 'normal_data'), ('invalid', 'x'*1000) # Larger payload ] timings = [] for case_type, data in test_cases: start = time.time() self.system.publish({'type': 'timing', 'data': data}) elapsed = time.time() - start timings.append(elapsed) # Timing difference should be minimal time_diff = abs(timings[1] - timings[0]) self.assertLess(time_diff, 0.01, f"Timing difference {time_diff:.4f}s > 10ms threshold") if __name__ == '__main__': unittest.main()