181 lines
5.6 KiB
Python
181 lines
5.6 KiB
Python
from fastapi import Request, Response, status
|
|
from fastapi.responses import JSONResponse
|
|
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
|
|
from typing import Set
|
|
import logging
|
|
import time
|
|
|
|
from config.config import settings
|
|
from security.csrf import CSRFProtection
|
|
from security.fingerprint import get_client_ip
|
|
from services.redis_service import redis_service
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
RATE_LIMIT_EXEMPT_PATHS: Set[str] = {
|
|
"/health",
|
|
"/docs",
|
|
"/redoc",
|
|
"/openapi.json",
|
|
}
|
|
|
|
|
|
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
|
async def dispatch(
|
|
self, request: Request, call_next: RequestResponseEndpoint
|
|
) -> Response:
|
|
response = await call_next(request)
|
|
|
|
response.headers["X-Content-Type-Options"] = "nosniff"
|
|
response.headers["X-Frame-Options"] = "DENY"
|
|
response.headers["X-XSS-Protection"] = "1; mode=block"
|
|
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
|
|
|
if settings.is_production:
|
|
response.headers["Strict-Transport-Security"] = (
|
|
"max-age=31536000; includeSubDomains"
|
|
)
|
|
|
|
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate"
|
|
response.headers["Pragma"] = "no-cache"
|
|
|
|
response.headers["Permissions-Policy"] = (
|
|
"geolocation=(), microphone=(), camera=()"
|
|
)
|
|
|
|
return response
|
|
|
|
|
|
class CSRFMiddleware(BaseHTTPMiddleware):
|
|
async def dispatch(
|
|
self, request: Request, call_next: RequestResponseEndpoint
|
|
) -> Response:
|
|
if CSRFProtection.is_exempt(request):
|
|
return await call_next(request)
|
|
|
|
if not CSRFProtection.validate_double_submit(request):
|
|
logger.warning(
|
|
f"CSRF validation echouee: {request.method} {request.url.path} "
|
|
f"depuis {get_client_ip(request)}"
|
|
)
|
|
return JSONResponse(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
content={"detail": "Verification CSRF echouee"},
|
|
)
|
|
|
|
return await call_next(request)
|
|
|
|
|
|
class RateLimitMiddleware(BaseHTTPMiddleware):
|
|
async def dispatch(
|
|
self, request: Request, call_next: RequestResponseEndpoint
|
|
) -> Response:
|
|
path = request.url.path.rstrip("/")
|
|
if path in RATE_LIMIT_EXEMPT_PATHS:
|
|
return await call_next(request)
|
|
|
|
ip = get_client_ip(request)
|
|
key = f"api:{ip}"
|
|
window_seconds = settings.rate_limit_api_window_seconds
|
|
max_requests = settings.rate_limit_api_requests
|
|
|
|
try:
|
|
count = await redis_service.increment_rate_limit(key, window_seconds)
|
|
remaining = max(0, max_requests - count)
|
|
|
|
response = await call_next(request)
|
|
|
|
response.headers["X-RateLimit-Limit"] = str(max_requests)
|
|
response.headers["X-RateLimit-Remaining"] = str(remaining)
|
|
response.headers["X-RateLimit-Reset"] = str(window_seconds)
|
|
|
|
if count > max_requests:
|
|
logger.warning(
|
|
f"Rate limit depasse pour IP {ip}: {count}/{max_requests}"
|
|
)
|
|
return JSONResponse(
|
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
content={"detail": "Limite de requetes atteinte"},
|
|
headers={
|
|
"X-RateLimit-Limit": str(max_requests),
|
|
"X-RateLimit-Remaining": "0",
|
|
"X-RateLimit-Reset": str(window_seconds),
|
|
"Retry-After": str(window_seconds),
|
|
},
|
|
)
|
|
|
|
return response
|
|
|
|
except Exception as e:
|
|
logger.error(f"Erreur rate limiting: {e}")
|
|
return await call_next(request)
|
|
|
|
|
|
class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
|
async def dispatch(
|
|
self, request: Request, call_next: RequestResponseEndpoint
|
|
) -> Response:
|
|
start_time = time.time()
|
|
|
|
ip = get_client_ip(request)
|
|
method = request.method
|
|
path = request.url.path
|
|
|
|
response = await call_next(request)
|
|
|
|
duration_ms = (time.time() - start_time) * 1000
|
|
|
|
log_level = logging.INFO
|
|
if response.status_code >= 500:
|
|
log_level = logging.ERROR
|
|
elif response.status_code >= 400:
|
|
log_level = logging.WARNING
|
|
|
|
logger.log(
|
|
log_level,
|
|
f"{method} {path} - {response.status_code} - {duration_ms:.2f}ms - {ip}",
|
|
)
|
|
|
|
return response
|
|
|
|
|
|
class FingerprintValidationMiddleware(BaseHTTPMiddleware):
|
|
VALIDATION_PATHS: Set[str] = {
|
|
"/auth/refresh",
|
|
"/auth/logout",
|
|
}
|
|
|
|
async def dispatch(
|
|
self, request: Request, call_next: RequestResponseEndpoint
|
|
) -> Response:
|
|
path = request.url.path.rstrip("/")
|
|
|
|
if path not in self.VALIDATION_PATHS:
|
|
return await call_next(request)
|
|
|
|
return await call_next(request)
|
|
|
|
|
|
def setup_security_middleware(app) -> None:
|
|
app.add_middleware(RequestLoggingMiddleware)
|
|
|
|
app.add_middleware(SecurityHeadersMiddleware)
|
|
|
|
app.add_middleware(FingerprintValidationMiddleware)
|
|
|
|
|
|
async def init_security_services() -> None:
|
|
try:
|
|
await redis_service.connect()
|
|
logger.info("Services de securite initialises")
|
|
except Exception as e:
|
|
logger.warning(f"Redis non disponible, fonctionnement en mode degrade: {e}")
|
|
|
|
|
|
async def shutdown_security_services() -> None:
|
|
try:
|
|
await redis_service.disconnect()
|
|
logger.info("Services de securite arretes")
|
|
except Exception as e:
|
|
logger.error(f"Erreur arret services securite: {e}")
|