diff --git a/api.py b/api.py index 5f3be7c..d1d477e 100644 --- a/api.py +++ b/api.py @@ -303,21 +303,12 @@ def get_auth_schemes_for_user(swagger_user: dict) -> dict: "description": "Authentification JWT pour utilisateurs (POST /auth/login)", } - if "API Keys Management" in allowed_tags or len(allowed_tags) > 3: - schemes["ApiKeyAuth"] = { - "type": "apiKey", - "in": "header", - "name": "X-API-Key", - "description": "Clé API pour intégrations externes (format: sdk_live_xxx)", - } - - if not schemes: - schemes["HTTPBearer"] = { - "type": "http", - "scheme": "bearer", - "bearerFormat": "JWT", - "description": "Authentification requise", - } + schemes["ApiKeyAuth"] = { + "type": "apiKey", + "in": "header", + "name": "X-API-Key", + "description": "Clé API pour intégrations externes (format: sdk_live_xxx)", + } return schemes @@ -353,15 +344,17 @@ def generate_filtered_openapi_schema( base_schema["components"]["securitySchemes"] = auth_schemes security_requirements = [] - if "HTTPBearer" in auth_schemes: - security_requirements.append({"HTTPBearer": []}) - if "ApiKeyAuth" in auth_schemes: - security_requirements.append({"ApiKeyAuth": []}) + if "HTTPBearer" in auth_schemes and "ApiKeyAuth" in auth_schemes: + security_requirements = [{"HTTPBearer": []}, {"ApiKeyAuth": []}] + elif "HTTPBearer" in auth_schemes: + security_requirements = [{"HTTPBearer": []}] + elif "ApiKeyAuth" in auth_schemes: + security_requirements = [{"ApiKeyAuth": []}] - base_schema["security"] = security_requirements if security_requirements else [] + base_schema["security"] = security_requirements if not allowed_tags: - logger.info(" Schéma OpenAPI complet (admin)") + logger.info("⚙️ Schéma OpenAPI complet (admin)") return base_schema filtered_paths = {} @@ -384,6 +377,7 @@ def generate_filtered_openapi_schema( operation_tags = operation.get("tags", []) if any(tag in allowed_tags for tag in operation_tags): + operation["security"] = security_requirements filtered_operations[method] = operation if filtered_operations: @@ -400,7 +394,39 @@ def generate_filtered_openapi_schema( if "components" in base_schema and "schemas" in base_schema["components"]: all_schemas = base_schema["components"]["schemas"] - filtered_schemas = get_schemas_for_tags(allowed_tags, all_schemas) + + referenced_schemas = set() + + def extract_schema_refs(obj): + """Extrait récursivement tous les $ref depuis un objet""" + if isinstance(obj, dict): + for key, value in obj.items(): + if key == "$ref" and isinstance(value, str): + schema_name = value.split("/")[-1] + referenced_schemas.add(schema_name) + else: + extract_schema_refs(value) + elif isinstance(obj, list): + for item in obj: + extract_schema_refs(item) + + extract_schema_refs(filtered_paths) + + def add_dependencies(schema_name): + if schema_name not in all_schemas: + return + schema_def = all_schemas[schema_name] + extract_schema_refs(schema_def) + + initial_schemas = referenced_schemas.copy() + for schema_name in initial_schemas: + add_dependencies(schema_name) + + filtered_schemas = {} + for schema_name in referenced_schemas: + if schema_name in all_schemas: + filtered_schemas[schema_name] = all_schemas[schema_name] + base_schema["components"]["schemas"] = filtered_schemas logger.info( @@ -429,10 +455,14 @@ async def custom_openapi_endpoint(request: Request): username = swagger_user.get("username", "unknown") allowed_tags = swagger_user.get("allowed_tags") - logger.info(f" OpenAPI demandé par: {username}, tags: {allowed_tags or 'ALL'}") + logger.info(f"📖 OpenAPI demandé par: {username}, tags: {allowed_tags or 'ALL'}") schema = generate_filtered_openapi_schema(app, allowed_tags, swagger_user) + if request.url.scheme == "https": + if "servers" not in schema or not schema["servers"]: + schema["servers"] = [{"url": str(request.base_url).rstrip("/")}] + return JSONResponse(content=schema)