from fastapi import Request, status from fastapi.responses import JSONResponse from fastapi.security import HTTPBasic, HTTPBasicCredentials from starlette.middleware.base import BaseHTTPMiddleware from starlette.types import ASGIApp from sqlalchemy import select from typing import Callable from datetime import datetime import logging import base64 logger = logging.getLogger(__name__) security = HTTPBasic() class SwaggerAuthMiddleware: PROTECTED_PATHS = ["/docs", "/redoc", "/openapi.json"] def __init__(self, app: ASGIApp): self.app = app async def __call__(self, scope, receive, send): if scope["type"] != "http": await self.app(scope, receive, send) return request = Request(scope, receive=receive) path = request.url.path if not any(path.startswith(p) for p in self.PROTECTED_PATHS): await self.app(scope, receive, send) return auth_header = request.headers.get("Authorization") if not auth_header or not auth_header.startswith("Basic "): response = JSONResponse( status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Authentification requise pour la documentation"}, headers={"WWW-Authenticate": 'Basic realm="Swagger UI"'}, ) await response(scope, receive, send) return try: encoded_credentials = auth_header.split(" ")[1] decoded_credentials = base64.b64decode(encoded_credentials).decode("utf-8") username, password = decoded_credentials.split(":", 1) credentials = HTTPBasicCredentials(username=username, password=password) if not await self._verify_credentials(credentials): response = JSONResponse( status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Identifiants invalides"}, headers={"WWW-Authenticate": 'Basic realm="Swagger UI"'}, ) await response(scope, receive, send) return except Exception as e: logger.error(f"Erreur parsing auth header: {e}") response = JSONResponse( status_code=status.HTTP_401_UNAUTHORIZED, content={"detail": "Format d'authentification invalide"}, headers={"WWW-Authenticate": 'Basic realm="Swagger UI"'}, ) await response(scope, receive, send) return await self.app(scope, receive, send) async def _verify_credentials(self, credentials: HTTPBasicCredentials) -> bool: """Vérifie les identifiants dans la base de données""" from database.db_config import async_session_factory from database.models.api_key import SwaggerUser from security.auth import verify_password try: async with async_session_factory() as session: result = await session.execute( select(SwaggerUser).where( SwaggerUser.username == credentials.username ) ) swagger_user = result.scalar_one_or_none() if swagger_user and swagger_user.is_active: if verify_password( credentials.password, swagger_user.hashed_password ): swagger_user.last_login = datetime.now() await session.commit() logger.info(f"✓ Accès Swagger autorisé: {credentials.username}") return True logger.warning(f"✗ Accès Swagger refusé: {credentials.username}") return False except Exception as e: logger.error(f"Erreur vérification credentials: {e}") return False class ApiKeyMiddlewareHTTP(BaseHTTPMiddleware): EXCLUDED_PATHS = [ "/docs", "/redoc", "/openapi.json", "/", "/auth/login", "/auth/register", "/auth/verify-email", "/auth/reset-password", "/auth/request-reset", "/auth/refresh", ] async def dispatch(self, request: Request, call_next: Callable): path = request.url.path method = request.method if self._is_excluded_path(path): return await call_next(request) auth_header = request.headers.get("Authorization") has_jwt = auth_header and auth_header.startswith("Bearer ") api_key = request.headers.get("X-API-Key") has_api_key = bool(api_key) if has_jwt: logger.debug(f"JWT détecté pour {method} {path}") return await call_next(request) if has_api_key: logger.debug(f"API Key détectée pour {method} {path}") return await self._handle_api_key_auth( request, api_key, path, method, call_next ) return JSONResponse( status_code=status.HTTP_401_UNAUTHORIZED, content={ "detail": "Authentification requise", "hint": "Utilisez 'X-API-Key: sdk_live_xxx' ou 'Authorization: Bearer '", }, headers={"WWW-Authenticate": 'Bearer realm="API", charset="UTF-8"'}, ) def _is_excluded_path(self, path: str) -> bool: """Vérifie si le chemin est exclu de l'authentification""" if path == "/": return True for excluded in self.EXCLUDED_PATHS: if excluded == "/": continue if path == excluded or path.startswith(excluded + "/"): return True return False async def _handle_api_key_auth( self, request: Request, api_key: str, path: str, method: str, call_next: Callable, ): """Gère l'authentification par API Key""" try: from database.db_config import async_session_factory from services.api_key import ApiKeyService async with async_session_factory() as session: service = ApiKeyService(session) api_key_obj = await service.verify_api_key(api_key) if not api_key_obj: return JSONResponse( status_code=status.HTTP_401_UNAUTHORIZED, content={ "detail": "Clé API invalide ou expirée", "hint": "Vérifiez votre clé X-API-Key", }, ) is_allowed, rate_info = await service.check_rate_limit(api_key_obj) if not is_allowed: return JSONResponse( status_code=status.HTTP_429_TOO_MANY_REQUESTS, content={"detail": "Rate limit dépassé"}, headers={ "X-RateLimit-Limit": str(rate_info["limit"]), "X-RateLimit-Remaining": "0", }, ) has_access = await service.check_endpoint_access(api_key_obj, path) if not has_access: return JSONResponse( status_code=status.HTTP_403_FORBIDDEN, content={ "detail": "Accès non autorisé à cet endpoint", "endpoint": path, "api_key": api_key_obj.key_prefix + "...", }, ) request.state.api_key = api_key_obj request.state.authenticated_via = "api_key" logger.info(f"✓ API Key valide: {api_key_obj.name} → {method} {path}") return await call_next(request) except Exception as e: logger.error(f"Erreur validation API Key: {e}", exc_info=True) return JSONResponse( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, content={"detail": "Erreur interne lors de la validation"}, ) ApiKeyMiddleware = ApiKeyMiddlewareHTTP def get_api_key_from_request(request: Request): """Récupère l'objet ApiKey depuis la requête si présent""" return getattr(request.state, "api_key", None) def get_auth_method(request: Request) -> str: """Retourne la méthode d'authentification utilisée""" return getattr(request.state, "authenticated_via", "none") __all__ = [ "SwaggerAuthMiddleware", "ApiKeyMiddlewareHTTP", "ApiKeyMiddleware", "get_api_key_from_request", "get_auth_method", ]