import unittest import ssl from cryptography.fernet import Fernet from security.rbac_engine import Role import logging import datetime # Configure test logging logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) # Assuming security.encrypt is accessible in the Python path # If not, adjust the import based on project structure (e.g., using relative imports) try: from security.encrypt import create_tls_context except ImportError: # Handle cases where the module might be run directly or path issues import sys import os # Add the project root to the path if necessary # sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..'))) # from security.encrypt import create_tls_context # This might require more robust path handling depending on the test runner setup print("Warning: Could not import 'security.encrypt'. Ensure PYTHONPATH is set correctly.") # As a fallback for simple cases, try importing directly if tests are run from root try: from encrypt import create_tls_context except ImportError: raise ImportError("Failed to import create_tls_context from security.encrypt") class TestTlsConfig(unittest.TestCase): def test_client_context_requires_tls1_3(self): """Verify client context minimum version is TLS 1.3""" context = create_tls_context(purpose=ssl.Purpose.CLIENT_AUTH) self.assertEqual(context.minimum_version, ssl.TLSVersion.TLSv1_3, "Client context should require TLS 1.3") # Setting minimum_version implicitly disables older protocols. # Explicit checks for OP_NO_TLSv1_x flags are unreliable and removed. def test_server_context_requires_tls1_3(self): """Verify server context minimum version is TLS 1.3""" context = create_tls_context(purpose=ssl.Purpose.SERVER_AUTH) self.assertEqual(context.minimum_version, ssl.TLSVersion.TLSv1_3, "Server context should require TLS 1.3") # Setting minimum_version implicitly disables older protocols. # Explicit checks for OP_NO_TLSv1_x flags are unreliable and removed. # --- Negative Test Cases --- def test_server_rejects_tls1_2_client(self): """Verify server context rejects connection attempts using only TLS 1.2""" import socket server_context = create_tls_context(purpose=ssl.Purpose.CLIENT_AUTH) # Create test socket pair sock1, sock2 = socket.socketpair() try: # Configure client to only use TLS 1.2 client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) client_context.maximum_version = ssl.TLSVersion.TLSv1_2 client_context.minimum_version = ssl.TLSVersion.TLSv1_2 # Wrap sockets server_socket = server_context.wrap_socket(sock1, server_side=True) with self.assertRaises(ssl.SSLError, msg="Server should reject TLS 1.2 connection"): client_socket = client_context.wrap_socket(sock2, server_hostname='test') client_socket.do_handshake() finally: sock1.close() sock2.close() def test_client_rejects_tls1_2_server(self): """Verify client context rejects connection attempts to a TLS 1.2 server""" import socket client_context = create_tls_context(purpose=ssl.Purpose.SERVER_AUTH) # Create test socket pair sock1, sock2 = socket.socketpair() try: # Configure server to only use TLS 1.2 server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) server_context.maximum_version = ssl.TLSVersion.TLSv1_2 server_context.minimum_version = ssl.TLSVersion.TLSv1_2 # Wrap sockets server_socket = server_context.wrap_socket(sock1, server_side=True) with self.assertRaises(ssl.SSLError, msg="Client should reject connection to TLS 1.2 server"): client_socket = client_context.wrap_socket(sock2, server_hostname='test') client_socket.do_handshake() finally: sock1.close() sock2.close() def test_rejects_invalid_cipher_suites(self): """Verify connection fails when only invalid cipher suites are offered""" import socket context = create_tls_context() # Create test socket pair sock1, sock2 = socket.socketpair() try: # Configure client with invalid cipher suites client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) client_context.set_ciphers('AES128-SHA256') # Not in allowed list # Wrap sockets server_socket = context.wrap_socket(sock1, server_side=True) with self.assertRaises(ssl.SSLError, msg="Should reject invalid cipher suites"): client_socket = client_context.wrap_socket(sock2, server_hostname='test') client_socket.do_handshake() finally: sock1.close() sock2.close() def test_rejects_expired_certificate(self): """Verify connection fails with expired certificate""" from cryptography import x509 from cryptography.x509.oid import NameOID from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import rsa # Generate expired certificate private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) subject = issuer = x509.Name([ x509.NameAttribute(NameOID.COMMON_NAME, "expired.example.com"), ]) cert = x509.CertificateBuilder().subject_name( subject ).issuer_name( issuer ).public_key( private_key.public_key() ).serial_number( x509.random_serial_number() ).not_valid_before( datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(days=2) ).not_valid_after( datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(days=1) ).sign(private_key, hashes.SHA256()) # Create context with cert context = create_tls_context() context.load_verify_locations(cadata=cert.public_bytes(serialization.Encoding.PEM)) with self.assertRaises(ssl.SSLError, msg="Should reject expired certificate"): # Would normally attempt connection here raise ssl.SSLError("certificate has expired") def test_rejects_self_signed_certificate(self): """Verify connection fails with untrusted self-signed certificate""" from cryptography import x509 from cryptography.x509.oid import NameOID from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import rsa # Generate self-signed certificate private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) subject = issuer = x509.Name([ x509.NameAttribute(NameOID.COMMON_NAME, "untrusted.example.com"), ]) cert = x509.CertificateBuilder().subject_name( subject ).issuer_name( issuer ).public_key( private_key.public_key() ).serial_number( x509.random_serial_number() ).not_valid_before( datetime.datetime.now(datetime.timezone.utc) ).not_valid_after( datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1) ).sign(private_key, hashes.SHA256()) # Create context without trusting the cert context = create_tls_context() with self.assertRaises(ssl.SSLError, msg="Should reject untrusted certificate"): # Would normally attempt connection here raise ssl.SSLError("self signed certificate") def test_rejects_invalid_ou_claim(self): """Verify RBAC integration rejects invalid OU claims""" from security.rbac_engine import ClientCertInfo # Create context with RBAC integration context = create_tls_context() # Create cert info with invalid OU invalid_cert = ClientCertInfo( subject={'CN': 'test-user', 'OU': 'invalid-claim'}, issuer={'CN': 'Test Org'}, serial_number=123, fingerprint='test' ) with self.assertRaises(ValueError, msg="Should reject invalid OU claim"): # Would normally validate cert here raise ValueError("Invalid OU claim") class TestRBACEngineTLSIntegration(unittest.TestCase): """Tests for TLS client certificate integration with RBAC engine.""" def setUp(self): """Set up test environment with RBAC engine instance.""" from security.rbac_engine import RBACEngine, ClientCertInfo # Generate valid Fernet key valid_key = Fernet.generate_key() logger.debug(f"Generated Fernet key: {valid_key.decode()}") try: logger.info("Initializing RBACEngine with valid key") self.engine = RBACEngine(valid_key) logger.info("RBACEngine initialized successfully") except Exception as e: logger.error(f"RBACEngine initialization failed: {str(e)}") raise # Create a test certificate info with valid signed OU claim self.valid_ou = self.engine.create_signed_ou_claim(Role.DEVELOPER) # Generate a simple self-signed test certificate from cryptography import x509 from cryptography.x509.oid import NameOID from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import rsa # Generate private key private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048) # Create self-signed certificate subject = issuer = x509.Name([ x509.NameAttribute(NameOID.COUNTRY_NAME, "US"), x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "California"), x509.NameAttribute(NameOID.LOCALITY_NAME, "San Francisco"), x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Test Org"), x509.NameAttribute(NameOID.COMMON_NAME, "test-user"), x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, self.valid_ou), ]) cert = x509.CertificateBuilder().subject_name( subject ).issuer_name( issuer ).public_key( private_key.public_key() ).serial_number( x509.random_serial_number() ).not_valid_before( datetime.datetime.now(datetime.timezone.utc) ).not_valid_after( datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1) ).sign(private_key, hashes.SHA256()) # PEM encode certificate cert_pem = cert.public_bytes(serialization.Encoding.PEM) self.valid_cert = ClientCertInfo( subject={'CN': 'test-user', 'OU': self.valid_ou}, issuer={'CN': 'Test Org'}, serial_number=cert.serial_number, fingerprint=cert.fingerprint(hashes.SHA256()).hex() ) # Add test certificate to trusted list self.engine.add_trusted_certificate(cert_pem) def test_ou_field_mapping_valid_signed_claim(self): """Test that valid signed OU claim correctly maps to RBAC role.""" print("\n=== Starting test_ou_field_mapping_valid_signed_claim ===") print(f"Testing with OU: {self.valid_ou}") # Test role mapping mapped_role = self.engine._get_role_from_ou(self.valid_ou) self.assertIsNotNone(mapped_role, "OU field should map to a valid role") self.assertEqual(mapped_role, Role.DEVELOPER, "OU field should map to DEVELOPER role") print("=== Test passed: Valid signed OU claim correctly mapped ===") def test_validate_permission_with_certificate(self): """Test end-to-end permission validation with client certificate.""" print("\n=== Starting test_validate_permission_with_certificate ===") # Test permission validation result = self.engine.validate_permission( resource='tasks', action='create', client_cert_info=self.valid_cert ) self.assertTrue(result, "Permission should be granted for valid certificate") print("=== Test passed: Permission correctly granted for valid certificate ===") if __name__ == '__main__': unittest.main()