Sage100-vps/middleware/security.py
2026-01-02 17:56:28 +03:00

181 lines
5.6 KiB
Python

from fastapi import Request, Response, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from typing import Set
import logging
import time
from config.config import settings
from security.csrf import CSRFProtection
from security.fingerprint import get_client_ip
from services.redis_service import redis_service
logger = logging.getLogger(__name__)
RATE_LIMIT_EXEMPT_PATHS: Set[str] = {
"/health",
"/docs",
"/redoc",
"/openapi.json",
}
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
response = await call_next(request)
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
if settings.is_production:
response.headers["Strict-Transport-Security"] = (
"max-age=31536000; includeSubDomains"
)
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate"
response.headers["Pragma"] = "no-cache"
response.headers["Permissions-Policy"] = (
"geolocation=(), microphone=(), camera=()"
)
return response
class CSRFMiddleware(BaseHTTPMiddleware):
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
if CSRFProtection.is_exempt(request):
return await call_next(request)
if not CSRFProtection.validate_double_submit(request):
logger.warning(
f"CSRF validation echouee: {request.method} {request.url.path} "
f"depuis {get_client_ip(request)}"
)
return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content={"detail": "Verification CSRF echouee"},
)
return await call_next(request)
class RateLimitMiddleware(BaseHTTPMiddleware):
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
path = request.url.path.rstrip("/")
if path in RATE_LIMIT_EXEMPT_PATHS:
return await call_next(request)
ip = get_client_ip(request)
key = f"api:{ip}"
window_seconds = settings.rate_limit_api_window_seconds
max_requests = settings.rate_limit_api_requests
try:
count = await redis_service.increment_rate_limit(key, window_seconds)
remaining = max(0, max_requests - count)
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(max_requests)
response.headers["X-RateLimit-Remaining"] = str(remaining)
response.headers["X-RateLimit-Reset"] = str(window_seconds)
if count > max_requests:
logger.warning(
f"Rate limit depasse pour IP {ip}: {count}/{max_requests}"
)
return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={"detail": "Limite de requetes atteinte"},
headers={
"X-RateLimit-Limit": str(max_requests),
"X-RateLimit-Remaining": "0",
"X-RateLimit-Reset": str(window_seconds),
"Retry-After": str(window_seconds),
},
)
return response
except Exception as e:
logger.error(f"Erreur rate limiting: {e}")
return await call_next(request)
class RequestLoggingMiddleware(BaseHTTPMiddleware):
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
start_time = time.time()
ip = get_client_ip(request)
method = request.method
path = request.url.path
response = await call_next(request)
duration_ms = (time.time() - start_time) * 1000
log_level = logging.INFO
if response.status_code >= 500:
log_level = logging.ERROR
elif response.status_code >= 400:
log_level = logging.WARNING
logger.log(
log_level,
f"{method} {path} - {response.status_code} - {duration_ms:.2f}ms - {ip}",
)
return response
class FingerprintValidationMiddleware(BaseHTTPMiddleware):
VALIDATION_PATHS: Set[str] = {
"/auth/refresh",
"/auth/logout",
}
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
path = request.url.path.rstrip("/")
if path not in self.VALIDATION_PATHS:
return await call_next(request)
return await call_next(request)
def setup_security_middleware(app) -> None:
app.add_middleware(RequestLoggingMiddleware)
app.add_middleware(SecurityHeadersMiddleware)
app.add_middleware(FingerprintValidationMiddleware)
async def init_security_services() -> None:
try:
await redis_service.connect()
logger.info("Services de securite initialises")
except Exception as e:
logger.warning(f"Redis non disponible, fonctionnement en mode degrade: {e}")
async def shutdown_security_services() -> None:
try:
await redis_service.disconnect()
logger.info("Services de securite arretes")
except Exception as e:
logger.error(f"Erreur arret services securite: {e}")