Sage100-vps/services/api_key.py

205 lines
6.6 KiB
Python

import secrets
import hashlib
import json
from datetime import datetime, timedelta
from typing import Optional, List, Dict
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, and_, or_
import logging
from database.models.api_key import ApiKey
logger = logging.getLogger(__name__)
class ApiKeyService:
"""Service de gestion des clés API"""
def __init__(self, session: AsyncSession):
self.session = session
@staticmethod
def generate_api_key() -> str:
"""Génère une clé API unique et sécurisée"""
random_part = secrets.token_urlsafe(32)
return f"sdk_live_{random_part}"
@staticmethod
def hash_api_key(api_key: str) -> str:
"""Hash la clé API pour stockage sécurisé"""
return hashlib.sha256(api_key.encode()).hexdigest()
@staticmethod
def get_key_prefix(api_key: str) -> str:
"""Extrait le préfixe de la clé pour identification"""
return api_key[:12] if len(api_key) >= 12 else api_key
async def create_api_key(
self,
name: str,
description: Optional[str] = None,
created_by: str = "system",
user_id: Optional[str] = None,
expires_in_days: Optional[int] = None,
rate_limit_per_minute: int = 60,
allowed_endpoints: Optional[List[str]] = None,
) -> tuple[ApiKey, str]:
api_key_plain = self.generate_api_key()
key_hash = self.hash_api_key(api_key_plain)
key_prefix = self.get_key_prefix(api_key_plain)
expires_at = None
if expires_in_days:
expires_at = datetime.now() + timedelta(days=expires_in_days)
api_key_obj = ApiKey(
key_hash=key_hash,
key_prefix=key_prefix,
name=name,
description=description,
created_by=created_by,
user_id=user_id,
expires_at=expires_at,
rate_limit_per_minute=rate_limit_per_minute,
allowed_endpoints=json.dumps(allowed_endpoints)
if allowed_endpoints
else None,
)
self.session.add(api_key_obj)
await self.session.commit()
await self.session.refresh(api_key_obj)
logger.info(f" Clé API créée: {name} (prefix: {key_prefix})")
return api_key_obj, api_key_plain
async def verify_api_key(self, api_key_plain: str) -> Optional[ApiKey]:
key_hash = self.hash_api_key(api_key_plain)
result = await self.session.execute(
select(ApiKey).where(
and_(
ApiKey.key_hash == key_hash,
ApiKey.is_active,
ApiKey.revoked_at.is_(None),
or_(
ApiKey.expires_at.is_(None), ApiKey.expires_at > datetime.now()
),
)
)
)
api_key_obj = result.scalar_one_or_none()
if api_key_obj:
api_key_obj.total_requests += 1
api_key_obj.last_used_at = datetime.now()
await self.session.commit()
logger.debug(f" Clé API validée: {api_key_obj.name}")
else:
logger.warning(" Clé API invalide ou expirée")
return api_key_obj
async def list_api_keys(
self,
include_revoked: bool = False,
user_id: Optional[str] = None,
) -> List[ApiKey]:
"""Liste les clés API"""
query = select(ApiKey)
if not include_revoked:
query = query.where(ApiKey.revoked_at.is_(None))
if user_id:
query = query.where(ApiKey.user_id == user_id)
query = query.order_by(ApiKey.created_at.desc())
result = await self.session.execute(query)
return list(result.scalars().all())
async def revoke_api_key(self, key_id: str) -> bool:
"""Révoque une clé API"""
result = await self.session.execute(select(ApiKey).where(ApiKey.id == key_id))
api_key_obj = result.scalar_one_or_none()
if not api_key_obj:
return False
api_key_obj.is_active = False
api_key_obj.revoked_at = datetime.now()
await self.session.commit()
logger.info(f" Clé API révoquée: {api_key_obj.name}")
return True
async def get_by_id(self, key_id: str) -> Optional[ApiKey]:
"""Récupère une clé API par son ID"""
result = await self.session.execute(select(ApiKey).where(ApiKey.id == key_id))
return result.scalar_one_or_none()
async def check_rate_limit(self, api_key_obj: ApiKey) -> tuple[bool, Dict]:
return True, {
"allowed": True,
"limit": api_key_obj.rate_limit_per_minute,
"remaining": api_key_obj.rate_limit_per_minute,
}
async def check_endpoint_access(self, api_key_obj: ApiKey, endpoint: str) -> bool:
"""Vérifie si la clé a accès à un endpoint spécifique"""
if not api_key_obj.allowed_endpoints:
return True
try:
allowed = json.loads(api_key_obj.allowed_endpoints)
for pattern in allowed:
if pattern == "*":
return True
if pattern.endswith("*"):
prefix = pattern[:-1]
if endpoint.startswith(prefix):
return True
if pattern == endpoint:
return True
return False
except json.JSONDecodeError:
logger.error(f" Erreur parsing allowed_endpoints pour {api_key_obj.id}")
return False
def api_key_to_response(api_key_obj: ApiKey, show_key: bool = False) -> Dict:
"""Convertit un objet ApiKey en réponse API"""
allowed_endpoints = None
if api_key_obj.allowed_endpoints:
try:
allowed_endpoints = json.loads(api_key_obj.allowed_endpoints)
except json.JSONDecodeError:
pass
is_expired = False
if api_key_obj.expires_at:
is_expired = api_key_obj.expires_at < datetime.now()
return {
"id": api_key_obj.id,
"name": api_key_obj.name,
"description": api_key_obj.description,
"key_prefix": api_key_obj.key_prefix,
"is_active": api_key_obj.is_active,
"is_expired": is_expired,
"rate_limit_per_minute": api_key_obj.rate_limit_per_minute,
"allowed_endpoints": allowed_endpoints,
"total_requests": api_key_obj.total_requests,
"last_used_at": api_key_obj.last_used_at,
"created_at": api_key_obj.created_at,
"expires_at": api_key_obj.expires_at,
"revoked_at": api_key_obj.revoked_at,
"created_by": api_key_obj.created_by,
}