ai-agent/tests/security/test_rbac_engine.py

319 lines
No EOL
12 KiB
Python

import unittest
import threading
import time
import ssl
from datetime import datetime, timedelta
from security.rbac_engine import RBACEngine, Role, ClientCertInfo
from cryptography import x509
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives import hashes
class TestRBACEngineSecurity(unittest.TestCase):
def setUp(self):
self.engine = RBACEngine(b'test_key_123456789012345678901234')
# Generate test RSA key pair
self.private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=2048
)
public_key = self.private_key.public_key()
# Create test certificate with OU field for role mapping
subject = x509.Name([
x509.NameAttribute(x509.NameOID.COMMON_NAME, u'test'),
x509.NameAttribute(x509.NameOID.ORGANIZATIONAL_UNIT_NAME, u'developer')
])
# Create test certificate
self.test_cert = x509.CertificateBuilder().subject_name(
subject
).issuer_name(
x509.Name([x509.NameAttribute(x509.NameOID.COMMON_NAME, u'test-ca')])
).public_key(
public_key
).serial_number(
x509.random_serial_number()
).not_valid_before(
datetime.utcnow()
).not_valid_after(
datetime.utcnow() + timedelta(days=365)
).add_extension(
x509.SubjectAlternativeName([x509.DNSName(u'test.example.com')]),
critical=False,
).sign(self.private_key, hashes.SHA256())
self.cert_pem = self.test_cert.public_bytes(serialization.Encoding.PEM)
self.engine.add_trusted_certificate(self.cert_pem)
def test_certificate_ou_role_mapping(self):
"""Test OU field to RBAC role mapping with signature verification"""
# Test valid mapping using certificate with OU field
cert_info = ClientCertInfo(
subject={'CN': 'test', 'OU': 'developer'},
issuer={'CN': 'test-ca'},
serial_number=123,
not_before=datetime.now(),
not_after=datetime.now() + timedelta(days=1),
fingerprint="test123",
raw_cert=self.test_cert
)
# Verify role mapping
role = self.engine.get_role_from_certificate(cert_info)
self.assertEqual(role, Role.DEVELOPER)
# Verify certificate signature
self.assertTrue(self.engine.validate_certificate(cert_info))
def test_certificate_revocation_check(self):
"""Test certificate revocation verification"""
# Create test cert
cert_info = ClientCertInfo(
subject={'CN': 'test'},
issuer={'CN': 'test-ca'},
serial_number=123,
not_before=datetime.now(),
not_after=datetime.now() + timedelta(days=1),
fingerprint="test123",
raw_cert=self.test_cert
)
# Test non-revoked cert
self.assertFalse(self.engine.is_certificate_revoked(cert_info))
# Revoke cert and test
self.engine.revoke_certificate(cert_info)
self.assertTrue(self.engine.is_certificate_revoked(cert_info))
def test_tls_handshake_logging(self):
"""Test TLS handshake parameters are logged correctly"""
from security.audit import AuditLogger
audit = AuditLogger()
# Simulate TLS handshake
tls_params = {
'version': 'TLSv1.3',
'cipher': 'AES256-GCM-SHA384',
'cert_fingerprint': 'test123',
'cert_subject': {'CN': 'test'},
'cert_issuer': {'CN': 'test-ca'},
'cert_validity': 'valid',
'cert_revoked': False
}
# Log operation
audit.log_operation(
operation_type='TLS_HANDSHAKE',
operation_result=True,
user='test_user',
role='DEVELOPER',
boundary_violation=False,
tls_params=tls_params
)
# Verify log entry exists
with sqlite3.connect(audit.db_path) as conn:
cursor = conn.execute("""
SELECT tls_version, tls_cipher FROM audit_logs
WHERE operation_type = 'TLS_HANDSHAKE'
""")
result = cursor.fetchone()
self.assertEqual(result[0], 'TLSv1.3')
self.assertEqual(result[1], 'AES256-GCM-SHA384')
def test_invalid_role_mapping(self):
"""Test invalid OU field handling"""
# Create cert with invalid OU claim
cert_info = ClientCertInfo(
subject={'CN': 'test', 'OU': 'invalid-role'},
issuer={'CN': 'test-ca'},
serial_number=123,
not_before=datetime.now(),
not_after=datetime.now() + timedelta(days=1),
fingerprint="test123",
raw_cert=self.test_cert
)
# Should raise exception for invalid role
with self.assertRaises(ValueError):
self.engine.get_role_from_certificate(cert_info)
def test_expired_certificate(self):
"""Test expired certificate handling"""
# Create expired cert
cert_info = ClientCertInfo(
subject={'CN': 'test'},
issuer={'CN': 'test-ca'},
serial_number=123,
not_before=datetime.now() - timedelta(days=2),
not_after=datetime.now() - timedelta(days=1),
fingerprint="test123",
raw_cert=self.test_cert
)
# Should raise exception for expired cert
with self.assertRaises(ssl.SSLError):
self.engine.validate_certificate(cert_info)
role = self.engine._get_role_from_ou(cert_info.subject['OU'])
self.assertEqual(role, Role.DEVELOPER)
# Test invalid signature
cert_info.subject['OU'] = "developer:invalid_signature"
role = self.engine._get_role_from_ou(cert_info.subject['OU'])
self.assertIsNone(role)
def test_certificate_revocation_check(self):
"""Test certificate revocation checking before RBAC validation"""
# Create valid certificate
cert_info = ClientCertInfo(
subject={'CN': 'test', 'OU': 'developer:internal'},
issuer={'CN': 'test-ca'},
serial_number=123,
not_before=datetime.now(),
not_after=datetime.now() + timedelta(days=1),
fingerprint="test123",
raw_cert=self.test_cert
)
# Test non-revoked certificate
self.assertFalse(self.engine.is_certificate_revoked(cert_info))
# Verify RBAC validation works when cert is valid
role = self.engine.get_role_from_certificate(cert_info)
self.assertEqual(role, Role.DEVELOPER)
# Revoke the certificate
self.engine.revoke_certificate(cert_info, reason="testing")
# Test revoked certificate
self.assertTrue(self.engine.is_certificate_revoked(cert_info))
# Verify RBAC validation fails for revoked cert
with self.assertRaises(ValueError):
self.engine.get_role_from_certificate(cert_info)
# Verify certificate validation fails
with self.assertRaises(ValueError):
self.engine.validate_certificate(cert_info)
def test_tls_handshake_logging(self):
"""Test TLS handshake parameter logging"""
signed_ou = self.engine.create_signed_ou_claim(Role.DEVELOPER)
cert_info = ClientCertInfo(
subject={'CN': 'test', 'OU': signed_ou},
issuer={'CN': 'test-ca'},
serial_number=123,
not_before=datetime.now(),
not_after=datetime.now() + timedelta(days=1),
fingerprint="test123",
raw_cert=self.test_cert
)
tls_params = {
'version': 'TLSv1.3',
'cipher': 'AES256-GCM-SHA384',
'fingerprint': 'test123',
'subject': {'CN': 'test'},
'issuer': {'CN': 'test-ca'},
'validity': 'valid',
'revoked': False
}
# This should log the TLS parameters
self.engine.validate_certificate(cert_info, tls_params)
# Verify audit log contains TLS params
from security.audit import AuditLogger
audit = AuditLogger()
with sqlite3.connect(audit.db_path) as conn:
cursor = conn.execute("""
SELECT tls_params FROM audit_logs
WHERE operation_type = 'TLS_HANDSHAKE'
LIMIT 1
""")
result = cursor.fetchone()
self.assertIsNotNone(result)
logged_params = json.loads(result[0])
self.assertEqual(logged_params['version'], 'TLSv1.3')
self.assertEqual(logged_params['cipher'], 'AES256-GCM-SHA384')
def test_tls_rbac_integration(self):
"""Test full TLS-RBAC integration flow"""
cert_info = ClientCertInfo(
subject={'CN': 'test', 'OU': 'developer'},
issuer={'CN': 'test-ca'},
serial_number=123,
not_before=datetime.now(),
not_after=datetime.now() + timedelta(days=1),
fingerprint="test123",
raw_cert=self.test_cert
)
tls_params = {
'version': 'TLSv1.3',
'cipher': 'AES256-GCM-SHA384',
'fingerprint': 'test123',
'subject': {'CN': 'test', 'OU': 'developer'},
'issuer': {'CN': 'test-ca'},
'validity': 'valid',
'revoked': False,
'extensions': ['SubjectAlternativeName'],
'signature_algorithm': 'SHA256'
}
# Validate certificate first (includes revocation check)
self.engine.validate_certificate(cert_info)
# Explicitly verify revocation status is checked
self.assertFalse(self.engine.is_certificate_revoked(cert_info))
# Map role from certificate
role = self.engine.get_role_from_certificate(cert_info)
self.assertEqual(role, Role.DEVELOPER)
# Log full TLS handshake
self.engine.log_tls_handshake(cert_info, tls_params)
# Verify all parameters were logged
from security.audit import AuditLogger
audit = AuditLogger()
with sqlite3.connect(audit.db_path) as conn:
cursor = conn.execute("""
SELECT tls_params FROM audit_logs
WHERE operation_type = 'TLS_HANDSHAKE'
ORDER BY timestamp DESC LIMIT 1
""")
logged_params = json.loads(cursor.fetchone()[0])
self.assertEqual(logged_params['version'], 'TLSv1.3')
self.assertEqual(logged_params['cipher'], 'AES256-GCM-SHA384')
self.assertEqual(logged_params['subject']['OU'], 'developer')
self.assertEqual(logged_params['signature_algorithm'], 'SHA256')
# Check permissions
self.assertTrue(self.engine.check_permission(role, "tasks", "create"))
def test_role_boundary_validation(self):
"""Test role boundary validation"""
# Valid boundary
cert_info = ClientCertInfo(
subject={'CN': 'test', 'OU': 'developer:internal'},
issuer={'CN': 'test-ca'},
serial_number=123,
not_before=datetime.now(),
not_after=datetime.now() + timedelta(days=1),
fingerprint="test123",
raw_cert=self.test_cert
)
role = self.engine.get_role_from_certificate(cert_info)
self.assertEqual(role, Role.DEVELOPER)
# Invalid boundary
cert_info.subject['OU'] = 'developer:external'
with self.assertRaises(ValueError):
self.engine.get_role_from_certificate(cert_info)
if __name__ == '__main__':
unittest.main()