feat(api): add tag-based OpenAPI schema filtering for Swagger users

This commit is contained in:
Fanilo-Nantenaina 2026-01-21 12:05:06 +03:00
parent a7457c3979
commit 92597a1143
3 changed files with 166 additions and 43 deletions

95
api.py
View file

@ -1,4 +1,5 @@
from fastapi import FastAPI, HTTPException, Path, Query, Depends, status, Body from fastapi import FastAPI, HTTPException, Path, Query, Depends, status, Body, Request
from fastapi.responses import JSONResponse
from fastapi.openapi.utils import get_openapi from fastapi.openapi.utils import get_openapi
from fastapi.responses import StreamingResponse, HTMLResponse, Response from fastapi.responses import StreamingResponse, HTMLResponse, Response
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
@ -175,18 +176,17 @@ app = FastAPI(
) """ ) """
def custom_openapi(): def generate_filtered_openapi(app: FastAPI, allowed_tags: Optional[List[str]] = None):
if app.openapi_schema: """Génère le schéma OpenAPI filtré selon les tags autorisés"""
return app.openapi_schema
openapi_schema = get_openapi( base_schema = get_openapi(
title=app.title, title=app.title,
version=app.version, version=app.version,
description=app.description, description=app.description,
routes=app.routes, routes=app.routes,
) )
openapi_schema["components"]["securitySchemes"] = { base_schema["components"]["securitySchemes"] = {
"HTTPBearer": { "HTTPBearer": {
"type": "http", "type": "http",
"scheme": "bearer", "scheme": "bearer",
@ -201,13 +201,68 @@ def custom_openapi():
}, },
} }
openapi_schema["security"] = [{"HTTPBearer": []}, {"ApiKeyAuth": []}] base_schema["security"] = [{"HTTPBearer": []}, {"ApiKeyAuth": []}]
app.openapi_schema = openapi_schema if not allowed_tags:
return app.openapi_schema return base_schema
filtered_paths = {}
for path, path_item in base_schema.get("paths", {}).items():
for method, operation in path_item.items():
if method in ["get", "post", "put", "delete", "patch", "options"]:
operation_tags = operation.get("tags", [])
if any(tag in allowed_tags for tag in operation_tags):
if path not in filtered_paths:
filtered_paths[path] = {}
filtered_paths[path][method] = operation
base_schema["paths"] = filtered_paths
if "tags" in base_schema:
base_schema["tags"] = [
tag_obj
for tag_obj in base_schema["tags"]
if tag_obj.get("name") in allowed_tags
]
return base_schema
app.openapi = custom_openapi async def get_swagger_user_from_request(request: Request) -> Optional[dict]:
"""Récupère l'utilisateur Swagger depuis la requête authentifiée"""
auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Basic "):
return None
import base64
from fastapi.security import HTTPBasicCredentials
from database.db_config import async_session_factory
from database.models.api_key import SwaggerUser
from sqlalchemy import select
try:
encoded_credentials = auth_header.split(" ")[1]
decoded_credentials = base64.b64decode(encoded_credentials).decode("utf-8")
username, _ = decoded_credentials.split(":", 1)
async with async_session_factory() as session:
result = await session.execute(
select(SwaggerUser).where(SwaggerUser.username == username)
)
swagger_user = result.scalar_one_or_none()
if swagger_user and swagger_user.is_active:
return {
"username": swagger_user.username,
"allowed_tags": swagger_user.allowed_tags_list,
}
except Exception as e:
logger.error(f"Erreur récupération utilisateur Swagger: {e}")
return None
setup_cors(app, mode="open") setup_cors(app, mode="open")
@ -221,6 +276,26 @@ app.include_router(universign_router)
app.include_router(entreprises_router) app.include_router(entreprises_router)
@app.get("/openapi.json", include_in_schema=False)
async def get_openapi_filtered(request: Request):
"""Retourne le schéma OpenAPI filtré selon l'utilisateur"""
swagger_user = await get_swagger_user_from_request(request)
if not swagger_user:
return JSONResponse(
status_code=401,
content={"detail": "Authentification requise"},
headers={"WWW-Authenticate": 'Basic realm="Swagger UI"'},
)
allowed_tags = swagger_user.get("allowed_tags")
schema = generate_filtered_openapi(app, allowed_tags)
return JSONResponse(content=schema)
@app.get("/clients", response_model=List[ClientDetails], tags=["Clients"]) @app.get("/clients", response_model=List[ClientDetails], tags=["Clients"])
async def obtenir_clients( async def obtenir_clients(
query: Optional[str] = Query(None), query: Optional[str] = Query(None),

View file

@ -1,4 +1,6 @@
from sqlalchemy import Column, String, Boolean, DateTime, Integer, Text from sqlalchemy import Column, String, Boolean, DateTime, Integer, Text
from typing import Optional, List
import json
from datetime import datetime from datetime import datetime
import uuid import uuid
@ -49,8 +51,23 @@ class SwaggerUser(Base):
is_active = Column(Boolean, default=True, nullable=False) is_active = Column(Boolean, default=True, nullable=False)
allowed_tags = Column(Text, nullable=True)
created_at = Column(DateTime, default=datetime.now, nullable=False) created_at = Column(DateTime, default=datetime.now, nullable=False)
last_login = Column(DateTime, nullable=True) last_login = Column(DateTime, nullable=True)
@property
def allowed_tags_list(self) -> Optional[List[str]]:
if self.allowed_tags:
try:
return json.loads(self.allowed_tags)
except json.JSONDecodeError:
return None
return None
@allowed_tags_list.setter
def allowed_tags_list(self, tags: Optional[List[str]]):
self.allowed_tags = json.dumps(tags) if tags is not None else None
def __repr__(self): def __repr__(self):
return f"<SwaggerUser(username='{self.username}', active={self.is_active})>" return f"<SwaggerUser(username='{self.username}', active={self.is_active})>"

View file

@ -1,6 +1,13 @@
import sys import sys
import os import os
from pathlib import Path from pathlib import Path
import asyncio
import argparse
import logging
from datetime import datetime
from typing import Optional, List
import json
from sqlalchemy import select
_current_file = Path(__file__).resolve() _current_file = Path(__file__).resolve()
_script_dir = _current_file.parent _script_dir = _current_file.parent
@ -35,30 +42,28 @@ for module in _test_imports:
except ImportError as e: except ImportError as e:
print(f" {module}: {e}") print(f" {module}: {e}")
import asyncio
import argparse
import logging
from datetime import datetime
from sqlalchemy import select
try: try:
from database.db_config import async_session_factory from database.db_config import async_session_factory
from database.models.user import User
from database.models.api_key import SwaggerUser, ApiKey from database.models.api_key import SwaggerUser, ApiKey
from services.api_key import ApiKeyService from services.api_key import ApiKeyService
from security.auth import hash_password from security.auth import hash_password
except ImportError as e: except ImportError as e:
print(f"\n ERREUR D'IMPORT: {e}") print(f"\n ERREUR D'IMPORT: {e}")
print(f" Vérifiez que vous êtes dans /app") print(" Vérifiez que vous êtes dans /app")
print(f" Commande correcte: cd /app && python scripts/manage_security.py ...") print(" Commande correcte: cd /app && python scripts/manage_security.py ...")
sys.exit(1) sys.exit(1)
logging.basicConfig(level=logging.INFO, format="%(levelname)s - %(message)s") logging.basicConfig(level=logging.INFO, format="%(levelname)s - %(message)s")
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
async def add_swagger_user(username: str, password: str, full_name: str = None): async def add_swagger_user(
username: str,
password: str,
full_name: str = None,
tags: Optional[List[str]] = None,
):
"""Ajouter un utilisateur Swagger""" """Ajouter un utilisateur Swagger"""
async with async_session_factory() as session: async with async_session_factory() as session:
result = await session.execute( result = await session.execute(
@ -75,6 +80,7 @@ async def add_swagger_user(username: str, password: str, full_name: str = None):
hashed_password=hash_password(password), hashed_password=hash_password(password),
full_name=full_name or username, full_name=full_name or username,
is_active=True, is_active=True,
allowed_tags=json.dumps(tags) if tags else None,
) )
session.add(swagger_user) session.add(swagger_user)
@ -96,14 +102,28 @@ async def list_swagger_users():
logger.info(f"👥 {len(users)} utilisateur(s) Swagger:\n") logger.info(f"👥 {len(users)} utilisateur(s) Swagger:\n")
for user in users: for user in users:
status = "" if user.is_active else "" status = "ACTIF" if user.is_active else "NON ACTIF"
logger.info(f" {status} {user.username}") logger.info(f" {status} {user.username}")
logger.info(f" Nom: {user.full_name}") logger.info(f" Nom: {user.full_name}")
logger.info(f" Créé: {user.created_at}") logger.info(f" Créé: {user.created_at}")
logger.info(f" Dernière connexion: {user.last_login or 'Jamais'}\n") logger.info(f" Dernière connexion: {user.last_login or 'Jamais'}")
if user.allowed_tags:
try:
tags = json.loads(user.allowed_tags)
if tags:
logger.info(f" Tags autorisés: {', '.join(tags)}")
else:
logger.info(" Tags autorisés: Tous (admin)")
except json.JSONDecodeError:
logger.info(" Tags: Erreur format")
else:
logger.info(" Tags autorisés: Tous (admin)")
logger.info("")
async def delete_swagger_user(username: str): async def delete_swagger_user(username: str, tags: Optional[List[str]] = None):
"""Supprimer un utilisateur Swagger""" """Supprimer un utilisateur Swagger"""
async with async_session_factory() as session: async with async_session_factory() as session:
result = await session.execute( result = await session.execute(
@ -117,7 +137,7 @@ async def delete_swagger_user(username: str):
await session.delete(user) await session.delete(user)
await session.commit() await session.commit()
logger.info(f"🗑️ Utilisateur Swagger supprimé: {username}") logger.info("🗑️ Utilisateur Swagger supprimé: {}".format(username))
async def create_api_key( async def create_api_key(
@ -143,21 +163,23 @@ async def create_api_key(
logger.info("=" * 70) logger.info("=" * 70)
logger.info("🔑 Clé API créée avec succès") logger.info("🔑 Clé API créée avec succès")
logger.info("=" * 70) logger.info("=" * 70)
logger.info(f" ID: {api_key_obj.id}") logger.info(" ID: {}".format(api_key_obj.id))
logger.info(f" Nom: {api_key_obj.name}") logger.info(" Nom: {}".format(api_key_obj.name))
logger.info(f" Clé: {api_key_plain}") logger.info(" Clé: {}".format(api_key_plain))
logger.info(f" Préfixe: {api_key_obj.key_prefix}") logger.info(" Préfixe: {}".format(api_key_obj.key_prefix))
logger.info(f" Rate limit: {api_key_obj.rate_limit_per_minute} req/min") logger.info(
logger.info(f" Expire le: {api_key_obj.expires_at}") " Rate limit: {} req/min".format(api_key_obj.rate_limit_per_minute)
)
logger.info(" Expire le: {}".format(api_key_obj.expires_at))
if api_key_obj.allowed_endpoints: if api_key_obj.allowed_endpoints:
import json import json
try: try:
endpoints_list = json.loads(api_key_obj.allowed_endpoints) endpoints_list = json.loads(api_key_obj.allowed_endpoints)
logger.info(f" Endpoints: {', '.join(endpoints_list)}") logger.info(" Endpoints: {}".format(", ".join(endpoints_list)))
except: except Exception:
logger.info(f" Endpoints: {api_key_obj.allowed_endpoints}") logger.info(" Endpoints: {}".format(api_key_obj.allowed_endpoints))
else: else:
logger.info(" Endpoints: Tous (aucune restriction)") logger.info(" Endpoints: Tous (aucune restriction)")
@ -200,7 +222,7 @@ async def list_api_keys():
if len(endpoints) > 4: if len(endpoints) > 4:
display += f"... (+{len(endpoints) - 4})" display += f"... (+{len(endpoints) - 4})"
logger.info(f" Endpoints: {display}") logger.info(f" Endpoints: {display}")
except: except Exception:
pass pass
else: else:
logger.info(" Endpoints: Tous") logger.info(" Endpoints: Tous")
@ -250,7 +272,7 @@ async def verify_api_key(api_key: str):
try: try:
endpoints = json.loads(key.allowed_endpoints) endpoints = json.loads(key.allowed_endpoints)
logger.info(f" Endpoints autorisés: {endpoints}") logger.info(f" Endpoints autorisés: {endpoints}")
except: except Exception:
pass pass
else: else:
logger.info(" Endpoints autorisés: Tous") logger.info(" Endpoints autorisés: Tous")
@ -269,6 +291,8 @@ Exemples:
python scripts/manage_security.py apikey create "SDK-ReadOnly" --endpoints "/clients" "/clients/*" "/devis" "/devis/*" python scripts/manage_security.py apikey create "SDK-ReadOnly" --endpoints "/clients" "/clients/*" "/devis" "/devis/*"
python scripts/manage_security.py apikey list python scripts/manage_security.py apikey list
python scripts/manage_security.py apikey verify sdk_live_xxxxx python scripts/manage_security.py apikey verify sdk_live_xxxxx
python scripts/manage_security.py swagger add client_user Secret123 --full-name "Client Tech IT" --tags Authentication Clients Devis Factures
python scripts/manage_security.py swagger add admin_user AdminPass --tags # vide = tout voir
""", """,
) )
subparsers = parser.add_subparsers(dest="command", help="Commandes") subparsers = parser.add_subparsers(dest="command", help="Commandes")
@ -280,6 +304,11 @@ Exemples:
add_p.add_argument("username", help="Nom d'utilisateur") add_p.add_argument("username", help="Nom d'utilisateur")
add_p.add_argument("password", help="Mot de passe") add_p.add_argument("password", help="Mot de passe")
add_p.add_argument("--full-name", help="Nom complet") add_p.add_argument("--full-name", help="Nom complet")
add_p.add_argument(
"--tags",
nargs="*",
help="Tags OpenAPI autorisés (ex. Clients Devis Authentication)",
)
swagger_sub.add_parser("list", help="Lister utilisateurs") swagger_sub.add_parser("list", help="Lister utilisateurs")
@ -312,11 +341,13 @@ Exemples:
if args.command == "swagger": if args.command == "swagger":
if args.swagger_command == "add": if args.swagger_command == "add":
await add_swagger_user(args.username, args.password, args.full_name) await add_swagger_user(
args.username, args.password, args.full_name, args.tags
)
elif args.swagger_command == "list": elif args.swagger_command == "list":
await list_swagger_users() await list_swagger_users()
elif args.swagger_command == "delete": elif args.swagger_command == "delete":
await delete_swagger_user(args.username) await delete_swagger_user(args.username, args.tags)
else: else:
swagger_parser.print_help() swagger_parser.print_help()