refactor(security): improve authentication logging and endpoint checks
This commit is contained in:
parent
82d1d92e58
commit
67ef83c4e3
3 changed files with 92 additions and 65 deletions
|
|
@ -1,11 +1,10 @@
|
|||
|
||||
from fastapi import Request, status
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.types import ASGIApp
|
||||
from sqlalchemy import select
|
||||
from typing import Optional, Callable
|
||||
from typing import Callable
|
||||
from datetime import datetime
|
||||
import logging
|
||||
import base64
|
||||
|
|
@ -16,7 +15,6 @@ security = HTTPBasic()
|
|||
|
||||
|
||||
class SwaggerAuthMiddleware:
|
||||
|
||||
PROTECTED_PATHS = ["/docs", "/redoc", "/openapi.json"]
|
||||
|
||||
def __init__(self, app: ASGIApp):
|
||||
|
|
@ -111,19 +109,30 @@ class ApiKeyMiddlewareHTTP(BaseHTTPMiddleware):
|
|||
"/redoc",
|
||||
"/openapi.json",
|
||||
"/",
|
||||
"/auth/login",
|
||||
"/auth/register",
|
||||
"/auth/verify-email",
|
||||
"/auth/reset-password",
|
||||
"/auth/request-reset",
|
||||
"/auth/refresh",
|
||||
"/health",
|
||||
"/auth",
|
||||
"/api-keys/verify",
|
||||
]
|
||||
|
||||
def _is_excluded_path(self, path: str) -> bool:
|
||||
"""Vérifie si le chemin est exclu de l'authentification"""
|
||||
if path == "/":
|
||||
return True
|
||||
|
||||
for excluded in self.EXCLUDED_PATHS:
|
||||
if excluded == "/":
|
||||
continue
|
||||
if path == excluded or path.startswith(excluded + "/"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable):
|
||||
path = request.url.path
|
||||
method = request.method
|
||||
|
||||
if self._is_excluded_path(path):
|
||||
logger.debug(f" Route publique: {method} {path}")
|
||||
return await call_next(request)
|
||||
|
||||
auth_header = request.headers.get("Authorization")
|
||||
|
|
@ -142,28 +151,17 @@ class ApiKeyMiddlewareHTTP(BaseHTTPMiddleware):
|
|||
request, api_key, path, method, call_next
|
||||
)
|
||||
|
||||
logger.warning(f" Aucune authentification pour {method} {path}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={
|
||||
"detail": "Authentification requise",
|
||||
"hint": "Utilisez 'X-API-Key: sdk_live_xxx' ou 'Authorization: Bearer <jwt>'",
|
||||
"path": path,
|
||||
},
|
||||
headers={"WWW-Authenticate": 'Bearer realm="API", charset="UTF-8"'},
|
||||
)
|
||||
|
||||
def _is_excluded_path(self, path: str) -> bool:
|
||||
"""Vérifie si le chemin est exclu de l'authentification"""
|
||||
if path == "/":
|
||||
return True
|
||||
|
||||
for excluded in self.EXCLUDED_PATHS:
|
||||
if excluded == "/":
|
||||
continue
|
||||
if path == excluded or path.startswith(excluded + "/"):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
async def _handle_api_key_auth(
|
||||
self,
|
||||
request: Request,
|
||||
|
|
@ -179,9 +177,11 @@ class ApiKeyMiddlewareHTTP(BaseHTTPMiddleware):
|
|||
|
||||
async with async_session_factory() as session:
|
||||
service = ApiKeyService(session)
|
||||
|
||||
api_key_obj = await service.verify_api_key(api_key)
|
||||
|
||||
if not api_key_obj:
|
||||
logger.warning(f" Clé API invalide pour {method} {path}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
content={
|
||||
|
|
@ -192,6 +192,7 @@ class ApiKeyMiddlewareHTTP(BaseHTTPMiddleware):
|
|||
|
||||
is_allowed, rate_info = await service.check_rate_limit(api_key_obj)
|
||||
if not is_allowed:
|
||||
logger.warning(f"⚠️ Rate limit dépassé: {api_key_obj.name}")
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
content={"detail": "Rate limit dépassé"},
|
||||
|
|
@ -203,19 +204,27 @@ class ApiKeyMiddlewareHTTP(BaseHTTPMiddleware):
|
|||
|
||||
has_access = await service.check_endpoint_access(api_key_obj, path)
|
||||
if not has_access:
|
||||
logger.warning(
|
||||
f"Accès refusé: {api_key_obj.name} → {method} {path}"
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
content={
|
||||
"detail": "Accès non autorisé à cet endpoint",
|
||||
"endpoint": path,
|
||||
"api_key": api_key_obj.key_prefix + "...",
|
||||
"api_key_name": api_key_obj.name,
|
||||
"allowed_endpoints": (
|
||||
api_key_obj.allowed_endpoints
|
||||
if api_key_obj.allowed_endpoints
|
||||
else "Tous"
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
request.state.api_key = api_key_obj
|
||||
request.state.authenticated_via = "api_key"
|
||||
|
||||
logger.info(f"✓ API Key valide: {api_key_obj.name} → {method} {path}")
|
||||
logger.info(f"✅ API Key valide: {api_key_obj.name} → {method} {path}")
|
||||
|
||||
return await call_next(request)
|
||||
|
||||
|
|
@ -243,7 +252,7 @@ def get_auth_method(request: Request) -> str:
|
|||
__all__ = [
|
||||
"SwaggerAuthMiddleware",
|
||||
"ApiKeyMiddlewareHTTP",
|
||||
"ApiKeyMiddleware", # Alias
|
||||
"ApiKeyMiddleware",
|
||||
"get_api_key_from_request",
|
||||
"get_auth_method",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -31,9 +31,9 @@ print("\nDEBUG: Vérification des imports...")
|
|||
for module in _test_imports:
|
||||
try:
|
||||
__import__(module)
|
||||
print(f" ✅ {module}")
|
||||
print(f" {module}")
|
||||
except ImportError as e:
|
||||
print(f" ❌ {module}: {e}")
|
||||
print(f" {module}: {e}")
|
||||
|
||||
import asyncio
|
||||
import argparse
|
||||
|
|
@ -49,7 +49,7 @@ try:
|
|||
from services.api_key import ApiKeyService
|
||||
from security.auth import hash_password
|
||||
except ImportError as e:
|
||||
print(f"\n❌ ERREUR D'IMPORT: {e}")
|
||||
print(f"\n ERREUR D'IMPORT: {e}")
|
||||
print(f" Vérifiez que vous êtes dans /app")
|
||||
print(f" Commande correcte: cd /app && python scripts/manage_security.py ...")
|
||||
sys.exit(1)
|
||||
|
|
@ -67,7 +67,7 @@ async def add_swagger_user(username: str, password: str, full_name: str = None):
|
|||
existing = result.scalar_one_or_none()
|
||||
|
||||
if existing:
|
||||
logger.error(f"❌ L'utilisateur '{username}' existe déjà")
|
||||
logger.error(f" L'utilisateur '{username}' existe déjà")
|
||||
return
|
||||
|
||||
swagger_user = SwaggerUser(
|
||||
|
|
@ -80,7 +80,7 @@ async def add_swagger_user(username: str, password: str, full_name: str = None):
|
|||
session.add(swagger_user)
|
||||
await session.commit()
|
||||
|
||||
logger.info(f"✅ Utilisateur Swagger créé: {username}")
|
||||
logger.info(f" Utilisateur Swagger créé: {username}")
|
||||
logger.info(f" Nom complet: {swagger_user.full_name}")
|
||||
|
||||
|
||||
|
|
@ -96,7 +96,7 @@ async def list_swagger_users():
|
|||
|
||||
logger.info(f"👥 {len(users)} utilisateur(s) Swagger:\n")
|
||||
for user in users:
|
||||
status = "✅" if user.is_active else "❌"
|
||||
status = "" if user.is_active else ""
|
||||
logger.info(f" {status} {user.username}")
|
||||
logger.info(f" Nom: {user.full_name}")
|
||||
logger.info(f" Créé: {user.created_at}")
|
||||
|
|
@ -112,7 +112,7 @@ async def delete_swagger_user(username: str):
|
|||
user = result.scalar_one_or_none()
|
||||
|
||||
if not user:
|
||||
logger.error(f"❌ Utilisateur '{username}' introuvable")
|
||||
logger.error(f" Utilisateur '{username}' introuvable")
|
||||
return
|
||||
|
||||
await session.delete(user)
|
||||
|
|
@ -182,7 +182,7 @@ async def list_api_keys():
|
|||
is_valid = key.is_active and (
|
||||
not key.expires_at or key.expires_at > datetime.now()
|
||||
)
|
||||
status = "✅" if is_valid else "❌"
|
||||
status = "" if is_valid else ""
|
||||
|
||||
logger.info(f" {status} {key.name:<30} ({key.key_prefix}...)")
|
||||
logger.info(f" ID: {key.id}")
|
||||
|
|
@ -214,7 +214,7 @@ async def revoke_api_key(key_id: str):
|
|||
key = result.scalar_one_or_none()
|
||||
|
||||
if not key:
|
||||
logger.error(f"❌ Clé API '{key_id}' introuvable")
|
||||
logger.error(f" Clé API '{key_id}' introuvable")
|
||||
return
|
||||
|
||||
key.is_active = False
|
||||
|
|
@ -232,11 +232,11 @@ async def verify_api_key(api_key: str):
|
|||
key = await service.verify_api_key(api_key)
|
||||
|
||||
if not key:
|
||||
logger.error("❌ Clé API invalide ou expirée")
|
||||
logger.error(" Clé API invalide ou expirée")
|
||||
return
|
||||
|
||||
logger.info("=" * 60)
|
||||
logger.info("✅ Clé API valide")
|
||||
logger.info(" Clé API valide")
|
||||
logger.info("=" * 60)
|
||||
logger.info(f" Nom: {key.name}")
|
||||
logger.info(f" ID: {key.id}")
|
||||
|
|
@ -346,7 +346,7 @@ if __name__ == "__main__":
|
|||
print("\nℹ️ Interrupted")
|
||||
sys.exit(0)
|
||||
except Exception as e:
|
||||
logger.error(f"❌ Erreur: {e}")
|
||||
logger.error(f" Erreur: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
|
|
|||
|
|
@ -134,7 +134,7 @@ class ApiKeyService:
|
|||
api_key_obj.revoked_at = datetime.now()
|
||||
await self.session.commit()
|
||||
|
||||
logger.info(f" Clé API révoquée: {api_key_obj.name}")
|
||||
logger.info(f"🗑️ Clé API révoquée: {api_key_obj.name}")
|
||||
return True
|
||||
|
||||
async def get_by_id(self, key_id: str) -> Optional[ApiKey]:
|
||||
|
|
@ -150,24 +150,42 @@ class ApiKeyService:
|
|||
}
|
||||
|
||||
async def check_endpoint_access(self, api_key_obj: ApiKey, endpoint: str) -> bool:
|
||||
"""Vérifie si la clé a accès à un endpoint spécifique"""
|
||||
if not api_key_obj.allowed_endpoints:
|
||||
logger.debug(
|
||||
f"🔓 API Key {api_key_obj.name}: Aucune restriction d'endpoint"
|
||||
)
|
||||
return True
|
||||
|
||||
try:
|
||||
allowed = json.loads(api_key_obj.allowed_endpoints)
|
||||
|
||||
for pattern in allowed:
|
||||
if pattern == "*":
|
||||
return True
|
||||
if pattern.endswith("*"):
|
||||
prefix = pattern[:-1]
|
||||
if endpoint.startswith(prefix):
|
||||
return True
|
||||
if pattern == endpoint:
|
||||
if "*" in allowed or "/*" in allowed:
|
||||
logger.debug(f"🔓 API Key {api_key_obj.name}: Accès global autorisé")
|
||||
return True
|
||||
|
||||
for pattern in allowed:
|
||||
if pattern == endpoint:
|
||||
logger.debug(f" Match exact: {pattern} == {endpoint}")
|
||||
return True
|
||||
|
||||
if pattern.endswith("/*"):
|
||||
base = pattern[:-2] # "/clients/*" → "/clients"
|
||||
if endpoint == base or endpoint.startswith(base + "/"):
|
||||
logger.debug(f" Match wildcard: {pattern} ↔ {endpoint}")
|
||||
return True
|
||||
|
||||
elif pattern.endswith("*"):
|
||||
base = pattern[:-1] # "/clients*" → "/clients"
|
||||
if endpoint.startswith(base):
|
||||
logger.debug(f" Match prefix: {pattern} ↔ {endpoint}")
|
||||
return True
|
||||
|
||||
logger.warning(
|
||||
f" API Key {api_key_obj.name}: Accès refusé à {endpoint}\n"
|
||||
f" Endpoints autorisés: {allowed}"
|
||||
)
|
||||
return False
|
||||
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f" Erreur parsing allowed_endpoints pour {api_key_obj.id}")
|
||||
return False
|
||||
|
|
|
|||
Loading…
Reference in a new issue