feat(security): enhance swagger auth with user context and filtered docs

This commit is contained in:
Fanilo-Nantenaina 2026-01-21 12:56:02 +03:00
parent 5f40c677a8
commit 797aed0240
2 changed files with 184 additions and 106 deletions

222
api.py
View file

@ -3,6 +3,8 @@ from fastapi.responses import JSONResponse
from fastapi.openapi.utils import get_openapi from fastapi.openapi.utils import get_openapi
from fastapi.responses import StreamingResponse, HTMLResponse, Response from fastapi.responses import StreamingResponse, HTMLResponse, Response
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from fastapi.openapi.docs import get_swagger_ui_html, get_redoc_html
from pydantic import BaseModel, Field, EmailStr from pydantic import BaseModel, Field, EmailStr
from typing import List, Optional from typing import List, Optional
from datetime import datetime, date from datetime import datetime, date
@ -96,7 +98,6 @@ from utils.generic_functions import (
universign_envoyer, universign_envoyer,
) )
from middleware.security import SwaggerAuthMiddleware, ApiKeyMiddlewareHTTP from middleware.security import SwaggerAuthMiddleware, ApiKeyMiddlewareHTTP
from core.dependencies import get_current_user from core.dependencies import get_current_user
from config.cors_config import setup_cors from config.cors_config import setup_cors
@ -123,12 +124,12 @@ logger = logging.getLogger(__name__)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: FastAPI):
"""Lifecycle de l'application"""
await init_db() await init_db()
logger.info("Base de données initialisée") logger.info("Base de données initialisée")
email_queue.session_factory = async_session_factory email_queue.session_factory = async_session_factory
email_queue.sage_client = sage_client email_queue.sage_client = sage_client
logger.info("sage_client injecté dans email_queue") logger.info("sage_client injecté dans email_queue")
email_queue.start(num_workers=settings.max_email_workers) email_queue.start(num_workers=settings.max_email_workers)
@ -137,18 +138,12 @@ async def lifespan(app: FastAPI):
sync_service = UniversignSyncService( sync_service = UniversignSyncService(
api_url=settings.universign_api_url, api_key=settings.universign_api_key api_url=settings.universign_api_url, api_key=settings.universign_api_key
) )
sync_service.configure( sync_service.configure(
sage_client=sage_client, email_queue=email_queue, settings=settings sage_client=sage_client, email_queue=email_queue, settings=settings
) )
scheduler = UniversignSyncScheduler( scheduler = UniversignSyncScheduler(sync_service=sync_service, interval_minutes=5)
sync_service=sync_service,
interval_minutes=5,
)
sync_task = asyncio.create_task(scheduler.start(async_session_factory)) sync_task = asyncio.create_task(scheduler.start(async_session_factory))
logger.info("Synchronisation Universign démarrée (5min)") logger.info("Synchronisation Universign démarrée (5min)")
yield yield
@ -160,25 +155,24 @@ async def lifespan(app: FastAPI):
app = FastAPI( app = FastAPI(
title="Sage Gateways", title="Sage Gateways API",
version="3.0.0", version="3.0.0",
description="Configuration multi-tenant des connexions Sage Gateway", description="API multi-tenant pour Sage 100c avec authentification hybride",
lifespan=lifespan, lifespan=lifespan,
openapi_tags=TAGS_METADATA, openapi_tags=TAGS_METADATA,
docs_url=None,
redoc_url=None,
openapi_url=None,
) )
""" app.add_middleware(
CORSMiddleware, def get_swagger_user_from_state(request: Request) -> Optional[dict]:
allow_origins=settings.cors_origins, return getattr(request.state, "swagger_user", None)
allow_methods=["GET", "POST", "PUT", "DELETE"],
allow_headers=["*"],
allow_credentials=True,
) """
def generate_filtered_openapi(app: FastAPI, allowed_tags: Optional[List[str]] = None): def generate_filtered_openapi_schema(
"""Génère le schéma OpenAPI filtré selon les tags autorisés""" app: FastAPI, allowed_tags: Optional[List[str]] = None
) -> dict:
base_schema = get_openapi( base_schema = get_openapi(
title=app.title, title=app.title,
version=app.version, version=app.version,
@ -204,19 +198,33 @@ def generate_filtered_openapi(app: FastAPI, allowed_tags: Optional[List[str]] =
base_schema["security"] = [{"HTTPBearer": []}, {"ApiKeyAuth": []}] base_schema["security"] = [{"HTTPBearer": []}, {"ApiKeyAuth": []}]
if not allowed_tags: if not allowed_tags:
logger.info("📚 Schéma OpenAPI complet (admin)")
return base_schema return base_schema
filtered_paths = {} filtered_paths = {}
for path, path_item in base_schema.get("paths", {}).items(): for path, path_item in base_schema.get("paths", {}).items():
filtered_operations = {}
for method, operation in path_item.items(): for method, operation in path_item.items():
if method in ["get", "post", "put", "delete", "patch", "options"]: if method not in [
"get",
"post",
"put",
"delete",
"patch",
"options",
"head",
]:
continue
operation_tags = operation.get("tags", []) operation_tags = operation.get("tags", [])
if any(tag in allowed_tags for tag in operation_tags): if any(tag in allowed_tags for tag in operation_tags):
if path not in filtered_paths: filtered_operations[method] = operation
filtered_paths[path] = {}
filtered_paths[path][method] = operation if filtered_operations:
filtered_paths[path] = filtered_operations
base_schema["paths"] = filtered_paths base_schema["paths"] = filtered_paths
@ -227,48 +235,81 @@ def generate_filtered_openapi(app: FastAPI, allowed_tags: Optional[List[str]] =
if tag_obj.get("name") in allowed_tags if tag_obj.get("name") in allowed_tags
] ]
logger.info(f"🔒 Schéma filtré: {len(filtered_paths)} paths, tags: {allowed_tags}")
return base_schema return base_schema
async def get_swagger_user_from_request(request: Request) -> Optional[dict]: @app.get("/openapi.json", include_in_schema=False)
"""Récupère l'utilisateur Swagger depuis la requête authentifiée""" async def custom_openapi_endpoint(request: Request):
auth_header = request.headers.get("Authorization") swagger_user = get_swagger_user_from_state(request)
if not auth_header or not auth_header.startswith("Basic "): if not swagger_user:
return None return JSONResponse(
status_code=401,
import base64 content={"detail": "Authentification Swagger requise"},
from fastapi.security import HTTPBasicCredentials headers={"WWW-Authenticate": 'Basic realm="Swagger UI"'},
from database.db_config import async_session_factory
from database.models.api_key import SwaggerUser
from sqlalchemy import select
try:
encoded_credentials = auth_header.split(" ")[1]
decoded_credentials = base64.b64decode(encoded_credentials).decode("utf-8")
username, _ = decoded_credentials.split(":", 1)
async with async_session_factory() as session:
result = await session.execute(
select(SwaggerUser).where(SwaggerUser.username == username)
) )
swagger_user = result.scalar_one_or_none()
if swagger_user and swagger_user.is_active: username = swagger_user.get("username", "unknown")
return { allowed_tags = swagger_user.get("allowed_tags")
"username": swagger_user.username,
"allowed_tags": swagger_user.allowed_tags_list,
}
except Exception as e:
logger.error(f"Erreur récupération utilisateur Swagger: {e}")
return None logger.info(f"📖 OpenAPI demandé par: {username}, tags: {allowed_tags or 'ALL'}")
schema = generate_filtered_openapi_schema(app, allowed_tags)
return JSONResponse(content=schema)
@app.get("/docs", include_in_schema=False)
async def custom_swagger_ui(request: Request):
swagger_user = get_swagger_user_from_state(request)
if not swagger_user:
return JSONResponse(
status_code=401,
content={"detail": "Authentification Swagger requise"},
headers={"WWW-Authenticate": 'Basic realm="Swagger UI"'},
)
return get_swagger_ui_html(
openapi_url="/openapi.json",
title=f"{app.title} - Documentation",
swagger_favicon_url="https://fastapi.tiangolo.com/img/favicon.png",
swagger_ui_parameters={
"persistAuthorization": True,
"displayRequestDuration": True,
"filter": True,
"tryItOutEnabled": True,
},
)
@app.get("/redoc", include_in_schema=False)
async def custom_redoc(request: Request):
swagger_user = get_swagger_user_from_state(request)
if not swagger_user:
return JSONResponse(
status_code=401,
content={"detail": "Authentification Swagger requise"},
headers={"WWW-Authenticate": 'Basic realm="Swagger UI"'},
)
return get_redoc_html(
openapi_url="/openapi.json",
title=f"{app.title} - Documentation",
redoc_favicon_url="https://fastapi.tiangolo.com/img/favicon.png",
)
setup_cors(app, mode="open") setup_cors(app, mode="open")
app.add_middleware(SwaggerAuthMiddleware) app.add_middleware(SwaggerAuthMiddleware)
app.add_middleware(ApiKeyMiddlewareHTTP) app.add_middleware(ApiKeyMiddlewareHTTP)
app.include_router(api_keys_router) app.include_router(api_keys_router)
app.include_router(auth_router) app.include_router(auth_router)
app.include_router(sage_gateway_router) app.include_router(sage_gateway_router)
@ -276,26 +317,6 @@ app.include_router(universign_router)
app.include_router(entreprises_router) app.include_router(entreprises_router)
@app.get("/openapi.json", include_in_schema=False)
async def get_openapi_filtered(request: Request):
"""Retourne le schéma OpenAPI filtré selon l'utilisateur"""
swagger_user = await get_swagger_user_from_request(request)
if not swagger_user:
return JSONResponse(
status_code=401,
content={"detail": "Authentification requise"},
headers={"WWW-Authenticate": 'Basic realm="Swagger UI"'},
)
allowed_tags = swagger_user.get("allowed_tags")
schema = generate_filtered_openapi(app, allowed_tags)
return JSONResponse(content=schema)
@app.get("/clients", response_model=List[ClientDetails], tags=["Clients"]) @app.get("/clients", response_model=List[ClientDetails], tags=["Clients"])
async def obtenir_clients( async def obtenir_clients(
query: Optional[str] = Query(None), query: Optional[str] = Query(None),
@ -3409,7 +3430,7 @@ async def get_reglement_detail(rg_no):
return sage_client.get_reglement_detail(rg_no) return sage_client.get_reglement_detail(rg_no)
@app.get("/health", tags=["System"]) """ @app.get("/health", tags=["System"])
async def health_check( async def health_check(
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
sage: SageGatewayClient = Depends(get_sage_client_for_user), sage: SageGatewayClient = Depends(get_sage_client_for_user),
@ -3426,29 +3447,64 @@ async def health_check(
"queue_size": email_queue.queue.qsize(), "queue_size": email_queue.queue.qsize(),
}, },
"timestamp": datetime.now().isoformat(), "timestamp": datetime.now().isoformat(),
} } """
@app.get("/", tags=["System"]) @app.get("/", tags=["System"])
async def root(): async def root():
"""
Point d'entrée de l'API
"""
return { return {
"api": "Sage 100c Dataven - VPS Linux", "api": "Sage 100c Dataven API",
"version": "3.0.0", "version": "3.0.0",
"documentation": "/docs (authentification requise)", "status": "operational",
"health": "/health", "documentation": {
"swagger": "/docs",
"redoc": "/redoc",
"openapi": "/openapi.json",
},
"authentication": { "authentication": {
"methods": [ "methods": [
{ {
"type": "JWT", "type": "JWT (Bearer Token)",
"header": "Authorization: Bearer <token>", "header": "Authorization: Bearer <token>",
"endpoint": "/api/auth/login", "obtain_token": "POST /auth/login",
"description": "Pour les utilisateurs finaux",
}, },
{ {
"type": "API Key", "type": "API Key",
"header": "X-API-Key: sdk_live_xxx", "header": "X-API-Key: sdk_live_xxx",
"endpoint": "/api/api-keys", "manage_keys": "GET /api-keys",
"description": "Pour les intégrations externes",
},
],
"note": "Les routes acceptent JWT OU API Key (au choix)",
},
"swagger_access": {
"authentication": "HTTP Basic Auth (voir /scripts/manage_security.py)",
"filtering": "Les routes visibles dépendent des tags autorisés de l'utilisateur",
},
}
@app.get("/health", tags=["System"])
async def health_check():
"""
Vérification de santé de l'API (sans authentification)
"""
return {
"status": "healthy",
"timestamp": "2025-01-21T00:00:00Z",
"services": {
"api": "operational",
"database": "connected",
"email_queue": {
"running": email_queue.running,
"workers": len(email_queue.workers)
if hasattr(email_queue, "workers")
else 0,
}, },
]
}, },
} }

View file

@ -2,12 +2,13 @@ from fastapi import Request, status
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from fastapi.security import HTTPBasic, HTTPBasicCredentials from fastapi.security import HTTPBasic, HTTPBasicCredentials
from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp from starlette.types import ASGIApp, Receive, Send
from sqlalchemy import select from sqlalchemy import select
from typing import Callable from typing import Callable, Optional
from datetime import datetime from datetime import datetime
import logging import logging
import base64 import base64
import json
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -20,7 +21,7 @@ class SwaggerAuthMiddleware:
def __init__(self, app: ASGIApp): def __init__(self, app: ASGIApp):
self.app = app self.app = app
async def __call__(self, scope, receive, send): async def __call__(self, scope, receive: Receive, send: Send):
if scope["type"] != "http": if scope["type"] != "http":
await self.app(scope, receive, send) await self.app(scope, receive, send)
return return
@ -50,7 +51,9 @@ class SwaggerAuthMiddleware:
credentials = HTTPBasicCredentials(username=username, password=password) credentials = HTTPBasicCredentials(username=username, password=password)
if not await self._verify_credentials(credentials): swagger_user = await self._verify_credentials(credentials)
if not swagger_user:
response = JSONResponse( response = JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
content={"detail": "Identifiants invalides"}, content={"detail": "Identifiants invalides"},
@ -59,6 +62,14 @@ class SwaggerAuthMiddleware:
await response(scope, receive, send) await response(scope, receive, send)
return return
if "state" not in scope:
scope["state"] = {}
scope["state"]["swagger_user"] = swagger_user
logger.info(
f"✓ Swagger auth: {swagger_user['username']} - tags: {swagger_user.get('allowed_tags', 'ALL')}"
)
except Exception as e: except Exception as e:
logger.error(f"Erreur parsing auth header: {e}") logger.error(f"Erreur parsing auth header: {e}")
response = JSONResponse( response = JSONResponse(
@ -71,8 +82,9 @@ class SwaggerAuthMiddleware:
await self.app(scope, receive, send) await self.app(scope, receive, send)
async def _verify_credentials(self, credentials: HTTPBasicCredentials) -> bool: async def _verify_credentials(
"""Vérifie les identifiants dans la base de données""" self, credentials: HTTPBasicCredentials
) -> Optional[dict]:
from database.db_config import async_session_factory from database.db_config import async_session_factory
from database.models.api_key import SwaggerUser from database.models.api_key import SwaggerUser
from security.auth import verify_password from security.auth import verify_password
@ -92,15 +104,22 @@ class SwaggerAuthMiddleware:
): ):
swagger_user.last_login = datetime.now() swagger_user.last_login = datetime.now()
await session.commit() await session.commit()
logger.info(f"✓ Accès Swagger autorisé: {credentials.username}") logger.info(f"✓ Accès Swagger autorisé: {credentials.username}")
return True
return {
"id": swagger_user.id,
"username": swagger_user.username,
"allowed_tags": swagger_user.allowed_tags_list, # None = admin complet
"is_active": swagger_user.is_active,
}
logger.warning(f"✗ Accès Swagger refusé: {credentials.username}") logger.warning(f"✗ Accès Swagger refusé: {credentials.username}")
return False return None
except Exception as e: except Exception as e:
logger.error(f"Erreur vérification credentials: {e}") logger.error(f"Erreur vérification credentials: {e}", exc_info=True)
return False return None
class ApiKeyMiddlewareHTTP(BaseHTTPMiddleware): class ApiKeyMiddlewareHTTP(BaseHTTPMiddleware):
@ -116,7 +135,7 @@ class ApiKeyMiddlewareHTTP(BaseHTTPMiddleware):
] ]
def _is_excluded_path(self, path: str) -> bool: def _is_excluded_path(self, path: str) -> bool:
"""Vérifie si le chemin est exclu de l'authentification""" """Vérifie si le chemin est exclu de l'authentification API Key"""
if path == "/": if path == "/":
return True return True
@ -149,7 +168,7 @@ class ApiKeyMiddlewareHTTP(BaseHTTPMiddleware):
if token.startswith("sdk_live_"): if token.startswith("sdk_live_"):
logger.warning( logger.warning(
" API Key envoyée dans Authorization au lieu de X-API-Key" " API Key envoyée dans Authorization au lieu de X-API-Key"
) )
return await self._handle_api_key_auth( return await self._handle_api_key_auth(
request, token, path, method, call_next request, token, path, method, call_next
@ -159,7 +178,7 @@ class ApiKeyMiddlewareHTTP(BaseHTTPMiddleware):
request.state.authenticated_via = "jwt" request.state.authenticated_via = "jwt"
return await call_next(request) return await call_next(request)
logger.debug(f" Aucune auth pour {method} {path} → délégation à FastAPI") logger.debug(f" Aucune auth pour {method} {path} → délégation à FastAPI")
return await call_next(request) return await call_next(request)
async def _handle_api_key_auth( async def _handle_api_key_auth(
@ -170,7 +189,6 @@ class ApiKeyMiddlewareHTTP(BaseHTTPMiddleware):
method: str, method: str,
call_next: Callable, call_next: Callable,
): ):
"""Gère l'authentification par API Key avec vérification STRICTE"""
try: try:
from database.db_config import async_session_factory from database.db_config import async_session_factory
from services.api_key import ApiKeyService from services.api_key import ApiKeyService
@ -181,7 +199,7 @@ class ApiKeyMiddlewareHTTP(BaseHTTPMiddleware):
api_key_obj = await service.verify_api_key(api_key) api_key_obj = await service.verify_api_key(api_key)
if not api_key_obj: if not api_key_obj:
logger.warning(f" Clé API invalide: {method} {path}") logger.warning(f" Clé API invalide: {method} {path}")
return JSONResponse( return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
content={ content={
@ -192,7 +210,7 @@ class ApiKeyMiddlewareHTTP(BaseHTTPMiddleware):
is_allowed, rate_info = await service.check_rate_limit(api_key_obj) is_allowed, rate_info = await service.check_rate_limit(api_key_obj)
if not is_allowed: if not is_allowed:
logger.warning(f" Rate limit: {api_key_obj.name}") logger.warning(f" Rate limit: {api_key_obj.name}")
return JSONResponse( return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS, status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={"detail": "Rate limit dépassé"}, content={"detail": "Rate limit dépassé"},
@ -205,8 +223,6 @@ class ApiKeyMiddlewareHTTP(BaseHTTPMiddleware):
has_access = await service.check_endpoint_access(api_key_obj, path) has_access = await service.check_endpoint_access(api_key_obj, path)
if not has_access: if not has_access:
import json
allowed = ( allowed = (
json.loads(api_key_obj.allowed_endpoints) json.loads(api_key_obj.allowed_endpoints)
if api_key_obj.allowed_endpoints if api_key_obj.allowed_endpoints
@ -214,7 +230,7 @@ class ApiKeyMiddlewareHTTP(BaseHTTPMiddleware):
) )
logger.warning( logger.warning(
f" ACCÈS REFUSÉ: {api_key_obj.name}\n" f"🚫 ACCÈS REFUSÉ: {api_key_obj.name}\n"
f" Endpoint demandé: {path}\n" f" Endpoint demandé: {path}\n"
f" Endpoints autorisés: {allowed}" f" Endpoints autorisés: {allowed}"
) )
@ -233,12 +249,12 @@ class ApiKeyMiddlewareHTTP(BaseHTTPMiddleware):
request.state.api_key = api_key_obj request.state.api_key = api_key_obj
request.state.authenticated_via = "api_key" request.state.authenticated_via = "api_key"
logger.info(f" ACCÈS AUTORISÉ: {api_key_obj.name}{method} {path}") logger.info(f" ACCÈS AUTORISÉ: {api_key_obj.name}{method} {path}")
return await call_next(request) return await call_next(request)
except Exception as e: except Exception as e:
logger.error(f" Erreur validation API Key: {e}", exc_info=True) logger.error(f"💥 Erreur validation API Key: {e}", exc_info=True)
return JSONResponse( return JSONResponse(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
content={"detail": f"Erreur interne: {str(e)}"}, content={"detail": f"Erreur interne: {str(e)}"},
@ -248,7 +264,7 @@ class ApiKeyMiddlewareHTTP(BaseHTTPMiddleware):
ApiKeyMiddleware = ApiKeyMiddlewareHTTP ApiKeyMiddleware = ApiKeyMiddlewareHTTP
def get_api_key_from_request(request: Request): def get_api_key_from_request(request: Request) -> Optional:
"""Récupère l'objet ApiKey depuis la requête si présent""" """Récupère l'objet ApiKey depuis la requête si présent"""
return getattr(request.state, "api_key", None) return getattr(request.state, "api_key", None)
@ -258,10 +274,16 @@ def get_auth_method(request: Request) -> str:
return getattr(request.state, "authenticated_via", "none") return getattr(request.state, "authenticated_via", "none")
def get_swagger_user_from_request(request: Request) -> Optional[dict]:
"""Récupère l'utilisateur Swagger depuis la requête"""
return getattr(request.state, "swagger_user", None)
__all__ = [ __all__ = [
"SwaggerAuthMiddleware", "SwaggerAuthMiddleware",
"ApiKeyMiddlewareHTTP", "ApiKeyMiddlewareHTTP",
"ApiKeyMiddleware", "ApiKeyMiddleware",
"get_api_key_from_request", "get_api_key_from_request",
"get_auth_method", "get_auth_method",
"get_swagger_user_from_request",
] ]