feat(api): add tag-based OpenAPI schema filtering for Swagger users
This commit is contained in:
parent
a7457c3979
commit
92597a1143
3 changed files with 166 additions and 43 deletions
95
api.py
95
api.py
|
|
@ -1,4 +1,5 @@
|
||||||
from fastapi import FastAPI, HTTPException, Path, Query, Depends, status, Body
|
from fastapi import FastAPI, HTTPException, Path, Query, Depends, status, Body, Request
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
from fastapi.openapi.utils import get_openapi
|
from fastapi.openapi.utils import get_openapi
|
||||||
from fastapi.responses import StreamingResponse, HTMLResponse, Response
|
from fastapi.responses import StreamingResponse, HTMLResponse, Response
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
|
|
@ -175,18 +176,17 @@ app = FastAPI(
|
||||||
) """
|
) """
|
||||||
|
|
||||||
|
|
||||||
def custom_openapi():
|
def generate_filtered_openapi(app: FastAPI, allowed_tags: Optional[List[str]] = None):
|
||||||
if app.openapi_schema:
|
"""Génère le schéma OpenAPI filtré selon les tags autorisés"""
|
||||||
return app.openapi_schema
|
|
||||||
|
|
||||||
openapi_schema = get_openapi(
|
base_schema = get_openapi(
|
||||||
title=app.title,
|
title=app.title,
|
||||||
version=app.version,
|
version=app.version,
|
||||||
description=app.description,
|
description=app.description,
|
||||||
routes=app.routes,
|
routes=app.routes,
|
||||||
)
|
)
|
||||||
|
|
||||||
openapi_schema["components"]["securitySchemes"] = {
|
base_schema["components"]["securitySchemes"] = {
|
||||||
"HTTPBearer": {
|
"HTTPBearer": {
|
||||||
"type": "http",
|
"type": "http",
|
||||||
"scheme": "bearer",
|
"scheme": "bearer",
|
||||||
|
|
@ -201,13 +201,68 @@ def custom_openapi():
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
openapi_schema["security"] = [{"HTTPBearer": []}, {"ApiKeyAuth": []}]
|
base_schema["security"] = [{"HTTPBearer": []}, {"ApiKeyAuth": []}]
|
||||||
|
|
||||||
app.openapi_schema = openapi_schema
|
if not allowed_tags:
|
||||||
return app.openapi_schema
|
return base_schema
|
||||||
|
|
||||||
|
filtered_paths = {}
|
||||||
|
|
||||||
|
for path, path_item in base_schema.get("paths", {}).items():
|
||||||
|
for method, operation in path_item.items():
|
||||||
|
if method in ["get", "post", "put", "delete", "patch", "options"]:
|
||||||
|
operation_tags = operation.get("tags", [])
|
||||||
|
|
||||||
|
if any(tag in allowed_tags for tag in operation_tags):
|
||||||
|
if path not in filtered_paths:
|
||||||
|
filtered_paths[path] = {}
|
||||||
|
filtered_paths[path][method] = operation
|
||||||
|
|
||||||
|
base_schema["paths"] = filtered_paths
|
||||||
|
|
||||||
|
if "tags" in base_schema:
|
||||||
|
base_schema["tags"] = [
|
||||||
|
tag_obj
|
||||||
|
for tag_obj in base_schema["tags"]
|
||||||
|
if tag_obj.get("name") in allowed_tags
|
||||||
|
]
|
||||||
|
|
||||||
|
return base_schema
|
||||||
|
|
||||||
|
|
||||||
app.openapi = custom_openapi
|
async def get_swagger_user_from_request(request: Request) -> Optional[dict]:
|
||||||
|
"""Récupère l'utilisateur Swagger depuis la requête authentifiée"""
|
||||||
|
auth_header = request.headers.get("Authorization")
|
||||||
|
|
||||||
|
if not auth_header or not auth_header.startswith("Basic "):
|
||||||
|
return None
|
||||||
|
|
||||||
|
import base64
|
||||||
|
from fastapi.security import HTTPBasicCredentials
|
||||||
|
from database.db_config import async_session_factory
|
||||||
|
from database.models.api_key import SwaggerUser
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
try:
|
||||||
|
encoded_credentials = auth_header.split(" ")[1]
|
||||||
|
decoded_credentials = base64.b64decode(encoded_credentials).decode("utf-8")
|
||||||
|
username, _ = decoded_credentials.split(":", 1)
|
||||||
|
|
||||||
|
async with async_session_factory() as session:
|
||||||
|
result = await session.execute(
|
||||||
|
select(SwaggerUser).where(SwaggerUser.username == username)
|
||||||
|
)
|
||||||
|
swagger_user = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if swagger_user and swagger_user.is_active:
|
||||||
|
return {
|
||||||
|
"username": swagger_user.username,
|
||||||
|
"allowed_tags": swagger_user.allowed_tags_list,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Erreur récupération utilisateur Swagger: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
setup_cors(app, mode="open")
|
setup_cors(app, mode="open")
|
||||||
|
|
@ -221,6 +276,26 @@ app.include_router(universign_router)
|
||||||
app.include_router(entreprises_router)
|
app.include_router(entreprises_router)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/openapi.json", include_in_schema=False)
|
||||||
|
async def get_openapi_filtered(request: Request):
|
||||||
|
"""Retourne le schéma OpenAPI filtré selon l'utilisateur"""
|
||||||
|
|
||||||
|
swagger_user = await get_swagger_user_from_request(request)
|
||||||
|
|
||||||
|
if not swagger_user:
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=401,
|
||||||
|
content={"detail": "Authentification requise"},
|
||||||
|
headers={"WWW-Authenticate": 'Basic realm="Swagger UI"'},
|
||||||
|
)
|
||||||
|
|
||||||
|
allowed_tags = swagger_user.get("allowed_tags")
|
||||||
|
|
||||||
|
schema = generate_filtered_openapi(app, allowed_tags)
|
||||||
|
|
||||||
|
return JSONResponse(content=schema)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/clients", response_model=List[ClientDetails], tags=["Clients"])
|
@app.get("/clients", response_model=List[ClientDetails], tags=["Clients"])
|
||||||
async def obtenir_clients(
|
async def obtenir_clients(
|
||||||
query: Optional[str] = Query(None),
|
query: Optional[str] = Query(None),
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,6 @@
|
||||||
from sqlalchemy import Column, String, Boolean, DateTime, Integer, Text
|
from sqlalchemy import Column, String, Boolean, DateTime, Integer, Text
|
||||||
|
from typing import Optional, List
|
||||||
|
import json
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
|
|
@ -49,8 +51,23 @@ class SwaggerUser(Base):
|
||||||
|
|
||||||
is_active = Column(Boolean, default=True, nullable=False)
|
is_active = Column(Boolean, default=True, nullable=False)
|
||||||
|
|
||||||
|
allowed_tags = Column(Text, nullable=True)
|
||||||
|
|
||||||
created_at = Column(DateTime, default=datetime.now, nullable=False)
|
created_at = Column(DateTime, default=datetime.now, nullable=False)
|
||||||
last_login = Column(DateTime, nullable=True)
|
last_login = Column(DateTime, nullable=True)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def allowed_tags_list(self) -> Optional[List[str]]:
|
||||||
|
if self.allowed_tags:
|
||||||
|
try:
|
||||||
|
return json.loads(self.allowed_tags)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
|
|
||||||
|
@allowed_tags_list.setter
|
||||||
|
def allowed_tags_list(self, tags: Optional[List[str]]):
|
||||||
|
self.allowed_tags = json.dumps(tags) if tags is not None else None
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<SwaggerUser(username='{self.username}', active={self.is_active})>"
|
return f"<SwaggerUser(username='{self.username}', active={self.is_active})>"
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,13 @@
|
||||||
import sys
|
import sys
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
import asyncio
|
||||||
|
import argparse
|
||||||
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, List
|
||||||
|
import json
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
_current_file = Path(__file__).resolve()
|
_current_file = Path(__file__).resolve()
|
||||||
_script_dir = _current_file.parent
|
_script_dir = _current_file.parent
|
||||||
|
|
@ -35,30 +42,28 @@ for module in _test_imports:
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
print(f" {module}: {e}")
|
print(f" {module}: {e}")
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import argparse
|
|
||||||
import logging
|
|
||||||
from datetime import datetime
|
|
||||||
|
|
||||||
from sqlalchemy import select
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from database.db_config import async_session_factory
|
from database.db_config import async_session_factory
|
||||||
from database.models.user import User
|
|
||||||
from database.models.api_key import SwaggerUser, ApiKey
|
from database.models.api_key import SwaggerUser, ApiKey
|
||||||
from services.api_key import ApiKeyService
|
from services.api_key import ApiKeyService
|
||||||
from security.auth import hash_password
|
from security.auth import hash_password
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
print(f"\n ERREUR D'IMPORT: {e}")
|
print(f"\n ERREUR D'IMPORT: {e}")
|
||||||
print(f" Vérifiez que vous êtes dans /app")
|
print(" Vérifiez que vous êtes dans /app")
|
||||||
print(f" Commande correcte: cd /app && python scripts/manage_security.py ...")
|
print(" Commande correcte: cd /app && python scripts/manage_security.py ...")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s - %(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s - %(message)s")
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
async def add_swagger_user(username: str, password: str, full_name: str = None):
|
async def add_swagger_user(
|
||||||
|
username: str,
|
||||||
|
password: str,
|
||||||
|
full_name: str = None,
|
||||||
|
tags: Optional[List[str]] = None,
|
||||||
|
):
|
||||||
"""Ajouter un utilisateur Swagger"""
|
"""Ajouter un utilisateur Swagger"""
|
||||||
async with async_session_factory() as session:
|
async with async_session_factory() as session:
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
|
|
@ -75,6 +80,7 @@ async def add_swagger_user(username: str, password: str, full_name: str = None):
|
||||||
hashed_password=hash_password(password),
|
hashed_password=hash_password(password),
|
||||||
full_name=full_name or username,
|
full_name=full_name or username,
|
||||||
is_active=True,
|
is_active=True,
|
||||||
|
allowed_tags=json.dumps(tags) if tags else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
session.add(swagger_user)
|
session.add(swagger_user)
|
||||||
|
|
@ -96,14 +102,28 @@ async def list_swagger_users():
|
||||||
|
|
||||||
logger.info(f"👥 {len(users)} utilisateur(s) Swagger:\n")
|
logger.info(f"👥 {len(users)} utilisateur(s) Swagger:\n")
|
||||||
for user in users:
|
for user in users:
|
||||||
status = "" if user.is_active else ""
|
status = "ACTIF" if user.is_active else "NON ACTIF"
|
||||||
logger.info(f" {status} {user.username}")
|
logger.info(f" {status} {user.username}")
|
||||||
logger.info(f" Nom: {user.full_name}")
|
logger.info(f" Nom: {user.full_name}")
|
||||||
logger.info(f" Créé: {user.created_at}")
|
logger.info(f" Créé: {user.created_at}")
|
||||||
logger.info(f" Dernière connexion: {user.last_login or 'Jamais'}\n")
|
logger.info(f" Dernière connexion: {user.last_login or 'Jamais'}")
|
||||||
|
|
||||||
|
if user.allowed_tags:
|
||||||
|
try:
|
||||||
|
tags = json.loads(user.allowed_tags)
|
||||||
|
if tags:
|
||||||
|
logger.info(f" Tags autorisés: {', '.join(tags)}")
|
||||||
|
else:
|
||||||
|
logger.info(" Tags autorisés: Tous (admin)")
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.info(" Tags: Erreur format")
|
||||||
|
else:
|
||||||
|
logger.info(" Tags autorisés: Tous (admin)")
|
||||||
|
|
||||||
|
logger.info("")
|
||||||
|
|
||||||
|
|
||||||
async def delete_swagger_user(username: str):
|
async def delete_swagger_user(username: str, tags: Optional[List[str]] = None):
|
||||||
"""Supprimer un utilisateur Swagger"""
|
"""Supprimer un utilisateur Swagger"""
|
||||||
async with async_session_factory() as session:
|
async with async_session_factory() as session:
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
|
|
@ -117,7 +137,7 @@ async def delete_swagger_user(username: str):
|
||||||
|
|
||||||
await session.delete(user)
|
await session.delete(user)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
logger.info(f"🗑️ Utilisateur Swagger supprimé: {username}")
|
logger.info("🗑️ Utilisateur Swagger supprimé: {}".format(username))
|
||||||
|
|
||||||
|
|
||||||
async def create_api_key(
|
async def create_api_key(
|
||||||
|
|
@ -143,21 +163,23 @@ async def create_api_key(
|
||||||
logger.info("=" * 70)
|
logger.info("=" * 70)
|
||||||
logger.info("🔑 Clé API créée avec succès")
|
logger.info("🔑 Clé API créée avec succès")
|
||||||
logger.info("=" * 70)
|
logger.info("=" * 70)
|
||||||
logger.info(f" ID: {api_key_obj.id}")
|
logger.info(" ID: {}".format(api_key_obj.id))
|
||||||
logger.info(f" Nom: {api_key_obj.name}")
|
logger.info(" Nom: {}".format(api_key_obj.name))
|
||||||
logger.info(f" Clé: {api_key_plain}")
|
logger.info(" Clé: {}".format(api_key_plain))
|
||||||
logger.info(f" Préfixe: {api_key_obj.key_prefix}")
|
logger.info(" Préfixe: {}".format(api_key_obj.key_prefix))
|
||||||
logger.info(f" Rate limit: {api_key_obj.rate_limit_per_minute} req/min")
|
logger.info(
|
||||||
logger.info(f" Expire le: {api_key_obj.expires_at}")
|
" Rate limit: {} req/min".format(api_key_obj.rate_limit_per_minute)
|
||||||
|
)
|
||||||
|
logger.info(" Expire le: {}".format(api_key_obj.expires_at))
|
||||||
|
|
||||||
if api_key_obj.allowed_endpoints:
|
if api_key_obj.allowed_endpoints:
|
||||||
import json
|
import json
|
||||||
|
|
||||||
try:
|
try:
|
||||||
endpoints_list = json.loads(api_key_obj.allowed_endpoints)
|
endpoints_list = json.loads(api_key_obj.allowed_endpoints)
|
||||||
logger.info(f" Endpoints: {', '.join(endpoints_list)}")
|
logger.info(" Endpoints: {}".format(", ".join(endpoints_list)))
|
||||||
except:
|
except Exception:
|
||||||
logger.info(f" Endpoints: {api_key_obj.allowed_endpoints}")
|
logger.info(" Endpoints: {}".format(api_key_obj.allowed_endpoints))
|
||||||
else:
|
else:
|
||||||
logger.info(" Endpoints: Tous (aucune restriction)")
|
logger.info(" Endpoints: Tous (aucune restriction)")
|
||||||
|
|
||||||
|
|
@ -200,7 +222,7 @@ async def list_api_keys():
|
||||||
if len(endpoints) > 4:
|
if len(endpoints) > 4:
|
||||||
display += f"... (+{len(endpoints) - 4})"
|
display += f"... (+{len(endpoints) - 4})"
|
||||||
logger.info(f" Endpoints: {display}")
|
logger.info(f" Endpoints: {display}")
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
logger.info(" Endpoints: Tous")
|
logger.info(" Endpoints: Tous")
|
||||||
|
|
@ -250,7 +272,7 @@ async def verify_api_key(api_key: str):
|
||||||
try:
|
try:
|
||||||
endpoints = json.loads(key.allowed_endpoints)
|
endpoints = json.loads(key.allowed_endpoints)
|
||||||
logger.info(f" Endpoints autorisés: {endpoints}")
|
logger.info(f" Endpoints autorisés: {endpoints}")
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
logger.info(" Endpoints autorisés: Tous")
|
logger.info(" Endpoints autorisés: Tous")
|
||||||
|
|
@ -269,6 +291,8 @@ Exemples:
|
||||||
python scripts/manage_security.py apikey create "SDK-ReadOnly" --endpoints "/clients" "/clients/*" "/devis" "/devis/*"
|
python scripts/manage_security.py apikey create "SDK-ReadOnly" --endpoints "/clients" "/clients/*" "/devis" "/devis/*"
|
||||||
python scripts/manage_security.py apikey list
|
python scripts/manage_security.py apikey list
|
||||||
python scripts/manage_security.py apikey verify sdk_live_xxxxx
|
python scripts/manage_security.py apikey verify sdk_live_xxxxx
|
||||||
|
python scripts/manage_security.py swagger add client_user Secret123 --full-name "Client Tech IT" --tags Authentication Clients Devis Factures
|
||||||
|
python scripts/manage_security.py swagger add admin_user AdminPass --tags # vide = tout voir
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
subparsers = parser.add_subparsers(dest="command", help="Commandes")
|
subparsers = parser.add_subparsers(dest="command", help="Commandes")
|
||||||
|
|
@ -280,6 +304,11 @@ Exemples:
|
||||||
add_p.add_argument("username", help="Nom d'utilisateur")
|
add_p.add_argument("username", help="Nom d'utilisateur")
|
||||||
add_p.add_argument("password", help="Mot de passe")
|
add_p.add_argument("password", help="Mot de passe")
|
||||||
add_p.add_argument("--full-name", help="Nom complet")
|
add_p.add_argument("--full-name", help="Nom complet")
|
||||||
|
add_p.add_argument(
|
||||||
|
"--tags",
|
||||||
|
nargs="*",
|
||||||
|
help="Tags OpenAPI autorisés (ex. Clients Devis Authentication)",
|
||||||
|
)
|
||||||
|
|
||||||
swagger_sub.add_parser("list", help="Lister utilisateurs")
|
swagger_sub.add_parser("list", help="Lister utilisateurs")
|
||||||
|
|
||||||
|
|
@ -312,11 +341,13 @@ Exemples:
|
||||||
|
|
||||||
if args.command == "swagger":
|
if args.command == "swagger":
|
||||||
if args.swagger_command == "add":
|
if args.swagger_command == "add":
|
||||||
await add_swagger_user(args.username, args.password, args.full_name)
|
await add_swagger_user(
|
||||||
|
args.username, args.password, args.full_name, args.tags
|
||||||
|
)
|
||||||
elif args.swagger_command == "list":
|
elif args.swagger_command == "list":
|
||||||
await list_swagger_users()
|
await list_swagger_users()
|
||||||
elif args.swagger_command == "delete":
|
elif args.swagger_command == "delete":
|
||||||
await delete_swagger_user(args.username)
|
await delete_swagger_user(args.username, args.tags)
|
||||||
else:
|
else:
|
||||||
swagger_parser.print_help()
|
swagger_parser.print_help()
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue