import hashlib import hmac import threading import os from datetime import datetime from typing import Dict, List, Optional, Tuple class MemoryAudit: def __init__(self, rbac_engine, hmac_key: Optional[bytes] = None): """Initialize audit logger with RBAC integration and HMAC protection""" self.rbac = rbac_engine self.sequence = 0 self.log_entries: List[Dict] = [] self._lock = threading.Lock() self.hmac_key = hmac_key or os.urandom(32) # 256-bit HMAC key self._previous_hash = b'' # For chained hashes def _generate_hmac(self, data: str) -> Tuple[str, bytes]: """Generate HMAC-SHA256 hash with chaining""" h = hmac.new(self.hmac_key, digestmod=hashlib.sha256) h.update(self._previous_hash) # Chain with previous hash h.update(data.encode()) current_hash = h.digest() self._previous_hash = current_hash return h.hexdigest(), current_hash def log_operation( self, operation: str, key: str, success: bool, user: Optional[str] = None, reason: Optional[str] = None ) -> str: """Log an operation with integrity verification hash""" with self._lock: self.sequence += 1 timestamp = datetime.utcnow().isoformat() hashed_key = self._hash_key(key) entry = { "sequence": self.sequence, "timestamp": timestamp, "operation": operation, "key_hash": hashed_key, "success": success, "user": user, "reason": reason or "" } # Generate HMAC-SHA256 integrity hash with chaining integrity_hash, _ = self._generate_hmac(str(entry)) entry["integrity_hash"] = integrity_hash # Store entry self.log_entries.append(entry) # Notify RBAC system if user: self.rbac._audit_access_attempt( user, "memory", operation, success, reason or f"Memory {operation} operation" ) return integrity_hash def verify_log_integrity(self) -> bool: """Verify all log entries' HMAC integrity with chaining""" with self._lock: if not self.log_entries: return True # Recompute all hashes with chaining test_key = self.hmac_key previous_hash = b'' for entry in self.log_entries: h = hmac.new(test_key, digestmod=hashlib.sha256) h.update(previous_hash) h.update(str({k:v for k,v in entry.items() if k != "integrity_hash"}).encode()) computed_hash = h.hexdigest() if computed_hash != entry["integrity_hash"]: return False previous_hash = h.digest() return True def by_operation(self, operation: str) -> List[Dict]: """Filter log entries by operation type""" with self._lock: return [entry for entry in self.log_entries if entry["operation"] == operation] def by_user(self, user: str) -> List[Dict]: """Filter log entries by user""" with self._lock: return [entry for entry in self.log_entries if entry.get("user") == user] def by_time_range(self, start: str, end: str) -> List[Dict]: """Filter log entries between start and end timestamps (ISO format)""" with self._lock: return [entry for entry in self.log_entries if start <= entry["timestamp"] <= end]