125 lines
4.1 KiB
Python
125 lines
4.1 KiB
Python
from fastapi import FastAPI
|
|
from fastapi.middleware.cors import CORSMiddleware
|
|
from typing import List
|
|
import os
|
|
import logging
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def configure_cors_open(app: FastAPI):
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=["*"],
|
|
allow_credentials=False,
|
|
allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"],
|
|
allow_headers=["*"],
|
|
expose_headers=["X-RateLimit-Limit", "X-RateLimit-Remaining"],
|
|
max_age=3600,
|
|
)
|
|
|
|
logger.info(" CORS configuré: Mode OUVERT (sécurisé par API Keys)")
|
|
logger.info(" - Origins: * (toutes)")
|
|
logger.info(" - Headers: * (dont X-API-Key)")
|
|
logger.info(" - Credentials: False")
|
|
|
|
|
|
def configure_cors_whitelist(app: FastAPI):
|
|
allowed_origins_str = os.getenv("CORS_ALLOWED_ORIGINS", "")
|
|
|
|
if allowed_origins_str:
|
|
allowed_origins = [
|
|
origin.strip()
|
|
for origin in allowed_origins_str.split(",")
|
|
if origin.strip()
|
|
]
|
|
else:
|
|
allowed_origins = ["*"]
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origins=allowed_origins,
|
|
allow_credentials=True,
|
|
allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"],
|
|
allow_headers=["Content-Type", "Authorization", "X-API-Key"],
|
|
expose_headers=["X-RateLimit-Limit", "X-RateLimit-Remaining"],
|
|
max_age=3600,
|
|
)
|
|
|
|
logger.info(" CORS configuré: Mode WHITELIST")
|
|
logger.info(f" - Origins autorisées: {len(allowed_origins)}")
|
|
for origin in allowed_origins:
|
|
logger.info(f" • {origin}")
|
|
|
|
|
|
def configure_cors_regex(app: FastAPI):
|
|
origin_regex = r"*"
|
|
|
|
app.add_middleware(
|
|
CORSMiddleware,
|
|
allow_origin_regex=origin_regex,
|
|
allow_credentials=True,
|
|
allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"],
|
|
allow_headers=["Content-Type", "Authorization", "X-API-Key"],
|
|
expose_headers=["X-RateLimit-Limit", "X-RateLimit-Remaining"],
|
|
max_age=3600,
|
|
)
|
|
|
|
logger.info(" CORS configuré: Mode REGEX")
|
|
logger.info(f" - Pattern: {origin_regex}")
|
|
|
|
|
|
def configure_cors_hybrid(app: FastAPI):
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
class HybridCORSMiddleware(BaseHTTPMiddleware):
|
|
def __init__(self, app, known_origins: List[str]):
|
|
super().__init__(app)
|
|
self.known_origins = set(known_origins)
|
|
|
|
async def dispatch(self, request, call_next):
|
|
origin = request.headers.get("origin")
|
|
|
|
if origin in self.known_origins:
|
|
response = await call_next(request)
|
|
response.headers["Access-Control-Allow-Origin"] = origin
|
|
response.headers["Access-Control-Allow-Credentials"] = "true"
|
|
response.headers["Access-Control-Allow-Methods"] = (
|
|
"GET, POST, PUT, DELETE, PATCH, OPTIONS"
|
|
)
|
|
response.headers["Access-Control-Allow-Headers"] = (
|
|
"Content-Type, Authorization, X-API-Key"
|
|
)
|
|
return response
|
|
|
|
response = await call_next(request)
|
|
response.headers["Access-Control-Allow-Origin"] = "*"
|
|
response.headers["Access-Control-Allow-Methods"] = (
|
|
"GET, POST, PUT, DELETE, PATCH, OPTIONS"
|
|
)
|
|
response.headers["Access-Control-Allow-Headers"] = "*"
|
|
return response
|
|
|
|
known_origins = ["*"]
|
|
|
|
app.add_middleware(HybridCORSMiddleware, known_origins=known_origins)
|
|
|
|
logger.info(" CORS configuré: Mode HYBRIDE")
|
|
logger.info(f" - Whitelist: {len(known_origins)} domaines")
|
|
logger.info(" - Fallback: * (ouvert)")
|
|
|
|
|
|
def setup_cors(app: FastAPI, mode: str = "open"):
|
|
if mode == "open":
|
|
configure_cors_open(app)
|
|
elif mode == "whitelist":
|
|
configure_cors_whitelist(app)
|
|
elif mode == "regex":
|
|
configure_cors_regex(app)
|
|
elif mode == "hybrid":
|
|
configure_cors_hybrid(app)
|
|
else:
|
|
logger.warning(
|
|
f" Mode CORS inconnu: {mode}. Utilisation de 'open' par défaut."
|
|
)
|
|
configure_cors_open(app)
|