ai-agent/security/audit.py

511 lines
No EOL
22 KiB
Python

import os
import hashlib
import hmac
import threading
import sqlite3
from datetime import datetime, timedelta
from typing import Dict, List, Optional
from pathlib import Path
from cryptography.fernet import Fernet
class SecureAudit:
def __init__(self, rbac_engine, db_path: str = "audit.db", key_path: str = "audit.key"):
"""Initialize secure audit logger with:
- AES-256 encryption for cron expressions and sensitive data
- HMAC-SHA256 obfuscation for task IDs
- Chained timestamp integrity verification"""
self.rbac = rbac_engine
self.sequence = 0
self._lock = threading.Lock()
self.last_hash = ""
# Initialize key management
self.key_path = Path(key_path)
self.hmac_key = self._init_key()
self.fernet = Fernet(Fernet.generate_key())
# Initialize database
self.db_path = Path(db_path)
self._init_db()
def log_operation(self, operation_type: str, operation_result: bool, **kwargs):
"""Log security operation with TLS parameters"""
with self._lock:
timestamp = datetime.utcnow().isoformat()
tls_params = kwargs.get('tls_params', {})
with sqlite3.connect(self.db_path) as conn:
conn.execute("""
INSERT INTO audit_logs (
sequence, timestamp, operation_type, operation_result,
user_identity, tls_version, tls_cipher, cert_fingerprint,
cert_subject, cert_issuer, cert_validity, cert_revoked,
role_mapped, boundary_violation
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
self.sequence,
timestamp,
operation_type,
operation_result,
kwargs.get('user', ''),
tls_params.get('version', ''),
tls_params.get('cipher', ''),
tls_params.get('cert_fingerprint', ''),
str(tls_params.get('cert_subject', '')),
str(tls_params.get('cert_issuer', '')),
tls_params.get('cert_validity', ''),
tls_params.get('cert_revoked', False),
kwargs.get('role', ''),
kwargs.get('boundary_violation', False)
))
self.sequence += 1
def _init_key(self) -> bytes:
"""Initialize or load HMAC key"""
if self.key_path.exists():
with open(self.key_path, "rb") as f:
return f.read()
else:
key = hashlib.sha256(os.urandom(32)).digest()
with open(self.key_path, "wb") as f:
f.write(key)
self.key_path.chmod(0o600) # Restrict permissions
return key
def _init_db(self):
"""Initialize SQLite database with enhanced TLS handshake logging columns"""
with sqlite3.connect(self.db_path) as conn:
conn.execute("""
CREATE TABLE IF NOT EXISTS audit_logs (
id INTEGER PRIMARY KEY,
sequence INTEGER,
timestamp TEXT,
operation_type TEXT,
operation_result BOOLEAN,
user_identity TEXT,
tls_version TEXT,
tls_cipher TEXT,
cert_fingerprint TEXT,
cert_subject TEXT,
cert_issuer TEXT,
cert_validity TEXT,
cert_revoked BOOLEAN,
role_mapped TEXT,
boundary_violation BOOLEAN,
tls_version TEXT,
cipher_suite TEXT,
cert_fingerprint TEXT,
client_cert_subject TEXT,
client_cert_issuer TEXT,
client_cert_validity TEXT,
client_cert_revoked INTEGER DEFAULT 0,
operation TEXT,
key_hash TEXT,
encrypted_key TEXT,
encrypted_cron TEXT DEFAULT '',
obfuscated_task_id TEXT DEFAULT '',
success INTEGER,
user TEXT,
reason TEXT,
integrity_hash TEXT,
previous_hash TEXT,
FOREIGN KEY(previous_hash) REFERENCES audit_logs(integrity_hash)
)
""")
conn.execute("CREATE INDEX IF NOT EXISTS idx_timestamp ON audit_logs(timestamp)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_user ON audit_logs(user)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_operation ON audit_logs(operation)")
def _calculate_hmac(self, data: str) -> str:
"""Calculate HMAC-SHA256 with:
- Chained hashes for tamper detection
- Timestamp integrity verification
- Task ID obfuscation"""
timestamp = datetime.utcnow().isoformat()
return hmac.new(
self.hmac_key,
(data + self.last_hash + timestamp).encode(),
hashlib.sha256
).hexdigest()
def _verify_timestamp(self, timestamp: str, max_skew: int = 30) -> bool:
"""Verify timestamp integrity with allowed clock skew (seconds)"""
log_time = datetime.fromisoformat(timestamp)
now = datetime.utcnow()
return abs((now - log_time).total_seconds()) <= max_skew
def _obfuscate_task_id(self, task_id: str) -> str:
"""Obfuscate task IDs with HMAC-SHA256 and salt"""
salt = os.urandom(16).hex()
return hmac.new(
self.hmac_key,
(task_id + salt).encode(),
hashlib.sha256
).hexdigest()
def log_operation(
self,
operation: str,
key: str,
success: bool,
user: Optional[str] = None,
reason: Optional[str] = None,
cron: Optional[str] = None,
task_id: Optional[str] = None,
tls_params: Optional[Dict] = None
) -> str:
"""Log an operation with:
- HMAC-SHA256 integrity protection
- AES-256 encrypted cron expressions
- Obfuscated task IDs"""
with self._lock:
self.sequence += 1
timestamp = datetime.utcnow().isoformat()
# Encrypt sensitive data with AES-256
encrypted_key = self.fernet.encrypt(key.encode()).decode()
hashed_key = hashlib.sha256(encrypted_key.encode()).hexdigest()
# Encrypt cron if provided
encrypted_cron = ""
if cron:
encrypted_cron = self.fernet.encrypt(cron.encode()).decode()
# Obfuscate task ID if provided
obfuscated_task_id = ""
if task_id:
obfuscated_task_id = self._obfuscate_task_id(task_id)
entry = {
"sequence": self.sequence,
"timestamp": timestamp,
"operation": operation,
"key_hash": hashed_key,
"encrypted_cron": encrypted_cron,
"obfuscated_task_id": obfuscated_task_id,
"success": success,
"user": user,
"reason": reason or "",
"previous_hash": self.last_hash
}
# Calculate HMAC-SHA256 integrity hash
integrity_hash = self._calculate_hmac(str(entry))
entry["integrity_hash"] = integrity_hash
self.last_hash = integrity_hash
# Add TLS params if provided
tls_version = ""
cipher_suite = ""
cert_fingerprint = ""
client_cert_subject = ""
client_cert_issuer = ""
client_cert_validity = ""
client_cert_revoked = 0
if tls_params:
tls_version = tls_params.get('version', '')
cipher_suite = tls_params.get('cipher', '')
cert_fingerprint = tls_params.get('cert_fingerprint', '')
client_cert_subject = tls_params.get('cert_subject', '')
client_cert_issuer = tls_params.get('cert_issuer', '')
client_cert_validity = tls_params.get('cert_validity', '')
client_cert_revoked = int(tls_params.get('cert_revoked', False))
# Store in database
with sqlite3.connect(self.db_path) as conn:
conn.execute("""
INSERT INTO audit_logs (
sequence, timestamp, operation, key_hash, encrypted_key,
encrypted_cron, obfuscated_task_id, success, user, reason,
integrity_hash, previous_hash, tls_version, cipher_suite,
cert_fingerprint, client_cert_subject, client_cert_issuer,
client_cert_validity, client_cert_revoked
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
entry["sequence"],
entry["timestamp"],
entry["operation"],
entry["key_hash"],
encrypted_key,
entry["encrypted_cron"],
entry["obfuscated_task_id"],
int(entry["success"]),
entry["user"],
entry["reason"],
entry["integrity_hash"],
entry["previous_hash"],
tls_version,
cipher_suite,
cert_fingerprint,
client_cert_subject,
client_cert_issuer,
client_cert_validity,
client_cert_revoked
))
# Notify RBAC system
if user:
self.rbac._audit_access_attempt(
user,
"memory",
operation,
success,
reason or f"Memory {operation} operation"
)
return integrity_hash
def verify_log_integrity(self) -> bool:
"""Verify all log entries maintain:
- Integrity chain
- Valid timestamps
- Proper encryption"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute("""
SELECT sequence, integrity_hash, previous_hash
FROM audit_logs
ORDER BY sequence
""")
last_hash = ""
for row in cursor:
seq, current_hash, prev_hash = row
if seq == 1:
if prev_hash != "":
return False
else:
if prev_hash != last_hash:
return False
# Verify timestamp is within acceptable skew
timestamp_row = conn.execute(
"SELECT timestamp FROM audit_logs WHERE sequence = ?",
(seq,)
).fetchone()
if not self._verify_timestamp(timestamp_row[0]):
return False
last_hash = current_hash
return True
def log_tls_handshake(self, cert_info: dict, tls_params: dict):
"""Log TLS handshake parameters for security auditing.
SYM-SEC-004/005 Requirements.
Args:
cert_info: Dictionary containing certificate information
tls_params: Dictionary of TLS parameters including:
- protocol: TLS protocol version
- cipher: Cipher suite name
- key_exchange: Key exchange algorithm
- authentication: Authentication method
- encryption: Encryption algorithm
- mac: MAC algorithm
- forward_secrecy: Boolean indicating forward secrecy
- session_resumed: Boolean for session resumption
- session_id: Session ID if available
- session_ticket: Session ticket if available
- ocsp_stapling: Boolean for OCSP stapling status
- sct_validation: Boolean for SCT validation
- extensions: List of TLS extensions
- alpn_protocol: Selected ALPN protocol if any
Logs:
- Full cipher suite breakdown
- Key exchange parameters
- Certificate chain validation details
- OCSP stapling status
- SCT validation status
- ALPN protocol selection
- Session resumption details
- Forward secrecy status
"""
try:
# Extract certificate chain details
cert_chain = []
if 'cert_chain' in cert_info:
cert_chain = [
{
'subject': cert.get('subject', 'unknown'),
'issuer': cert.get('issuer', 'unknown'),
'serial': cert.get('serial', 'unknown'),
'valid_from': cert.get('valid_from', 'unknown'),
'valid_to': cert.get('valid_to', 'unknown'),
'key_algorithm': cert.get('key_algorithm', 'unknown'),
'key_size': cert.get('key_size', 'unknown')
}
for cert in cert_info['cert_chain']
]
log_entry = {
'timestamp': datetime.utcnow().isoformat(),
'event': 'tls_handshake',
'client': cert_info.get('subject', {}).get('CN', 'unknown'),
'protocol': tls_params.get('protocol', 'unknown'),
'cipher_suite': {
'name': tls_params.get('cipher', 'unknown'),
'key_exchange': {
'algorithm': tls_params.get('key_exchange', 'unknown'),
'strength': tls_params.get('key_strength', 'unknown'),
'ephemeral': tls_params.get('key_ephemeral', False)
},
'authentication': tls_params.get('authentication', 'unknown'),
'encryption': {
'algorithm': tls_params.get('encryption', 'unknown'),
'strength': tls_params.get('encryption_strength', 'unknown'),
'mode': tls_params.get('encryption_mode', 'unknown')
},
'mac': {
'algorithm': tls_params.get('mac', 'unknown'),
'strength': tls_params.get('mac_strength', 'unknown')
},
'forward_secrecy': tls_params.get('forward_secrecy', False)
},
'session': {
'resumed': tls_params.get('session_resumed', False),
'id': tls_params.get('session_id', None),
'ticket': tls_params.get('session_ticket', None),
'lifetime': tls_params.get('session_lifetime', 0)
},
'certificates': cert_chain,
'extensions': [
{
'type': ext.get('type', 'unknown'),
'data': hashlib.sha256(str(ext).encode()).hexdigest()
}
for ext in tls_params.get('extensions', [])
],
'security_indicators': {
'ocsp_stapling': tls_params.get('ocsp_stapling', False),
'sct_validation': tls_params.get('sct_validation', False),
'alpn': tls_params.get('alpn_protocol', None),
'compression': tls_params.get('compression', None)
},
'validation': {
'chain_valid': tls_params.get('chain_valid', False),
'hostname_match': tls_params.get('hostname_match', False),
'revocation_status': tls_params.get('revocation_status', 'unknown'),
'expiry_status': tls_params.get('expiry_status', 'valid')
}
}
# Calculate integrity hash
integrity_hash = self._calculate_hmac(str(log_entry))
log_entry['integrity_hash'] = integrity_hash
log_entry['previous_hash'] = self.last_hash
self.last_hash = integrity_hash
# Store in database
with sqlite3.connect(self.db_path) as conn:
# Encrypt sensitive fields before storage
encrypted_client = self.fernet.encrypt(log_entry['client'].encode()).decode()
encrypted_certs = self.fernet.encrypt(str(log_entry['certificates']).decode()
conn.execute("""
INSERT INTO audit_logs (
sequence, timestamp, operation, key_hash,
encrypted_cron, obfuscated_task_id, success, user, reason,
integrity_hash, previous_hash, tls_version, cipher_suite,
cert_fingerprint, client_cert_subject, client_cert_issuer,
client_cert_validity, client_cert_revoked, tls_details
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
self.sequence + 1,
log_entry['timestamp'],
log_entry['event'],
hashlib.sha256(str(log_entry['cipher_suite']).encode()).hexdigest(),
'', # encrypted_cron
'', # obfuscated_task_id
1, # success
encrypted_client,
'TLS handshake completed',
log_entry['integrity_hash'],
log_entry['previous_hash'],
log_entry['protocol'],
str(log_entry['cipher_suite']),
hashlib.sha256(str(log_entry['certificates']).encode()).hexdigest(),
encrypted_certs,
str(log_entry['validation']),
int(log_entry['validation']['revocation_status'] == 'revoked'),
self.fernet.encrypt(str(log_entry).encode()).decode()
))
self.sequence += 1
except Exception as e:
logger.error(f"Error logging TLS handshake: {str(e)}")
# Fall back to basic logging if detailed logging fails
basic_log = {
'timestamp': datetime.utcnow().isoformat(),
'event': 'tls_handshake',
'client': cert_info.get('subject', {}).get('CN', 'unknown'),
'protocol': tls_params.get('protocol', 'unknown'),
'error': str(e)
}
self._write_log_entry(basic_log)
def purge_old_entries(self, days: int = 90):
"""Purge entries older than specified days"""
cutoff = (datetime.utcnow() - timedelta(days=days)).isoformat()
with sqlite3.connect(self.db_path) as conn:
conn.execute("DELETE FROM audit_logs WHERE timestamp < ?", (cutoff,))
def queue_access(self, operation: str, user: str, data: dict, status: str):
"""Queue an access attempt for batched logging"""
with self._lock:
if not hasattr(self, '_batch_queue'):
self._batch_queue = []
self._batch_timer = threading.Timer(1.0, self._flush_batch)
self._batch_timer.start()
self._batch_queue.append({
'operation': operation,
'user': user,
'data': data,
'status': status,
'timestamp': datetime.utcnow().isoformat()
})
if len(self._batch_queue) >= 10: # Flush if batch size reaches 10
self._flush_batch()
def _flush_batch(self):
"""Flush queued audit entries to database"""
if not hasattr(self, '_batch_queue') or not self._batch_queue:
return
with self._lock:
batch = self._batch_queue
self._batch_queue = []
with sqlite3.connect(self.db_path) as conn:
for entry in batch:
self.sequence += 1
data_str = str(entry['data'])
hashed_data = hashlib.sha256(data_str.encode()).hexdigest()
integrity_hash = self._calculate_hmac(f"{entry['operation']}:{entry['user']}:{hashed_data}")
conn.execute("""
INSERT INTO audit_logs (
sequence, timestamp, operation, key_hash,
encrypted_cron, obfuscated_task_id, success, user, reason,
integrity_hash, previous_hash
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
self.sequence,
entry['timestamp'],
entry['operation'],
hashed_data,
1 if entry['status'] == 'completed' else 0,
entry['user'],
entry['status'],
integrity_hash,
self.last_hash
))
self.last_hash = integrity_hash
# Reset timer
if hasattr(self, '_batch_timer'):
self._batch_timer.cancel()
self._batch_timer = threading.Timer(1.0, self._flush_batch)
self._batch_timer.start()