Sage100-vps/middleware/security.py

271 lines
9.6 KiB
Python

from fastapi import Request, status
from fastapi.responses import JSONResponse
from fastapi.security import HTTPBasic, HTTPBasicCredentials
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
from sqlalchemy import select
from typing import Callable
from datetime import datetime
import logging
import base64
logger = logging.getLogger(__name__)
security = HTTPBasic()
class SwaggerAuthMiddleware:
PROTECTED_PATHS = ["/docs", "/redoc", "/openapi.json"]
def __init__(self, app: ASGIApp):
self.app = app
async def __call__(self, scope, receive, send):
if scope["type"] != "http":
await self.app(scope, receive, send)
return
request = Request(scope, receive=receive)
path = request.url.path
if not any(path.startswith(p) for p in self.PROTECTED_PATHS):
await self.app(scope, receive, send)
return
auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Basic "):
response = JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"detail": "Authentification requise pour la documentation"},
headers={"WWW-Authenticate": 'Basic realm="Swagger UI"'},
)
await response(scope, receive, send)
return
try:
encoded_credentials = auth_header.split(" ")[1]
decoded_credentials = base64.b64decode(encoded_credentials).decode("utf-8")
username, password = decoded_credentials.split(":", 1)
credentials = HTTPBasicCredentials(username=username, password=password)
if not await self._verify_credentials(credentials):
response = JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"detail": "Identifiants invalides"},
headers={"WWW-Authenticate": 'Basic realm="Swagger UI"'},
)
await response(scope, receive, send)
return
except Exception as e:
logger.error(f"Erreur parsing auth header: {e}")
response = JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"detail": "Format d'authentification invalide"},
headers={"WWW-Authenticate": 'Basic realm="Swagger UI"'},
)
await response(scope, receive, send)
return
await self.app(scope, receive, send)
async def _verify_credentials(self, credentials: HTTPBasicCredentials) -> bool:
"""Vérifie les identifiants dans la base de données"""
from database.db_config import async_session_factory
from database.models.api_key import SwaggerUser
from security.auth import verify_password
try:
async with async_session_factory() as session:
result = await session.execute(
select(SwaggerUser).where(
SwaggerUser.username == credentials.username
)
)
swagger_user = result.scalar_one_or_none()
if swagger_user and swagger_user.is_active:
if verify_password(
credentials.password, swagger_user.hashed_password
):
swagger_user.last_login = datetime.now()
await session.commit()
logger.info(f"✓ Accès Swagger autorisé: {credentials.username}")
return True
logger.warning(f"✗ Accès Swagger refusé: {credentials.username}")
return False
except Exception as e:
logger.error(f"Erreur vérification credentials: {e}")
return False
class ApiKeyMiddlewareHTTP(BaseHTTPMiddleware):
EXCLUDED_PATHS = [
"/docs",
"/redoc",
"/openapi.json",
"/",
"/health",
"/auth",
"/api-keys/verify",
]
def _is_excluded_path(self, path: str) -> bool:
"""Vérifie si le chemin est exclu de l'authentification"""
if path == "/":
return True
for excluded in self.EXCLUDED_PATHS:
if excluded == "/":
continue
if path == excluded or path.startswith(excluded + "/"):
return True
return False
async def dispatch(self, request: Request, call_next: Callable):
path = request.url.path
method = request.method
if self._is_excluded_path(path):
return await call_next(request)
auth_header = request.headers.get("Authorization")
api_key_header = request.headers.get("X-API-Key")
if auth_header and auth_header.startswith("Bearer "):
token = auth_header.split(" ")[1]
if token.startswith("sdk_live_"):
logger.warning(
" API Key envoyée dans Authorization au lieu de X-API-Key"
)
api_key_header = token
else:
logger.debug(f" JWT détecté pour {method} {path}")
return await call_next(request)
if api_key_header:
logger.debug(f" API Key détectée pour {method} {path}")
return await self._handle_api_key_auth(
request, api_key_header, path, method, call_next
)
logger.warning(f" Aucune authentification: {method} {path}")
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={
"detail": "Authentification requise",
"hint": "Utilisez 'X-API-Key: sdk_live_xxx' ou 'Authorization: Bearer <jwt>'",
"path": path,
},
headers={"WWW-Authenticate": 'Bearer realm="API", charset="UTF-8"'},
)
async def _handle_api_key_auth(
self,
request: Request,
api_key: str,
path: str,
method: str,
call_next: Callable,
):
"""Gère l'authentification par API Key avec vérification STRICTE"""
try:
from database.db_config import async_session_factory
from services.api_key import ApiKeyService
async with async_session_factory() as session:
service = ApiKeyService(session)
api_key_obj = await service.verify_api_key(api_key)
if not api_key_obj:
logger.warning(f" Clé API invalide: {method} {path}")
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={
"detail": "Clé API invalide ou expirée",
"hint": "Vérifiez votre clé X-API-Key",
},
)
is_allowed, rate_info = await service.check_rate_limit(api_key_obj)
if not is_allowed:
logger.warning(f"⚠️ Rate limit: {api_key_obj.name}")
return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={"detail": "Rate limit dépassé"},
headers={
"X-RateLimit-Limit": str(rate_info["limit"]),
"X-RateLimit-Remaining": "0",
},
)
has_access = await service.check_endpoint_access(api_key_obj, path)
if not has_access:
import json
allowed = (
json.loads(api_key_obj.allowed_endpoints)
if api_key_obj.allowed_endpoints
else ["Tous"]
)
logger.warning(
f" ACCÈS REFUSÉ: {api_key_obj.name}\n"
f" Endpoint demandé: {path}\n"
f" Endpoints autorisés: {allowed}"
)
return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content={
"detail": "Accès non autorisé à cet endpoint",
"endpoint_requested": path,
"api_key_name": api_key_obj.name,
"allowed_endpoints": allowed,
"hint": "Cette clé API n'a pas accès à cet endpoint. Contactez l'administrateur.",
},
)
request.state.api_key = api_key_obj
request.state.authenticated_via = "api_key"
logger.info(f" ACCÈS AUTORISÉ: {api_key_obj.name}{method} {path}")
return await call_next(request)
except Exception as e:
logger.error(f" Erreur validation API Key: {e}", exc_info=True)
return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"detail": f"Erreur interne: {str(e)}"},
)
ApiKeyMiddleware = ApiKeyMiddlewareHTTP
def get_api_key_from_request(request: Request):
"""Récupère l'objet ApiKey depuis la requête si présent"""
return getattr(request.state, "api_key", None)
def get_auth_method(request: Request) -> str:
"""Retourne la méthode d'authentification utilisée"""
return getattr(request.state, "authenticated_via", "none")
__all__ = [
"SwaggerAuthMiddleware",
"ApiKeyMiddlewareHTTP",
"ApiKeyMiddleware",
"get_api_key_from_request",
"get_auth_method",
]