ai-agent/security/audit.py

277 lines
No EOL
10 KiB
Python

import os
import hashlib
import hmac
import threading
import sqlite3
from datetime import datetime, timedelta
from typing import Dict, List, Optional
from pathlib import Path
from cryptography.fernet import Fernet
class SecureAudit:
def __init__(self, rbac_engine, db_path: str = "audit.db", key_path: str = "audit.key"):
"""Initialize secure audit logger with:
- AES-256 encryption for cron expressions and sensitive data
- HMAC-SHA256 obfuscation for task IDs
- Chained timestamp integrity verification"""
self.rbac = rbac_engine
self.sequence = 0
self._lock = threading.Lock()
self.last_hash = ""
# Initialize key management
self.key_path = Path(key_path)
self.hmac_key = self._init_key()
self.fernet = Fernet(Fernet.generate_key())
# Initialize database
self.db_path = Path(db_path)
self._init_db()
def _init_key(self) -> bytes:
"""Initialize or load HMAC key"""
if self.key_path.exists():
with open(self.key_path, "rb") as f:
return f.read()
else:
key = hashlib.sha256(os.urandom(32)).digest()
with open(self.key_path, "wb") as f:
f.write(key)
self.key_path.chmod(0o600) # Restrict permissions
return key
def _init_db(self):
"""Initialize SQLite database"""
with sqlite3.connect(self.db_path) as conn:
conn.execute("""
CREATE TABLE IF NOT EXISTS audit_logs (
id INTEGER PRIMARY KEY,
sequence INTEGER,
timestamp TEXT,
operation TEXT,
key_hash TEXT,
encrypted_key TEXT,
encrypted_cron TEXT DEFAULT '',
obfuscated_task_id TEXT DEFAULT '',
success INTEGER,
user TEXT,
reason TEXT,
integrity_hash TEXT,
previous_hash TEXT,
FOREIGN KEY(previous_hash) REFERENCES audit_logs(integrity_hash)
)
""")
conn.execute("CREATE INDEX IF NOT EXISTS idx_timestamp ON audit_logs(timestamp)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_user ON audit_logs(user)")
conn.execute("CREATE INDEX IF NOT EXISTS idx_operation ON audit_logs(operation)")
def _calculate_hmac(self, data: str) -> str:
"""Calculate HMAC-SHA256 with:
- Chained hashes for tamper detection
- Timestamp integrity verification
- Task ID obfuscation"""
timestamp = datetime.utcnow().isoformat()
return hmac.new(
self.hmac_key,
(data + self.last_hash + timestamp).encode(),
hashlib.sha256
).hexdigest()
def _verify_timestamp(self, timestamp: str, max_skew: int = 30) -> bool:
"""Verify timestamp integrity with allowed clock skew (seconds)"""
log_time = datetime.fromisoformat(timestamp)
now = datetime.utcnow()
return abs((now - log_time).total_seconds()) <= max_skew
def _obfuscate_task_id(self, task_id: str) -> str:
"""Obfuscate task IDs with HMAC-SHA256 and salt"""
salt = os.urandom(16).hex()
return hmac.new(
self.hmac_key,
(task_id + salt).encode(),
hashlib.sha256
).hexdigest()
def log_operation(
self,
operation: str,
key: str,
success: bool,
user: Optional[str] = None,
reason: Optional[str] = None,
cron: Optional[str] = None,
task_id: Optional[str] = None
) -> str:
"""Log an operation with:
- HMAC-SHA256 integrity protection
- AES-256 encrypted cron expressions
- Obfuscated task IDs"""
with self._lock:
self.sequence += 1
timestamp = datetime.utcnow().isoformat()
# Encrypt sensitive data with AES-256
encrypted_key = self.fernet.encrypt(key.encode()).decode()
hashed_key = hashlib.sha256(encrypted_key.encode()).hexdigest()
# Encrypt cron if provided
encrypted_cron = ""
if cron:
encrypted_cron = self.fernet.encrypt(cron.encode()).decode()
# Obfuscate task ID if provided
obfuscated_task_id = ""
if task_id:
obfuscated_task_id = self._obfuscate_task_id(task_id)
entry = {
"sequence": self.sequence,
"timestamp": timestamp,
"operation": operation,
"key_hash": hashed_key,
"encrypted_cron": encrypted_cron,
"obfuscated_task_id": obfuscated_task_id,
"success": success,
"user": user,
"reason": reason or "",
"previous_hash": self.last_hash
}
# Calculate HMAC-SHA256 integrity hash
integrity_hash = self._calculate_hmac(str(entry))
entry["integrity_hash"] = integrity_hash
self.last_hash = integrity_hash
# Store in database
with sqlite3.connect(self.db_path) as conn:
conn.execute("""
INSERT INTO audit_logs (
sequence, timestamp, operation, key_hash, encrypted_key,
encrypted_cron, obfuscated_task_id, success, user, reason,
integrity_hash, previous_hash
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
entry["sequence"],
entry["timestamp"],
entry["operation"],
entry["key_hash"],
encrypted_key,
entry["encrypted_cron"],
entry["obfuscated_task_id"],
int(entry["success"]),
entry["user"],
entry["reason"],
entry["integrity_hash"],
entry["previous_hash"]
))
# Notify RBAC system
if user:
self.rbac._audit_access_attempt(
user,
"memory",
operation,
success,
reason or f"Memory {operation} operation"
)
return integrity_hash
def verify_log_integrity(self) -> bool:
"""Verify all log entries maintain:
- Integrity chain
- Valid timestamps
- Proper encryption"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute("""
SELECT sequence, integrity_hash, previous_hash
FROM audit_logs
ORDER BY sequence
""")
last_hash = ""
for row in cursor:
seq, current_hash, prev_hash = row
if seq == 1:
if prev_hash != "":
return False
else:
if prev_hash != last_hash:
return False
# Verify timestamp is within acceptable skew
timestamp_row = conn.execute(
"SELECT timestamp FROM audit_logs WHERE sequence = ?",
(seq,)
).fetchone()
if not self._verify_timestamp(timestamp_row[0]):
return False
last_hash = current_hash
return True
def purge_old_entries(self, days: int = 90):
"""Purge entries older than specified days"""
cutoff = (datetime.utcnow() - timedelta(days=days)).isoformat()
with sqlite3.connect(self.db_path) as conn:
conn.execute("DELETE FROM audit_logs WHERE timestamp < ?", (cutoff,))
def queue_access(self, operation: str, user: str, data: dict, status: str):
"""Queue an access attempt for batched logging"""
with self._lock:
if not hasattr(self, '_batch_queue'):
self._batch_queue = []
self._batch_timer = threading.Timer(1.0, self._flush_batch)
self._batch_timer.start()
self._batch_queue.append({
'operation': operation,
'user': user,
'data': data,
'status': status,
'timestamp': datetime.utcnow().isoformat()
})
if len(self._batch_queue) >= 10: # Flush if batch size reaches 10
self._flush_batch()
def _flush_batch(self):
"""Flush queued audit entries to database"""
if not hasattr(self, '_batch_queue') or not self._batch_queue:
return
with self._lock:
batch = self._batch_queue
self._batch_queue = []
with sqlite3.connect(self.db_path) as conn:
for entry in batch:
self.sequence += 1
data_str = str(entry['data'])
hashed_data = hashlib.sha256(data_str.encode()).hexdigest()
integrity_hash = self._calculate_hmac(f"{entry['operation']}:{entry['user']}:{hashed_data}")
conn.execute("""
INSERT INTO audit_logs (
sequence, timestamp, operation, key_hash,
encrypted_cron, obfuscated_task_id, success, user, reason,
integrity_hash, previous_hash
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
""", (
self.sequence,
entry['timestamp'],
entry['operation'],
hashed_data,
1 if entry['status'] == 'completed' else 0,
entry['user'],
entry['status'],
integrity_hash,
self.last_hash
))
self.last_hash = integrity_hash
# Reset timer
if hasattr(self, '_batch_timer'):
self._batch_timer.cancel()
self._batch_timer = threading.Timer(1.0, self._flush_batch)
self._batch_timer.start()