ai-agent/security/tests/test_event_security.py

222 lines
No EOL
8 KiB
Python

"""Security tests for event framework integration."""
import unittest
import time
from unittest.mock import patch, MagicMock
from security.encrypt import AES256Cipher
from events.core import EventSystem
class TestEventSecurity(unittest.TestCase):
"""Security-specific event framework tests."""
def setUp(self):
self.cipher = AES256Cipher()
self.system = EventSystem(MagicMock())
self.original_key = self.cipher.key
def test_key_rotation(self):
"""Test event handling during key rotation."""
# Initial key works
event1 = {'type': 'rotate', 'data': 'secret1'}
self.system.publish(event1)
# Rotate key
new_key = AES256Cipher.generate_key()
self.cipher.rotate_key(new_key)
# New key works
event2 = {'type': 'rotate', 'data': 'secret2'}
self.system.publish(event2)
# Verify both events processed
time.sleep(0.1)
self.assertEqual(len(self.system.get_processed_events()), 2)
def test_invalid_key_handling(self):
"""Test handling of events with invalid keys."""
with patch('security.encrypt.AES256Cipher.decrypt') as mock_decrypt:
mock_decrypt.side_effect = ValueError("Invalid key")
error_count = 0
def error_handler(event):
nonlocal error_count
error_count += 1
self.system.subscribe('invalid', error_handler)
self.system.publish({'type': 'invalid', 'data': 'bad'})
time.sleep(0.1)
self.assertEqual(error_count, 1)
def test_tampered_event_detection(self):
"""Test detection of tampered event payloads."""
with patch('security.encrypt.AES256Cipher.verify_tag') as mock_verify:
mock_verify.return_value = False
tampered_count = 0
def tamper_handler(event):
nonlocal tampered_count
tampered_count += 1
self.system.subscribe('tampered', tamper_handler)
self.system.publish({'type': 'tampered', 'data': 'changed'})
time.sleep(0.1)
self.assertEqual(tampered_count, 1)
def test_security_performance(self):
"""Test security operation performance."""
start_time = time.time()
for i in range(100):
self.system.publish({'type': 'perf', 'data': str(i)})
duration = time.time() - start_time
stats = self.system.get_performance_stats()
self.assertLess(duration, 1.0) # 100 events in <1s
self.assertEqual(stats['total_events'], 100)
self.assertLess(stats['avg_security_latency'], 0.01)
def test_critical_path_coverage(self):
"""Verify 100% coverage of security critical paths."""
# Test all security-sensitive event types
test_cases = [
('auth', {'user': 'admin', 'action': 'login'}),
('permission', {'resource': 'db', 'access': 'write'}),
('audit', {'action': 'delete', 'target': 'record123'})
]
results = []
def handler(event):
results.append(event['type'])
self.system.subscribe('*', handler)
for event_type, payload in test_cases:
self.system.publish({'type': event_type, **payload})
time.sleep(0.1)
self.assertEqual(sorted(results), ['auth', 'audit', 'permission'])
def test_key_rotation_edge_cases(self):
"""Test edge cases during key rotation."""
# Test rapid key rotation
for i in range(5):
new_key = AES256Cipher.generate_key()
self.cipher.rotate_key(new_key)
event = {'type': 'rotate', 'data': f'secret{i}'}
self.system.publish(event)
time.sleep(0.2)
self.assertEqual(len(self.system.get_processed_events()), 5)
def test_tampered_event_types(self):
"""Test detection of various tampered event types."""
tamper_types = ['auth', 'config', 'data', 'system']
tampered_count = 0
def tamper_handler(event):
nonlocal tampered_count
tampered_count += 1
self.system.subscribe('*', tamper_handler)
with patch('security.encrypt.AES256Cipher.verify_tag') as mock_verify:
mock_verify.return_value = False
for event_type in tamper_types:
self.system.publish({'type': event_type, 'data': 'tampered'})
time.sleep(0.1)
self.assertEqual(tampered_count, len(tamper_types))
def test_negative_security_operations(self):
"""Test negative cases for security operations."""
# Test invalid key format
with self.assertRaises(ValueError):
self.cipher.rotate_key('invalid-key-format')
# Test empty event handling
with self.assertRaises(ValueError):
self.system.publish(None)
# Test invalid event structure
with self.assertRaises(ValueError):
self.system.publish({'invalid': 'structure'})
def test_malformed_encryption_headers(self):
"""Test handling of events with malformed encryption headers."""
with patch('security.encrypt.AES256Cipher.decrypt') as mock_decrypt:
mock_decrypt.side_effect = ValueError("Invalid header")
error_count = 0
def error_handler(event):
nonlocal error_count
error_count += 1
self.system.subscribe('malformed', error_handler)
self.system.publish({'type': 'malformed', 'data': 'bad_header'})
time.sleep(0.1)
self.assertEqual(error_count, 1)
def test_partial_message_corruption(self):
"""Test detection of partially corrupted messages."""
with patch('security.encrypt.AES256Cipher.decrypt') as mock_decrypt:
# Return partial data
mock_decrypt.return_value = {'type': 'partial', 'data': 'corrupt'}
corrupt_count = 0
def corrupt_handler(event):
nonlocal corrupt_count
if len(event.get('data', '')) < 10: # Simulate truncated data
corrupt_count += 1
self.system.subscribe('partial', corrupt_handler)
self.system.publish({'type': 'partial', 'data': 'full_message'})
time.sleep(0.1)
self.assertEqual(corrupt_count, 1)
def test_replay_attack_detection(self):
"""Test detection of replayed events."""
event_id = '12345'
event = {'type': 'replay', 'id': event_id, 'data': 'original'}
# First publish should succeed
self.system.publish(event)
time.sleep(0.1)
# Replay should be detected
replay_count = 0
def replay_handler(e):
nonlocal replay_count
if e.get('replay_detected'):
replay_count += 1
self.system.subscribe('replay', replay_handler)
self.system.publish(event)
time.sleep(0.1)
self.assertEqual(replay_count, 1)
def test_timing_side_channels(self):
"""Test for timing side channels in security operations."""
test_cases = [
('valid', 'normal_data'),
('invalid', 'x'*1000) # Larger payload
]
timings = []
for case_type, data in test_cases:
start = time.time()
self.system.publish({'type': 'timing', 'data': data})
elapsed = time.time() - start
timings.append(elapsed)
# Timing difference should be minimal
time_diff = abs(timings[1] - timings[0])
self.assertLess(time_diff, 0.01,
f"Timing difference {time_diff:.4f}s > 10ms threshold")
if __name__ == '__main__':
unittest.main()