refactor(security): improve authentication logging and endpoint checks
This commit is contained in:
parent
82d1d92e58
commit
67ef83c4e3
3 changed files with 92 additions and 65 deletions
|
|
@ -1,11 +1,10 @@
|
||||||
|
|
||||||
from fastapi import Request, status
|
from fastapi import Request, status
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
from fastapi.security import HTTPBasic, HTTPBasicCredentials
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
from starlette.types import ASGIApp
|
from starlette.types import ASGIApp
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from typing import Optional, Callable
|
from typing import Callable
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import logging
|
import logging
|
||||||
import base64
|
import base64
|
||||||
|
|
@ -16,7 +15,6 @@ security = HTTPBasic()
|
||||||
|
|
||||||
|
|
||||||
class SwaggerAuthMiddleware:
|
class SwaggerAuthMiddleware:
|
||||||
|
|
||||||
PROTECTED_PATHS = ["/docs", "/redoc", "/openapi.json"]
|
PROTECTED_PATHS = ["/docs", "/redoc", "/openapi.json"]
|
||||||
|
|
||||||
def __init__(self, app: ASGIApp):
|
def __init__(self, app: ASGIApp):
|
||||||
|
|
@ -111,19 +109,30 @@ class ApiKeyMiddlewareHTTP(BaseHTTPMiddleware):
|
||||||
"/redoc",
|
"/redoc",
|
||||||
"/openapi.json",
|
"/openapi.json",
|
||||||
"/",
|
"/",
|
||||||
"/auth/login",
|
"/health",
|
||||||
"/auth/register",
|
"/auth",
|
||||||
"/auth/verify-email",
|
"/api-keys/verify",
|
||||||
"/auth/reset-password",
|
|
||||||
"/auth/request-reset",
|
|
||||||
"/auth/refresh",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def _is_excluded_path(self, path: str) -> bool:
|
||||||
|
"""Vérifie si le chemin est exclu de l'authentification"""
|
||||||
|
if path == "/":
|
||||||
|
return True
|
||||||
|
|
||||||
|
for excluded in self.EXCLUDED_PATHS:
|
||||||
|
if excluded == "/":
|
||||||
|
continue
|
||||||
|
if path == excluded or path.startswith(excluded + "/"):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
async def dispatch(self, request: Request, call_next: Callable):
|
async def dispatch(self, request: Request, call_next: Callable):
|
||||||
path = request.url.path
|
path = request.url.path
|
||||||
method = request.method
|
method = request.method
|
||||||
|
|
||||||
if self._is_excluded_path(path):
|
if self._is_excluded_path(path):
|
||||||
|
logger.debug(f" Route publique: {method} {path}")
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
auth_header = request.headers.get("Authorization")
|
auth_header = request.headers.get("Authorization")
|
||||||
|
|
@ -142,28 +151,17 @@ class ApiKeyMiddlewareHTTP(BaseHTTPMiddleware):
|
||||||
request, api_key, path, method, call_next
|
request, api_key, path, method, call_next
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.warning(f" Aucune authentification pour {method} {path}")
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
content={
|
content={
|
||||||
"detail": "Authentification requise",
|
"detail": "Authentification requise",
|
||||||
"hint": "Utilisez 'X-API-Key: sdk_live_xxx' ou 'Authorization: Bearer <jwt>'",
|
"hint": "Utilisez 'X-API-Key: sdk_live_xxx' ou 'Authorization: Bearer <jwt>'",
|
||||||
|
"path": path,
|
||||||
},
|
},
|
||||||
headers={"WWW-Authenticate": 'Bearer realm="API", charset="UTF-8"'},
|
headers={"WWW-Authenticate": 'Bearer realm="API", charset="UTF-8"'},
|
||||||
)
|
)
|
||||||
|
|
||||||
def _is_excluded_path(self, path: str) -> bool:
|
|
||||||
"""Vérifie si le chemin est exclu de l'authentification"""
|
|
||||||
if path == "/":
|
|
||||||
return True
|
|
||||||
|
|
||||||
for excluded in self.EXCLUDED_PATHS:
|
|
||||||
if excluded == "/":
|
|
||||||
continue
|
|
||||||
if path == excluded or path.startswith(excluded + "/"):
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def _handle_api_key_auth(
|
async def _handle_api_key_auth(
|
||||||
self,
|
self,
|
||||||
request: Request,
|
request: Request,
|
||||||
|
|
@ -179,9 +177,11 @@ class ApiKeyMiddlewareHTTP(BaseHTTPMiddleware):
|
||||||
|
|
||||||
async with async_session_factory() as session:
|
async with async_session_factory() as session:
|
||||||
service = ApiKeyService(session)
|
service = ApiKeyService(session)
|
||||||
|
|
||||||
api_key_obj = await service.verify_api_key(api_key)
|
api_key_obj = await service.verify_api_key(api_key)
|
||||||
|
|
||||||
if not api_key_obj:
|
if not api_key_obj:
|
||||||
|
logger.warning(f" Clé API invalide pour {method} {path}")
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
content={
|
content={
|
||||||
|
|
@ -192,6 +192,7 @@ class ApiKeyMiddlewareHTTP(BaseHTTPMiddleware):
|
||||||
|
|
||||||
is_allowed, rate_info = await service.check_rate_limit(api_key_obj)
|
is_allowed, rate_info = await service.check_rate_limit(api_key_obj)
|
||||||
if not is_allowed:
|
if not is_allowed:
|
||||||
|
logger.warning(f"⚠️ Rate limit dépassé: {api_key_obj.name}")
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||||
content={"detail": "Rate limit dépassé"},
|
content={"detail": "Rate limit dépassé"},
|
||||||
|
|
@ -203,19 +204,27 @@ class ApiKeyMiddlewareHTTP(BaseHTTPMiddleware):
|
||||||
|
|
||||||
has_access = await service.check_endpoint_access(api_key_obj, path)
|
has_access = await service.check_endpoint_access(api_key_obj, path)
|
||||||
if not has_access:
|
if not has_access:
|
||||||
|
logger.warning(
|
||||||
|
f"Accès refusé: {api_key_obj.name} → {method} {path}"
|
||||||
|
)
|
||||||
return JSONResponse(
|
return JSONResponse(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
content={
|
content={
|
||||||
"detail": "Accès non autorisé à cet endpoint",
|
"detail": "Accès non autorisé à cet endpoint",
|
||||||
"endpoint": path,
|
"endpoint": path,
|
||||||
"api_key": api_key_obj.key_prefix + "...",
|
"api_key_name": api_key_obj.name,
|
||||||
|
"allowed_endpoints": (
|
||||||
|
api_key_obj.allowed_endpoints
|
||||||
|
if api_key_obj.allowed_endpoints
|
||||||
|
else "Tous"
|
||||||
|
),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
request.state.api_key = api_key_obj
|
request.state.api_key = api_key_obj
|
||||||
request.state.authenticated_via = "api_key"
|
request.state.authenticated_via = "api_key"
|
||||||
|
|
||||||
logger.info(f"✓ API Key valide: {api_key_obj.name} → {method} {path}")
|
logger.info(f"✅ API Key valide: {api_key_obj.name} → {method} {path}")
|
||||||
|
|
||||||
return await call_next(request)
|
return await call_next(request)
|
||||||
|
|
||||||
|
|
@ -243,7 +252,7 @@ def get_auth_method(request: Request) -> str:
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"SwaggerAuthMiddleware",
|
"SwaggerAuthMiddleware",
|
||||||
"ApiKeyMiddlewareHTTP",
|
"ApiKeyMiddlewareHTTP",
|
||||||
"ApiKeyMiddleware", # Alias
|
"ApiKeyMiddleware",
|
||||||
"get_api_key_from_request",
|
"get_api_key_from_request",
|
||||||
"get_auth_method",
|
"get_auth_method",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -31,9 +31,9 @@ print("\nDEBUG: Vérification des imports...")
|
||||||
for module in _test_imports:
|
for module in _test_imports:
|
||||||
try:
|
try:
|
||||||
__import__(module)
|
__import__(module)
|
||||||
print(f" ✅ {module}")
|
print(f" {module}")
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
print(f" ❌ {module}: {e}")
|
print(f" {module}: {e}")
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import argparse
|
import argparse
|
||||||
|
|
@ -49,7 +49,7 @@ try:
|
||||||
from services.api_key import ApiKeyService
|
from services.api_key import ApiKeyService
|
||||||
from security.auth import hash_password
|
from security.auth import hash_password
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
print(f"\n❌ ERREUR D'IMPORT: {e}")
|
print(f"\n ERREUR D'IMPORT: {e}")
|
||||||
print(f" Vérifiez que vous êtes dans /app")
|
print(f" Vérifiez que vous êtes dans /app")
|
||||||
print(f" Commande correcte: cd /app && python scripts/manage_security.py ...")
|
print(f" Commande correcte: cd /app && python scripts/manage_security.py ...")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
@ -67,7 +67,7 @@ async def add_swagger_user(username: str, password: str, full_name: str = None):
|
||||||
existing = result.scalar_one_or_none()
|
existing = result.scalar_one_or_none()
|
||||||
|
|
||||||
if existing:
|
if existing:
|
||||||
logger.error(f"❌ L'utilisateur '{username}' existe déjà")
|
logger.error(f" L'utilisateur '{username}' existe déjà")
|
||||||
return
|
return
|
||||||
|
|
||||||
swagger_user = SwaggerUser(
|
swagger_user = SwaggerUser(
|
||||||
|
|
@ -80,7 +80,7 @@ async def add_swagger_user(username: str, password: str, full_name: str = None):
|
||||||
session.add(swagger_user)
|
session.add(swagger_user)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
logger.info(f"✅ Utilisateur Swagger créé: {username}")
|
logger.info(f" Utilisateur Swagger créé: {username}")
|
||||||
logger.info(f" Nom complet: {swagger_user.full_name}")
|
logger.info(f" Nom complet: {swagger_user.full_name}")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -96,7 +96,7 @@ async def list_swagger_users():
|
||||||
|
|
||||||
logger.info(f"👥 {len(users)} utilisateur(s) Swagger:\n")
|
logger.info(f"👥 {len(users)} utilisateur(s) Swagger:\n")
|
||||||
for user in users:
|
for user in users:
|
||||||
status = "✅" if user.is_active else "❌"
|
status = "" if user.is_active else ""
|
||||||
logger.info(f" {status} {user.username}")
|
logger.info(f" {status} {user.username}")
|
||||||
logger.info(f" Nom: {user.full_name}")
|
logger.info(f" Nom: {user.full_name}")
|
||||||
logger.info(f" Créé: {user.created_at}")
|
logger.info(f" Créé: {user.created_at}")
|
||||||
|
|
@ -112,7 +112,7 @@ async def delete_swagger_user(username: str):
|
||||||
user = result.scalar_one_or_none()
|
user = result.scalar_one_or_none()
|
||||||
|
|
||||||
if not user:
|
if not user:
|
||||||
logger.error(f"❌ Utilisateur '{username}' introuvable")
|
logger.error(f" Utilisateur '{username}' introuvable")
|
||||||
return
|
return
|
||||||
|
|
||||||
await session.delete(user)
|
await session.delete(user)
|
||||||
|
|
@ -182,7 +182,7 @@ async def list_api_keys():
|
||||||
is_valid = key.is_active and (
|
is_valid = key.is_active and (
|
||||||
not key.expires_at or key.expires_at > datetime.now()
|
not key.expires_at or key.expires_at > datetime.now()
|
||||||
)
|
)
|
||||||
status = "✅" if is_valid else "❌"
|
status = "" if is_valid else ""
|
||||||
|
|
||||||
logger.info(f" {status} {key.name:<30} ({key.key_prefix}...)")
|
logger.info(f" {status} {key.name:<30} ({key.key_prefix}...)")
|
||||||
logger.info(f" ID: {key.id}")
|
logger.info(f" ID: {key.id}")
|
||||||
|
|
@ -214,7 +214,7 @@ async def revoke_api_key(key_id: str):
|
||||||
key = result.scalar_one_or_none()
|
key = result.scalar_one_or_none()
|
||||||
|
|
||||||
if not key:
|
if not key:
|
||||||
logger.error(f"❌ Clé API '{key_id}' introuvable")
|
logger.error(f" Clé API '{key_id}' introuvable")
|
||||||
return
|
return
|
||||||
|
|
||||||
key.is_active = False
|
key.is_active = False
|
||||||
|
|
@ -232,11 +232,11 @@ async def verify_api_key(api_key: str):
|
||||||
key = await service.verify_api_key(api_key)
|
key = await service.verify_api_key(api_key)
|
||||||
|
|
||||||
if not key:
|
if not key:
|
||||||
logger.error("❌ Clé API invalide ou expirée")
|
logger.error(" Clé API invalide ou expirée")
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("=" * 60)
|
logger.info("=" * 60)
|
||||||
logger.info("✅ Clé API valide")
|
logger.info(" Clé API valide")
|
||||||
logger.info("=" * 60)
|
logger.info("=" * 60)
|
||||||
logger.info(f" Nom: {key.name}")
|
logger.info(f" Nom: {key.name}")
|
||||||
logger.info(f" ID: {key.id}")
|
logger.info(f" ID: {key.id}")
|
||||||
|
|
@ -346,7 +346,7 @@ if __name__ == "__main__":
|
||||||
print("\nℹ️ Interrupted")
|
print("\nℹ️ Interrupted")
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"❌ Erreur: {e}")
|
logger.error(f" Erreur: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
|
||||||
|
|
@ -134,7 +134,7 @@ class ApiKeyService:
|
||||||
api_key_obj.revoked_at = datetime.now()
|
api_key_obj.revoked_at = datetime.now()
|
||||||
await self.session.commit()
|
await self.session.commit()
|
||||||
|
|
||||||
logger.info(f" Clé API révoquée: {api_key_obj.name}")
|
logger.info(f"🗑️ Clé API révoquée: {api_key_obj.name}")
|
||||||
return True
|
return True
|
||||||
|
|
||||||
async def get_by_id(self, key_id: str) -> Optional[ApiKey]:
|
async def get_by_id(self, key_id: str) -> Optional[ApiKey]:
|
||||||
|
|
@ -150,24 +150,42 @@ class ApiKeyService:
|
||||||
}
|
}
|
||||||
|
|
||||||
async def check_endpoint_access(self, api_key_obj: ApiKey, endpoint: str) -> bool:
|
async def check_endpoint_access(self, api_key_obj: ApiKey, endpoint: str) -> bool:
|
||||||
"""Vérifie si la clé a accès à un endpoint spécifique"""
|
|
||||||
if not api_key_obj.allowed_endpoints:
|
if not api_key_obj.allowed_endpoints:
|
||||||
|
logger.debug(
|
||||||
|
f"🔓 API Key {api_key_obj.name}: Aucune restriction d'endpoint"
|
||||||
|
)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
allowed = json.loads(api_key_obj.allowed_endpoints)
|
allowed = json.loads(api_key_obj.allowed_endpoints)
|
||||||
|
|
||||||
for pattern in allowed:
|
if "*" in allowed or "/*" in allowed:
|
||||||
if pattern == "*":
|
logger.debug(f"🔓 API Key {api_key_obj.name}: Accès global autorisé")
|
||||||
return True
|
|
||||||
if pattern.endswith("*"):
|
|
||||||
prefix = pattern[:-1]
|
|
||||||
if endpoint.startswith(prefix):
|
|
||||||
return True
|
|
||||||
if pattern == endpoint:
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
for pattern in allowed:
|
||||||
|
if pattern == endpoint:
|
||||||
|
logger.debug(f" Match exact: {pattern} == {endpoint}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
if pattern.endswith("/*"):
|
||||||
|
base = pattern[:-2] # "/clients/*" → "/clients"
|
||||||
|
if endpoint == base or endpoint.startswith(base + "/"):
|
||||||
|
logger.debug(f" Match wildcard: {pattern} ↔ {endpoint}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
elif pattern.endswith("*"):
|
||||||
|
base = pattern[:-1] # "/clients*" → "/clients"
|
||||||
|
if endpoint.startswith(base):
|
||||||
|
logger.debug(f" Match prefix: {pattern} ↔ {endpoint}")
|
||||||
|
return True
|
||||||
|
|
||||||
|
logger.warning(
|
||||||
|
f" API Key {api_key_obj.name}: Accès refusé à {endpoint}\n"
|
||||||
|
f" Endpoints autorisés: {allowed}"
|
||||||
|
)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
logger.error(f" Erreur parsing allowed_endpoints pour {api_key_obj.id}")
|
logger.error(f" Erreur parsing allowed_endpoints pour {api_key_obj.id}")
|
||||||
return False
|
return False
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue