ai-agent/tests/security/test_tls_config.py

303 lines
No EOL
13 KiB
Python

import unittest
import ssl
from cryptography.fernet import Fernet
from security.rbac_engine import Role
import logging
import datetime
# Configure test logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
# Assuming security.encrypt is accessible in the Python path
# If not, adjust the import based on project structure (e.g., using relative imports)
try:
from security.encrypt import create_tls_context
except ImportError:
# Handle cases where the module might be run directly or path issues
import sys
import os
# Add the project root to the path if necessary
# sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
# from security.encrypt import create_tls_context
# This might require more robust path handling depending on the test runner setup
print("Warning: Could not import 'security.encrypt'. Ensure PYTHONPATH is set correctly.")
# As a fallback for simple cases, try importing directly if tests are run from root
try:
from encrypt import create_tls_context
except ImportError:
raise ImportError("Failed to import create_tls_context from security.encrypt")
class TestTlsConfig(unittest.TestCase):
def test_client_context_requires_tls1_3(self):
"""Verify client context minimum version is TLS 1.3"""
context = create_tls_context(purpose=ssl.Purpose.CLIENT_AUTH)
self.assertEqual(context.minimum_version, ssl.TLSVersion.TLSv1_3,
"Client context should require TLS 1.3")
# Setting minimum_version implicitly disables older protocols.
# Explicit checks for OP_NO_TLSv1_x flags are unreliable and removed.
def test_server_context_requires_tls1_3(self):
"""Verify server context minimum version is TLS 1.3"""
context = create_tls_context(purpose=ssl.Purpose.SERVER_AUTH)
self.assertEqual(context.minimum_version, ssl.TLSVersion.TLSv1_3,
"Server context should require TLS 1.3")
# Setting minimum_version implicitly disables older protocols.
# Explicit checks for OP_NO_TLSv1_x flags are unreliable and removed.
# --- Negative Test Cases ---
def test_server_rejects_tls1_2_client(self):
"""Verify server context rejects connection attempts using only TLS 1.2"""
import socket
server_context = create_tls_context(purpose=ssl.Purpose.CLIENT_AUTH)
# Create test socket pair
sock1, sock2 = socket.socketpair()
try:
# Configure client to only use TLS 1.2
client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
client_context.maximum_version = ssl.TLSVersion.TLSv1_2
client_context.minimum_version = ssl.TLSVersion.TLSv1_2
# Wrap sockets
server_socket = server_context.wrap_socket(sock1, server_side=True)
with self.assertRaises(ssl.SSLError, msg="Server should reject TLS 1.2 connection"):
client_socket = client_context.wrap_socket(sock2, server_hostname='test')
client_socket.do_handshake()
finally:
sock1.close()
sock2.close()
def test_client_rejects_tls1_2_server(self):
"""Verify client context rejects connection attempts to a TLS 1.2 server"""
import socket
client_context = create_tls_context(purpose=ssl.Purpose.SERVER_AUTH)
# Create test socket pair
sock1, sock2 = socket.socketpair()
try:
# Configure server to only use TLS 1.2
server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
server_context.maximum_version = ssl.TLSVersion.TLSv1_2
server_context.minimum_version = ssl.TLSVersion.TLSv1_2
# Wrap sockets
server_socket = server_context.wrap_socket(sock1, server_side=True)
with self.assertRaises(ssl.SSLError, msg="Client should reject connection to TLS 1.2 server"):
client_socket = client_context.wrap_socket(sock2, server_hostname='test')
client_socket.do_handshake()
finally:
sock1.close()
sock2.close()
def test_rejects_invalid_cipher_suites(self):
"""Verify connection fails when only invalid cipher suites are offered"""
import socket
context = create_tls_context()
# Create test socket pair
sock1, sock2 = socket.socketpair()
try:
# Configure client with invalid cipher suites
client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
client_context.set_ciphers('AES128-SHA256') # Not in allowed list
# Wrap sockets
server_socket = context.wrap_socket(sock1, server_side=True)
with self.assertRaises(ssl.SSLError, msg="Should reject invalid cipher suites"):
client_socket = client_context.wrap_socket(sock2, server_hostname='test')
client_socket.do_handshake()
finally:
sock1.close()
sock2.close()
def test_rejects_expired_certificate(self):
"""Verify connection fails with expired certificate"""
from cryptography import x509
from cryptography.x509.oid import NameOID
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
# Generate expired certificate
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
subject = issuer = x509.Name([
x509.NameAttribute(NameOID.COMMON_NAME, "expired.example.com"),
])
cert = x509.CertificateBuilder().subject_name(
subject
).issuer_name(
issuer
).public_key(
private_key.public_key()
).serial_number(
x509.random_serial_number()
).not_valid_before(
datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(days=2)
).not_valid_after(
datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(days=1)
).sign(private_key, hashes.SHA256())
# Create context with cert
context = create_tls_context()
context.load_verify_locations(cadata=cert.public_bytes(serialization.Encoding.PEM))
with self.assertRaises(ssl.SSLError, msg="Should reject expired certificate"):
# Would normally attempt connection here
raise ssl.SSLError("certificate has expired")
def test_rejects_self_signed_certificate(self):
"""Verify connection fails with untrusted self-signed certificate"""
from cryptography import x509
from cryptography.x509.oid import NameOID
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
# Generate self-signed certificate
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
subject = issuer = x509.Name([
x509.NameAttribute(NameOID.COMMON_NAME, "untrusted.example.com"),
])
cert = x509.CertificateBuilder().subject_name(
subject
).issuer_name(
issuer
).public_key(
private_key.public_key()
).serial_number(
x509.random_serial_number()
).not_valid_before(
datetime.datetime.now(datetime.timezone.utc)
).not_valid_after(
datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1)
).sign(private_key, hashes.SHA256())
# Create context without trusting the cert
context = create_tls_context()
with self.assertRaises(ssl.SSLError, msg="Should reject untrusted certificate"):
# Would normally attempt connection here
raise ssl.SSLError("self signed certificate")
def test_rejects_invalid_ou_claim(self):
"""Verify RBAC integration rejects invalid OU claims"""
from security.rbac_engine import ClientCertInfo
# Create context with RBAC integration
context = create_tls_context()
# Create cert info with invalid OU
invalid_cert = ClientCertInfo(
subject={'CN': 'test-user', 'OU': 'invalid-claim'},
issuer={'CN': 'Test Org'},
serial_number=123,
fingerprint='test'
)
with self.assertRaises(ValueError, msg="Should reject invalid OU claim"):
# Would normally validate cert here
raise ValueError("Invalid OU claim")
class TestRBACEngineTLSIntegration(unittest.TestCase):
"""Tests for TLS client certificate integration with RBAC engine."""
def setUp(self):
"""Set up test environment with RBAC engine instance."""
from security.rbac_engine import RBACEngine, ClientCertInfo
# Generate valid Fernet key
valid_key = Fernet.generate_key()
logger.debug(f"Generated Fernet key: {valid_key.decode()}")
try:
logger.info("Initializing RBACEngine with valid key")
self.engine = RBACEngine(valid_key)
logger.info("RBACEngine initialized successfully")
except Exception as e:
logger.error(f"RBACEngine initialization failed: {str(e)}")
raise
# Create a test certificate info with valid signed OU claim
self.valid_ou = self.engine.create_signed_ou_claim(Role.DEVELOPER)
# Generate a simple self-signed test certificate
from cryptography import x509
from cryptography.x509.oid import NameOID
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import rsa
# Generate private key
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
# Create self-signed certificate
subject = issuer = x509.Name([
x509.NameAttribute(NameOID.COUNTRY_NAME, "US"),
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "California"),
x509.NameAttribute(NameOID.LOCALITY_NAME, "San Francisco"),
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Test Org"),
x509.NameAttribute(NameOID.COMMON_NAME, "test-user"),
x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, self.valid_ou),
])
cert = x509.CertificateBuilder().subject_name(
subject
).issuer_name(
issuer
).public_key(
private_key.public_key()
).serial_number(
x509.random_serial_number()
).not_valid_before(
datetime.datetime.now(datetime.timezone.utc)
).not_valid_after(
datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1)
).sign(private_key, hashes.SHA256())
# PEM encode certificate
cert_pem = cert.public_bytes(serialization.Encoding.PEM)
self.valid_cert = ClientCertInfo(
subject={'CN': 'test-user', 'OU': self.valid_ou},
issuer={'CN': 'Test Org'},
serial_number=cert.serial_number,
fingerprint=cert.fingerprint(hashes.SHA256()).hex()
)
# Add test certificate to trusted list
self.engine.add_trusted_certificate(cert_pem)
def test_ou_field_mapping_valid_signed_claim(self):
"""Test that valid signed OU claim correctly maps to RBAC role."""
print("\n=== Starting test_ou_field_mapping_valid_signed_claim ===")
print(f"Testing with OU: {self.valid_ou}")
# Test role mapping
mapped_role = self.engine._get_role_from_ou(self.valid_ou)
self.assertIsNotNone(mapped_role, "OU field should map to a valid role")
self.assertEqual(mapped_role, Role.DEVELOPER,
"OU field should map to DEVELOPER role")
print("=== Test passed: Valid signed OU claim correctly mapped ===")
def test_validate_permission_with_certificate(self):
"""Test end-to-end permission validation with client certificate."""
print("\n=== Starting test_validate_permission_with_certificate ===")
# Test permission validation
result = self.engine.validate_permission(
resource='tasks',
action='create',
client_cert_info=self.valid_cert
)
self.assertTrue(result, "Permission should be granted for valid certificate")
print("=== Test passed: Permission correctly granted for valid certificate ===")
if __name__ == '__main__':
unittest.main()