ai-agent/security/memory/audit.py

107 lines
No EOL
3.7 KiB
Python

import hashlib
import hmac
import threading
import os
from datetime import datetime
from typing import Dict, List, Optional, Tuple
class MemoryAudit:
def __init__(self, rbac_engine, hmac_key: Optional[bytes] = None):
"""Initialize audit logger with RBAC integration and HMAC protection"""
self.rbac = rbac_engine
self.sequence = 0
self.log_entries: List[Dict] = []
self._lock = threading.Lock()
self.hmac_key = hmac_key or os.urandom(32) # 256-bit HMAC key
self._previous_hash = b'' # For chained hashes
def _generate_hmac(self, data: str) -> Tuple[str, bytes]:
"""Generate HMAC-SHA256 hash with chaining"""
h = hmac.new(self.hmac_key, digestmod=hashlib.sha256)
h.update(self._previous_hash) # Chain with previous hash
h.update(data.encode())
current_hash = h.digest()
self._previous_hash = current_hash
return h.hexdigest(), current_hash
def log_operation(
self,
operation: str,
key: str,
success: bool,
user: Optional[str] = None,
reason: Optional[str] = None
) -> str:
"""Log an operation with integrity verification hash"""
with self._lock:
self.sequence += 1
timestamp = datetime.utcnow().isoformat()
hashed_key = self._hash_key(key)
entry = {
"sequence": self.sequence,
"timestamp": timestamp,
"operation": operation,
"key_hash": hashed_key,
"success": success,
"user": user,
"reason": reason or ""
}
# Generate HMAC-SHA256 integrity hash with chaining
integrity_hash, _ = self._generate_hmac(str(entry))
entry["integrity_hash"] = integrity_hash
# Store entry
self.log_entries.append(entry)
# Notify RBAC system
if user:
self.rbac._audit_access_attempt(
user,
"memory",
operation,
success,
reason or f"Memory {operation} operation"
)
return integrity_hash
def verify_log_integrity(self) -> bool:
"""Verify all log entries' HMAC integrity with chaining"""
with self._lock:
if not self.log_entries:
return True
# Recompute all hashes with chaining
test_key = self.hmac_key
previous_hash = b''
for entry in self.log_entries:
h = hmac.new(test_key, digestmod=hashlib.sha256)
h.update(previous_hash)
h.update(str({k:v for k,v in entry.items()
if k != "integrity_hash"}).encode())
computed_hash = h.hexdigest()
if computed_hash != entry["integrity_hash"]:
return False
previous_hash = h.digest()
return True
def by_operation(self, operation: str) -> List[Dict]:
"""Filter log entries by operation type"""
with self._lock:
return [entry for entry in self.log_entries
if entry["operation"] == operation]
def by_user(self, user: str) -> List[Dict]:
"""Filter log entries by user"""
with self._lock:
return [entry for entry in self.log_entries
if entry.get("user") == user]
def by_time_range(self, start: str, end: str) -> List[Dict]:
"""Filter log entries between start and end timestamps (ISO format)"""
with self._lock:
return [entry for entry in self.log_entries
if start <= entry["timestamp"] <= end]