233 lines
7.6 KiB
Python
233 lines
7.6 KiB
Python
import secrets
|
|
import hashlib
|
|
import json
|
|
from datetime import datetime, timedelta
|
|
from typing import Optional, List, Dict
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select, and_, or_
|
|
import logging
|
|
|
|
from database.models.api_key import ApiKey
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ApiKeyService:
|
|
"""Service de gestion des clés API"""
|
|
|
|
def __init__(self, session: AsyncSession):
|
|
self.session = session
|
|
|
|
@staticmethod
|
|
def generate_api_key() -> str:
|
|
"""Génère une clé API unique et sécurisée"""
|
|
# Format: sdk_live_xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx
|
|
random_part = secrets.token_urlsafe(32)
|
|
return f"sdk_live_{random_part}"
|
|
|
|
@staticmethod
|
|
def hash_api_key(api_key: str) -> str:
|
|
"""Hash la clé API pour stockage sécurisé"""
|
|
return hashlib.sha256(api_key.encode()).hexdigest()
|
|
|
|
@staticmethod
|
|
def get_key_prefix(api_key: str) -> str:
|
|
"""Extrait le préfixe de la clé pour identification"""
|
|
# Retourne les 12 premiers caractères
|
|
return api_key[:12] if len(api_key) >= 12 else api_key
|
|
|
|
async def create_api_key(
|
|
self,
|
|
name: str,
|
|
description: Optional[str] = None,
|
|
created_by: str = "system",
|
|
user_id: Optional[str] = None,
|
|
expires_in_days: Optional[int] = None,
|
|
rate_limit_per_minute: int = 60,
|
|
allowed_endpoints: Optional[List[str]] = None,
|
|
) -> tuple[ApiKey, str]:
|
|
"""
|
|
Crée une nouvelle clé API
|
|
|
|
Returns:
|
|
tuple[ApiKey, str]: (objet ApiKey, clé en clair - à ne montrer qu'une fois)
|
|
"""
|
|
# Génération de la clé
|
|
api_key_plain = self.generate_api_key()
|
|
key_hash = self.hash_api_key(api_key_plain)
|
|
key_prefix = self.get_key_prefix(api_key_plain)
|
|
|
|
# Calcul de la date d'expiration
|
|
expires_at = None
|
|
if expires_in_days:
|
|
expires_at = datetime.now() + timedelta(days=expires_in_days)
|
|
|
|
# Création de l'objet
|
|
api_key_obj = ApiKey(
|
|
key_hash=key_hash,
|
|
key_prefix=key_prefix,
|
|
name=name,
|
|
description=description,
|
|
created_by=created_by,
|
|
user_id=user_id,
|
|
expires_at=expires_at,
|
|
rate_limit_per_minute=rate_limit_per_minute,
|
|
allowed_endpoints=json.dumps(allowed_endpoints)
|
|
if allowed_endpoints
|
|
else None,
|
|
)
|
|
|
|
self.session.add(api_key_obj)
|
|
await self.session.commit()
|
|
await self.session.refresh(api_key_obj)
|
|
|
|
logger.info(f"✅ Clé API créée: {name} (prefix: {key_prefix})")
|
|
|
|
return api_key_obj, api_key_plain
|
|
|
|
async def verify_api_key(self, api_key_plain: str) -> Optional[ApiKey]:
|
|
"""
|
|
Vérifie une clé API et retourne l'objet si valide
|
|
|
|
Returns:
|
|
Optional[ApiKey]: L'objet ApiKey si valide, None sinon
|
|
"""
|
|
key_hash = self.hash_api_key(api_key_plain)
|
|
|
|
result = await self.session.execute(
|
|
select(ApiKey).where(
|
|
and_(
|
|
ApiKey.key_hash == key_hash,
|
|
ApiKey.is_active == True,
|
|
ApiKey.revoked_at.is_(None),
|
|
or_(
|
|
ApiKey.expires_at.is_(None), ApiKey.expires_at > datetime.now()
|
|
),
|
|
)
|
|
)
|
|
)
|
|
|
|
api_key_obj = result.scalar_one_or_none()
|
|
|
|
if api_key_obj:
|
|
# Mise à jour des statistiques
|
|
api_key_obj.total_requests += 1
|
|
api_key_obj.last_used_at = datetime.now()
|
|
await self.session.commit()
|
|
|
|
logger.debug(f"🔑 Clé API validée: {api_key_obj.name}")
|
|
else:
|
|
logger.warning(f"⚠️ Clé API invalide ou expirée")
|
|
|
|
return api_key_obj
|
|
|
|
async def list_api_keys(
|
|
self,
|
|
include_revoked: bool = False,
|
|
user_id: Optional[str] = None,
|
|
) -> List[ApiKey]:
|
|
"""Liste les clés API"""
|
|
query = select(ApiKey)
|
|
|
|
if not include_revoked:
|
|
query = query.where(ApiKey.revoked_at.is_(None))
|
|
|
|
if user_id:
|
|
query = query.where(ApiKey.user_id == user_id)
|
|
|
|
query = query.order_by(ApiKey.created_at.desc())
|
|
|
|
result = await self.session.execute(query)
|
|
return list(result.scalars().all())
|
|
|
|
async def revoke_api_key(self, key_id: str) -> bool:
|
|
"""Révoque une clé API"""
|
|
result = await self.session.execute(select(ApiKey).where(ApiKey.id == key_id))
|
|
api_key_obj = result.scalar_one_or_none()
|
|
|
|
if not api_key_obj:
|
|
return False
|
|
|
|
api_key_obj.is_active = False
|
|
api_key_obj.revoked_at = datetime.now()
|
|
await self.session.commit()
|
|
|
|
logger.info(f"🚫 Clé API révoquée: {api_key_obj.name}")
|
|
return True
|
|
|
|
async def get_by_id(self, key_id: str) -> Optional[ApiKey]:
|
|
"""Récupère une clé API par son ID"""
|
|
result = await self.session.execute(select(ApiKey).where(ApiKey.id == key_id))
|
|
return result.scalar_one_or_none()
|
|
|
|
async def check_rate_limit(self, api_key_obj: ApiKey) -> tuple[bool, Dict]:
|
|
"""
|
|
Vérifie le rate limit d'une clé API (à implémenter avec Redis/cache)
|
|
|
|
Returns:
|
|
tuple[bool, Dict]: (is_allowed, info_dict)
|
|
"""
|
|
# TODO: Implémenter avec Redis pour un vrai rate limiting
|
|
# Pour l'instant, retourne toujours True
|
|
return True, {
|
|
"allowed": True,
|
|
"limit": api_key_obj.rate_limit_per_minute,
|
|
"remaining": api_key_obj.rate_limit_per_minute,
|
|
}
|
|
|
|
async def check_endpoint_access(self, api_key_obj: ApiKey, endpoint: str) -> bool:
|
|
"""Vérifie si la clé a accès à un endpoint spécifique"""
|
|
if not api_key_obj.allowed_endpoints:
|
|
# Si aucune restriction, accès total
|
|
return True
|
|
|
|
try:
|
|
allowed = json.loads(api_key_obj.allowed_endpoints)
|
|
|
|
# Support des wildcards
|
|
for pattern in allowed:
|
|
if pattern == "*":
|
|
return True
|
|
if pattern.endswith("*"):
|
|
prefix = pattern[:-1]
|
|
if endpoint.startswith(prefix):
|
|
return True
|
|
if pattern == endpoint:
|
|
return True
|
|
|
|
return False
|
|
except json.JSONDecodeError:
|
|
logger.error(f"❌ Erreur parsing allowed_endpoints pour {api_key_obj.id}")
|
|
return False
|
|
|
|
|
|
def api_key_to_response(api_key_obj: ApiKey, show_key: bool = False) -> Dict:
|
|
"""Convertit un objet ApiKey en réponse API"""
|
|
|
|
allowed_endpoints = None
|
|
if api_key_obj.allowed_endpoints:
|
|
try:
|
|
allowed_endpoints = json.loads(api_key_obj.allowed_endpoints)
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
is_expired = False
|
|
if api_key_obj.expires_at:
|
|
is_expired = api_key_obj.expires_at < datetime.now()
|
|
|
|
return {
|
|
"id": api_key_obj.id,
|
|
"name": api_key_obj.name,
|
|
"description": api_key_obj.description,
|
|
"key_prefix": api_key_obj.key_prefix,
|
|
"is_active": api_key_obj.is_active,
|
|
"is_expired": is_expired,
|
|
"rate_limit_per_minute": api_key_obj.rate_limit_per_minute,
|
|
"allowed_endpoints": allowed_endpoints,
|
|
"total_requests": api_key_obj.total_requests,
|
|
"last_used_at": api_key_obj.last_used_at,
|
|
"created_at": api_key_obj.created_at,
|
|
"expires_at": api_key_obj.expires_at,
|
|
"revoked_at": api_key_obj.revoked_at,
|
|
"created_by": api_key_obj.created_by,
|
|
}
|