refactor(security): improve middleware structure and configuration handling

This commit is contained in:
Fanilo-Nantenaina 2026-01-20 14:47:07 +03:00
parent c84e4ddc20
commit 3cdb490ee5
4 changed files with 188 additions and 176 deletions

4
api.py
View file

@ -96,7 +96,7 @@ from utils.generic_functions import (
) )
from middleware.security import SwaggerAuthMiddleware, ApiKeyMiddleware from middleware.security import SwaggerAuthMiddleware, ApiKeyMiddlewareHTTP
from core.dependencies import get_current_user from core.dependencies import get_current_user
from config.cors_config import setup_cors from config.cors_config import setup_cors
from routes.api_keys import router as api_keys_router from routes.api_keys import router as api_keys_router
@ -198,7 +198,7 @@ app.openapi = custom_openapi """
setup_cors(app, mode="open") setup_cors(app, mode="open")
app.add_middleware(SwaggerAuthMiddleware) app.add_middleware(SwaggerAuthMiddleware)
app.add_middleware(ApiKeyMiddleware) app.add_middleware(ApiKeyMiddlewareHTTP)
app.include_router(api_keys_router) app.include_router(api_keys_router)
app.include_router(auth_router) app.include_router(auth_router)

View file

@ -1,14 +1,14 @@
import os
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker
from sqlalchemy.pool import NullPool from sqlalchemy.pool import NullPool
from sqlalchemy import event, text from sqlalchemy import event, text
import logging import logging
from config.config import settings
from database.models.generic_model import Base from database.models.generic_model import Base
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
DATABASE_URL = os.getenv("DATABASE_URL") DATABASE_URL = settings.database_url
def _configure_sqlite_connection(dbapi_connection, connection_record): def _configure_sqlite_connection(dbapi_connection, connection_record):

View file

@ -1,49 +1,23 @@
from fastapi import Request, status from fastapi import Request, status
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from fastapi.security import HTTPBasic, HTTPBasicCredentials from fastapi.security import HTTPBasic, HTTPBasicCredentials
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
from sqlalchemy import select from sqlalchemy import select
from typing import Optional from typing import Callable
from datetime import datetime from datetime import datetime
import logging import logging
import base64
from database import get_session
from database.models.api_key import SwaggerUser
from security.auth import verify_password
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
security = HTTPBasic() security = HTTPBasic()
async def verify_swagger_credentials(credentials: HTTPBasicCredentials) -> bool:
username = credentials.username
password = credentials.password
try:
async for session in get_session():
result = await session.execute(
select(SwaggerUser).where(SwaggerUser.username == username)
)
swagger_user = result.scalar_one_or_none()
if swagger_user and swagger_user.is_active:
if verify_password(password, swagger_user.hashed_password):
swagger_user.last_login = datetime.now()
await session.commit()
logger.info(f" Accès Swagger autorisé (DB): {username}")
return True
logger.warning(f" Tentative d'accès Swagger refusée: {username}")
return False
except Exception as e:
logger.error(f" Erreur vérification Swagger credentials: {e}")
return False
class SwaggerAuthMiddleware: class SwaggerAuthMiddleware:
def __init__(self, app): PROTECTED_PATHS = ["/docs", "/redoc", "/openapi.json"]
def __init__(self, app: ASGIApp):
self.app = app self.app = app
async def __call__(self, scope, receive, send): async def __call__(self, scope, receive, send):
@ -54,34 +28,29 @@ class SwaggerAuthMiddleware:
request = Request(scope, receive=receive) request = Request(scope, receive=receive)
path = request.url.path path = request.url.path
protected_paths = ["/docs", "/redoc", "/openapi.json"] if not any(path.startswith(p) for p in self.PROTECTED_PATHS):
await self.app(scope, receive, send)
return
if any(path.startswith(protected_path) for protected_path in protected_paths):
auth_header = request.headers.get("Authorization") auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Basic "): if not auth_header or not auth_header.startswith("Basic "):
response = JSONResponse( response = JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
content={ content={"detail": "Authentification requise pour la documentation"},
"detail": "Authentification requise pour accéder à la documentation"
},
headers={"WWW-Authenticate": 'Basic realm="Swagger UI"'}, headers={"WWW-Authenticate": 'Basic realm="Swagger UI"'},
) )
await response(scope, receive, send) await response(scope, receive, send)
return return
try: try:
import base64
encoded_credentials = auth_header.split(" ")[1] encoded_credentials = auth_header.split(" ")[1]
decoded_credentials = base64.b64decode(encoded_credentials).decode( decoded_credentials = base64.b64decode(encoded_credentials).decode("utf-8")
"utf-8"
)
username, password = decoded_credentials.split(":", 1) username, password = decoded_credentials.split(":", 1)
credentials = HTTPBasicCredentials(username=username, password=password) credentials = HTTPBasicCredentials(username=username, password=password)
if not await verify_swagger_credentials(credentials): if not await self._verify_credentials(credentials):
response = JSONResponse( response = JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
content={"detail": "Identifiants invalides"}, content={"detail": "Identifiants invalides"},
@ -91,7 +60,7 @@ class SwaggerAuthMiddleware:
return return
except Exception as e: except Exception as e:
logger.error(f" Erreur parsing auth header: {e}") logger.error(f"Erreur parsing auth header: {e}")
response = JSONResponse( response = JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
content={"detail": "Format d'authentification invalide"}, content={"detail": "Format d'authentification invalide"},
@ -102,24 +71,43 @@ class SwaggerAuthMiddleware:
await self.app(scope, receive, send) await self.app(scope, receive, send)
async def _verify_credentials(self, credentials: HTTPBasicCredentials) -> bool:
"""Vérifie les identifiants dans la base de données"""
from database.db_config import async_session_factory
from database.models.api_key import SwaggerUser
from security.auth import verify_password
class ApiKeyMiddleware: try:
def __init__(self, app): async with async_session_factory() as session:
self.app = app result = await session.execute(
select(SwaggerUser).where(
SwaggerUser.username == credentials.username
)
)
swagger_user = result.scalar_one_or_none()
async def __call__(self, scope, receive, send): if swagger_user and swagger_user.is_active:
if scope["type"] != "http": if verify_password(
await self.app(scope, receive, send) credentials.password, swagger_user.hashed_password
return ):
swagger_user.last_login = datetime.now()
await session.commit()
logger.info(f"✓ Accès Swagger autorisé: {credentials.username}")
return True
request = Request(scope, receive=receive) logger.warning(f"✗ Accès Swagger refusé: {credentials.username}")
path = request.url.path return False
excluded_paths = [ except Exception as e:
logger.error(f"Erreur vérification credentials: {e}")
return False
class ApiKeyMiddlewareHTTP(BaseHTTPMiddleware):
EXCLUDED_PATHS = [
"/docs", "/docs",
"/redoc", "/redoc",
"/openapi.json", "/openapi.json",
"/health",
"/", "/",
"/auth/login", "/auth/login",
"/auth/register", "/auth/register",
@ -129,47 +117,80 @@ class ApiKeyMiddleware:
"/auth/refresh", "/auth/refresh",
] ]
if any(path.startswith(excluded_path) for excluded_path in excluded_paths): async def dispatch(self, request: Request, call_next: Callable):
await self.app(scope, receive, send) path = request.url.path
return method = request.method
if self._is_excluded_path(path):
return await call_next(request)
auth_header = request.headers.get("Authorization") auth_header = request.headers.get("Authorization")
has_jwt = auth_header and auth_header.startswith("Bearer ") has_jwt = auth_header and auth_header.startswith("Bearer ")
api_key = request.headers.get("X-API-Key") api_key = request.headers.get("X-API-Key")
has_api_key = api_key is not None has_api_key = bool(api_key)
if has_jwt: if has_jwt:
logger.debug(f" JWT détecté pour {path}") logger.debug(f"JWT détecté pour {method} {path}")
await self.app(scope, receive, send) return await call_next(request)
return
elif has_api_key: if has_api_key:
logger.debug(f" API Key détectée pour {path}") logger.debug(f"API Key détectée pour {method} {path}")
return await self._handle_api_key_auth(
request, api_key, path, method, call_next
)
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={
"detail": "Authentification requise",
"hint": "Utilisez 'X-API-Key: sdk_live_xxx' ou 'Authorization: Bearer <jwt>'",
},
headers={"WWW-Authenticate": 'Bearer realm="API", charset="UTF-8"'},
)
def _is_excluded_path(self, path: str) -> bool:
"""Vérifie si le chemin est exclu de l'authentification"""
if path == "/":
return True
for excluded in self.EXCLUDED_PATHS:
if excluded == "/":
continue
if path == excluded or path.startswith(excluded + "/"):
return True
return False
async def _handle_api_key_auth(
self,
request: Request,
api_key: str,
path: str,
method: str,
call_next: Callable,
):
"""Gère l'authentification par API Key"""
try:
from database.db_config import async_session_factory
from services.api_key import ApiKeyService from services.api_key import ApiKeyService
try: async with async_session_factory() as session:
async for session in get_session(): service = ApiKeyService(session)
api_key_service = ApiKeyService(session) api_key_obj = await service.verify_api_key(api_key)
api_key_obj = await api_key_service.verify_api_key(api_key)
if not api_key_obj: if not api_key_obj:
response = JSONResponse( return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
content={ content={
"detail": "Clé API invalide ou expirée", "detail": "Clé API invalide ou expirée",
"hint": "Utilisez X-API-Key: sdk_live_xxx ou Authorization: Bearer <jwt>", "hint": "Vérifiez votre clé X-API-Key",
}, },
) )
await response(scope, receive, send)
return
is_allowed, rate_info = await api_key_service.check_rate_limit( is_allowed, rate_info = await service.check_rate_limit(api_key_obj)
api_key_obj
)
if not is_allowed: if not is_allowed:
response = JSONResponse( return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS, status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={"detail": "Rate limit dépassé"}, content={"detail": "Rate limit dépassé"},
headers={ headers={
@ -177,14 +198,10 @@ class ApiKeyMiddleware:
"X-RateLimit-Remaining": "0", "X-RateLimit-Remaining": "0",
}, },
) )
await response(scope, receive, send)
return
has_access = await api_key_service.check_endpoint_access( has_access = await service.check_endpoint_access(api_key_obj, path)
api_key_obj, path
)
if not has_access: if not has_access:
response = JSONResponse( return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
content={ content={
"detail": "Accès non autorisé à cet endpoint", "detail": "Accès non autorisé à cet endpoint",
@ -192,45 +209,39 @@ class ApiKeyMiddleware:
"api_key": api_key_obj.key_prefix + "...", "api_key": api_key_obj.key_prefix + "...",
}, },
) )
await response(scope, receive, send)
return
request.state.api_key = api_key_obj request.state.api_key = api_key_obj
request.state.authenticated_via = "api_key" request.state.authenticated_via = "api_key"
logger.info(f" API Key valide: {api_key_obj.name} {path}") logger.info(f" API Key valide: {api_key_obj.name} {method} {path}")
await self.app(scope, receive, send) return await call_next(request)
return
except Exception as e: except Exception as e:
logger.error(f" Erreur validation API Key: {e}", exc_info=True) logger.error(f"Erreur validation API Key: {e}", exc_info=True)
response = JSONResponse( return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={ content={"detail": "Erreur interne lors de la validation"},
"detail": "Erreur interne lors de la validation de la clé"
},
) )
await response(scope, receive, send)
return
else:
response = JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={
"detail": "Authentification requise",
"hint": "Utilisez soit 'X-API-Key: sdk_live_xxx' soit 'Authorization: Bearer <jwt>'",
},
headers={"WWW-Authenticate": 'Bearer realm="API", charset="UTF-8"'},
)
await response(scope, receive, send)
return
def get_api_key_from_request(request: Request) -> Optional: ApiKeyMiddleware = ApiKeyMiddlewareHTTP
def get_api_key_from_request(request: Request):
"""Récupère l'objet ApiKey depuis la requête si présent""" """Récupère l'objet ApiKey depuis la requête si présent"""
return getattr(request.state, "api_key", None) return getattr(request.state, "api_key", None)
def get_auth_method(request: Request) -> str: def get_auth_method(request: Request) -> str:
"""Retourne la méthode d'authentification utilisée"""
return getattr(request.state, "authenticated_via", "none") return getattr(request.state, "authenticated_via", "none")
__all__ = [
"SwaggerAuthMiddleware",
"ApiKeyMiddlewareHTTP",
"ApiKeyMiddleware",
"get_api_key_from_request",
"get_auth_method",
]

View file

@ -4,12 +4,13 @@ from typing import Optional, Dict
import jwt import jwt
import secrets import secrets
import hashlib import hashlib
import os
SECRET_KEY = os.getenv("JWT_SECRET") from config.config import settings
ALGORITHM = os.getenv("JWT_ALGORITHM")
ACCESS_TOKEN_EXPIRE_MINUTES = os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES") SECRET_KEY = settings.jwt_secret
REFRESH_TOKEN_EXPIRE_DAYS = os.getenv("REFRESH_TOKEN_EXPIRE_DAYS") ALGORITHM = settings.jwt_algorithm
ACCESS_TOKEN_EXPIRE_MINUTES = settings.access_token_expire_minutes
REFRESH_TOKEN_EXPIRE_DAYS = settings.refresh_token_expire_days
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")