feat(api): add tag-based OpenAPI schema filtering for Swagger users
This commit is contained in:
parent
a7457c3979
commit
92597a1143
3 changed files with 166 additions and 43 deletions
95
api.py
95
api.py
|
|
@ -1,4 +1,5 @@
|
|||
from fastapi import FastAPI, HTTPException, Path, Query, Depends, status, Body
|
||||
from fastapi import FastAPI, HTTPException, Path, Query, Depends, status, Body, Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from fastapi.openapi.utils import get_openapi
|
||||
from fastapi.responses import StreamingResponse, HTMLResponse, Response
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
|
|
@ -175,18 +176,17 @@ app = FastAPI(
|
|||
) """
|
||||
|
||||
|
||||
def custom_openapi():
|
||||
if app.openapi_schema:
|
||||
return app.openapi_schema
|
||||
def generate_filtered_openapi(app: FastAPI, allowed_tags: Optional[List[str]] = None):
|
||||
"""Génère le schéma OpenAPI filtré selon les tags autorisés"""
|
||||
|
||||
openapi_schema = get_openapi(
|
||||
base_schema = get_openapi(
|
||||
title=app.title,
|
||||
version=app.version,
|
||||
description=app.description,
|
||||
routes=app.routes,
|
||||
)
|
||||
|
||||
openapi_schema["components"]["securitySchemes"] = {
|
||||
base_schema["components"]["securitySchemes"] = {
|
||||
"HTTPBearer": {
|
||||
"type": "http",
|
||||
"scheme": "bearer",
|
||||
|
|
@ -201,13 +201,68 @@ def custom_openapi():
|
|||
},
|
||||
}
|
||||
|
||||
openapi_schema["security"] = [{"HTTPBearer": []}, {"ApiKeyAuth": []}]
|
||||
base_schema["security"] = [{"HTTPBearer": []}, {"ApiKeyAuth": []}]
|
||||
|
||||
app.openapi_schema = openapi_schema
|
||||
return app.openapi_schema
|
||||
if not allowed_tags:
|
||||
return base_schema
|
||||
|
||||
filtered_paths = {}
|
||||
|
||||
for path, path_item in base_schema.get("paths", {}).items():
|
||||
for method, operation in path_item.items():
|
||||
if method in ["get", "post", "put", "delete", "patch", "options"]:
|
||||
operation_tags = operation.get("tags", [])
|
||||
|
||||
if any(tag in allowed_tags for tag in operation_tags):
|
||||
if path not in filtered_paths:
|
||||
filtered_paths[path] = {}
|
||||
filtered_paths[path][method] = operation
|
||||
|
||||
base_schema["paths"] = filtered_paths
|
||||
|
||||
if "tags" in base_schema:
|
||||
base_schema["tags"] = [
|
||||
tag_obj
|
||||
for tag_obj in base_schema["tags"]
|
||||
if tag_obj.get("name") in allowed_tags
|
||||
]
|
||||
|
||||
return base_schema
|
||||
|
||||
|
||||
app.openapi = custom_openapi
|
||||
async def get_swagger_user_from_request(request: Request) -> Optional[dict]:
|
||||
"""Récupère l'utilisateur Swagger depuis la requête authentifiée"""
|
||||
auth_header = request.headers.get("Authorization")
|
||||
|
||||
if not auth_header or not auth_header.startswith("Basic "):
|
||||
return None
|
||||
|
||||
import base64
|
||||
from fastapi.security import HTTPBasicCredentials
|
||||
from database.db_config import async_session_factory
|
||||
from database.models.api_key import SwaggerUser
|
||||
from sqlalchemy import select
|
||||
|
||||
try:
|
||||
encoded_credentials = auth_header.split(" ")[1]
|
||||
decoded_credentials = base64.b64decode(encoded_credentials).decode("utf-8")
|
||||
username, _ = decoded_credentials.split(":", 1)
|
||||
|
||||
async with async_session_factory() as session:
|
||||
result = await session.execute(
|
||||
select(SwaggerUser).where(SwaggerUser.username == username)
|
||||
)
|
||||
swagger_user = result.scalar_one_or_none()
|
||||
|
||||
if swagger_user and swagger_user.is_active:
|
||||
return {
|
||||
"username": swagger_user.username,
|
||||
"allowed_tags": swagger_user.allowed_tags_list,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"Erreur récupération utilisateur Swagger: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
setup_cors(app, mode="open")
|
||||
|
|
@ -221,6 +276,26 @@ app.include_router(universign_router)
|
|||
app.include_router(entreprises_router)
|
||||
|
||||
|
||||
@app.get("/openapi.json", include_in_schema=False)
|
||||
async def get_openapi_filtered(request: Request):
|
||||
"""Retourne le schéma OpenAPI filtré selon l'utilisateur"""
|
||||
|
||||
swagger_user = await get_swagger_user_from_request(request)
|
||||
|
||||
if not swagger_user:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Authentification requise"},
|
||||
headers={"WWW-Authenticate": 'Basic realm="Swagger UI"'},
|
||||
)
|
||||
|
||||
allowed_tags = swagger_user.get("allowed_tags")
|
||||
|
||||
schema = generate_filtered_openapi(app, allowed_tags)
|
||||
|
||||
return JSONResponse(content=schema)
|
||||
|
||||
|
||||
@app.get("/clients", response_model=List[ClientDetails], tags=["Clients"])
|
||||
async def obtenir_clients(
|
||||
query: Optional[str] = Query(None),
|
||||
|
|
|
|||
|
|
@ -1,4 +1,6 @@
|
|||
from sqlalchemy import Column, String, Boolean, DateTime, Integer, Text
|
||||
from typing import Optional, List
|
||||
import json
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
|
|
@ -49,8 +51,23 @@ class SwaggerUser(Base):
|
|||
|
||||
is_active = Column(Boolean, default=True, nullable=False)
|
||||
|
||||
allowed_tags = Column(Text, nullable=True)
|
||||
|
||||
created_at = Column(DateTime, default=datetime.now, nullable=False)
|
||||
last_login = Column(DateTime, nullable=True)
|
||||
|
||||
@property
|
||||
def allowed_tags_list(self) -> Optional[List[str]]:
|
||||
if self.allowed_tags:
|
||||
try:
|
||||
return json.loads(self.allowed_tags)
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
return None
|
||||
|
||||
@allowed_tags_list.setter
|
||||
def allowed_tags_list(self, tags: Optional[List[str]]):
|
||||
self.allowed_tags = json.dumps(tags) if tags is not None else None
|
||||
|
||||
def __repr__(self):
|
||||
return f"<SwaggerUser(username='{self.username}', active={self.is_active})>"
|
||||
|
|
|
|||
|
|
@ -1,6 +1,13 @@
|
|||
import sys
|
||||
import os
|
||||
from pathlib import Path
|
||||
import asyncio
|
||||
import argparse
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional, List
|
||||
import json
|
||||
from sqlalchemy import select
|
||||
|
||||
_current_file = Path(__file__).resolve()
|
||||
_script_dir = _current_file.parent
|
||||
|
|
@ -35,30 +42,28 @@ for module in _test_imports:
|
|||
except ImportError as e:
|
||||
print(f" {module}: {e}")
|
||||
|
||||
import asyncio
|
||||
import argparse
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
try:
|
||||
from database.db_config import async_session_factory
|
||||
from database.models.user import User
|
||||
from database.models.api_key import SwaggerUser, ApiKey
|
||||
from services.api_key import ApiKeyService
|
||||
from security.auth import hash_password
|
||||
except ImportError as e:
|
||||
print(f"\n ERREUR D'IMPORT: {e}")
|
||||
print(f" Vérifiez que vous êtes dans /app")
|
||||
print(f" Commande correcte: cd /app && python scripts/manage_security.py ...")
|
||||
print(" Vérifiez que vous êtes dans /app")
|
||||
print(" Commande correcte: cd /app && python scripts/manage_security.py ...")
|
||||
sys.exit(1)
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s - %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def add_swagger_user(username: str, password: str, full_name: str = None):
|
||||
async def add_swagger_user(
|
||||
username: str,
|
||||
password: str,
|
||||
full_name: str = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
):
|
||||
"""Ajouter un utilisateur Swagger"""
|
||||
async with async_session_factory() as session:
|
||||
result = await session.execute(
|
||||
|
|
@ -75,6 +80,7 @@ async def add_swagger_user(username: str, password: str, full_name: str = None):
|
|||
hashed_password=hash_password(password),
|
||||
full_name=full_name or username,
|
||||
is_active=True,
|
||||
allowed_tags=json.dumps(tags) if tags else None,
|
||||
)
|
||||
|
||||
session.add(swagger_user)
|
||||
|
|
@ -96,14 +102,28 @@ async def list_swagger_users():
|
|||
|
||||
logger.info(f"👥 {len(users)} utilisateur(s) Swagger:\n")
|
||||
for user in users:
|
||||
status = "" if user.is_active else ""
|
||||
status = "ACTIF" if user.is_active else "NON ACTIF"
|
||||
logger.info(f" {status} {user.username}")
|
||||
logger.info(f" Nom: {user.full_name}")
|
||||
logger.info(f" Créé: {user.created_at}")
|
||||
logger.info(f" Dernière connexion: {user.last_login or 'Jamais'}\n")
|
||||
logger.info(f" Dernière connexion: {user.last_login or 'Jamais'}")
|
||||
|
||||
if user.allowed_tags:
|
||||
try:
|
||||
tags = json.loads(user.allowed_tags)
|
||||
if tags:
|
||||
logger.info(f" Tags autorisés: {', '.join(tags)}")
|
||||
else:
|
||||
logger.info(" Tags autorisés: Tous (admin)")
|
||||
except json.JSONDecodeError:
|
||||
logger.info(" Tags: Erreur format")
|
||||
else:
|
||||
logger.info(" Tags autorisés: Tous (admin)")
|
||||
|
||||
logger.info("")
|
||||
|
||||
|
||||
async def delete_swagger_user(username: str):
|
||||
async def delete_swagger_user(username: str, tags: Optional[List[str]] = None):
|
||||
"""Supprimer un utilisateur Swagger"""
|
||||
async with async_session_factory() as session:
|
||||
result = await session.execute(
|
||||
|
|
@ -117,7 +137,7 @@ async def delete_swagger_user(username: str):
|
|||
|
||||
await session.delete(user)
|
||||
await session.commit()
|
||||
logger.info(f"🗑️ Utilisateur Swagger supprimé: {username}")
|
||||
logger.info("🗑️ Utilisateur Swagger supprimé: {}".format(username))
|
||||
|
||||
|
||||
async def create_api_key(
|
||||
|
|
@ -143,21 +163,23 @@ async def create_api_key(
|
|||
logger.info("=" * 70)
|
||||
logger.info("🔑 Clé API créée avec succès")
|
||||
logger.info("=" * 70)
|
||||
logger.info(f" ID: {api_key_obj.id}")
|
||||
logger.info(f" Nom: {api_key_obj.name}")
|
||||
logger.info(f" Clé: {api_key_plain}")
|
||||
logger.info(f" Préfixe: {api_key_obj.key_prefix}")
|
||||
logger.info(f" Rate limit: {api_key_obj.rate_limit_per_minute} req/min")
|
||||
logger.info(f" Expire le: {api_key_obj.expires_at}")
|
||||
logger.info(" ID: {}".format(api_key_obj.id))
|
||||
logger.info(" Nom: {}".format(api_key_obj.name))
|
||||
logger.info(" Clé: {}".format(api_key_plain))
|
||||
logger.info(" Préfixe: {}".format(api_key_obj.key_prefix))
|
||||
logger.info(
|
||||
" Rate limit: {} req/min".format(api_key_obj.rate_limit_per_minute)
|
||||
)
|
||||
logger.info(" Expire le: {}".format(api_key_obj.expires_at))
|
||||
|
||||
if api_key_obj.allowed_endpoints:
|
||||
import json
|
||||
|
||||
try:
|
||||
endpoints_list = json.loads(api_key_obj.allowed_endpoints)
|
||||
logger.info(f" Endpoints: {', '.join(endpoints_list)}")
|
||||
except:
|
||||
logger.info(f" Endpoints: {api_key_obj.allowed_endpoints}")
|
||||
logger.info(" Endpoints: {}".format(", ".join(endpoints_list)))
|
||||
except Exception:
|
||||
logger.info(" Endpoints: {}".format(api_key_obj.allowed_endpoints))
|
||||
else:
|
||||
logger.info(" Endpoints: Tous (aucune restriction)")
|
||||
|
||||
|
|
@ -200,7 +222,7 @@ async def list_api_keys():
|
|||
if len(endpoints) > 4:
|
||||
display += f"... (+{len(endpoints) - 4})"
|
||||
logger.info(f" Endpoints: {display}")
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
logger.info(" Endpoints: Tous")
|
||||
|
|
@ -250,7 +272,7 @@ async def verify_api_key(api_key: str):
|
|||
try:
|
||||
endpoints = json.loads(key.allowed_endpoints)
|
||||
logger.info(f" Endpoints autorisés: {endpoints}")
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
logger.info(" Endpoints autorisés: Tous")
|
||||
|
|
@ -269,6 +291,8 @@ Exemples:
|
|||
python scripts/manage_security.py apikey create "SDK-ReadOnly" --endpoints "/clients" "/clients/*" "/devis" "/devis/*"
|
||||
python scripts/manage_security.py apikey list
|
||||
python scripts/manage_security.py apikey verify sdk_live_xxxxx
|
||||
python scripts/manage_security.py swagger add client_user Secret123 --full-name "Client Tech IT" --tags Authentication Clients Devis Factures
|
||||
python scripts/manage_security.py swagger add admin_user AdminPass --tags # vide = tout voir
|
||||
""",
|
||||
)
|
||||
subparsers = parser.add_subparsers(dest="command", help="Commandes")
|
||||
|
|
@ -280,6 +304,11 @@ Exemples:
|
|||
add_p.add_argument("username", help="Nom d'utilisateur")
|
||||
add_p.add_argument("password", help="Mot de passe")
|
||||
add_p.add_argument("--full-name", help="Nom complet")
|
||||
add_p.add_argument(
|
||||
"--tags",
|
||||
nargs="*",
|
||||
help="Tags OpenAPI autorisés (ex. Clients Devis Authentication)",
|
||||
)
|
||||
|
||||
swagger_sub.add_parser("list", help="Lister utilisateurs")
|
||||
|
||||
|
|
@ -312,11 +341,13 @@ Exemples:
|
|||
|
||||
if args.command == "swagger":
|
||||
if args.swagger_command == "add":
|
||||
await add_swagger_user(args.username, args.password, args.full_name)
|
||||
await add_swagger_user(
|
||||
args.username, args.password, args.full_name, args.tags
|
||||
)
|
||||
elif args.swagger_command == "list":
|
||||
await list_swagger_users()
|
||||
elif args.swagger_command == "delete":
|
||||
await delete_swagger_user(args.username)
|
||||
await delete_swagger_user(args.username, args.tags)
|
||||
else:
|
||||
swagger_parser.print_help()
|
||||
|
||||
|
|
|
|||
Loading…
Reference in a new issue