from fastapi import Request, Response, status from fastapi.responses import JSONResponse from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint from typing import Set import logging import time from config.config import settings from security.csrf import CSRFProtection from security.fingerprint import get_client_ip from services.redis_service import redis_service logger = logging.getLogger(__name__) RATE_LIMIT_EXEMPT_PATHS: Set[str] = { "/health", "/docs", "/redoc", "/openapi.json", } class SecurityHeadersMiddleware(BaseHTTPMiddleware): async def dispatch( self, request: Request, call_next: RequestResponseEndpoint ) -> Response: response = await call_next(request) response.headers["X-Content-Type-Options"] = "nosniff" response.headers["X-Frame-Options"] = "DENY" response.headers["X-XSS-Protection"] = "1; mode=block" response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" if settings.is_production: response.headers["Strict-Transport-Security"] = ( "max-age=31536000; includeSubDomains" ) response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate" response.headers["Pragma"] = "no-cache" response.headers["Permissions-Policy"] = ( "geolocation=(), microphone=(), camera=()" ) return response class CSRFMiddleware(BaseHTTPMiddleware): async def dispatch( self, request: Request, call_next: RequestResponseEndpoint ) -> Response: if CSRFProtection.is_exempt(request): return await call_next(request) if not CSRFProtection.validate_double_submit(request): logger.warning( f"CSRF validation echouee: {request.method} {request.url.path} " f"depuis {get_client_ip(request)}" ) return JSONResponse( status_code=status.HTTP_403_FORBIDDEN, content={"detail": "Verification CSRF echouee"}, ) return await call_next(request) class RateLimitMiddleware(BaseHTTPMiddleware): async def dispatch( self, request: Request, call_next: RequestResponseEndpoint ) -> Response: path = request.url.path.rstrip("/") if path in RATE_LIMIT_EXEMPT_PATHS: return await call_next(request) ip = get_client_ip(request) key = f"api:{ip}" window_seconds = settings.rate_limit_api_window_seconds max_requests = settings.rate_limit_api_requests try: count = await redis_service.increment_rate_limit(key, window_seconds) remaining = max(0, max_requests - count) response = await call_next(request) response.headers["X-RateLimit-Limit"] = str(max_requests) response.headers["X-RateLimit-Remaining"] = str(remaining) response.headers["X-RateLimit-Reset"] = str(window_seconds) if count > max_requests: logger.warning( f"Rate limit depasse pour IP {ip}: {count}/{max_requests}" ) return JSONResponse( status_code=status.HTTP_429_TOO_MANY_REQUESTS, content={"detail": "Limite de requetes atteinte"}, headers={ "X-RateLimit-Limit": str(max_requests), "X-RateLimit-Remaining": "0", "X-RateLimit-Reset": str(window_seconds), "Retry-After": str(window_seconds), }, ) return response except Exception as e: logger.error(f"Erreur rate limiting: {e}") return await call_next(request) class RequestLoggingMiddleware(BaseHTTPMiddleware): async def dispatch( self, request: Request, call_next: RequestResponseEndpoint ) -> Response: start_time = time.time() ip = get_client_ip(request) method = request.method path = request.url.path response = await call_next(request) duration_ms = (time.time() - start_time) * 1000 log_level = logging.INFO if response.status_code >= 500: log_level = logging.ERROR elif response.status_code >= 400: log_level = logging.WARNING logger.log( log_level, f"{method} {path} - {response.status_code} - {duration_ms:.2f}ms - {ip}", ) return response class FingerprintValidationMiddleware(BaseHTTPMiddleware): VALIDATION_PATHS: Set[str] = { "/auth/refresh", "/auth/logout", } async def dispatch( self, request: Request, call_next: RequestResponseEndpoint ) -> Response: path = request.url.path.rstrip("/") if path not in self.VALIDATION_PATHS: return await call_next(request) return await call_next(request) def setup_security_middleware(app) -> None: app.add_middleware(RequestLoggingMiddleware) app.add_middleware(SecurityHeadersMiddleware) app.add_middleware(FingerprintValidationMiddleware) async def init_security_services() -> None: try: await redis_service.connect() logger.info("Services de securite initialises") except Exception as e: logger.warning(f"Redis non disponible, fonctionnement en mode degrade: {e}") async def shutdown_security_services() -> None: try: await redis_service.disconnect() logger.info("Services de securite arretes") except Exception as e: logger.error(f"Erreur arret services securite: {e}")