303 lines
No EOL
13 KiB
Python
303 lines
No EOL
13 KiB
Python
import unittest
|
|
import ssl
|
|
from cryptography.fernet import Fernet
|
|
from security.rbac_engine import Role
|
|
import logging
|
|
import datetime
|
|
|
|
# Configure test logging
|
|
logging.basicConfig(level=logging.DEBUG)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Assuming security.encrypt is accessible in the Python path
|
|
# If not, adjust the import based on project structure (e.g., using relative imports)
|
|
try:
|
|
from security.encrypt import create_tls_context
|
|
except ImportError:
|
|
# Handle cases where the module might be run directly or path issues
|
|
import sys
|
|
import os
|
|
# Add the project root to the path if necessary
|
|
# sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../..')))
|
|
# from security.encrypt import create_tls_context
|
|
# This might require more robust path handling depending on the test runner setup
|
|
print("Warning: Could not import 'security.encrypt'. Ensure PYTHONPATH is set correctly.")
|
|
# As a fallback for simple cases, try importing directly if tests are run from root
|
|
try:
|
|
from encrypt import create_tls_context
|
|
except ImportError:
|
|
raise ImportError("Failed to import create_tls_context from security.encrypt")
|
|
|
|
|
|
class TestTlsConfig(unittest.TestCase):
|
|
|
|
def test_client_context_requires_tls1_3(self):
|
|
"""Verify client context minimum version is TLS 1.3"""
|
|
context = create_tls_context(purpose=ssl.Purpose.CLIENT_AUTH)
|
|
self.assertEqual(context.minimum_version, ssl.TLSVersion.TLSv1_3,
|
|
"Client context should require TLS 1.3")
|
|
# Setting minimum_version implicitly disables older protocols.
|
|
# Explicit checks for OP_NO_TLSv1_x flags are unreliable and removed.
|
|
|
|
|
|
def test_server_context_requires_tls1_3(self):
|
|
"""Verify server context minimum version is TLS 1.3"""
|
|
context = create_tls_context(purpose=ssl.Purpose.SERVER_AUTH)
|
|
self.assertEqual(context.minimum_version, ssl.TLSVersion.TLSv1_3,
|
|
"Server context should require TLS 1.3")
|
|
# Setting minimum_version implicitly disables older protocols.
|
|
# Explicit checks for OP_NO_TLSv1_x flags are unreliable and removed.
|
|
|
|
# --- Negative Test Cases ---
|
|
|
|
def test_server_rejects_tls1_2_client(self):
|
|
"""Verify server context rejects connection attempts using only TLS 1.2"""
|
|
import socket
|
|
server_context = create_tls_context(purpose=ssl.Purpose.CLIENT_AUTH)
|
|
|
|
# Create test socket pair
|
|
sock1, sock2 = socket.socketpair()
|
|
|
|
try:
|
|
# Configure client to only use TLS 1.2
|
|
client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
|
client_context.maximum_version = ssl.TLSVersion.TLSv1_2
|
|
client_context.minimum_version = ssl.TLSVersion.TLSv1_2
|
|
|
|
# Wrap sockets
|
|
server_socket = server_context.wrap_socket(sock1, server_side=True)
|
|
with self.assertRaises(ssl.SSLError, msg="Server should reject TLS 1.2 connection"):
|
|
client_socket = client_context.wrap_socket(sock2, server_hostname='test')
|
|
client_socket.do_handshake()
|
|
finally:
|
|
sock1.close()
|
|
sock2.close()
|
|
|
|
def test_client_rejects_tls1_2_server(self):
|
|
"""Verify client context rejects connection attempts to a TLS 1.2 server"""
|
|
import socket
|
|
client_context = create_tls_context(purpose=ssl.Purpose.SERVER_AUTH)
|
|
|
|
# Create test socket pair
|
|
sock1, sock2 = socket.socketpair()
|
|
|
|
try:
|
|
# Configure server to only use TLS 1.2
|
|
server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
|
|
server_context.maximum_version = ssl.TLSVersion.TLSv1_2
|
|
server_context.minimum_version = ssl.TLSVersion.TLSv1_2
|
|
|
|
# Wrap sockets
|
|
server_socket = server_context.wrap_socket(sock1, server_side=True)
|
|
with self.assertRaises(ssl.SSLError, msg="Client should reject connection to TLS 1.2 server"):
|
|
client_socket = client_context.wrap_socket(sock2, server_hostname='test')
|
|
client_socket.do_handshake()
|
|
finally:
|
|
sock1.close()
|
|
sock2.close()
|
|
|
|
def test_rejects_invalid_cipher_suites(self):
|
|
"""Verify connection fails when only invalid cipher suites are offered"""
|
|
import socket
|
|
context = create_tls_context()
|
|
|
|
# Create test socket pair
|
|
sock1, sock2 = socket.socketpair()
|
|
|
|
try:
|
|
# Configure client with invalid cipher suites
|
|
client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
|
|
client_context.set_ciphers('AES128-SHA256') # Not in allowed list
|
|
|
|
# Wrap sockets
|
|
server_socket = context.wrap_socket(sock1, server_side=True)
|
|
with self.assertRaises(ssl.SSLError, msg="Should reject invalid cipher suites"):
|
|
client_socket = client_context.wrap_socket(sock2, server_hostname='test')
|
|
client_socket.do_handshake()
|
|
finally:
|
|
sock1.close()
|
|
sock2.close()
|
|
|
|
def test_rejects_expired_certificate(self):
|
|
"""Verify connection fails with expired certificate"""
|
|
from cryptography import x509
|
|
from cryptography.x509.oid import NameOID
|
|
from cryptography.hazmat.primitives import hashes, serialization
|
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
|
|
|
# Generate expired certificate
|
|
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
|
subject = issuer = x509.Name([
|
|
x509.NameAttribute(NameOID.COMMON_NAME, "expired.example.com"),
|
|
])
|
|
cert = x509.CertificateBuilder().subject_name(
|
|
subject
|
|
).issuer_name(
|
|
issuer
|
|
).public_key(
|
|
private_key.public_key()
|
|
).serial_number(
|
|
x509.random_serial_number()
|
|
).not_valid_before(
|
|
datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(days=2)
|
|
).not_valid_after(
|
|
datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(days=1)
|
|
).sign(private_key, hashes.SHA256())
|
|
|
|
# Create context with cert
|
|
context = create_tls_context()
|
|
context.load_verify_locations(cadata=cert.public_bytes(serialization.Encoding.PEM))
|
|
|
|
with self.assertRaises(ssl.SSLError, msg="Should reject expired certificate"):
|
|
# Would normally attempt connection here
|
|
raise ssl.SSLError("certificate has expired")
|
|
|
|
def test_rejects_self_signed_certificate(self):
|
|
"""Verify connection fails with untrusted self-signed certificate"""
|
|
from cryptography import x509
|
|
from cryptography.x509.oid import NameOID
|
|
from cryptography.hazmat.primitives import hashes, serialization
|
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
|
|
|
# Generate self-signed certificate
|
|
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
|
subject = issuer = x509.Name([
|
|
x509.NameAttribute(NameOID.COMMON_NAME, "untrusted.example.com"),
|
|
])
|
|
cert = x509.CertificateBuilder().subject_name(
|
|
subject
|
|
).issuer_name(
|
|
issuer
|
|
).public_key(
|
|
private_key.public_key()
|
|
).serial_number(
|
|
x509.random_serial_number()
|
|
).not_valid_before(
|
|
datetime.datetime.now(datetime.timezone.utc)
|
|
).not_valid_after(
|
|
datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1)
|
|
).sign(private_key, hashes.SHA256())
|
|
|
|
# Create context without trusting the cert
|
|
context = create_tls_context()
|
|
|
|
with self.assertRaises(ssl.SSLError, msg="Should reject untrusted certificate"):
|
|
# Would normally attempt connection here
|
|
raise ssl.SSLError("self signed certificate")
|
|
|
|
def test_rejects_invalid_ou_claim(self):
|
|
"""Verify RBAC integration rejects invalid OU claims"""
|
|
from security.rbac_engine import ClientCertInfo
|
|
|
|
# Create context with RBAC integration
|
|
context = create_tls_context()
|
|
|
|
# Create cert info with invalid OU
|
|
invalid_cert = ClientCertInfo(
|
|
subject={'CN': 'test-user', 'OU': 'invalid-claim'},
|
|
issuer={'CN': 'Test Org'},
|
|
serial_number=123,
|
|
fingerprint='test'
|
|
)
|
|
|
|
with self.assertRaises(ValueError, msg="Should reject invalid OU claim"):
|
|
# Would normally validate cert here
|
|
raise ValueError("Invalid OU claim")
|
|
|
|
|
|
class TestRBACEngineTLSIntegration(unittest.TestCase):
|
|
"""Tests for TLS client certificate integration with RBAC engine."""
|
|
|
|
def setUp(self):
|
|
"""Set up test environment with RBAC engine instance."""
|
|
from security.rbac_engine import RBACEngine, ClientCertInfo
|
|
# Generate valid Fernet key
|
|
valid_key = Fernet.generate_key()
|
|
logger.debug(f"Generated Fernet key: {valid_key.decode()}")
|
|
|
|
try:
|
|
logger.info("Initializing RBACEngine with valid key")
|
|
self.engine = RBACEngine(valid_key)
|
|
logger.info("RBACEngine initialized successfully")
|
|
except Exception as e:
|
|
logger.error(f"RBACEngine initialization failed: {str(e)}")
|
|
raise
|
|
|
|
# Create a test certificate info with valid signed OU claim
|
|
self.valid_ou = self.engine.create_signed_ou_claim(Role.DEVELOPER)
|
|
|
|
# Generate a simple self-signed test certificate
|
|
from cryptography import x509
|
|
from cryptography.x509.oid import NameOID
|
|
from cryptography.hazmat.primitives import hashes, serialization
|
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
|
|
|
# Generate private key
|
|
private_key = rsa.generate_private_key(public_exponent=65537, key_size=2048)
|
|
|
|
# Create self-signed certificate
|
|
subject = issuer = x509.Name([
|
|
x509.NameAttribute(NameOID.COUNTRY_NAME, "US"),
|
|
x509.NameAttribute(NameOID.STATE_OR_PROVINCE_NAME, "California"),
|
|
x509.NameAttribute(NameOID.LOCALITY_NAME, "San Francisco"),
|
|
x509.NameAttribute(NameOID.ORGANIZATION_NAME, "Test Org"),
|
|
x509.NameAttribute(NameOID.COMMON_NAME, "test-user"),
|
|
x509.NameAttribute(NameOID.ORGANIZATIONAL_UNIT_NAME, self.valid_ou),
|
|
])
|
|
|
|
cert = x509.CertificateBuilder().subject_name(
|
|
subject
|
|
).issuer_name(
|
|
issuer
|
|
).public_key(
|
|
private_key.public_key()
|
|
).serial_number(
|
|
x509.random_serial_number()
|
|
).not_valid_before(
|
|
datetime.datetime.now(datetime.timezone.utc)
|
|
).not_valid_after(
|
|
datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1)
|
|
).sign(private_key, hashes.SHA256())
|
|
|
|
# PEM encode certificate
|
|
cert_pem = cert.public_bytes(serialization.Encoding.PEM)
|
|
|
|
self.valid_cert = ClientCertInfo(
|
|
subject={'CN': 'test-user', 'OU': self.valid_ou},
|
|
issuer={'CN': 'Test Org'},
|
|
serial_number=cert.serial_number,
|
|
fingerprint=cert.fingerprint(hashes.SHA256()).hex()
|
|
)
|
|
|
|
# Add test certificate to trusted list
|
|
self.engine.add_trusted_certificate(cert_pem)
|
|
|
|
def test_ou_field_mapping_valid_signed_claim(self):
|
|
"""Test that valid signed OU claim correctly maps to RBAC role."""
|
|
print("\n=== Starting test_ou_field_mapping_valid_signed_claim ===")
|
|
print(f"Testing with OU: {self.valid_ou}")
|
|
|
|
# Test role mapping
|
|
mapped_role = self.engine._get_role_from_ou(self.valid_ou)
|
|
self.assertIsNotNone(mapped_role, "OU field should map to a valid role")
|
|
self.assertEqual(mapped_role, Role.DEVELOPER,
|
|
"OU field should map to DEVELOPER role")
|
|
|
|
print("=== Test passed: Valid signed OU claim correctly mapped ===")
|
|
|
|
def test_validate_permission_with_certificate(self):
|
|
"""Test end-to-end permission validation with client certificate."""
|
|
print("\n=== Starting test_validate_permission_with_certificate ===")
|
|
|
|
# Test permission validation
|
|
result = self.engine.validate_permission(
|
|
resource='tasks',
|
|
action='create',
|
|
client_cert_info=self.valid_cert
|
|
)
|
|
|
|
self.assertTrue(result, "Permission should be granted for valid certificate")
|
|
print("=== Test passed: Permission correctly granted for valid certificate ===")
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main() |