117 lines
3.2 KiB
Python
117 lines
3.2 KiB
Python
"""
|
|
security/csrf.py - Protection contre les attaques Cross-Site Request Forgery
|
|
"""
|
|
|
|
from fastapi import Request, HTTPException, status
|
|
from typing import Optional, Set
|
|
import logging
|
|
|
|
from config.config import settings
|
|
from security.auth import decode_token, create_csrf_token, constant_time_compare
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
SAFE_METHODS: Set[str] = {"GET", "HEAD", "OPTIONS", "TRACE"}
|
|
|
|
CSRF_EXEMPT_PATHS: Set[str] = {
|
|
"/auth/login",
|
|
"/auth/register",
|
|
"/auth/forgot-password",
|
|
"/auth/verify-email",
|
|
"/auth/resend-verification",
|
|
"/health",
|
|
"/docs",
|
|
"/redoc",
|
|
"/openapi.json",
|
|
"/webhooks/universign",
|
|
}
|
|
|
|
|
|
class CSRFProtection:
|
|
@classmethod
|
|
def is_exempt(cls, request: Request) -> bool:
|
|
if request.method in SAFE_METHODS:
|
|
return True
|
|
|
|
path = request.url.path.rstrip("/")
|
|
if path in CSRF_EXEMPT_PATHS:
|
|
return True
|
|
|
|
for exempt_path in CSRF_EXEMPT_PATHS:
|
|
if path.startswith(exempt_path):
|
|
return True
|
|
|
|
return False
|
|
|
|
@classmethod
|
|
def generate_token(cls, session_id: str) -> str:
|
|
return create_csrf_token(session_id)
|
|
|
|
@classmethod
|
|
def validate_token(cls, request: Request, session_id: Optional[str] = None) -> bool:
|
|
csrf_header = request.headers.get("X-CSRF-Token")
|
|
|
|
if not csrf_header:
|
|
logger.warning("Token CSRF manquant dans le header")
|
|
return False
|
|
|
|
payload = decode_token(csrf_header, expected_type="csrf")
|
|
|
|
if not payload:
|
|
logger.warning("Token CSRF invalide ou expire")
|
|
return False
|
|
|
|
if session_id and payload.get("sid") != session_id:
|
|
logger.warning("Token CSRF ne correspond pas a la session")
|
|
return False
|
|
|
|
return True
|
|
|
|
@classmethod
|
|
def validate_double_submit(cls, request: Request) -> bool:
|
|
header_token = request.headers.get("X-CSRF-Token")
|
|
cookie_token = request.cookies.get(settings.cookie_csrf_token_name)
|
|
|
|
if not header_token or not cookie_token:
|
|
logger.warning("Token CSRF manquant (header ou cookie)")
|
|
return False
|
|
|
|
if not constant_time_compare(header_token, cookie_token):
|
|
logger.warning("Tokens CSRF ne correspondent pas")
|
|
return False
|
|
|
|
return True
|
|
|
|
@classmethod
|
|
def validate_request(
|
|
cls,
|
|
request: Request,
|
|
session_id: Optional[str] = None,
|
|
use_double_submit: bool = True,
|
|
) -> bool:
|
|
if cls.is_exempt(request):
|
|
return True
|
|
|
|
if use_double_submit:
|
|
if not cls.validate_double_submit(request):
|
|
return False
|
|
|
|
return cls.validate_token(request, session_id)
|
|
|
|
|
|
async def verify_csrf(request: Request, session_id: Optional[str] = None) -> None:
|
|
if CSRFProtection.is_exempt(request):
|
|
return
|
|
|
|
if not CSRFProtection.validate_request(request, session_id):
|
|
logger.warning(
|
|
f"Verification CSRF echouee pour {request.method} {request.url.path}"
|
|
)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN, detail="Verification CSRF echouee"
|
|
)
|
|
|
|
|
|
def generate_csrf_for_session(session_id: str) -> str:
|
|
return CSRFProtection.generate_token(session_id)
|