271 lines
9.6 KiB
Python
271 lines
9.6 KiB
Python
from fastapi import Request, status
|
|
from fastapi.responses import JSONResponse
|
|
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
from starlette.types import ASGIApp
|
|
from sqlalchemy import select
|
|
from typing import Callable
|
|
from datetime import datetime
|
|
import logging
|
|
import base64
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
security = HTTPBasic()
|
|
|
|
|
|
class SwaggerAuthMiddleware:
|
|
PROTECTED_PATHS = ["/docs", "/redoc", "/openapi.json"]
|
|
|
|
def __init__(self, app: ASGIApp):
|
|
self.app = app
|
|
|
|
async def __call__(self, scope, receive, send):
|
|
if scope["type"] != "http":
|
|
await self.app(scope, receive, send)
|
|
return
|
|
|
|
request = Request(scope, receive=receive)
|
|
path = request.url.path
|
|
|
|
if not any(path.startswith(p) for p in self.PROTECTED_PATHS):
|
|
await self.app(scope, receive, send)
|
|
return
|
|
|
|
auth_header = request.headers.get("Authorization")
|
|
|
|
if not auth_header or not auth_header.startswith("Basic "):
|
|
response = JSONResponse(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
content={"detail": "Authentification requise pour la documentation"},
|
|
headers={"WWW-Authenticate": 'Basic realm="Swagger UI"'},
|
|
)
|
|
await response(scope, receive, send)
|
|
return
|
|
|
|
try:
|
|
encoded_credentials = auth_header.split(" ")[1]
|
|
decoded_credentials = base64.b64decode(encoded_credentials).decode("utf-8")
|
|
username, password = decoded_credentials.split(":", 1)
|
|
|
|
credentials = HTTPBasicCredentials(username=username, password=password)
|
|
|
|
if not await self._verify_credentials(credentials):
|
|
response = JSONResponse(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
content={"detail": "Identifiants invalides"},
|
|
headers={"WWW-Authenticate": 'Basic realm="Swagger UI"'},
|
|
)
|
|
await response(scope, receive, send)
|
|
return
|
|
|
|
except Exception as e:
|
|
logger.error(f"Erreur parsing auth header: {e}")
|
|
response = JSONResponse(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
content={"detail": "Format d'authentification invalide"},
|
|
headers={"WWW-Authenticate": 'Basic realm="Swagger UI"'},
|
|
)
|
|
await response(scope, receive, send)
|
|
return
|
|
|
|
await self.app(scope, receive, send)
|
|
|
|
async def _verify_credentials(self, credentials: HTTPBasicCredentials) -> bool:
|
|
"""Vérifie les identifiants dans la base de données"""
|
|
from database.db_config import async_session_factory
|
|
from database.models.api_key import SwaggerUser
|
|
from security.auth import verify_password
|
|
|
|
try:
|
|
async with async_session_factory() as session:
|
|
result = await session.execute(
|
|
select(SwaggerUser).where(
|
|
SwaggerUser.username == credentials.username
|
|
)
|
|
)
|
|
swagger_user = result.scalar_one_or_none()
|
|
|
|
if swagger_user and swagger_user.is_active:
|
|
if verify_password(
|
|
credentials.password, swagger_user.hashed_password
|
|
):
|
|
swagger_user.last_login = datetime.now()
|
|
await session.commit()
|
|
logger.info(f"✓ Accès Swagger autorisé: {credentials.username}")
|
|
return True
|
|
|
|
logger.warning(f"✗ Accès Swagger refusé: {credentials.username}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
logger.error(f"Erreur vérification credentials: {e}")
|
|
return False
|
|
|
|
|
|
class ApiKeyMiddlewareHTTP(BaseHTTPMiddleware):
|
|
EXCLUDED_PATHS = [
|
|
"/docs",
|
|
"/redoc",
|
|
"/openapi.json",
|
|
"/",
|
|
"/health",
|
|
"/auth",
|
|
"/api-keys/verify",
|
|
]
|
|
|
|
def _is_excluded_path(self, path: str) -> bool:
|
|
"""Vérifie si le chemin est exclu de l'authentification"""
|
|
if path == "/":
|
|
return True
|
|
|
|
for excluded in self.EXCLUDED_PATHS:
|
|
if excluded == "/":
|
|
continue
|
|
if path == excluded or path.startswith(excluded + "/"):
|
|
return True
|
|
|
|
return False
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable):
|
|
path = request.url.path
|
|
method = request.method
|
|
|
|
if self._is_excluded_path(path):
|
|
return await call_next(request)
|
|
|
|
auth_header = request.headers.get("Authorization")
|
|
api_key_header = request.headers.get("X-API-Key")
|
|
|
|
if auth_header and auth_header.startswith("Bearer "):
|
|
token = auth_header.split(" ")[1]
|
|
|
|
if token.startswith("sdk_live_"):
|
|
logger.warning(
|
|
" API Key envoyée dans Authorization au lieu de X-API-Key"
|
|
)
|
|
api_key_header = token
|
|
else:
|
|
logger.debug(f" JWT détecté pour {method} {path}")
|
|
return await call_next(request)
|
|
|
|
if api_key_header:
|
|
logger.debug(f" API Key détectée pour {method} {path}")
|
|
return await self._handle_api_key_auth(
|
|
request, api_key_header, path, method, call_next
|
|
)
|
|
|
|
logger.warning(f" Aucune authentification: {method} {path}")
|
|
return JSONResponse(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
content={
|
|
"detail": "Authentification requise",
|
|
"hint": "Utilisez 'X-API-Key: sdk_live_xxx' ou 'Authorization: Bearer <jwt>'",
|
|
"path": path,
|
|
},
|
|
headers={"WWW-Authenticate": 'Bearer realm="API", charset="UTF-8"'},
|
|
)
|
|
|
|
async def _handle_api_key_auth(
|
|
self,
|
|
request: Request,
|
|
api_key: str,
|
|
path: str,
|
|
method: str,
|
|
call_next: Callable,
|
|
):
|
|
"""Gère l'authentification par API Key avec vérification STRICTE"""
|
|
try:
|
|
from database.db_config import async_session_factory
|
|
from services.api_key import ApiKeyService
|
|
|
|
async with async_session_factory() as session:
|
|
service = ApiKeyService(session)
|
|
|
|
api_key_obj = await service.verify_api_key(api_key)
|
|
|
|
if not api_key_obj:
|
|
logger.warning(f" Clé API invalide: {method} {path}")
|
|
return JSONResponse(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
content={
|
|
"detail": "Clé API invalide ou expirée",
|
|
"hint": "Vérifiez votre clé X-API-Key",
|
|
},
|
|
)
|
|
|
|
is_allowed, rate_info = await service.check_rate_limit(api_key_obj)
|
|
if not is_allowed:
|
|
logger.warning(f"⚠️ Rate limit: {api_key_obj.name}")
|
|
return JSONResponse(
|
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
content={"detail": "Rate limit dépassé"},
|
|
headers={
|
|
"X-RateLimit-Limit": str(rate_info["limit"]),
|
|
"X-RateLimit-Remaining": "0",
|
|
},
|
|
)
|
|
|
|
has_access = await service.check_endpoint_access(api_key_obj, path)
|
|
|
|
if not has_access:
|
|
import json
|
|
|
|
allowed = (
|
|
json.loads(api_key_obj.allowed_endpoints)
|
|
if api_key_obj.allowed_endpoints
|
|
else ["Tous"]
|
|
)
|
|
|
|
logger.warning(
|
|
f" ACCÈS REFUSÉ: {api_key_obj.name}\n"
|
|
f" Endpoint demandé: {path}\n"
|
|
f" Endpoints autorisés: {allowed}"
|
|
)
|
|
|
|
return JSONResponse(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
content={
|
|
"detail": "Accès non autorisé à cet endpoint",
|
|
"endpoint_requested": path,
|
|
"api_key_name": api_key_obj.name,
|
|
"allowed_endpoints": allowed,
|
|
"hint": "Cette clé API n'a pas accès à cet endpoint. Contactez l'administrateur.",
|
|
},
|
|
)
|
|
|
|
request.state.api_key = api_key_obj
|
|
request.state.authenticated_via = "api_key"
|
|
|
|
logger.info(f" ACCÈS AUTORISÉ: {api_key_obj.name} → {method} {path}")
|
|
|
|
return await call_next(request)
|
|
|
|
except Exception as e:
|
|
logger.error(f" Erreur validation API Key: {e}", exc_info=True)
|
|
return JSONResponse(
|
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
content={"detail": f"Erreur interne: {str(e)}"},
|
|
)
|
|
|
|
|
|
ApiKeyMiddleware = ApiKeyMiddlewareHTTP
|
|
|
|
|
|
def get_api_key_from_request(request: Request):
|
|
"""Récupère l'objet ApiKey depuis la requête si présent"""
|
|
return getattr(request.state, "api_key", None)
|
|
|
|
|
|
def get_auth_method(request: Request) -> str:
|
|
"""Retourne la méthode d'authentification utilisée"""
|
|
return getattr(request.state, "authenticated_via", "none")
|
|
|
|
|
|
__all__ = [
|
|
"SwaggerAuthMiddleware",
|
|
"ApiKeyMiddlewareHTTP",
|
|
"ApiKeyMiddleware",
|
|
"get_api_key_from_request",
|
|
"get_auth_method",
|
|
]
|