refactor(api): simplify auth scheme handling and improve schema filtering

This commit is contained in:
Fanilo-Nantenaina 2026-01-21 13:41:37 +03:00
parent 437ecd0ed3
commit a6a623d1ab

76
api.py
View file

@ -303,21 +303,12 @@ def get_auth_schemes_for_user(swagger_user: dict) -> dict:
"description": "Authentification JWT pour utilisateurs (POST /auth/login)", "description": "Authentification JWT pour utilisateurs (POST /auth/login)",
} }
if "API Keys Management" in allowed_tags or len(allowed_tags) > 3: schemes["ApiKeyAuth"] = {
schemes["ApiKeyAuth"] = { "type": "apiKey",
"type": "apiKey", "in": "header",
"in": "header", "name": "X-API-Key",
"name": "X-API-Key", "description": "Clé API pour intégrations externes (format: sdk_live_xxx)",
"description": "Clé API pour intégrations externes (format: sdk_live_xxx)", }
}
if not schemes:
schemes["HTTPBearer"] = {
"type": "http",
"scheme": "bearer",
"bearerFormat": "JWT",
"description": "Authentification requise",
}
return schemes return schemes
@ -353,15 +344,17 @@ def generate_filtered_openapi_schema(
base_schema["components"]["securitySchemes"] = auth_schemes base_schema["components"]["securitySchemes"] = auth_schemes
security_requirements = [] security_requirements = []
if "HTTPBearer" in auth_schemes: if "HTTPBearer" in auth_schemes and "ApiKeyAuth" in auth_schemes:
security_requirements.append({"HTTPBearer": []}) security_requirements = [{"HTTPBearer": []}, {"ApiKeyAuth": []}]
if "ApiKeyAuth" in auth_schemes: elif "HTTPBearer" in auth_schemes:
security_requirements.append({"ApiKeyAuth": []}) security_requirements = [{"HTTPBearer": []}]
elif "ApiKeyAuth" in auth_schemes:
security_requirements = [{"ApiKeyAuth": []}]
base_schema["security"] = security_requirements if security_requirements else [] base_schema["security"] = security_requirements
if not allowed_tags: if not allowed_tags:
logger.info(" Schéma OpenAPI complet (admin)") logger.info("⚙️ Schéma OpenAPI complet (admin)")
return base_schema return base_schema
filtered_paths = {} filtered_paths = {}
@ -384,6 +377,7 @@ def generate_filtered_openapi_schema(
operation_tags = operation.get("tags", []) operation_tags = operation.get("tags", [])
if any(tag in allowed_tags for tag in operation_tags): if any(tag in allowed_tags for tag in operation_tags):
operation["security"] = security_requirements
filtered_operations[method] = operation filtered_operations[method] = operation
if filtered_operations: if filtered_operations:
@ -400,7 +394,39 @@ def generate_filtered_openapi_schema(
if "components" in base_schema and "schemas" in base_schema["components"]: if "components" in base_schema and "schemas" in base_schema["components"]:
all_schemas = base_schema["components"]["schemas"] all_schemas = base_schema["components"]["schemas"]
filtered_schemas = get_schemas_for_tags(allowed_tags, all_schemas)
referenced_schemas = set()
def extract_schema_refs(obj):
"""Extrait récursivement tous les $ref depuis un objet"""
if isinstance(obj, dict):
for key, value in obj.items():
if key == "$ref" and isinstance(value, str):
schema_name = value.split("/")[-1]
referenced_schemas.add(schema_name)
else:
extract_schema_refs(value)
elif isinstance(obj, list):
for item in obj:
extract_schema_refs(item)
extract_schema_refs(filtered_paths)
def add_dependencies(schema_name):
if schema_name not in all_schemas:
return
schema_def = all_schemas[schema_name]
extract_schema_refs(schema_def)
initial_schemas = referenced_schemas.copy()
for schema_name in initial_schemas:
add_dependencies(schema_name)
filtered_schemas = {}
for schema_name in referenced_schemas:
if schema_name in all_schemas:
filtered_schemas[schema_name] = all_schemas[schema_name]
base_schema["components"]["schemas"] = filtered_schemas base_schema["components"]["schemas"] = filtered_schemas
logger.info( logger.info(
@ -429,10 +455,14 @@ async def custom_openapi_endpoint(request: Request):
username = swagger_user.get("username", "unknown") username = swagger_user.get("username", "unknown")
allowed_tags = swagger_user.get("allowed_tags") allowed_tags = swagger_user.get("allowed_tags")
logger.info(f" OpenAPI demandé par: {username}, tags: {allowed_tags or 'ALL'}") logger.info(f"📖 OpenAPI demandé par: {username}, tags: {allowed_tags or 'ALL'}")
schema = generate_filtered_openapi_schema(app, allowed_tags, swagger_user) schema = generate_filtered_openapi_schema(app, allowed_tags, swagger_user)
if request.url.scheme == "https":
if "servers" not in schema or not schema["servers"]:
schema["servers"] = [{"url": str(request.base_url).rstrip("/")}]
return JSONResponse(content=schema) return JSONResponse(content=schema)