From 3cdb490ee5975877b43d8463c99b92a19769f3fb Mon Sep 17 00:00:00 2001 From: Fanilo-Nantenaina Date: Tue, 20 Jan 2026 14:47:07 +0300 Subject: [PATCH] refactor(security): improve middleware structure and configuration handling --- api.py | 4 +- database/db_config.py | 4 +- middleware/security.py | 345 +++++++++++++++++++++-------------------- security/auth.py | 11 +- 4 files changed, 188 insertions(+), 176 deletions(-) diff --git a/api.py b/api.py index c35e280..06f57e4 100644 --- a/api.py +++ b/api.py @@ -96,7 +96,7 @@ from utils.generic_functions import ( ) -from middleware.security import SwaggerAuthMiddleware, ApiKeyMiddleware +from middleware.security import SwaggerAuthMiddleware, ApiKeyMiddlewareHTTP from core.dependencies import get_current_user from config.cors_config import setup_cors from routes.api_keys import router as api_keys_router @@ -198,7 +198,7 @@ app.openapi = custom_openapi """ setup_cors(app, mode="open") app.add_middleware(SwaggerAuthMiddleware) -app.add_middleware(ApiKeyMiddleware) +app.add_middleware(ApiKeyMiddlewareHTTP) app.include_router(api_keys_router) app.include_router(auth_router) diff --git a/database/db_config.py b/database/db_config.py index bb98f5c..692822c 100644 --- a/database/db_config.py +++ b/database/db_config.py @@ -1,14 +1,14 @@ -import os from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker from sqlalchemy.pool import NullPool from sqlalchemy import event, text import logging +from config.config import settings from database.models.generic_model import Base logger = logging.getLogger(__name__) -DATABASE_URL = os.getenv("DATABASE_URL") +DATABASE_URL = settings.database_url def _configure_sqlite_connection(dbapi_connection, connection_record): diff --git a/middleware/security.py b/middleware/security.py index 137e7dd..2bea533 100644 --- a/middleware/security.py +++ b/middleware/security.py @@ -1,49 +1,23 @@ from fastapi import Request, status from fastapi.responses import JSONResponse from fastapi.security import HTTPBasic, HTTPBasicCredentials +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.types import ASGIApp from sqlalchemy import select -from typing import Optional +from typing import Callable from datetime import datetime import logging - -from database import get_session -from database.models.api_key import SwaggerUser -from security.auth import verify_password +import base64 logger = logging.getLogger(__name__) security = HTTPBasic() -async def verify_swagger_credentials(credentials: HTTPBasicCredentials) -> bool: - username = credentials.username - password = credentials.password - - try: - async for session in get_session(): - result = await session.execute( - select(SwaggerUser).where(SwaggerUser.username == username) - ) - swagger_user = result.scalar_one_or_none() - - if swagger_user and swagger_user.is_active: - if verify_password(password, swagger_user.hashed_password): - swagger_user.last_login = datetime.now() - await session.commit() - - logger.info(f" Accès Swagger autorisé (DB): {username}") - return True - - logger.warning(f" Tentative d'accès Swagger refusée: {username}") - return False - - except Exception as e: - logger.error(f" Erreur vérification Swagger credentials: {e}") - return False - - class SwaggerAuthMiddleware: - def __init__(self, app): + PROTECTED_PATHS = ["/docs", "/redoc", "/openapi.json"] + + def __init__(self, app: ASGIApp): self.app = app async def __call__(self, scope, receive, send): @@ -54,183 +28,220 @@ class SwaggerAuthMiddleware: request = Request(scope, receive=receive) path = request.url.path - protected_paths = ["/docs", "/redoc", "/openapi.json"] + if not any(path.startswith(p) for p in self.PROTECTED_PATHS): + await self.app(scope, receive, send) + return - if any(path.startswith(protected_path) for protected_path in protected_paths): - auth_header = request.headers.get("Authorization") + auth_header = request.headers.get("Authorization") - if not auth_header or not auth_header.startswith("Basic "): + if not auth_header or not auth_header.startswith("Basic "): + response = JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={"detail": "Authentification requise pour la documentation"}, + headers={"WWW-Authenticate": 'Basic realm="Swagger UI"'}, + ) + await response(scope, receive, send) + return + + try: + encoded_credentials = auth_header.split(" ")[1] + decoded_credentials = base64.b64decode(encoded_credentials).decode("utf-8") + username, password = decoded_credentials.split(":", 1) + + credentials = HTTPBasicCredentials(username=username, password=password) + + if not await self._verify_credentials(credentials): response = JSONResponse( status_code=status.HTTP_401_UNAUTHORIZED, - content={ - "detail": "Authentification requise pour accéder à la documentation" - }, + content={"detail": "Identifiants invalides"}, headers={"WWW-Authenticate": 'Basic realm="Swagger UI"'}, ) await response(scope, receive, send) return - try: - import base64 - - encoded_credentials = auth_header.split(" ")[1] - decoded_credentials = base64.b64decode(encoded_credentials).decode( - "utf-8" - ) - username, password = decoded_credentials.split(":", 1) - - credentials = HTTPBasicCredentials(username=username, password=password) - - if not await verify_swagger_credentials(credentials): - response = JSONResponse( - status_code=status.HTTP_401_UNAUTHORIZED, - content={"detail": "Identifiants invalides"}, - headers={"WWW-Authenticate": 'Basic realm="Swagger UI"'}, - ) - await response(scope, receive, send) - return - - except Exception as e: - logger.error(f" Erreur parsing auth header: {e}") - response = JSONResponse( - status_code=status.HTTP_401_UNAUTHORIZED, - content={"detail": "Format d'authentification invalide"}, - headers={"WWW-Authenticate": 'Basic realm="Swagger UI"'}, - ) - await response(scope, receive, send) - return + except Exception as e: + logger.error(f"Erreur parsing auth header: {e}") + response = JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={"detail": "Format d'authentification invalide"}, + headers={"WWW-Authenticate": 'Basic realm="Swagger UI"'}, + ) + await response(scope, receive, send) + return await self.app(scope, receive, send) + async def _verify_credentials(self, credentials: HTTPBasicCredentials) -> bool: + """Vérifie les identifiants dans la base de données""" + from database.db_config import async_session_factory + from database.models.api_key import SwaggerUser + from security.auth import verify_password -class ApiKeyMiddleware: - def __init__(self, app): - self.app = app + try: + async with async_session_factory() as session: + result = await session.execute( + select(SwaggerUser).where( + SwaggerUser.username == credentials.username + ) + ) + swagger_user = result.scalar_one_or_none() - async def __call__(self, scope, receive, send): - if scope["type"] != "http": - await self.app(scope, receive, send) - return + if swagger_user and swagger_user.is_active: + if verify_password( + credentials.password, swagger_user.hashed_password + ): + swagger_user.last_login = datetime.now() + await session.commit() + logger.info(f"✓ Accès Swagger autorisé: {credentials.username}") + return True - request = Request(scope, receive=receive) + logger.warning(f"✗ Accès Swagger refusé: {credentials.username}") + return False + + except Exception as e: + logger.error(f"Erreur vérification credentials: {e}") + return False + + +class ApiKeyMiddlewareHTTP(BaseHTTPMiddleware): + EXCLUDED_PATHS = [ + "/docs", + "/redoc", + "/openapi.json", + "/", + "/auth/login", + "/auth/register", + "/auth/verify-email", + "/auth/reset-password", + "/auth/request-reset", + "/auth/refresh", + ] + + async def dispatch(self, request: Request, call_next: Callable): path = request.url.path + method = request.method - excluded_paths = [ - "/docs", - "/redoc", - "/openapi.json", - "/health", - "/", - "/auth/login", - "/auth/register", - "/auth/verify-email", - "/auth/reset-password", - "/auth/request-reset", - "/auth/refresh", - ] - - if any(path.startswith(excluded_path) for excluded_path in excluded_paths): - await self.app(scope, receive, send) - return + if self._is_excluded_path(path): + return await call_next(request) auth_header = request.headers.get("Authorization") has_jwt = auth_header and auth_header.startswith("Bearer ") api_key = request.headers.get("X-API-Key") - has_api_key = api_key is not None + has_api_key = bool(api_key) if has_jwt: - logger.debug(f" JWT détecté pour {path}") - await self.app(scope, receive, send) - return + logger.debug(f"JWT détecté pour {method} {path}") + return await call_next(request) - elif has_api_key: - logger.debug(f" API Key détectée pour {path}") + if has_api_key: + logger.debug(f"API Key détectée pour {method} {path}") + return await self._handle_api_key_auth( + request, api_key, path, method, call_next + ) + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={ + "detail": "Authentification requise", + "hint": "Utilisez 'X-API-Key: sdk_live_xxx' ou 'Authorization: Bearer '", + }, + headers={"WWW-Authenticate": 'Bearer realm="API", charset="UTF-8"'}, + ) + + def _is_excluded_path(self, path: str) -> bool: + """Vérifie si le chemin est exclu de l'authentification""" + if path == "/": + return True + + for excluded in self.EXCLUDED_PATHS: + if excluded == "/": + continue + if path == excluded or path.startswith(excluded + "/"): + return True + + return False + + async def _handle_api_key_auth( + self, + request: Request, + api_key: str, + path: str, + method: str, + call_next: Callable, + ): + """Gère l'authentification par API Key""" + try: + from database.db_config import async_session_factory from services.api_key import ApiKeyService - try: - async for session in get_session(): - api_key_service = ApiKeyService(session) - api_key_obj = await api_key_service.verify_api_key(api_key) + async with async_session_factory() as session: + service = ApiKeyService(session) + api_key_obj = await service.verify_api_key(api_key) - if not api_key_obj: - response = JSONResponse( - status_code=status.HTTP_401_UNAUTHORIZED, - content={ - "detail": "Clé API invalide ou expirée", - "hint": "Utilisez X-API-Key: sdk_live_xxx ou Authorization: Bearer ", - }, - ) - await response(scope, receive, send) - return - - is_allowed, rate_info = await api_key_service.check_rate_limit( - api_key_obj + if not api_key_obj: + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={ + "detail": "Clé API invalide ou expirée", + "hint": "Vérifiez votre clé X-API-Key", + }, ) - if not is_allowed: - response = JSONResponse( - status_code=status.HTTP_429_TOO_MANY_REQUESTS, - content={"detail": "Rate limit dépassé"}, - headers={ - "X-RateLimit-Limit": str(rate_info["limit"]), - "X-RateLimit-Remaining": "0", - }, - ) - await response(scope, receive, send) - return - has_access = await api_key_service.check_endpoint_access( - api_key_obj, path + is_allowed, rate_info = await service.check_rate_limit(api_key_obj) + if not is_allowed: + return JSONResponse( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + content={"detail": "Rate limit dépassé"}, + headers={ + "X-RateLimit-Limit": str(rate_info["limit"]), + "X-RateLimit-Remaining": "0", + }, ) - if not has_access: - response = JSONResponse( - status_code=status.HTTP_403_FORBIDDEN, - content={ - "detail": "Accès non autorisé à cet endpoint", - "endpoint": path, - "api_key": api_key_obj.key_prefix + "...", - }, - ) - await response(scope, receive, send) - return - request.state.api_key = api_key_obj - request.state.authenticated_via = "api_key" + has_access = await service.check_endpoint_access(api_key_obj, path) + if not has_access: + return JSONResponse( + status_code=status.HTTP_403_FORBIDDEN, + content={ + "detail": "Accès non autorisé à cet endpoint", + "endpoint": path, + "api_key": api_key_obj.key_prefix + "...", + }, + ) - logger.info(f" API Key valide: {api_key_obj.name} → {path}") + request.state.api_key = api_key_obj + request.state.authenticated_via = "api_key" - await self.app(scope, receive, send) - return + logger.info(f"✓ API Key valide: {api_key_obj.name} → {method} {path}") - except Exception as e: - logger.error(f" Erreur validation API Key: {e}", exc_info=True) - response = JSONResponse( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - content={ - "detail": "Erreur interne lors de la validation de la clé" - }, - ) - await response(scope, receive, send) - return + return await call_next(request) - else: - response = JSONResponse( - status_code=status.HTTP_401_UNAUTHORIZED, - content={ - "detail": "Authentification requise", - "hint": "Utilisez soit 'X-API-Key: sdk_live_xxx' soit 'Authorization: Bearer '", - }, - headers={"WWW-Authenticate": 'Bearer realm="API", charset="UTF-8"'}, + except Exception as e: + logger.error(f"Erreur validation API Key: {e}", exc_info=True) + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={"detail": "Erreur interne lors de la validation"}, ) - await response(scope, receive, send) - return -def get_api_key_from_request(request: Request) -> Optional: +ApiKeyMiddleware = ApiKeyMiddlewareHTTP + + +def get_api_key_from_request(request: Request): """Récupère l'objet ApiKey depuis la requête si présent""" return getattr(request.state, "api_key", None) def get_auth_method(request: Request) -> str: + """Retourne la méthode d'authentification utilisée""" return getattr(request.state, "authenticated_via", "none") + + +__all__ = [ + "SwaggerAuthMiddleware", + "ApiKeyMiddlewareHTTP", + "ApiKeyMiddleware", + "get_api_key_from_request", + "get_auth_method", +] diff --git a/security/auth.py b/security/auth.py index 3708708..e05b6a0 100644 --- a/security/auth.py +++ b/security/auth.py @@ -4,12 +4,13 @@ from typing import Optional, Dict import jwt import secrets import hashlib -import os -SECRET_KEY = os.getenv("JWT_SECRET") -ALGORITHM = os.getenv("JWT_ALGORITHM") -ACCESS_TOKEN_EXPIRE_MINUTES = os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES") -REFRESH_TOKEN_EXPIRE_DAYS = os.getenv("REFRESH_TOKEN_EXPIRE_DAYS") +from config.config import settings + +SECRET_KEY = settings.jwt_secret +ALGORITHM = settings.jwt_algorithm +ACCESS_TOKEN_EXPIRE_MINUTES = settings.access_token_expire_minutes +REFRESH_TOKEN_EXPIRE_DAYS = settings.refresh_token_expire_days pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")