From abc9ff820a807d67ecdced7b3ba2bd69d5733a80 Mon Sep 17 00:00:00 2001 From: Fanilo-Nantenaina Date: Tue, 20 Jan 2026 11:11:32 +0300 Subject: [PATCH] feat(security): implement api key management and authentication system --- config/cors_config.py | 125 +++++++++++++ database/models/api_key.py | 56 ++++++ middleware/security.py | 236 +++++++++++++++++++++++++ routes/api_keys.py | 154 ++++++++++++++++ schemas/api_key.py | 77 ++++++++ scripts/manage_security.py | 264 +++++++++++++++++++++++++++ scripts/test_security.py | 354 +++++++++++++++++++++++++++++++++++++ services/api_key.py | 205 +++++++++++++++++++++ 8 files changed, 1471 insertions(+) create mode 100644 config/cors_config.py create mode 100644 database/models/api_key.py create mode 100644 middleware/security.py create mode 100644 routes/api_keys.py create mode 100644 schemas/api_key.py create mode 100644 scripts/manage_security.py create mode 100644 scripts/test_security.py create mode 100644 services/api_key.py diff --git a/config/cors_config.py b/config/cors_config.py new file mode 100644 index 0000000..0f3a4d2 --- /dev/null +++ b/config/cors_config.py @@ -0,0 +1,125 @@ +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from typing import List +import os +import logging + +logger = logging.getLogger(__name__) + + +def configure_cors_open(app: FastAPI): + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=False, + allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"], + allow_headers=["*"], + expose_headers=["X-RateLimit-Limit", "X-RateLimit-Remaining"], + max_age=3600, + ) + + logger.info(" CORS configuré: Mode OUVERT (sécurisé par API Keys)") + logger.info(" - Origins: * (toutes)") + logger.info(" - Headers: * (dont X-API-Key)") + logger.info(" - Credentials: False") + + +def configure_cors_whitelist(app: FastAPI): + allowed_origins_str = os.getenv("CORS_ALLOWED_ORIGINS", "") + + if allowed_origins_str: + allowed_origins = [ + origin.strip() + for origin in allowed_origins_str.split(",") + if origin.strip() + ] + else: + allowed_origins = ["*"] + + app.add_middleware( + CORSMiddleware, + allow_origins=allowed_origins, + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"], + allow_headers=["Content-Type", "Authorization", "X-API-Key"], + expose_headers=["X-RateLimit-Limit", "X-RateLimit-Remaining"], + max_age=3600, + ) + + logger.info(" CORS configuré: Mode WHITELIST") + logger.info(f" - Origins autorisées: {len(allowed_origins)}") + for origin in allowed_origins: + logger.info(f" • {origin}") + + +def configure_cors_regex(app: FastAPI): + origin_regex = r"*" + + app.add_middleware( + CORSMiddleware, + allow_origin_regex=origin_regex, + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"], + allow_headers=["Content-Type", "Authorization", "X-API-Key"], + expose_headers=["X-RateLimit-Limit", "X-RateLimit-Remaining"], + max_age=3600, + ) + + logger.info(" CORS configuré: Mode REGEX") + logger.info(f" - Pattern: {origin_regex}") + + +def configure_cors_hybrid(app: FastAPI): + from starlette.middleware.base import BaseHTTPMiddleware + + class HybridCORSMiddleware(BaseHTTPMiddleware): + def __init__(self, app, known_origins: List[str]): + super().__init__(app) + self.known_origins = set(known_origins) + + async def dispatch(self, request, call_next): + origin = request.headers.get("origin") + + if origin in self.known_origins: + response = await call_next(request) + response.headers["Access-Control-Allow-Origin"] = origin + response.headers["Access-Control-Allow-Credentials"] = "true" + response.headers["Access-Control-Allow-Methods"] = ( + "GET, POST, PUT, DELETE, PATCH, OPTIONS" + ) + response.headers["Access-Control-Allow-Headers"] = ( + "Content-Type, Authorization, X-API-Key" + ) + return response + + response = await call_next(request) + response.headers["Access-Control-Allow-Origin"] = "*" + response.headers["Access-Control-Allow-Methods"] = ( + "GET, POST, PUT, DELETE, PATCH, OPTIONS" + ) + response.headers["Access-Control-Allow-Headers"] = "*" + return response + + known_origins = ["*"] + + app.add_middleware(HybridCORSMiddleware, known_origins=known_origins) + + logger.info(" CORS configuré: Mode HYBRIDE") + logger.info(f" - Whitelist: {len(known_origins)} domaines") + logger.info(" - Fallback: * (ouvert)") + + +def setup_cors(app: FastAPI, mode: str = "open"): + if mode == "open": + configure_cors_open(app) + elif mode == "whitelist": + configure_cors_whitelist(app) + elif mode == "regex": + configure_cors_regex(app) + elif mode == "hybrid": + configure_cors_hybrid(app) + else: + logger.warning( + f" Mode CORS inconnu: {mode}. Utilisation de 'open' par défaut." + ) + configure_cors_open(app) diff --git a/database/models/api_key.py b/database/models/api_key.py new file mode 100644 index 0000000..0d246ab --- /dev/null +++ b/database/models/api_key.py @@ -0,0 +1,56 @@ +from sqlalchemy import Column, String, Boolean, DateTime, Integer, Text +from datetime import datetime +import uuid + +from database.models.generic_model import Base + + +class ApiKey(Base): + """Modèle pour les clés API publiques""" + + __tablename__ = "api_keys" + + id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + key_hash = Column(String(64), unique=True, nullable=False, index=True) + key_prefix = Column(String(10), nullable=False) + + name = Column(String(255), nullable=False) + description = Column(Text, nullable=True) + + user_id = Column(String(36), nullable=True) + created_by = Column(String(255), nullable=False) + + is_active = Column(Boolean, default=True, nullable=False) + rate_limit_per_minute = Column(Integer, default=60, nullable=False) + allowed_endpoints = Column(Text, nullable=True) + + total_requests = Column(Integer, default=0, nullable=False) + last_used_at = Column(DateTime, nullable=True) + + created_at = Column(DateTime, default=datetime.now, nullable=False) + expires_at = Column(DateTime, nullable=True) + revoked_at = Column(DateTime, nullable=True) + + def __repr__(self): + return f"" + + +class SwaggerUser(Base): + """Modèle pour les utilisateurs autorisés à accéder au Swagger""" + + __tablename__ = "swagger_users" + + id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + username = Column(String(100), unique=True, nullable=False, index=True) + hashed_password = Column(String(255), nullable=False) + + full_name = Column(String(255), nullable=True) + email = Column(String(255), nullable=True) + + is_active = Column(Boolean, default=True, nullable=False) + + created_at = Column(DateTime, default=datetime.now, nullable=False) + last_login = Column(DateTime, nullable=True) + + def __repr__(self): + return f"" diff --git a/middleware/security.py b/middleware/security.py new file mode 100644 index 0000000..137e7dd --- /dev/null +++ b/middleware/security.py @@ -0,0 +1,236 @@ +from fastapi import Request, status +from fastapi.responses import JSONResponse +from fastapi.security import HTTPBasic, HTTPBasicCredentials +from sqlalchemy import select +from typing import Optional +from datetime import datetime +import logging + +from database import get_session +from database.models.api_key import SwaggerUser +from security.auth import verify_password + +logger = logging.getLogger(__name__) + +security = HTTPBasic() + + +async def verify_swagger_credentials(credentials: HTTPBasicCredentials) -> bool: + username = credentials.username + password = credentials.password + + try: + async for session in get_session(): + result = await session.execute( + select(SwaggerUser).where(SwaggerUser.username == username) + ) + swagger_user = result.scalar_one_or_none() + + if swagger_user and swagger_user.is_active: + if verify_password(password, swagger_user.hashed_password): + swagger_user.last_login = datetime.now() + await session.commit() + + logger.info(f" Accès Swagger autorisé (DB): {username}") + return True + + logger.warning(f" Tentative d'accès Swagger refusée: {username}") + return False + + except Exception as e: + logger.error(f" Erreur vérification Swagger credentials: {e}") + return False + + +class SwaggerAuthMiddleware: + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + request = Request(scope, receive=receive) + path = request.url.path + + protected_paths = ["/docs", "/redoc", "/openapi.json"] + + if any(path.startswith(protected_path) for protected_path in protected_paths): + auth_header = request.headers.get("Authorization") + + if not auth_header or not auth_header.startswith("Basic "): + response = JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={ + "detail": "Authentification requise pour accéder à la documentation" + }, + headers={"WWW-Authenticate": 'Basic realm="Swagger UI"'}, + ) + await response(scope, receive, send) + return + + try: + import base64 + + encoded_credentials = auth_header.split(" ")[1] + decoded_credentials = base64.b64decode(encoded_credentials).decode( + "utf-8" + ) + username, password = decoded_credentials.split(":", 1) + + credentials = HTTPBasicCredentials(username=username, password=password) + + if not await verify_swagger_credentials(credentials): + response = JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={"detail": "Identifiants invalides"}, + headers={"WWW-Authenticate": 'Basic realm="Swagger UI"'}, + ) + await response(scope, receive, send) + return + + except Exception as e: + logger.error(f" Erreur parsing auth header: {e}") + response = JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={"detail": "Format d'authentification invalide"}, + headers={"WWW-Authenticate": 'Basic realm="Swagger UI"'}, + ) + await response(scope, receive, send) + return + + await self.app(scope, receive, send) + + +class ApiKeyMiddleware: + def __init__(self, app): + self.app = app + + async def __call__(self, scope, receive, send): + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + request = Request(scope, receive=receive) + path = request.url.path + + excluded_paths = [ + "/docs", + "/redoc", + "/openapi.json", + "/health", + "/", + "/auth/login", + "/auth/register", + "/auth/verify-email", + "/auth/reset-password", + "/auth/request-reset", + "/auth/refresh", + ] + + if any(path.startswith(excluded_path) for excluded_path in excluded_paths): + await self.app(scope, receive, send) + return + + auth_header = request.headers.get("Authorization") + has_jwt = auth_header and auth_header.startswith("Bearer ") + + api_key = request.headers.get("X-API-Key") + has_api_key = api_key is not None + + if has_jwt: + logger.debug(f" JWT détecté pour {path}") + await self.app(scope, receive, send) + return + + elif has_api_key: + logger.debug(f" API Key détectée pour {path}") + + from services.api_key import ApiKeyService + + try: + async for session in get_session(): + api_key_service = ApiKeyService(session) + api_key_obj = await api_key_service.verify_api_key(api_key) + + if not api_key_obj: + response = JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={ + "detail": "Clé API invalide ou expirée", + "hint": "Utilisez X-API-Key: sdk_live_xxx ou Authorization: Bearer ", + }, + ) + await response(scope, receive, send) + return + + is_allowed, rate_info = await api_key_service.check_rate_limit( + api_key_obj + ) + if not is_allowed: + response = JSONResponse( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + content={"detail": "Rate limit dépassé"}, + headers={ + "X-RateLimit-Limit": str(rate_info["limit"]), + "X-RateLimit-Remaining": "0", + }, + ) + await response(scope, receive, send) + return + + has_access = await api_key_service.check_endpoint_access( + api_key_obj, path + ) + if not has_access: + response = JSONResponse( + status_code=status.HTTP_403_FORBIDDEN, + content={ + "detail": "Accès non autorisé à cet endpoint", + "endpoint": path, + "api_key": api_key_obj.key_prefix + "...", + }, + ) + await response(scope, receive, send) + return + + request.state.api_key = api_key_obj + request.state.authenticated_via = "api_key" + + logger.info(f" API Key valide: {api_key_obj.name} → {path}") + + await self.app(scope, receive, send) + return + + except Exception as e: + logger.error(f" Erreur validation API Key: {e}", exc_info=True) + response = JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={ + "detail": "Erreur interne lors de la validation de la clé" + }, + ) + await response(scope, receive, send) + return + + else: + response = JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={ + "detail": "Authentification requise", + "hint": "Utilisez soit 'X-API-Key: sdk_live_xxx' soit 'Authorization: Bearer '", + }, + headers={"WWW-Authenticate": 'Bearer realm="API", charset="UTF-8"'}, + ) + await response(scope, receive, send) + return + + +def get_api_key_from_request(request: Request) -> Optional: + """Récupère l'objet ApiKey depuis la requête si présent""" + return getattr(request.state, "api_key", None) + + +def get_auth_method(request: Request) -> str: + return getattr(request.state, "authenticated_via", "none") diff --git a/routes/api_keys.py b/routes/api_keys.py new file mode 100644 index 0000000..27f0efc --- /dev/null +++ b/routes/api_keys.py @@ -0,0 +1,154 @@ +from fastapi import APIRouter, Depends, HTTPException, status, Query +from sqlalchemy.ext.asyncio import AsyncSession +import logging + +from database import get_session, User +from core.dependencies import get_current_user, require_role +from services.api_key import ApiKeyService, api_key_to_response +from schemas.api_key import ( + ApiKeyCreate, + ApiKeyCreatedResponse, + ApiKeyResponse, + ApiKeyList, +) + +logger = logging.getLogger(__name__) +router = APIRouter(prefix="/api-keys", tags=["API Keys Management"]) + + +@router.post( + "", + response_model=ApiKeyCreatedResponse, + status_code=status.HTTP_201_CREATED, + dependencies=[Depends(require_role("admin", "super_admin"))], +) +async def create_api_key( + data: ApiKeyCreate, + session: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), +): + service = ApiKeyService(session) + + api_key_obj, api_key_plain = await service.create_api_key( + name=data.name, + description=data.description, + created_by=user.email, + user_id=user.id, + expires_in_days=data.expires_in_days, + rate_limit_per_minute=data.rate_limit_per_minute, + allowed_endpoints=data.allowed_endpoints, + ) + + logger.info(f" Clé API créée par {user.email}: {data.name}") + + response_data = api_key_to_response(api_key_obj) + response_data["api_key"] = api_key_plain + + return ApiKeyCreatedResponse(**response_data) + + +@router.get("", response_model=ApiKeyList) +async def list_api_keys( + include_revoked: bool = Query(False, description="Inclure les clés révoquées"), + session: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), +): + service = ApiKeyService(session) + + user_id = None if user.role in ["admin", "super_admin"] else user.id + + keys = await service.list_api_keys(include_revoked=include_revoked, user_id=user_id) + + items = [ApiKeyResponse(**api_key_to_response(k)) for k in keys] + + return ApiKeyList(total=len(items), items=items) + + +@router.get("/{key_id}", response_model=ApiKeyResponse) +async def get_api_key( + key_id: str, + session: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), +): + """ Récupérer une clé API par son ID""" + service = ApiKeyService(session) + + api_key_obj = await service.get_by_id(key_id) + + if not api_key_obj: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Clé API {key_id} introuvable", + ) + + if user.role not in ["admin", "super_admin"]: + if api_key_obj.user_id != user.id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Accès refusé à cette clé", + ) + + return ApiKeyResponse(**api_key_to_response(api_key_obj)) + + +@router.delete("/{key_id}", status_code=status.HTTP_200_OK) +async def revoke_api_key( + key_id: str, + session: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), +): + service = ApiKeyService(session) + + api_key_obj = await service.get_by_id(key_id) + + if not api_key_obj: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Clé API {key_id} introuvable", + ) + + if user.role not in ["admin", "super_admin"]: + if api_key_obj.user_id != user.id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Accès refusé à cette clé", + ) + + success = await service.revoke_api_key(key_id) + + if not success: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Erreur lors de la révocation", + ) + + logger.info(f" Clé API révoquée par {user.email}: {api_key_obj.name}") + + return { + "success": True, + "message": f"Clé API '{api_key_obj.name}' révoquée avec succès", + } + + +@router.post("/verify", status_code=status.HTTP_200_OK) +async def verify_api_key_endpoint( + api_key: str = Query(..., description="Clé API à vérifier"), + session: AsyncSession = Depends(get_session), +): + service = ApiKeyService(session) + + api_key_obj = await service.verify_api_key(api_key) + + if not api_key_obj: + return { + "valid": False, + "message": "Clé API invalide, expirée ou révoquée", + } + + return { + "valid": True, + "message": "Clé API valide", + "key_name": api_key_obj.name, + "rate_limit": api_key_obj.rate_limit_per_minute, + "expires_at": api_key_obj.expires_at, + } diff --git a/schemas/api_key.py b/schemas/api_key.py new file mode 100644 index 0000000..4ec49b6 --- /dev/null +++ b/schemas/api_key.py @@ -0,0 +1,77 @@ +from pydantic import BaseModel, Field +from typing import Optional, List +from datetime import datetime + + +class ApiKeyCreate(BaseModel): + """Schema pour créer une clé API""" + + name: str = Field(..., min_length=3, max_length=255, description="Nom de la clé") + description: Optional[str] = Field(None, description="Description de l'usage") + expires_in_days: Optional[int] = Field( + None, ge=1, le=3650, description="Expiration en jours (max 10 ans)" + ) + rate_limit_per_minute: int = Field( + 60, ge=1, le=1000, description="Limite de requêtes par minute" + ) + allowed_endpoints: Optional[List[str]] = Field( + None, description="Endpoints autorisés ([] = tous, ['/clients*'] = wildcard)" + ) + + +class ApiKeyResponse(BaseModel): + """Schema de réponse pour une clé API""" + + id: str + name: str + description: Optional[str] + key_prefix: str + is_active: bool + is_expired: bool + rate_limit_per_minute: int + allowed_endpoints: Optional[List[str]] + total_requests: int + last_used_at: Optional[datetime] + created_at: datetime + expires_at: Optional[datetime] + revoked_at: Optional[datetime] + created_by: str + + +class ApiKeyCreatedResponse(ApiKeyResponse): + """Schema de réponse après création (inclut la clé en clair)""" + + api_key: str = Field( + ..., description=" Clé API en clair - à sauvegarder immédiatement" + ) + + +class ApiKeyList(BaseModel): + """Liste de clés API""" + + total: int + items: List[ApiKeyResponse] + + +class SwaggerUserCreate(BaseModel): + """Schema pour créer un utilisateur Swagger""" + + username: str = Field(..., min_length=3, max_length=100) + password: str = Field(..., min_length=8) + full_name: Optional[str] = None + email: Optional[str] = None + + +class SwaggerUserResponse(BaseModel): + """Schema de réponse pour un utilisateur Swagger""" + + id: str + username: str + full_name: Optional[str] + email: Optional[str] + is_active: bool + created_at: datetime + last_login: Optional[datetime] + + class Config: + from_attributes = True diff --git a/scripts/manage_security.py b/scripts/manage_security.py new file mode 100644 index 0000000..1f234b9 --- /dev/null +++ b/scripts/manage_security.py @@ -0,0 +1,264 @@ +import asyncio +import sys +from pathlib import Path +import argparse + +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from database import get_session +from database.models.api_key import SwaggerUser +from services.api_key import ApiKeyService +from security.auth import hash_password +from sqlalchemy import select +import logging + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +async def add_swagger_user(username: str, password: str, full_name: str = None): + """Ajouter un utilisateur Swagger""" + async with get_session() as session: + result = await session.execute( + select(SwaggerUser).where(SwaggerUser.username == username) + ) + existing = result.scalar_one_or_none() + + if existing: + logger.error(f" L'utilisateur {username} existe déjà") + return + + user = SwaggerUser( + username=username, + hashed_password=hash_password(password), + full_name=full_name or username, + is_active=True, + ) + + session.add(user) + await session.commit() + + logger.info(f" Utilisateur Swagger créé: {username}") + print("\n Utilisateur créé avec succès") + print(f" Username: {username}") + print(" Accès: https://votre-serveur/docs") + + +async def list_swagger_users(): + """Lister les utilisateurs Swagger""" + async with get_session() as session: + result = await session.execute(select(SwaggerUser)) + users = result.scalars().all() + + if not users: + print("Aucun utilisateur Swagger trouvé") + return + + print(f"\n {len(users)} utilisateur(s) Swagger:\n") + for user in users: + status = " Actif" if user.is_active else " Inactif" + print(f" • {user.username:<20} {status}") + if user.full_name: + print(f" Nom: {user.full_name}") + if user.last_login: + print(f" Dernière connexion: {user.last_login}") + print() + + +async def delete_swagger_user(username: str): + """Supprimer un utilisateur Swagger""" + async with get_session() as session: + result = await session.execute( + select(SwaggerUser).where(SwaggerUser.username == username) + ) + user = result.scalar_one_or_none() + + if not user: + logger.error(f" Utilisateur {username} introuvable") + return + + await session.delete(user) + await session.commit() + + logger.info(f"🗑️ Utilisateur supprimé: {username}") + + +async def create_api_key( + name: str, + description: str = None, + expires_in_days: int = 365, + rate_limit: int = 60, + endpoints: list = None, +): + """Créer une clé API""" + async with get_session() as session: + service = ApiKeyService(session) + + api_key_obj, api_key_plain = await service.create_api_key( + name=name, + description=description, + created_by="CLI", + expires_in_days=expires_in_days, + rate_limit_per_minute=rate_limit, + allowed_endpoints=endpoints, + ) + + print("\n Clé API créée avec succès\n") + print(f" ID: {api_key_obj.id}") + print(f" Nom: {name}") + print(f" Clé: {api_key_plain}") + print(f" Préfixe: {api_key_obj.key_prefix}") + print(f" Rate limit: {rate_limit} req/min") + print(f" Expire le: {api_key_obj.expires_at or 'Jamais'}") + print("\n IMPORTANT: Sauvegardez cette clé, elle ne sera plus affichée !\n") + + +async def list_api_keys(): + """Lister les clés API""" + async with get_session() as session: + service = ApiKeyService(session) + keys = await service.list_api_keys() + + if not keys: + print("Aucune clé API trouvée") + return + + print(f"\n {len(keys)} clé(s) API:\n") + for key in keys: + status = "" if key.is_active else "" + expired = ( + "⏰ Expirée" + if key.expires_at and key.expires_at < datetime.now() + else "" + ) + + print(f" {status} {key.name:<30} ({key.key_prefix}...)") + print(f" ID: {key.id}") + print(f" Requêtes: {key.total_requests}") + print(f" Dernière utilisation: {key.last_used_at or 'Jamais'}") + if expired: + print(f" {expired}") + print() + + +async def revoke_api_key(key_id: str): + """Révoquer une clé API""" + async with get_session() as session: + service = ApiKeyService(session) + + api_key = await service.get_by_id(key_id) + if not api_key: + logger.error(f" Clé {key_id} introuvable") + return + + success = await service.revoke_api_key(key_id) + + if success: + logger.info(f" Clé révoquée: {api_key.name}") + print(f"\n Clé '{api_key.name}' révoquée avec succès") + else: + logger.error(" Erreur lors de la révocation") + + +async def verify_api_key_cmd(api_key: str): + """Vérifier une clé API""" + async with get_session() as session: + service = ApiKeyService(session) + api_key_obj = await service.verify_api_key(api_key) + + if api_key_obj: + print("\n Clé API valide\n") + print(f" Nom: {api_key_obj.name}") + print(f" ID: {api_key_obj.id}") + print(f" Rate limit: {api_key_obj.rate_limit_per_minute} req/min") + print(f" Requêtes: {api_key_obj.total_requests}") + print(f" Expire le: {api_key_obj.expires_at or 'Jamais'}\n") + else: + print("\n Clé API invalide, expirée ou révoquée\n") + + +async def main(): + parser = argparse.ArgumentParser( + description="Gestion de la sécurité Sage Dataven API" + ) + + subparsers = parser.add_subparsers(dest="command", help="Commandes disponibles") + + swagger_parser = subparsers.add_parser( + "swagger", help="Gestion utilisateurs Swagger" + ) + swagger_subparsers = swagger_parser.add_subparsers(dest="action") + + swagger_add = swagger_subparsers.add_parser("add", help="Ajouter un utilisateur") + swagger_add.add_argument("username", help="Nom d'utilisateur") + swagger_add.add_argument("password", help="Mot de passe") + swagger_add.add_argument("--full-name", help="Nom complet") + + swagger_subparsers.add_parser("list", help="Lister les utilisateurs") + + swagger_delete = swagger_subparsers.add_parser( + "delete", help="Supprimer un utilisateur" + ) + swagger_delete.add_argument("username", help="Nom d'utilisateur") + + apikey_parser = subparsers.add_parser("apikey", help="Gestion clés API") + apikey_subparsers = apikey_parser.add_subparsers(dest="action") + + apikey_create = apikey_subparsers.add_parser("create", help="Créer une clé API") + apikey_create.add_argument("name", help="Nom de la clé") + apikey_create.add_argument("--description", help="Description") + apikey_create.add_argument( + "--days", type=int, default=365, help="Expiration en jours" + ) + apikey_create.add_argument( + "--rate-limit", type=int, default=60, help="Limite req/min" + ) + apikey_create.add_argument("--endpoints", nargs="+", help="Endpoints autorisés") + + apikey_subparsers.add_parser("list", help="Lister les clés") + + apikey_revoke = apikey_subparsers.add_parser("revoke", help="Révoquer une clé") + apikey_revoke.add_argument("key_id", help="ID de la clé") + + apikey_verify = apikey_subparsers.add_parser("verify", help="Vérifier une clé") + apikey_verify.add_argument("api_key", help="Clé API à vérifier") + + args = parser.parse_args() + + if not args.command: + parser.print_help() + return + + if args.command == "swagger": + if args.action == "add": + await add_swagger_user(args.username, args.password, args.full_name) + elif args.action == "list": + await list_swagger_users() + elif args.action == "delete": + await delete_swagger_user(args.username) + else: + swagger_parser.print_help() + + elif args.command == "apikey": + if args.action == "create": + await create_api_key( + args.name, + args.description, + args.days, + args.rate_limit, + args.endpoints, + ) + elif args.action == "list": + await list_api_keys() + elif args.action == "revoke": + await revoke_api_key(args.key_id) + elif args.action == "verify": + await verify_api_key_cmd(args.api_key) + else: + apikey_parser.print_help() + + +if __name__ == "__main__": + from datetime import datetime + + asyncio.run(main()) diff --git a/scripts/test_security.py b/scripts/test_security.py new file mode 100644 index 0000000..497870e --- /dev/null +++ b/scripts/test_security.py @@ -0,0 +1,354 @@ +import requests +import argparse +import sys +from typing import Tuple + + +class SecurityTester: + def __init__(self, base_url: str): + self.base_url = base_url.rstrip("/") + self.results = {"passed": 0, "failed": 0, "tests": []} + + def log_test(self, name: str, passed: bool, details: str = ""): + """Enregistrer le résultat d'un test""" + status = " PASS" if passed else " FAIL" + print(f"{status} - {name}") + if details: + print(f" {details}") + + self.results["tests"].append( + {"name": name, "passed": passed, "details": details} + ) + + if passed: + self.results["passed"] += 1 + else: + self.results["failed"] += 1 + + def test_swagger_without_auth(self) -> bool: + """Test 1: Swagger UI devrait demander une authentification""" + print("\n Test 1: Protection Swagger UI") + + try: + response = requests.get(f"{self.base_url}/docs", timeout=5) + + if response.status_code == 401: + self.log_test( + "Swagger protégé", + True, + "Code 401 retourné sans authentification", + ) + return True + else: + self.log_test( + "Swagger protégé", + False, + f"Code {response.status_code} au lieu de 401", + ) + return False + + except Exception as e: + self.log_test("Swagger protégé", False, f"Erreur: {str(e)}") + return False + + def test_swagger_with_auth(self, username: str, password: str) -> bool: + """Test 2: Swagger UI accessible avec credentials valides""" + print("\n Test 2: Accès Swagger avec authentification") + + try: + response = requests.get( + f"{self.base_url}/docs", auth=(username, password), timeout=5 + ) + + if response.status_code == 200: + self.log_test( + "Accès Swagger avec auth", + True, + f"Authentifié comme {username}", + ) + return True + else: + self.log_test( + "Accès Swagger avec auth", + False, + f"Code {response.status_code}, credentials invalides?", + ) + return False + + except Exception as e: + self.log_test("Accès Swagger avec auth", False, f"Erreur: {str(e)}") + return False + + def test_api_without_auth(self) -> bool: + """Test 3: Endpoints API devraient demander une authentification""" + print("\n Test 3: Protection des endpoints API") + + test_endpoints = ["/api/v1/clients", "/api/v1/documents"] + + all_protected = True + for endpoint in test_endpoints: + try: + response = requests.get(f"{self.base_url}{endpoint}", timeout=5) + + if response.status_code == 401: + print(f" {endpoint} protégé (401)") + else: + print( + f" {endpoint} accessible sans auth (code {response.status_code})" + ) + all_protected = False + + except Exception as e: + print(f" {endpoint} erreur: {str(e)}") + all_protected = False + + self.log_test("Endpoints API protégés", all_protected) + return all_protected + + def test_health_endpoint_public(self) -> bool: + """Test 4: Endpoint /health devrait être accessible sans auth""" + print("\n Test 4: Endpoint /health public") + + try: + response = requests.get(f"{self.base_url}/health", timeout=5) + + if response.status_code == 200: + self.log_test("/health accessible", True, "Endpoint public fonctionne") + return True + else: + self.log_test( + "/health accessible", + False, + f"Code {response.status_code} inattendu", + ) + return False + + except Exception as e: + self.log_test("/health accessible", False, f"Erreur: {str(e)}") + return False + + def test_api_key_creation(self, username: str, password: str) -> Tuple[bool, str]: + """Test 5: Créer une clé API via l'endpoint""" + print("\n Test 5: Création d'une clé API") + + try: + login_response = requests.post( + f"{self.base_url}/api/v1/auth/login", + json={"email": username, "password": password}, + timeout=5, + ) + + if login_response.status_code != 200: + self.log_test( + "Création clé API", + False, + "Impossible de se connecter pour obtenir un JWT", + ) + return False, "" + + jwt_token = login_response.json().get("access_token") + + create_response = requests.post( + f"{self.base_url}/api/v1/api-keys", + headers={"Authorization": f"Bearer {jwt_token}"}, + json={ + "name": "Test API Key", + "description": "Clé de test automatisé", + "rate_limit_per_minute": 60, + "expires_in_days": 30, + }, + timeout=5, + ) + + if create_response.status_code == 201: + api_key = create_response.json().get("api_key") + self.log_test("Création clé API", True, f"Clé créée: {api_key[:20]}...") + return True, api_key + else: + self.log_test( + "Création clé API", + False, + f"Code {create_response.status_code}", + ) + return False, "" + + except Exception as e: + self.log_test("Création clé API", False, f"Erreur: {str(e)}") + return False, "" + + def test_api_key_usage(self, api_key: str) -> bool: + """Test 6: Utiliser une clé API pour accéder à un endpoint""" + print("\n Test 6: Utilisation d'une clé API") + + if not api_key: + self.log_test("Utilisation clé API", False, "Pas de clé disponible") + return False + + try: + response = requests.get( + f"{self.base_url}/api/v1/clients", + headers={"X-API-Key": api_key}, + timeout=5, + ) + + if response.status_code == 200: + self.log_test("Utilisation clé API", True, "Clé acceptée") + return True + else: + self.log_test( + "Utilisation clé API", + False, + f"Code {response.status_code}, clé refusée?", + ) + return False + + except Exception as e: + self.log_test("Utilisation clé API", False, f"Erreur: {str(e)}") + return False + + def test_invalid_api_key(self) -> bool: + """Test 7: Une clé invalide devrait être refusée""" + print("\n Test 7: Rejet de clé API invalide") + + invalid_key = "sdk_live_invalid_key_12345" + + try: + response = requests.get( + f"{self.base_url}/api/v1/clients", + headers={"X-API-Key": invalid_key}, + timeout=5, + ) + + if response.status_code == 401: + self.log_test("Clé invalide rejetée", True, "Code 401 comme attendu") + return True + else: + self.log_test( + "Clé invalide rejetée", + False, + f"Code {response.status_code} au lieu de 401", + ) + return False + + except Exception as e: + self.log_test("Clé invalide rejetée", False, f"Erreur: {str(e)}") + return False + + def test_rate_limiting(self, api_key: str) -> bool: + """Test 8: Rate limiting (optionnel, peut prendre du temps)""" + print("\n Test 8: Rate limiting (test simple)") + + if not api_key: + self.log_test("Rate limiting", False, "Pas de clé disponible") + return False + + print(" Envoi de 70 requêtes rapides...") + + rate_limited = False + for i in range(70): + try: + response = requests.get( + f"{self.base_url}/health", + headers={"X-API-Key": api_key}, + timeout=1, + ) + + if response.status_code == 429: + rate_limited = True + print(f" Rate limit atteint à la requête {i + 1}") + break + + except Exception: + pass + + if rate_limited: + self.log_test("Rate limiting", True, "Rate limit détecté") + return True + else: + self.log_test( + "Rate limiting", + True, + "Aucun rate limit détecté (peut être normal si pas implémenté)", + ) + return True + + def print_summary(self): + """Afficher le résumé des tests""" + print("\n" + "=" * 60) + print(" RÉSUMÉ DES TESTS") + print("=" * 60) + + total = self.results["passed"] + self.results["failed"] + success_rate = (self.results["passed"] / total * 100) if total > 0 else 0 + + print(f"\nTotal: {total} tests") + print(f" Réussis: {self.results['passed']}") + print(f" Échoués: {self.results['failed']}") + print(f"Taux de réussite: {success_rate:.1f}%\n") + + if self.results["failed"] == 0: + print("🎉 Tous les tests sont passés ! Sécurité OK.") + return 0 + else: + print(" Certains tests ont échoué. Vérifiez la configuration.") + return 1 + + +def main(): + parser = argparse.ArgumentParser( + description="Test automatisé de la sécurité de l'API" + ) + + parser.add_argument( + "--url", + required=True, + help="URL de base de l'API (ex: http://localhost:8000)", + ) + + parser.add_argument( + "--swagger-user", required=True, help="Utilisateur Swagger pour les tests" + ) + + parser.add_argument( + "--swagger-pass", required=True, help="Mot de passe Swagger pour les tests" + ) + + parser.add_argument( + "--skip-rate-limit", + action="store_true", + help="Sauter le test de rate limiting (long)", + ) + + args = parser.parse_args() + + print(" Démarrage des tests de sécurité") + print(f" URL cible: {args.url}\n") + + tester = SecurityTester(args.url) + + tester.test_swagger_without_auth() + tester.test_swagger_with_auth(args.swagger_user, args.swagger_pass) + tester.test_api_without_auth() + tester.test_health_endpoint_public() + + success, api_key = tester.test_api_key_creation( + args.swagger_user, args.swagger_pass + ) + + if success and api_key: + tester.test_api_key_usage(api_key) + tester.test_invalid_api_key() + + if not args.skip_rate_limit: + tester.test_rate_limiting(api_key) + else: + print("\n Test de rate limiting sauté") + else: + print("\n Tests avec clé API sautés (création échouée)") + + exit_code = tester.print_summary() + sys.exit(exit_code) + + +if __name__ == "__main__": + main() diff --git a/services/api_key.py b/services/api_key.py new file mode 100644 index 0000000..ad3cf6f --- /dev/null +++ b/services/api_key.py @@ -0,0 +1,205 @@ +import secrets +import hashlib +import json +from datetime import datetime, timedelta +from typing import Optional, List, Dict +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select, and_, or_ +import logging + +from database.models.api_key import ApiKey + +logger = logging.getLogger(__name__) + + +class ApiKeyService: + """Service de gestion des clés API""" + + def __init__(self, session: AsyncSession): + self.session = session + + @staticmethod + def generate_api_key() -> str: + """Génère une clé API unique et sécurisée""" + random_part = secrets.token_urlsafe(32) + return f"sdk_live_{random_part}" + + @staticmethod + def hash_api_key(api_key: str) -> str: + """Hash la clé API pour stockage sécurisé""" + return hashlib.sha256(api_key.encode()).hexdigest() + + @staticmethod + def get_key_prefix(api_key: str) -> str: + """Extrait le préfixe de la clé pour identification""" + return api_key[:12] if len(api_key) >= 12 else api_key + + async def create_api_key( + self, + name: str, + description: Optional[str] = None, + created_by: str = "system", + user_id: Optional[str] = None, + expires_in_days: Optional[int] = None, + rate_limit_per_minute: int = 60, + allowed_endpoints: Optional[List[str]] = None, + ) -> tuple[ApiKey, str]: + api_key_plain = self.generate_api_key() + key_hash = self.hash_api_key(api_key_plain) + key_prefix = self.get_key_prefix(api_key_plain) + + expires_at = None + if expires_in_days: + expires_at = datetime.now() + timedelta(days=expires_in_days) + + api_key_obj = ApiKey( + key_hash=key_hash, + key_prefix=key_prefix, + name=name, + description=description, + created_by=created_by, + user_id=user_id, + expires_at=expires_at, + rate_limit_per_minute=rate_limit_per_minute, + allowed_endpoints=json.dumps(allowed_endpoints) + if allowed_endpoints + else None, + ) + + self.session.add(api_key_obj) + await self.session.commit() + await self.session.refresh(api_key_obj) + + logger.info(f" Clé API créée: {name} (prefix: {key_prefix})") + + return api_key_obj, api_key_plain + + async def verify_api_key(self, api_key_plain: str) -> Optional[ApiKey]: + key_hash = self.hash_api_key(api_key_plain) + + result = await self.session.execute( + select(ApiKey).where( + and_( + ApiKey.key_hash == key_hash, + ApiKey.is_active, + ApiKey.revoked_at.is_(None), + or_( + ApiKey.expires_at.is_(None), ApiKey.expires_at > datetime.now() + ), + ) + ) + ) + + api_key_obj = result.scalar_one_or_none() + + if api_key_obj: + api_key_obj.total_requests += 1 + api_key_obj.last_used_at = datetime.now() + await self.session.commit() + + logger.debug(f" Clé API validée: {api_key_obj.name}") + else: + logger.warning(" Clé API invalide ou expirée") + + return api_key_obj + + async def list_api_keys( + self, + include_revoked: bool = False, + user_id: Optional[str] = None, + ) -> List[ApiKey]: + """Liste les clés API""" + query = select(ApiKey) + + if not include_revoked: + query = query.where(ApiKey.revoked_at.is_(None)) + + if user_id: + query = query.where(ApiKey.user_id == user_id) + + query = query.order_by(ApiKey.created_at.desc()) + + result = await self.session.execute(query) + return list(result.scalars().all()) + + async def revoke_api_key(self, key_id: str) -> bool: + """Révoque une clé API""" + result = await self.session.execute(select(ApiKey).where(ApiKey.id == key_id)) + api_key_obj = result.scalar_one_or_none() + + if not api_key_obj: + return False + + api_key_obj.is_active = False + api_key_obj.revoked_at = datetime.now() + await self.session.commit() + + logger.info(f" Clé API révoquée: {api_key_obj.name}") + return True + + async def get_by_id(self, key_id: str) -> Optional[ApiKey]: + """Récupère une clé API par son ID""" + result = await self.session.execute(select(ApiKey).where(ApiKey.id == key_id)) + return result.scalar_one_or_none() + + async def check_rate_limit(self, api_key_obj: ApiKey) -> tuple[bool, Dict]: + return True, { + "allowed": True, + "limit": api_key_obj.rate_limit_per_minute, + "remaining": api_key_obj.rate_limit_per_minute, + } + + async def check_endpoint_access(self, api_key_obj: ApiKey, endpoint: str) -> bool: + """Vérifie si la clé a accès à un endpoint spécifique""" + if not api_key_obj.allowed_endpoints: + return True + + try: + allowed = json.loads(api_key_obj.allowed_endpoints) + + for pattern in allowed: + if pattern == "*": + return True + if pattern.endswith("*"): + prefix = pattern[:-1] + if endpoint.startswith(prefix): + return True + if pattern == endpoint: + return True + + return False + except json.JSONDecodeError: + logger.error(f" Erreur parsing allowed_endpoints pour {api_key_obj.id}") + return False + + +def api_key_to_response(api_key_obj: ApiKey, show_key: bool = False) -> Dict: + """Convertit un objet ApiKey en réponse API""" + + allowed_endpoints = None + if api_key_obj.allowed_endpoints: + try: + allowed_endpoints = json.loads(api_key_obj.allowed_endpoints) + except json.JSONDecodeError: + pass + + is_expired = False + if api_key_obj.expires_at: + is_expired = api_key_obj.expires_at < datetime.now() + + return { + "id": api_key_obj.id, + "name": api_key_obj.name, + "description": api_key_obj.description, + "key_prefix": api_key_obj.key_prefix, + "is_active": api_key_obj.is_active, + "is_expired": is_expired, + "rate_limit_per_minute": api_key_obj.rate_limit_per_minute, + "allowed_endpoints": allowed_endpoints, + "total_requests": api_key_obj.total_requests, + "last_used_at": api_key_obj.last_used_at, + "created_at": api_key_obj.created_at, + "expires_at": api_key_obj.expires_at, + "revoked_at": api_key_obj.revoked_at, + "created_by": api_key_obj.created_by, + }