import secrets import hashlib import json from datetime import datetime, timedelta from typing import Optional, List, Dict from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select, and_, or_ import logging from database.models.api_key import ApiKey logger = logging.getLogger(__name__) class ApiKeyService: """Service de gestion des clés API""" def __init__(self, session: AsyncSession): self.session = session @staticmethod def generate_api_key() -> str: """Génère une clé API unique et sécurisée""" random_part = secrets.token_urlsafe(32) return f"sdk_live_{random_part}" @staticmethod def hash_api_key(api_key: str) -> str: """Hash la clé API pour stockage sécurisé""" return hashlib.sha256(api_key.encode()).hexdigest() @staticmethod def get_key_prefix(api_key: str) -> str: """Extrait le préfixe de la clé pour identification""" return api_key[:12] if len(api_key) >= 12 else api_key async def create_api_key( self, name: str, description: Optional[str] = None, created_by: str = "system", user_id: Optional[str] = None, expires_in_days: Optional[int] = None, rate_limit_per_minute: int = 60, allowed_endpoints: Optional[List[str]] = None, ) -> tuple[ApiKey, str]: api_key_plain = self.generate_api_key() key_hash = self.hash_api_key(api_key_plain) key_prefix = self.get_key_prefix(api_key_plain) expires_at = None if expires_in_days: expires_at = datetime.now() + timedelta(days=expires_in_days) api_key_obj = ApiKey( key_hash=key_hash, key_prefix=key_prefix, name=name, description=description, created_by=created_by, user_id=user_id, expires_at=expires_at, rate_limit_per_minute=rate_limit_per_minute, allowed_endpoints=json.dumps(allowed_endpoints) if allowed_endpoints else None, ) self.session.add(api_key_obj) await self.session.commit() await self.session.refresh(api_key_obj) logger.info(f" Clé API créée: {name} (prefix: {key_prefix})") return api_key_obj, api_key_plain async def verify_api_key(self, api_key_plain: str) -> Optional[ApiKey]: key_hash = self.hash_api_key(api_key_plain) result = await self.session.execute( select(ApiKey).where( and_( ApiKey.key_hash == key_hash, ApiKey.is_active, ApiKey.revoked_at.is_(None), or_( ApiKey.expires_at.is_(None), ApiKey.expires_at > datetime.now() ), ) ) ) api_key_obj = result.scalar_one_or_none() if api_key_obj: api_key_obj.total_requests += 1 api_key_obj.last_used_at = datetime.now() await self.session.commit() logger.debug(f" Clé API validée: {api_key_obj.name}") else: logger.warning(" Clé API invalide ou expirée") return api_key_obj async def list_api_keys( self, include_revoked: bool = False, user_id: Optional[str] = None, ) -> List[ApiKey]: """Liste les clés API""" query = select(ApiKey) if not include_revoked: query = query.where(ApiKey.revoked_at.is_(None)) if user_id: query = query.where(ApiKey.user_id == user_id) query = query.order_by(ApiKey.created_at.desc()) result = await self.session.execute(query) return list(result.scalars().all()) async def revoke_api_key(self, key_id: str) -> bool: """Révoque une clé API""" result = await self.session.execute(select(ApiKey).where(ApiKey.id == key_id)) api_key_obj = result.scalar_one_or_none() if not api_key_obj: return False api_key_obj.is_active = False api_key_obj.revoked_at = datetime.now() await self.session.commit() logger.info(f" Clé API révoquée: {api_key_obj.name}") return True async def get_by_id(self, key_id: str) -> Optional[ApiKey]: """Récupère une clé API par son ID""" result = await self.session.execute(select(ApiKey).where(ApiKey.id == key_id)) return result.scalar_one_or_none() async def check_rate_limit(self, api_key_obj: ApiKey) -> tuple[bool, Dict]: return True, { "allowed": True, "limit": api_key_obj.rate_limit_per_minute, "remaining": api_key_obj.rate_limit_per_minute, } async def check_endpoint_access(self, api_key_obj: ApiKey, endpoint: str) -> bool: """Vérifie si la clé a accès à un endpoint spécifique""" if not api_key_obj.allowed_endpoints: return True try: allowed = json.loads(api_key_obj.allowed_endpoints) for pattern in allowed: if pattern == "*": return True if pattern.endswith("*"): prefix = pattern[:-1] if endpoint.startswith(prefix): return True if pattern == endpoint: return True return False except json.JSONDecodeError: logger.error(f" Erreur parsing allowed_endpoints pour {api_key_obj.id}") return False def api_key_to_response(api_key_obj: ApiKey, show_key: bool = False) -> Dict: """Convertit un objet ApiKey en réponse API""" allowed_endpoints = None if api_key_obj.allowed_endpoints: try: allowed_endpoints = json.loads(api_key_obj.allowed_endpoints) except json.JSONDecodeError: pass is_expired = False if api_key_obj.expires_at: is_expired = api_key_obj.expires_at < datetime.now() return { "id": api_key_obj.id, "name": api_key_obj.name, "description": api_key_obj.description, "key_prefix": api_key_obj.key_prefix, "is_active": api_key_obj.is_active, "is_expired": is_expired, "rate_limit_per_minute": api_key_obj.rate_limit_per_minute, "allowed_endpoints": allowed_endpoints, "total_requests": api_key_obj.total_requests, "last_used_at": api_key_obj.last_used_at, "created_at": api_key_obj.created_at, "expires_at": api_key_obj.expires_at, "revoked_at": api_key_obj.revoked_at, "created_by": api_key_obj.created_by, }