refactor(security): improve authentication logging and endpoint checks

This commit is contained in:
Fanilo-Nantenaina 2026-01-20 16:01:54 +03:00
parent 82d1d92e58
commit 67ef83c4e3
3 changed files with 92 additions and 65 deletions

View file

@ -1,11 +1,10 @@
from fastapi import Request, status from fastapi import Request, status
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from fastapi.security import HTTPBasic, HTTPBasicCredentials from fastapi.security import HTTPBasic, HTTPBasicCredentials
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp from starlette.types import ASGIApp
from sqlalchemy import select from sqlalchemy import select
from typing import Optional, Callable from typing import Callable
from datetime import datetime from datetime import datetime
import logging import logging
import base64 import base64
@ -16,7 +15,6 @@ security = HTTPBasic()
class SwaggerAuthMiddleware: class SwaggerAuthMiddleware:
PROTECTED_PATHS = ["/docs", "/redoc", "/openapi.json"] PROTECTED_PATHS = ["/docs", "/redoc", "/openapi.json"]
def __init__(self, app: ASGIApp): def __init__(self, app: ASGIApp):
@ -111,19 +109,30 @@ class ApiKeyMiddlewareHTTP(BaseHTTPMiddleware):
"/redoc", "/redoc",
"/openapi.json", "/openapi.json",
"/", "/",
"/auth/login", "/health",
"/auth/register", "/auth",
"/auth/verify-email", "/api-keys/verify",
"/auth/reset-password",
"/auth/request-reset",
"/auth/refresh",
] ]
def _is_excluded_path(self, path: str) -> bool:
"""Vérifie si le chemin est exclu de l'authentification"""
if path == "/":
return True
for excluded in self.EXCLUDED_PATHS:
if excluded == "/":
continue
if path == excluded or path.startswith(excluded + "/"):
return True
return False
async def dispatch(self, request: Request, call_next: Callable): async def dispatch(self, request: Request, call_next: Callable):
path = request.url.path path = request.url.path
method = request.method method = request.method
if self._is_excluded_path(path): if self._is_excluded_path(path):
logger.debug(f" Route publique: {method} {path}")
return await call_next(request) return await call_next(request)
auth_header = request.headers.get("Authorization") auth_header = request.headers.get("Authorization")
@ -142,28 +151,17 @@ class ApiKeyMiddlewareHTTP(BaseHTTPMiddleware):
request, api_key, path, method, call_next request, api_key, path, method, call_next
) )
logger.warning(f" Aucune authentification pour {method} {path}")
return JSONResponse( return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
content={ content={
"detail": "Authentification requise", "detail": "Authentification requise",
"hint": "Utilisez 'X-API-Key: sdk_live_xxx' ou 'Authorization: Bearer <jwt>'", "hint": "Utilisez 'X-API-Key: sdk_live_xxx' ou 'Authorization: Bearer <jwt>'",
"path": path,
}, },
headers={"WWW-Authenticate": 'Bearer realm="API", charset="UTF-8"'}, headers={"WWW-Authenticate": 'Bearer realm="API", charset="UTF-8"'},
) )
def _is_excluded_path(self, path: str) -> bool:
"""Vérifie si le chemin est exclu de l'authentification"""
if path == "/":
return True
for excluded in self.EXCLUDED_PATHS:
if excluded == "/":
continue
if path == excluded or path.startswith(excluded + "/"):
return True
return False
async def _handle_api_key_auth( async def _handle_api_key_auth(
self, self,
request: Request, request: Request,
@ -179,9 +177,11 @@ class ApiKeyMiddlewareHTTP(BaseHTTPMiddleware):
async with async_session_factory() as session: async with async_session_factory() as session:
service = ApiKeyService(session) service = ApiKeyService(session)
api_key_obj = await service.verify_api_key(api_key) api_key_obj = await service.verify_api_key(api_key)
if not api_key_obj: if not api_key_obj:
logger.warning(f" Clé API invalide pour {method} {path}")
return JSONResponse( return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
content={ content={
@ -192,6 +192,7 @@ class ApiKeyMiddlewareHTTP(BaseHTTPMiddleware):
is_allowed, rate_info = await service.check_rate_limit(api_key_obj) is_allowed, rate_info = await service.check_rate_limit(api_key_obj)
if not is_allowed: if not is_allowed:
logger.warning(f"⚠️ Rate limit dépassé: {api_key_obj.name}")
return JSONResponse( return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS, status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={"detail": "Rate limit dépassé"}, content={"detail": "Rate limit dépassé"},
@ -203,19 +204,27 @@ class ApiKeyMiddlewareHTTP(BaseHTTPMiddleware):
has_access = await service.check_endpoint_access(api_key_obj, path) has_access = await service.check_endpoint_access(api_key_obj, path)
if not has_access: if not has_access:
logger.warning(
f"Accès refusé: {api_key_obj.name}{method} {path}"
)
return JSONResponse( return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
content={ content={
"detail": "Accès non autorisé à cet endpoint", "detail": "Accès non autorisé à cet endpoint",
"endpoint": path, "endpoint": path,
"api_key": api_key_obj.key_prefix + "...", "api_key_name": api_key_obj.name,
"allowed_endpoints": (
api_key_obj.allowed_endpoints
if api_key_obj.allowed_endpoints
else "Tous"
),
}, },
) )
request.state.api_key = api_key_obj request.state.api_key = api_key_obj
request.state.authenticated_via = "api_key" request.state.authenticated_via = "api_key"
logger.info(f" API Key valide: {api_key_obj.name}{method} {path}") logger.info(f" API Key valide: {api_key_obj.name}{method} {path}")
return await call_next(request) return await call_next(request)
@ -243,7 +252,7 @@ def get_auth_method(request: Request) -> str:
__all__ = [ __all__ = [
"SwaggerAuthMiddleware", "SwaggerAuthMiddleware",
"ApiKeyMiddlewareHTTP", "ApiKeyMiddlewareHTTP",
"ApiKeyMiddleware", # Alias "ApiKeyMiddleware",
"get_api_key_from_request", "get_api_key_from_request",
"get_auth_method", "get_auth_method",
] ]

View file

@ -31,9 +31,9 @@ print("\nDEBUG: Vérification des imports...")
for module in _test_imports: for module in _test_imports:
try: try:
__import__(module) __import__(module)
print(f" {module}") print(f" {module}")
except ImportError as e: except ImportError as e:
print(f" {module}: {e}") print(f" {module}: {e}")
import asyncio import asyncio
import argparse import argparse
@ -49,7 +49,7 @@ try:
from services.api_key import ApiKeyService from services.api_key import ApiKeyService
from security.auth import hash_password from security.auth import hash_password
except ImportError as e: except ImportError as e:
print(f"\n ERREUR D'IMPORT: {e}") print(f"\n ERREUR D'IMPORT: {e}")
print(f" Vérifiez que vous êtes dans /app") print(f" Vérifiez que vous êtes dans /app")
print(f" Commande correcte: cd /app && python scripts/manage_security.py ...") print(f" Commande correcte: cd /app && python scripts/manage_security.py ...")
sys.exit(1) sys.exit(1)
@ -67,7 +67,7 @@ async def add_swagger_user(username: str, password: str, full_name: str = None):
existing = result.scalar_one_or_none() existing = result.scalar_one_or_none()
if existing: if existing:
logger.error(f" L'utilisateur '{username}' existe déjà") logger.error(f" L'utilisateur '{username}' existe déjà")
return return
swagger_user = SwaggerUser( swagger_user = SwaggerUser(
@ -80,7 +80,7 @@ async def add_swagger_user(username: str, password: str, full_name: str = None):
session.add(swagger_user) session.add(swagger_user)
await session.commit() await session.commit()
logger.info(f" Utilisateur Swagger créé: {username}") logger.info(f" Utilisateur Swagger créé: {username}")
logger.info(f" Nom complet: {swagger_user.full_name}") logger.info(f" Nom complet: {swagger_user.full_name}")
@ -96,7 +96,7 @@ async def list_swagger_users():
logger.info(f"👥 {len(users)} utilisateur(s) Swagger:\n") logger.info(f"👥 {len(users)} utilisateur(s) Swagger:\n")
for user in users: for user in users:
status = "" if user.is_active else "" status = "" if user.is_active else ""
logger.info(f" {status} {user.username}") logger.info(f" {status} {user.username}")
logger.info(f" Nom: {user.full_name}") logger.info(f" Nom: {user.full_name}")
logger.info(f" Créé: {user.created_at}") logger.info(f" Créé: {user.created_at}")
@ -112,7 +112,7 @@ async def delete_swagger_user(username: str):
user = result.scalar_one_or_none() user = result.scalar_one_or_none()
if not user: if not user:
logger.error(f" Utilisateur '{username}' introuvable") logger.error(f" Utilisateur '{username}' introuvable")
return return
await session.delete(user) await session.delete(user)
@ -182,7 +182,7 @@ async def list_api_keys():
is_valid = key.is_active and ( is_valid = key.is_active and (
not key.expires_at or key.expires_at > datetime.now() not key.expires_at or key.expires_at > datetime.now()
) )
status = "" if is_valid else "" status = "" if is_valid else ""
logger.info(f" {status} {key.name:<30} ({key.key_prefix}...)") logger.info(f" {status} {key.name:<30} ({key.key_prefix}...)")
logger.info(f" ID: {key.id}") logger.info(f" ID: {key.id}")
@ -214,7 +214,7 @@ async def revoke_api_key(key_id: str):
key = result.scalar_one_or_none() key = result.scalar_one_or_none()
if not key: if not key:
logger.error(f" Clé API '{key_id}' introuvable") logger.error(f" Clé API '{key_id}' introuvable")
return return
key.is_active = False key.is_active = False
@ -232,11 +232,11 @@ async def verify_api_key(api_key: str):
key = await service.verify_api_key(api_key) key = await service.verify_api_key(api_key)
if not key: if not key:
logger.error(" Clé API invalide ou expirée") logger.error(" Clé API invalide ou expirée")
return return
logger.info("=" * 60) logger.info("=" * 60)
logger.info(" Clé API valide") logger.info(" Clé API valide")
logger.info("=" * 60) logger.info("=" * 60)
logger.info(f" Nom: {key.name}") logger.info(f" Nom: {key.name}")
logger.info(f" ID: {key.id}") logger.info(f" ID: {key.id}")
@ -346,7 +346,7 @@ if __name__ == "__main__":
print("\n Interrupted") print("\n Interrupted")
sys.exit(0) sys.exit(0)
except Exception as e: except Exception as e:
logger.error(f" Erreur: {e}") logger.error(f" Erreur: {e}")
import traceback import traceback
traceback.print_exc() traceback.print_exc()

View file

@ -134,7 +134,7 @@ class ApiKeyService:
api_key_obj.revoked_at = datetime.now() api_key_obj.revoked_at = datetime.now()
await self.session.commit() await self.session.commit()
logger.info(f" Clé API révoquée: {api_key_obj.name}") logger.info(f"🗑️ Clé API révoquée: {api_key_obj.name}")
return True return True
async def get_by_id(self, key_id: str) -> Optional[ApiKey]: async def get_by_id(self, key_id: str) -> Optional[ApiKey]:
@ -150,24 +150,42 @@ class ApiKeyService:
} }
async def check_endpoint_access(self, api_key_obj: ApiKey, endpoint: str) -> bool: async def check_endpoint_access(self, api_key_obj: ApiKey, endpoint: str) -> bool:
"""Vérifie si la clé a accès à un endpoint spécifique"""
if not api_key_obj.allowed_endpoints: if not api_key_obj.allowed_endpoints:
logger.debug(
f"🔓 API Key {api_key_obj.name}: Aucune restriction d'endpoint"
)
return True return True
try: try:
allowed = json.loads(api_key_obj.allowed_endpoints) allowed = json.loads(api_key_obj.allowed_endpoints)
for pattern in allowed: if "*" in allowed or "/*" in allowed:
if pattern == "*": logger.debug(f"🔓 API Key {api_key_obj.name}: Accès global autorisé")
return True
if pattern.endswith("*"):
prefix = pattern[:-1]
if endpoint.startswith(prefix):
return True
if pattern == endpoint:
return True return True
for pattern in allowed:
if pattern == endpoint:
logger.debug(f" Match exact: {pattern} == {endpoint}")
return True
if pattern.endswith("/*"):
base = pattern[:-2] # "/clients/*" → "/clients"
if endpoint == base or endpoint.startswith(base + "/"):
logger.debug(f" Match wildcard: {pattern}{endpoint}")
return True
elif pattern.endswith("*"):
base = pattern[:-1] # "/clients*" → "/clients"
if endpoint.startswith(base):
logger.debug(f" Match prefix: {pattern}{endpoint}")
return True
logger.warning(
f" API Key {api_key_obj.name}: Accès refusé à {endpoint}\n"
f" Endpoints autorisés: {allowed}"
)
return False return False
except json.JSONDecodeError: except json.JSONDecodeError:
logger.error(f" Erreur parsing allowed_endpoints pour {api_key_obj.id}") logger.error(f" Erreur parsing allowed_endpoints pour {api_key_obj.id}")
return False return False