refactor(security): improve middleware structure and configuration handling

This commit is contained in:
Fanilo-Nantenaina 2026-01-20 14:47:07 +03:00
parent c84e4ddc20
commit 3cdb490ee5
4 changed files with 188 additions and 176 deletions

4
api.py
View file

@ -96,7 +96,7 @@ from utils.generic_functions import (
)
from middleware.security import SwaggerAuthMiddleware, ApiKeyMiddleware
from middleware.security import SwaggerAuthMiddleware, ApiKeyMiddlewareHTTP
from core.dependencies import get_current_user
from config.cors_config import setup_cors
from routes.api_keys import router as api_keys_router
@ -198,7 +198,7 @@ app.openapi = custom_openapi """
setup_cors(app, mode="open")
app.add_middleware(SwaggerAuthMiddleware)
app.add_middleware(ApiKeyMiddleware)
app.add_middleware(ApiKeyMiddlewareHTTP)
app.include_router(api_keys_router)
app.include_router(auth_router)

View file

@ -1,14 +1,14 @@
import os
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy.pool import NullPool
from sqlalchemy import event, text
import logging
from config.config import settings
from database.models.generic_model import Base
logger = logging.getLogger(__name__)
DATABASE_URL = os.getenv("DATABASE_URL")
DATABASE_URL = settings.database_url
def _configure_sqlite_connection(dbapi_connection, connection_record):

View file

@ -1,49 +1,23 @@
from fastapi import Request, status
from fastapi.responses import JSONResponse
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
from sqlalchemy import select
from typing import Optional
from typing import Callable
from datetime import datetime
import logging
from database import get_session
from database.models.api_key import SwaggerUser
from security.auth import verify_password
import base64
logger = logging.getLogger(__name__)
security = HTTPBasic()
async def verify_swagger_credentials(credentials: HTTPBasicCredentials) -> bool:
username = credentials.username
password = credentials.password
try:
async for session in get_session():
result = await session.execute(
select(SwaggerUser).where(SwaggerUser.username == username)
)
swagger_user = result.scalar_one_or_none()
if swagger_user and swagger_user.is_active:
if verify_password(password, swagger_user.hashed_password):
swagger_user.last_login = datetime.now()
await session.commit()
logger.info(f" Accès Swagger autorisé (DB): {username}")
return True
logger.warning(f" Tentative d'accès Swagger refusée: {username}")
return False
except Exception as e:
logger.error(f" Erreur vérification Swagger credentials: {e}")
return False
class SwaggerAuthMiddleware:
def __init__(self, app):
PROTECTED_PATHS = ["/docs", "/redoc", "/openapi.json"]
def __init__(self, app: ASGIApp):
self.app = app
async def __call__(self, scope, receive, send):
@ -54,183 +28,220 @@ class SwaggerAuthMiddleware:
request = Request(scope, receive=receive)
path = request.url.path
protected_paths = ["/docs", "/redoc", "/openapi.json"]
if not any(path.startswith(p) for p in self.PROTECTED_PATHS):
await self.app(scope, receive, send)
return
if any(path.startswith(protected_path) for protected_path in protected_paths):
auth_header = request.headers.get("Authorization")
auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Basic "):
if not auth_header or not auth_header.startswith("Basic "):
response = JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"detail": "Authentification requise pour la documentation"},
headers={"WWW-Authenticate": 'Basic realm="Swagger UI"'},
)
await response(scope, receive, send)
return
try:
encoded_credentials = auth_header.split(" ")[1]
decoded_credentials = base64.b64decode(encoded_credentials).decode("utf-8")
username, password = decoded_credentials.split(":", 1)
credentials = HTTPBasicCredentials(username=username, password=password)
if not await self._verify_credentials(credentials):
response = JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={
"detail": "Authentification requise pour accéder à la documentation"
},
content={"detail": "Identifiants invalides"},
headers={"WWW-Authenticate": 'Basic realm="Swagger UI"'},
)
await response(scope, receive, send)
return
try:
import base64
encoded_credentials = auth_header.split(" ")[1]
decoded_credentials = base64.b64decode(encoded_credentials).decode(
"utf-8"
)
username, password = decoded_credentials.split(":", 1)
credentials = HTTPBasicCredentials(username=username, password=password)
if not await verify_swagger_credentials(credentials):
response = JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"detail": "Identifiants invalides"},
headers={"WWW-Authenticate": 'Basic realm="Swagger UI"'},
)
await response(scope, receive, send)
return
except Exception as e:
logger.error(f" Erreur parsing auth header: {e}")
response = JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"detail": "Format d'authentification invalide"},
headers={"WWW-Authenticate": 'Basic realm="Swagger UI"'},
)
await response(scope, receive, send)
return
except Exception as e:
logger.error(f"Erreur parsing auth header: {e}")
response = JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"detail": "Format d'authentification invalide"},
headers={"WWW-Authenticate": 'Basic realm="Swagger UI"'},
)
await response(scope, receive, send)
return
await self.app(scope, receive, send)
async def _verify_credentials(self, credentials: HTTPBasicCredentials) -> bool:
"""Vérifie les identifiants dans la base de données"""
from database.db_config import async_session_factory
from database.models.api_key import SwaggerUser
from security.auth import verify_password
class ApiKeyMiddleware:
def __init__(self, app):
self.app = app
try:
async with async_session_factory() as session:
result = await session.execute(
select(SwaggerUser).where(
SwaggerUser.username == credentials.username
)
)
swagger_user = result.scalar_one_or_none()
async def __call__(self, scope, receive, send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return
if swagger_user and swagger_user.is_active:
if verify_password(
credentials.password, swagger_user.hashed_password
):
swagger_user.last_login = datetime.now()
await session.commit()
logger.info(f"✓ Accès Swagger autorisé: {credentials.username}")
return True
request = Request(scope, receive=receive)
logger.warning(f"✗ Accès Swagger refusé: {credentials.username}")
return False
except Exception as e:
logger.error(f"Erreur vérification credentials: {e}")
return False
class ApiKeyMiddlewareHTTP(BaseHTTPMiddleware):
EXCLUDED_PATHS = [
"/docs",
"/redoc",
"/openapi.json",
"/",
"/auth/login",
"/auth/register",
"/auth/verify-email",
"/auth/reset-password",
"/auth/request-reset",
"/auth/refresh",
]
async def dispatch(self, request: Request, call_next: Callable):
path = request.url.path
method = request.method
excluded_paths = [
"/docs",
"/redoc",
"/openapi.json",
"/health",
"/",
"/auth/login",
"/auth/register",
"/auth/verify-email",
"/auth/reset-password",
"/auth/request-reset",
"/auth/refresh",
]
if any(path.startswith(excluded_path) for excluded_path in excluded_paths):
await self.app(scope, receive, send)
return
if self._is_excluded_path(path):
return await call_next(request)
auth_header = request.headers.get("Authorization")
has_jwt = auth_header and auth_header.startswith("Bearer ")
api_key = request.headers.get("X-API-Key")
has_api_key = api_key is not None
has_api_key = bool(api_key)
if has_jwt:
logger.debug(f" JWT détecté pour {path}")
await self.app(scope, receive, send)
return
logger.debug(f"JWT détecté pour {method} {path}")
return await call_next(request)
elif has_api_key:
logger.debug(f" API Key détectée pour {path}")
if has_api_key:
logger.debug(f"API Key détectée pour {method} {path}")
return await self._handle_api_key_auth(
request, api_key, path, method, call_next
)
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={
"detail": "Authentification requise",
"hint": "Utilisez 'X-API-Key: sdk_live_xxx' ou 'Authorization: Bearer <jwt>'",
},
headers={"WWW-Authenticate": 'Bearer realm="API", charset="UTF-8"'},
)
def _is_excluded_path(self, path: str) -> bool:
"""Vérifie si le chemin est exclu de l'authentification"""
if path == "/":
return True
for excluded in self.EXCLUDED_PATHS:
if excluded == "/":
continue
if path == excluded or path.startswith(excluded + "/"):
return True
return False
async def _handle_api_key_auth(
self,
request: Request,
api_key: str,
path: str,
method: str,
call_next: Callable,
):
"""Gère l'authentification par API Key"""
try:
from database.db_config import async_session_factory
from services.api_key import ApiKeyService
try:
async for session in get_session():
api_key_service = ApiKeyService(session)
api_key_obj = await api_key_service.verify_api_key(api_key)
async with async_session_factory() as session:
service = ApiKeyService(session)
api_key_obj = await service.verify_api_key(api_key)
if not api_key_obj:
response = JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={
"detail": "Clé API invalide ou expirée",
"hint": "Utilisez X-API-Key: sdk_live_xxx ou Authorization: Bearer <jwt>",
},
)
await response(scope, receive, send)
return
is_allowed, rate_info = await api_key_service.check_rate_limit(
api_key_obj
if not api_key_obj:
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={
"detail": "Clé API invalide ou expirée",
"hint": "Vérifiez votre clé X-API-Key",
},
)
if not is_allowed:
response = JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={"detail": "Rate limit dépassé"},
headers={
"X-RateLimit-Limit": str(rate_info["limit"]),
"X-RateLimit-Remaining": "0",
},
)
await response(scope, receive, send)
return
has_access = await api_key_service.check_endpoint_access(
api_key_obj, path
is_allowed, rate_info = await service.check_rate_limit(api_key_obj)
if not is_allowed:
return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={"detail": "Rate limit dépassé"},
headers={
"X-RateLimit-Limit": str(rate_info["limit"]),
"X-RateLimit-Remaining": "0",
},
)
if not has_access:
response = JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content={
"detail": "Accès non autorisé à cet endpoint",
"endpoint": path,
"api_key": api_key_obj.key_prefix + "...",
},
)
await response(scope, receive, send)
return
request.state.api_key = api_key_obj
request.state.authenticated_via = "api_key"
has_access = await service.check_endpoint_access(api_key_obj, path)
if not has_access:
return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content={
"detail": "Accès non autorisé à cet endpoint",
"endpoint": path,
"api_key": api_key_obj.key_prefix + "...",
},
)
logger.info(f" API Key valide: {api_key_obj.name}{path}")
request.state.api_key = api_key_obj
request.state.authenticated_via = "api_key"
await self.app(scope, receive, send)
return
logger.info(f"✓ API Key valide: {api_key_obj.name}{method} {path}")
except Exception as e:
logger.error(f" Erreur validation API Key: {e}", exc_info=True)
response = JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={
"detail": "Erreur interne lors de la validation de la clé"
},
)
await response(scope, receive, send)
return
return await call_next(request)
else:
response = JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={
"detail": "Authentification requise",
"hint": "Utilisez soit 'X-API-Key: sdk_live_xxx' soit 'Authorization: Bearer <jwt>'",
},
headers={"WWW-Authenticate": 'Bearer realm="API", charset="UTF-8"'},
except Exception as e:
logger.error(f"Erreur validation API Key: {e}", exc_info=True)
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"detail": "Erreur interne lors de la validation"},
)
await response(scope, receive, send)
return
def get_api_key_from_request(request: Request) -> Optional:
ApiKeyMiddleware = ApiKeyMiddlewareHTTP
def get_api_key_from_request(request: Request):
"""Récupère l'objet ApiKey depuis la requête si présent"""
return getattr(request.state, "api_key", None)
def get_auth_method(request: Request) -> str:
"""Retourne la méthode d'authentification utilisée"""
return getattr(request.state, "authenticated_via", "none")
__all__ = [
"SwaggerAuthMiddleware",
"ApiKeyMiddlewareHTTP",
"ApiKeyMiddleware",
"get_api_key_from_request",
"get_auth_method",
]

View file

@ -4,12 +4,13 @@ from typing import Optional, Dict
import jwt
import secrets
import hashlib
import os
SECRET_KEY = os.getenv("JWT_SECRET")
ALGORITHM = os.getenv("JWT_ALGORITHM")
ACCESS_TOKEN_EXPIRE_MINUTES = os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES")
REFRESH_TOKEN_EXPIRE_DAYS = os.getenv("REFRESH_TOKEN_EXPIRE_DAYS")
from config.config import settings
SECRET_KEY = settings.jwt_secret
ALGORITHM = settings.jwt_algorithm
ACCESS_TOKEN_EXPIRE_MINUTES = settings.access_token_expire_minutes
REFRESH_TOKEN_EXPIRE_DAYS = settings.refresh_token_expire_days
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")