diff --git a/api.py b/api.py index e2dc23d..7eb5984 100644 --- a/api.py +++ b/api.py @@ -1,5 +1,5 @@ from fastapi import FastAPI, HTTPException, Path, Query, Depends, status, Body -from fastapi.middleware.cors import CORSMiddleware +from fastapi.openapi.utils import get_openapi from fastapi.responses import StreamingResponse, HTMLResponse, Response from fastapi.encoders import jsonable_encoder from pydantic import BaseModel, Field, EmailStr @@ -95,7 +95,11 @@ from utils.generic_functions import ( universign_envoyer, ) + +from middleware.security import SwaggerAuthMiddleware, ApiKeyMiddlewareHTTP from core.dependencies import get_current_user +from config.cors_config import setup_cors +from routes.api_keys import router as api_keys_router if os.path.exists("/app"): LOGS_DIR = FilePath("/app/logs") @@ -133,7 +137,6 @@ async def lifespan(app: FastAPI): api_url=settings.universign_api_url, api_key=settings.universign_api_key ) - # Configuration du service avec les dépendances sync_service.configure( sage_client=sage_client, email_queue=email_queue, settings=settings ) @@ -163,15 +166,55 @@ app = FastAPI( openapi_tags=TAGS_METADATA, ) -app.add_middleware( +""" app.add_middleware( CORSMiddleware, allow_origins=settings.cors_origins, allow_methods=["GET", "POST", "PUT", "DELETE"], allow_headers=["*"], allow_credentials=True, -) +) """ +def custom_openapi(): + if app.openapi_schema: + return app.openapi_schema + + openapi_schema = get_openapi( + title=app.title, + version=app.version, + description=app.description, + routes=app.routes, + ) + + openapi_schema["components"]["securitySchemes"] = { + "HTTPBearer": { + "type": "http", + "scheme": "bearer", + "bearerFormat": "JWT", + "description": "Authentification JWT pour utilisateurs (POST /auth/login)", + }, + "ApiKeyAuth": { + "type": "apiKey", + "in": "header", + "name": "X-API-Key", + "description": "Clé API pour intégrations externes (format: sdk_live_xxx)", + }, + } + + openapi_schema["security"] = [{"HTTPBearer": []}, {"ApiKeyAuth": []}] + + app.openapi_schema = openapi_schema + return app.openapi_schema + + +app.openapi = custom_openapi + + +setup_cors(app, mode="open") +app.add_middleware(SwaggerAuthMiddleware) +app.add_middleware(ApiKeyMiddlewareHTTP) + +app.include_router(api_keys_router) app.include_router(auth_router) app.include_router(sage_gateway_router) app.include_router(universign_router) @@ -181,6 +224,7 @@ app.include_router(entreprises_router) @app.get("/clients", response_model=List[ClientDetails], tags=["Clients"]) async def obtenir_clients( query: Optional[str] = Query(None), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -194,6 +238,7 @@ async def obtenir_clients( @app.get("/clients/{code}", response_model=ClientDetails, tags=["Clients"]) async def lire_client_detail( code: str, + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -216,6 +261,7 @@ async def modifier_client( code: str, client_update: ClientUpdate, session: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -241,6 +287,7 @@ async def modifier_client( async def ajouter_client( client: ClientCreate, session: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -265,6 +312,7 @@ async def ajouter_client( @app.get("/articles", response_model=List[Article], tags=["Articles"]) async def rechercher_articles( query: Optional[str] = Query(None), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -283,6 +331,7 @@ async def rechercher_articles( ) async def creer_article( article: ArticleCreate, + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -323,6 +372,7 @@ async def creer_article( async def modifier_article( reference: str = Path(..., description="Référence de l'article à modifier"), article: ArticleUpdate = Body(...), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -366,6 +416,7 @@ async def modifier_article( @app.get("/articles/{reference}", response_model=Article, tags=["Articles"]) async def lire_article( reference: str = Path(..., description="Référence de l'article"), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -395,6 +446,7 @@ async def lire_article( @app.post("/devis", response_model=Devis, status_code=201, tags=["Devis"]) async def creer_devis( devis: DevisRequest, + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -434,6 +486,7 @@ async def modifier_devis( id: str, devis_update: DevisUpdate, session: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -479,6 +532,7 @@ async def modifier_devis( async def creer_commande( commande: CommandeCreate, session: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -528,6 +582,7 @@ async def modifier_commande( id: str, commande_update: CommandeUpdate, session: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -576,6 +631,7 @@ async def lister_devis( inclure_lignes: bool = Query( True, description="Inclure les lignes de chaque devis" ), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -592,6 +648,7 @@ async def lister_devis( @app.get("/devis/{id}", tags=["Devis"]) async def lire_devis( id: str, + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -612,6 +669,7 @@ async def lire_devis( @app.get("/devis/{id}/pdf", tags=["Devis"]) async def telecharger_devis_pdf( id: str, + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -634,6 +692,7 @@ async def telecharger_document_pdf( description="Type de document (0=Devis, 10=Commande, 30=Livraison, 60=Facture, 50=Avoir)", ), numero: str = Path(..., description="Numéro du document"), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -690,6 +749,7 @@ async def envoyer_devis_email( id: str, request: EmailEnvoi, session: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -745,6 +805,7 @@ async def changer_statut_document( nouveau_statut: int = Query( ..., ge=0, le=6, description="0=Saisi, 1=Confirmé, 2=Accepté" ), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): document_type_sql = None @@ -861,6 +922,7 @@ async def changer_statut_document( @app.get("/commandes/{id}", tags=["Commandes"]) async def lire_commande( id: str, + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -879,6 +941,7 @@ async def lire_commande( async def lister_commandes( limit: int = Query(100, le=1000), statut: Optional[int] = Query(None), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -894,6 +957,7 @@ async def lister_commandes( async def devis_vers_commande( id: str, session: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -938,6 +1002,7 @@ async def devis_vers_commande( async def commande_vers_facture( id: str, session: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -1039,6 +1104,7 @@ async def envoyer_emails_lot( async def valider_remise( client_id: str = Query(..., min_length=1), remise_pourcentage: float = Query(0.0, ge=0, le=100), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -1072,6 +1138,7 @@ async def relancer_devis_signature( id: str, relance: RelanceDevis, session: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -1138,6 +1205,7 @@ class ContactClientResponse(BaseModel): @app.get("/devis/{id}/contact", response_model=ContactClientResponse, tags=["Devis"]) async def recuperer_contact_devis( id: str, + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -1165,6 +1233,7 @@ async def recuperer_contact_devis( async def lister_factures( limit: int = Query(100, le=1000), statut: Optional[int] = Query(None), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -1179,6 +1248,7 @@ async def lister_factures( @app.get("/factures/{numero}", tags=["Factures"]) async def lire_facture_detail( numero: str, + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -1205,6 +1275,7 @@ class RelanceFacture(BaseModel): async def creer_facture( facture: FactureCreate, session: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -1254,6 +1325,7 @@ async def modifier_facture( id: str, facture_update: FactureUpdate, session: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -1323,6 +1395,7 @@ async def relancer_facture( id: str, relance: RelanceFacture, session: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -1393,6 +1466,7 @@ async def journal_emails( destinataire: Optional[str] = Query(None), limit: int = Query(100, le=1000), session: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): query = select(EmailLog) @@ -1428,6 +1502,7 @@ async def journal_emails( async def exporter_logs_csv( statut: Optional[StatutEmail] = Query(None), session: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): query = select(EmailLog) @@ -1584,6 +1659,7 @@ async def supprimer_template( @app.post("/templates/emails/preview", tags=["Emails"]) async def previsualiser_email( preview: TemplatePreview, + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): if preview.template_id not in templates_email_db: @@ -1622,6 +1698,7 @@ async def previsualiser_email( @app.get("/prospects", tags=["Prospects"]) async def rechercher_prospects( query: Optional[str] = Query(None), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -1635,6 +1712,7 @@ async def rechercher_prospects( @app.get("/prospects/{code}", tags=["Prospects"]) async def lire_prospect( code: str, + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -1654,6 +1732,7 @@ async def lire_prospect( ) async def rechercher_fournisseurs( query: Optional[str] = Query(None), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -1675,6 +1754,7 @@ async def rechercher_fournisseurs( async def ajouter_fournisseur( fournisseur: FournisseurCreate, session: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -1704,6 +1784,7 @@ async def modifier_fournisseur( code: str, fournisseur_update: FournisseurUpdate, session: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -1726,6 +1807,7 @@ async def modifier_fournisseur( @app.get("/fournisseurs/{code}", tags=["Fournisseurs"]) async def lire_fournisseur( code: str, + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -1744,6 +1826,7 @@ async def lire_fournisseur( async def lister_avoirs( limit: int = Query(100, le=1000), statut: Optional[int] = Query(None), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -1757,6 +1840,7 @@ async def lister_avoirs( @app.get("/avoirs/{numero}", tags=["Avoirs"]) async def lire_avoir( numero: str, + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -1775,6 +1859,7 @@ async def lire_avoir( async def creer_avoir( avoir: AvoirCreate, session: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -1822,6 +1907,7 @@ async def modifier_avoir( id: str, avoir_update: AvoirUpdate, session: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -1867,6 +1953,7 @@ async def modifier_avoir( async def lister_livraisons( limit: int = Query(100, le=1000), statut: Optional[int] = Query(None), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -1880,6 +1967,7 @@ async def lister_livraisons( @app.get("/livraisons/{numero}", tags=["Livraisons"]) async def lire_livraison( numero: str, + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -1898,6 +1986,7 @@ async def lire_livraison( async def creer_livraison( livraison: LivraisonCreate, session: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -1951,6 +2040,7 @@ async def modifier_livraison( id: str, livraison_update: LivraisonUpdate, session: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -1996,6 +2086,7 @@ async def modifier_livraison( async def livraison_vers_facture( id: str, session: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -2039,6 +2130,7 @@ async def livraison_vers_facture( async def devis_vers_facture_direct( id: str, session: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -2099,6 +2191,7 @@ async def devis_vers_facture_direct( async def commande_vers_livraison( id: str, session: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -2170,6 +2263,7 @@ async def commande_vers_livraison( ) async def lister_familles( filtre: Optional[str] = Query(None, description="Filtre sur code ou intitulé"), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -2195,6 +2289,7 @@ async def lister_familles( ) async def lire_famille( code: str = Path(..., description="Code de la famille (ex: ZDIVERS)"), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -2230,6 +2325,7 @@ async def lire_famille( ) async def creer_famille( famille: FamilleCreate, + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -2273,6 +2369,7 @@ async def creer_famille( ) async def creer_entree_stock( entree: EntreeStock, + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -2309,6 +2406,7 @@ async def creer_entree_stock( ) async def creer_sortie_stock( sortie: SortieStock, + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -2344,6 +2442,7 @@ async def creer_sortie_stock( ) async def lire_mouvement_stock( numero: str = Path(..., description="Numéro du mouvement (ex: ME00123 ou MS00124)"), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -2376,6 +2475,7 @@ async def lire_mouvement_stock( summary="Statistiques sur les familles", ) async def statistiques_familles( + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -2484,6 +2584,7 @@ async def statistiques_utilisateurs(session: AsyncSession = Depends(get_session) async def creer_contact( numero: str, contact: ContactCreate, + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -2516,6 +2617,7 @@ async def creer_contact( @app.get("/tiers/{numero}/contacts", response_model=List[Contact], tags=["Contacts"]) async def lister_contacts( numero: str, + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -2534,6 +2636,7 @@ async def lister_contacts( async def obtenir_contact( numero: str, contact_numero: int, + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -2559,6 +2662,7 @@ async def modifier_contact( numero: str, contact_numero: int, contact: ContactUpdate, + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -2590,6 +2694,7 @@ async def modifier_contact( async def supprimer_contact( numero: str, contact_numero: int, + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -2604,6 +2709,7 @@ async def supprimer_contact( async def definir_contact_defaut( numero: str, contact_numero: int, + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -2625,6 +2731,7 @@ async def obtenir_tiers( description="Filtre par type: 0/client, 1/fournisseur, 2/prospect, 3/all ou strings", ), query: Optional[str] = Query(None, description="Recherche sur code ou intitulé"), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -2639,6 +2746,7 @@ async def obtenir_tiers( @app.get("/tiers/{code}", response_model=TiersDetails, tags=["Tiers"]) async def lire_tiers_detail( code: str, + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -2666,7 +2774,6 @@ async def get_current_sage_config( } -# Routes Collaborateurs @app.get( "/collaborateurs", response_model=List[CollaborateurDetails], @@ -2677,6 +2784,7 @@ async def lister_collaborateurs( actifs_seulement: bool = Query( True, description="Exclure les collaborateurs en sommeil" ), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): """Liste tous les collaborateurs""" @@ -2695,6 +2803,7 @@ async def lister_collaborateurs( ) async def lire_collaborateur_detail( numero: int, + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): """Lit un collaborateur par son numéro""" @@ -2721,6 +2830,7 @@ async def lire_collaborateur_detail( ) async def creer_collaborateur( collaborateur: CollaborateurCreate, + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): """Crée un nouveau collaborateur""" @@ -2747,6 +2857,7 @@ async def creer_collaborateur( async def modifier_collaborateur( numero: int, collaborateur: CollaborateurUpdate, + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): """Modifie un collaborateur existant""" @@ -2769,6 +2880,7 @@ async def modifier_collaborateur( @app.get("/societe/info", response_model=SocieteInfo, tags=["Société"]) async def obtenir_informations_societe( + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -2788,6 +2900,7 @@ async def obtenir_informations_societe( @app.get("/societe/logo", tags=["Société"]) async def obtenir_logo_societe( + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): """Retourne le logo en tant qu'image directe""" @@ -2812,6 +2925,7 @@ async def obtenir_logo_societe( @app.get("/societe/preview", response_class=HTMLResponse, tags=["Société"]) async def preview_societe( + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): """Page HTML pour visualiser les infos société avec logo""" @@ -2885,6 +2999,7 @@ async def preview_societe( async def valider_facture( numero_facture: str, _: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -2908,6 +3023,7 @@ async def valider_facture( async def devalider_facture( numero_facture: str, _: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -2931,6 +3047,7 @@ async def devalider_facture( async def get_statut_validation_facture( numero_facture: str, _: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -2951,6 +3068,7 @@ async def regler_facture( numero_facture: str, reglement: ReglementFactureCreate, session: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -2994,6 +3112,7 @@ async def regler_facture( async def regler_factures_multiple( reglement: ReglementMultipleCreate, session: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -3032,6 +3151,7 @@ async def regler_factures_multiple( async def get_reglements_facture( numero_facture: str, session: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -3056,6 +3176,7 @@ async def get_reglements_client( date_fin: Optional[datetime] = Query(None, description="Date fin"), inclure_soldees: bool = Query(True, description="Inclure les factures soldées"), session: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -3080,6 +3201,7 @@ async def get_reglements_client( @app.get("/journaux/banque", tags=["Règlements"]) async def get_journaux_banque( + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): try: @@ -3092,6 +3214,7 @@ async def get_journaux_banque( @app.get("/reglements/modes", tags=["Référentiels"]) async def get_modes_reglement( + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): """Liste des modes de règlement disponibles dans Sage""" @@ -3105,6 +3228,7 @@ async def get_modes_reglement( @app.get("/devises", tags=["Référentiels"]) async def get_devises( + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): """Liste des devises disponibles dans Sage""" @@ -3118,6 +3242,7 @@ async def get_devises( @app.get("/journaux/tresorerie", tags=["Référentiels"]) async def get_journaux_tresorerie( + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): """Liste des journaux de trésorerie (banque + caisse)""" @@ -3136,6 +3261,7 @@ async def get_comptes_generaux( None, description="client | fournisseur | banque | caisse | tva | produit | charge", ), + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): """Liste des comptes généraux""" @@ -3149,6 +3275,7 @@ async def get_comptes_generaux( @app.get("/tva/taux", tags=["Référentiels"]) async def get_tva_taux( + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): """Liste des taux de TVA""" @@ -3162,6 +3289,7 @@ async def get_tva_taux( @app.get("/parametres/encaissement", tags=["Référentiels"]) async def get_parametres_encaissement( + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): """Paramètres TVA sur encaissement""" @@ -3208,6 +3336,7 @@ async def get_reglement_detail(rg_no): @app.get("/health", tags=["System"]) async def health_check( + user: User = Depends(get_current_user), sage: SageGatewayClient = Depends(get_sage_client_for_user), ): gateway_health = sage.health() @@ -3229,9 +3358,23 @@ async def health_check( async def root(): return { "api": "Sage 100c Dataven - VPS Linux", - "version": "2.0.0", - "documentation": "/docs", + "version": "3.0.0", + "documentation": "/docs (authentification requise)", "health": "/health", + "authentication": { + "methods": [ + { + "type": "JWT", + "header": "Authorization: Bearer ", + "endpoint": "/api/auth/login", + }, + { + "type": "API Key", + "header": "X-API-Key: sdk_live_xxx", + "endpoint": "/api/api-keys", + }, + ] + }, } diff --git a/config/config.py b/config/config.py index 63bf99b..d60c4d8 100644 --- a/config/config.py +++ b/config/config.py @@ -7,7 +7,6 @@ class Settings(BaseSettings): env_file=".env", env_file_encoding="utf-8", case_sensitive=False, extra="ignore" ) - # === JWT & Auth === jwt_secret: str jwt_algorithm: str access_token_expire_minutes: int @@ -21,15 +20,12 @@ class Settings(BaseSettings): SAGE_TYPE_BON_AVOIR: int = 50 SAGE_TYPE_FACTURE: int = 60 - # === Sage Gateway (Windows) === sage_gateway_url: str sage_gateway_token: str frontend_url: str - # === Base de données === database_url: str = "sqlite+aiosqlite:///./data/sage_dataven.db" - # === SMTP === smtp_host: str smtp_port: int = 587 smtp_user: str @@ -37,21 +33,17 @@ class Settings(BaseSettings): smtp_from: str smtp_use_tls: bool = True - # === Universign === universign_api_key: str universign_api_url: str - # === API === api_host: str api_port: int api_reload: bool = False - # === Email Queue === max_email_workers: int = 3 max_retry_attempts: int = 3 retry_delay_seconds: int = 3 - # === CORS === cors_origins: List[str] = ["*"] diff --git a/config/cors_config.py b/config/cors_config.py new file mode 100644 index 0000000..0f3a4d2 --- /dev/null +++ b/config/cors_config.py @@ -0,0 +1,125 @@ +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from typing import List +import os +import logging + +logger = logging.getLogger(__name__) + + +def configure_cors_open(app: FastAPI): + app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=False, + allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"], + allow_headers=["*"], + expose_headers=["X-RateLimit-Limit", "X-RateLimit-Remaining"], + max_age=3600, + ) + + logger.info(" CORS configuré: Mode OUVERT (sécurisé par API Keys)") + logger.info(" - Origins: * (toutes)") + logger.info(" - Headers: * (dont X-API-Key)") + logger.info(" - Credentials: False") + + +def configure_cors_whitelist(app: FastAPI): + allowed_origins_str = os.getenv("CORS_ALLOWED_ORIGINS", "") + + if allowed_origins_str: + allowed_origins = [ + origin.strip() + for origin in allowed_origins_str.split(",") + if origin.strip() + ] + else: + allowed_origins = ["*"] + + app.add_middleware( + CORSMiddleware, + allow_origins=allowed_origins, + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"], + allow_headers=["Content-Type", "Authorization", "X-API-Key"], + expose_headers=["X-RateLimit-Limit", "X-RateLimit-Remaining"], + max_age=3600, + ) + + logger.info(" CORS configuré: Mode WHITELIST") + logger.info(f" - Origins autorisées: {len(allowed_origins)}") + for origin in allowed_origins: + logger.info(f" • {origin}") + + +def configure_cors_regex(app: FastAPI): + origin_regex = r"*" + + app.add_middleware( + CORSMiddleware, + allow_origin_regex=origin_regex, + allow_credentials=True, + allow_methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS"], + allow_headers=["Content-Type", "Authorization", "X-API-Key"], + expose_headers=["X-RateLimit-Limit", "X-RateLimit-Remaining"], + max_age=3600, + ) + + logger.info(" CORS configuré: Mode REGEX") + logger.info(f" - Pattern: {origin_regex}") + + +def configure_cors_hybrid(app: FastAPI): + from starlette.middleware.base import BaseHTTPMiddleware + + class HybridCORSMiddleware(BaseHTTPMiddleware): + def __init__(self, app, known_origins: List[str]): + super().__init__(app) + self.known_origins = set(known_origins) + + async def dispatch(self, request, call_next): + origin = request.headers.get("origin") + + if origin in self.known_origins: + response = await call_next(request) + response.headers["Access-Control-Allow-Origin"] = origin + response.headers["Access-Control-Allow-Credentials"] = "true" + response.headers["Access-Control-Allow-Methods"] = ( + "GET, POST, PUT, DELETE, PATCH, OPTIONS" + ) + response.headers["Access-Control-Allow-Headers"] = ( + "Content-Type, Authorization, X-API-Key" + ) + return response + + response = await call_next(request) + response.headers["Access-Control-Allow-Origin"] = "*" + response.headers["Access-Control-Allow-Methods"] = ( + "GET, POST, PUT, DELETE, PATCH, OPTIONS" + ) + response.headers["Access-Control-Allow-Headers"] = "*" + return response + + known_origins = ["*"] + + app.add_middleware(HybridCORSMiddleware, known_origins=known_origins) + + logger.info(" CORS configuré: Mode HYBRIDE") + logger.info(f" - Whitelist: {len(known_origins)} domaines") + logger.info(" - Fallback: * (ouvert)") + + +def setup_cors(app: FastAPI, mode: str = "open"): + if mode == "open": + configure_cors_open(app) + elif mode == "whitelist": + configure_cors_whitelist(app) + elif mode == "regex": + configure_cors_regex(app) + elif mode == "hybrid": + configure_cors_hybrid(app) + else: + logger.warning( + f" Mode CORS inconnu: {mode}. Utilisation de 'open' par défaut." + ) + configure_cors_open(app) diff --git a/core/dependencies.py b/core/dependencies.py index 039081c..c1468dd 100644 --- a/core/dependencies.py +++ b/core/dependencies.py @@ -1,94 +1,118 @@ -from fastapi import Depends, HTTPException, status +from fastapi import Depends, HTTPException, status, Request from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select +from typing import Optional +from jwt.exceptions import InvalidTokenError + from database import get_session, User from security.auth import decode_token -from typing import Optional -from datetime import datetime -security = HTTPBearer() +security = HTTPBearer(auto_error=False) -async def get_current_user( - credentials: HTTPAuthorizationCredentials = Depends(security), - session: AsyncSession = Depends(get_session), -) -> User: - token = credentials.credentials - - payload = decode_token(token) - if not payload: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Token invalide ou expiré", - headers={"WWW-Authenticate": "Bearer"}, - ) - - if payload.get("type") != "access": - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Type de token incorrect", - headers={"WWW-Authenticate": "Bearer"}, - ) - - user_id: str = payload.get("sub") - if not user_id: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Token malformé", - headers={"WWW-Authenticate": "Bearer"}, - ) - - result = await session.execute(select(User).where(User.id == user_id)) - user = result.scalar_one_or_none() - - if not user: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Utilisateur introuvable", - headers={"WWW-Authenticate": "Bearer"}, - ) - - if not user.is_active: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail="Compte désactivé" - ) - - if not user.is_verified: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Email non vérifié. Consultez votre boîte de réception.", - ) - - if user.locked_until and user.locked_until > datetime.now(): - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Compte temporairement verrouillé suite à trop de tentatives échouées", - ) - - return user - - -async def get_current_user_optional( +async def get_current_user_hybrid( + request: Request, credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), session: AsyncSession = Depends(get_session), -) -> Optional[User]: +) -> User: + api_key_obj = getattr(request.state, "api_key", None) + + if api_key_obj: + if api_key_obj.user_id: + result = await session.execute( + select(User).where(User.id == api_key_obj.user_id) + ) + user = result.scalar_one_or_none() + + if user: + user._is_api_key_user = True + user._api_key_obj = api_key_obj + return user + + virtual_user = User( + id=f"api_key_{api_key_obj.id}", + email=f"api_key_{api_key_obj.id}@virtual.local", + nom=api_key_obj.name, + prenom="API", + hashed_password="", + role="api_client", + is_active=True, + is_verified=True, + ) + + virtual_user._is_api_key_user = True + virtual_user._api_key_obj = api_key_obj + + return virtual_user + if not credentials: - return None + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Authentification requise (JWT ou API Key)", + headers={"WWW-Authenticate": "Bearer"}, + ) + + token = credentials.credentials try: - return await get_current_user(credentials, session) - except HTTPException: - return None + payload = decode_token(token) + user_id: str = payload.get("sub") + + if user_id is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Token invalide: user_id manquant", + headers={"WWW-Authenticate": "Bearer"}, + ) + + result = await session.execute(select(User).where(User.id == user_id)) + user = result.scalar_one_or_none() + + if user is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Utilisateur introuvable", + headers={"WWW-Authenticate": "Bearer"}, + ) + + if not user.is_active: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Utilisateur inactif", + ) + + return user + + except InvalidTokenError as e: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail=f"Token invalide: {str(e)}", + headers={"WWW-Authenticate": "Bearer"}, + ) -def require_role(*allowed_roles: str): - async def role_checker(user: User = Depends(get_current_user)) -> User: +def require_role_hybrid(*allowed_roles: str): + async def role_checker(user: User = Depends(get_current_user_hybrid)) -> User: if user.role not in allowed_roles: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail=f"Accès refusé. Rôles requis: {', '.join(allowed_roles)}", + detail=f"Accès interdit. Rôles autorisés: {', '.join(allowed_roles)}", ) return user return role_checker + + +def is_api_key_user(user: User) -> bool: + """Vérifie si l'utilisateur est authentifié via API Key""" + return getattr(user, "_is_api_key_user", False) + + +def get_api_key_from_user(user: User): + """Récupère l'objet ApiKey depuis un utilisateur (si applicable)""" + return getattr(user, "_api_key_obj", None) + + +get_current_user = get_current_user_hybrid +require_role = require_role_hybrid diff --git a/create_admin.py b/create_admin.py index d3cb786..96197ec 100644 --- a/create_admin.py +++ b/create_admin.py @@ -19,7 +19,6 @@ async def create_admin(): print(" Création d'un compte administrateur") print("=" * 60 + "\n") - # Saisie des informations email = input("Email de l'admin: ").strip().lower() if not email or "@" not in email: print(" Email invalide") @@ -32,7 +31,6 @@ async def create_admin(): print(" Prénom et nom requis") return False - # Mot de passe avec validation while True: password = input( "Mot de passe (min 8 car., 1 maj, 1 min, 1 chiffre, 1 spécial): " @@ -58,7 +56,6 @@ async def create_admin(): print(f"\n Un utilisateur avec l'email {email} existe déjà") return False - # Créer l'admin admin = User( id=str(uuid.uuid4()), email=email, diff --git a/database/db_config.py b/database/db_config.py index bb98f5c..692822c 100644 --- a/database/db_config.py +++ b/database/db_config.py @@ -1,14 +1,14 @@ -import os from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, async_sessionmaker from sqlalchemy.pool import NullPool from sqlalchemy import event, text import logging +from config.config import settings from database.models.generic_model import Base logger = logging.getLogger(__name__) -DATABASE_URL = os.getenv("DATABASE_URL") +DATABASE_URL = settings.database_url def _configure_sqlite_connection(dbapi_connection, connection_record): diff --git a/database/models/api_key.py b/database/models/api_key.py new file mode 100644 index 0000000..0d246ab --- /dev/null +++ b/database/models/api_key.py @@ -0,0 +1,56 @@ +from sqlalchemy import Column, String, Boolean, DateTime, Integer, Text +from datetime import datetime +import uuid + +from database.models.generic_model import Base + + +class ApiKey(Base): + """Modèle pour les clés API publiques""" + + __tablename__ = "api_keys" + + id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + key_hash = Column(String(64), unique=True, nullable=False, index=True) + key_prefix = Column(String(10), nullable=False) + + name = Column(String(255), nullable=False) + description = Column(Text, nullable=True) + + user_id = Column(String(36), nullable=True) + created_by = Column(String(255), nullable=False) + + is_active = Column(Boolean, default=True, nullable=False) + rate_limit_per_minute = Column(Integer, default=60, nullable=False) + allowed_endpoints = Column(Text, nullable=True) + + total_requests = Column(Integer, default=0, nullable=False) + last_used_at = Column(DateTime, nullable=True) + + created_at = Column(DateTime, default=datetime.now, nullable=False) + expires_at = Column(DateTime, nullable=True) + revoked_at = Column(DateTime, nullable=True) + + def __repr__(self): + return f"" + + +class SwaggerUser(Base): + """Modèle pour les utilisateurs autorisés à accéder au Swagger""" + + __tablename__ = "swagger_users" + + id = Column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + username = Column(String(100), unique=True, nullable=False, index=True) + hashed_password = Column(String(255), nullable=False) + + full_name = Column(String(255), nullable=True) + email = Column(String(255), nullable=True) + + is_active = Column(Boolean, default=True, nullable=False) + + created_at = Column(DateTime, default=datetime.now, nullable=False) + last_login = Column(DateTime, nullable=True) + + def __repr__(self): + return f"" diff --git a/middleware/security.py b/middleware/security.py new file mode 100644 index 0000000..88df4f6 --- /dev/null +++ b/middleware/security.py @@ -0,0 +1,267 @@ +from fastapi import Request, status +from fastapi.responses import JSONResponse +from fastapi.security import HTTPBasic, HTTPBasicCredentials +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.types import ASGIApp +from sqlalchemy import select +from typing import Callable +from datetime import datetime +import logging +import base64 + +logger = logging.getLogger(__name__) + +security = HTTPBasic() + + +class SwaggerAuthMiddleware: + PROTECTED_PATHS = ["/docs", "/redoc", "/openapi.json"] + + def __init__(self, app: ASGIApp): + self.app = app + + async def __call__(self, scope, receive, send): + if scope["type"] != "http": + await self.app(scope, receive, send) + return + + request = Request(scope, receive=receive) + path = request.url.path + + if not any(path.startswith(p) for p in self.PROTECTED_PATHS): + await self.app(scope, receive, send) + return + + auth_header = request.headers.get("Authorization") + + if not auth_header or not auth_header.startswith("Basic "): + response = JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={"detail": "Authentification requise pour la documentation"}, + headers={"WWW-Authenticate": 'Basic realm="Swagger UI"'}, + ) + await response(scope, receive, send) + return + + try: + encoded_credentials = auth_header.split(" ")[1] + decoded_credentials = base64.b64decode(encoded_credentials).decode("utf-8") + username, password = decoded_credentials.split(":", 1) + + credentials = HTTPBasicCredentials(username=username, password=password) + + if not await self._verify_credentials(credentials): + response = JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={"detail": "Identifiants invalides"}, + headers={"WWW-Authenticate": 'Basic realm="Swagger UI"'}, + ) + await response(scope, receive, send) + return + + except Exception as e: + logger.error(f"Erreur parsing auth header: {e}") + response = JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={"detail": "Format d'authentification invalide"}, + headers={"WWW-Authenticate": 'Basic realm="Swagger UI"'}, + ) + await response(scope, receive, send) + return + + await self.app(scope, receive, send) + + async def _verify_credentials(self, credentials: HTTPBasicCredentials) -> bool: + """Vérifie les identifiants dans la base de données""" + from database.db_config import async_session_factory + from database.models.api_key import SwaggerUser + from security.auth import verify_password + + try: + async with async_session_factory() as session: + result = await session.execute( + select(SwaggerUser).where( + SwaggerUser.username == credentials.username + ) + ) + swagger_user = result.scalar_one_or_none() + + if swagger_user and swagger_user.is_active: + if verify_password( + credentials.password, swagger_user.hashed_password + ): + swagger_user.last_login = datetime.now() + await session.commit() + logger.info(f"✓ Accès Swagger autorisé: {credentials.username}") + return True + + logger.warning(f"✗ Accès Swagger refusé: {credentials.username}") + return False + + except Exception as e: + logger.error(f"Erreur vérification credentials: {e}") + return False + + +class ApiKeyMiddlewareHTTP(BaseHTTPMiddleware): + EXCLUDED_PATHS = [ + "/docs", + "/redoc", + "/openapi.json", + "/", + "/health", + "/auth", + "/api-keys/verify", + "/universign/webhook", + ] + + def _is_excluded_path(self, path: str) -> bool: + """Vérifie si le chemin est exclu de l'authentification""" + if path == "/": + return True + + for excluded in self.EXCLUDED_PATHS: + if excluded == "/": + continue + if path == excluded or path.startswith(excluded + "/"): + return True + + return False + + async def dispatch(self, request: Request, call_next: Callable): + path = request.url.path + method = request.method + + if self._is_excluded_path(path): + return await call_next(request) + + auth_header = request.headers.get("Authorization") + api_key_header = request.headers.get("X-API-Key") + + if api_key_header: + logger.debug(f"🔑 API Key détectée pour {method} {path}") + return await self._handle_api_key_auth( + request, api_key_header, path, method, call_next + ) + + if auth_header and auth_header.startswith("Bearer "): + token = auth_header.split(" ")[1] + + if token.startswith("sdk_live_"): + logger.warning( + " API Key envoyée dans Authorization au lieu de X-API-Key" + ) + return await self._handle_api_key_auth( + request, token, path, method, call_next + ) + + logger.debug(f"🎫 JWT détecté pour {method} {path} → délégation à FastAPI") + request.state.authenticated_via = "jwt" + return await call_next(request) + + logger.debug(f" Aucune auth pour {method} {path} → délégation à FastAPI") + return await call_next(request) + + async def _handle_api_key_auth( + self, + request: Request, + api_key: str, + path: str, + method: str, + call_next: Callable, + ): + """Gère l'authentification par API Key avec vérification STRICTE""" + try: + from database.db_config import async_session_factory + from services.api_key import ApiKeyService + + async with async_session_factory() as session: + service = ApiKeyService(session) + + api_key_obj = await service.verify_api_key(api_key) + + if not api_key_obj: + logger.warning(f" Clé API invalide: {method} {path}") + return JSONResponse( + status_code=status.HTTP_401_UNAUTHORIZED, + content={ + "detail": "Clé API invalide ou expirée", + "hint": "Vérifiez votre clé X-API-Key", + }, + ) + + is_allowed, rate_info = await service.check_rate_limit(api_key_obj) + if not is_allowed: + logger.warning(f" Rate limit: {api_key_obj.name}") + return JSONResponse( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + content={"detail": "Rate limit dépassé"}, + headers={ + "X-RateLimit-Limit": str(rate_info["limit"]), + "X-RateLimit-Remaining": "0", + }, + ) + + has_access = await service.check_endpoint_access(api_key_obj, path) + + if not has_access: + import json + + allowed = ( + json.loads(api_key_obj.allowed_endpoints) + if api_key_obj.allowed_endpoints + else ["Tous"] + ) + + logger.warning( + f" ACCÈS REFUSÉ: {api_key_obj.name}\n" + f" Endpoint demandé: {path}\n" + f" Endpoints autorisés: {allowed}" + ) + + return JSONResponse( + status_code=status.HTTP_403_FORBIDDEN, + content={ + "detail": "Accès non autorisé à cet endpoint", + "endpoint_requested": path, + "api_key_name": api_key_obj.name, + "allowed_endpoints": allowed, + "hint": "Cette clé API n'a pas accès à cet endpoint. Contactez l'administrateur.", + }, + ) + + request.state.api_key = api_key_obj + request.state.authenticated_via = "api_key" + + logger.info(f" ACCÈS AUTORISÉ: {api_key_obj.name} → {method} {path}") + + return await call_next(request) + + except Exception as e: + logger.error(f" Erreur validation API Key: {e}", exc_info=True) + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={"detail": f"Erreur interne: {str(e)}"}, + ) + + +ApiKeyMiddleware = ApiKeyMiddlewareHTTP + + +def get_api_key_from_request(request: Request): + """Récupère l'objet ApiKey depuis la requête si présent""" + return getattr(request.state, "api_key", None) + + +def get_auth_method(request: Request) -> str: + """Retourne la méthode d'authentification utilisée""" + return getattr(request.state, "authenticated_via", "none") + + +__all__ = [ + "SwaggerAuthMiddleware", + "ApiKeyMiddlewareHTTP", + "ApiKeyMiddleware", + "get_api_key_from_request", + "get_auth_method", +] diff --git a/routes/api_keys.py b/routes/api_keys.py new file mode 100644 index 0000000..1e753de --- /dev/null +++ b/routes/api_keys.py @@ -0,0 +1,154 @@ +from fastapi import APIRouter, Depends, HTTPException, status, Query +from sqlalchemy.ext.asyncio import AsyncSession +import logging + +from database import get_session, User +from core.dependencies import get_current_user, require_role +from services.api_key import ApiKeyService, api_key_to_response +from schemas.api_key import ( + ApiKeyCreate, + ApiKeyCreatedResponse, + ApiKeyResponse, + ApiKeyList, +) + +logger = logging.getLogger(__name__) +router = APIRouter(prefix="/api-keys", tags=["API Keys Management"]) + + +@router.post( + "", + response_model=ApiKeyCreatedResponse, + status_code=status.HTTP_201_CREATED, + dependencies=[Depends(require_role("admin", "super_admin"))], +) +async def create_api_key( + data: ApiKeyCreate, + session: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), +): + service = ApiKeyService(session) + + api_key_obj, api_key_plain = await service.create_api_key( + name=data.name, + description=data.description, + created_by=user.email, + user_id=user.id, + expires_in_days=data.expires_in_days, + rate_limit_per_minute=data.rate_limit_per_minute, + allowed_endpoints=data.allowed_endpoints, + ) + + logger.info(f" Clé API créée par {user.email}: {data.name}") + + response_data = api_key_to_response(api_key_obj) + response_data["api_key"] = api_key_plain + + return ApiKeyCreatedResponse(**response_data) + + +@router.get("", response_model=ApiKeyList) +async def list_api_keys( + include_revoked: bool = Query(False, description="Inclure les clés révoquées"), + session: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), +): + service = ApiKeyService(session) + + user_id = None if user.role in ["admin", "super_admin"] else user.id + + keys = await service.list_api_keys(include_revoked=include_revoked, user_id=user_id) + + items = [ApiKeyResponse(**api_key_to_response(k)) for k in keys] + + return ApiKeyList(total=len(items), items=items) + + +@router.get("/{key_id}", response_model=ApiKeyResponse) +async def get_api_key( + key_id: str, + session: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), +): + """Récupérer une clé API par son ID""" + service = ApiKeyService(session) + + api_key_obj = await service.get_by_id(key_id) + + if not api_key_obj: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Clé API {key_id} introuvable", + ) + + if user.role not in ["admin", "super_admin"]: + if api_key_obj.user_id != user.id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Accès refusé à cette clé", + ) + + return ApiKeyResponse(**api_key_to_response(api_key_obj)) + + +@router.delete("/{key_id}", status_code=status.HTTP_200_OK) +async def revoke_api_key( + key_id: str, + session: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), +): + service = ApiKeyService(session) + + api_key_obj = await service.get_by_id(key_id) + + if not api_key_obj: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail=f"Clé API {key_id} introuvable", + ) + + if user.role not in ["admin", "super_admin"]: + if api_key_obj.user_id != user.id: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Accès refusé à cette clé", + ) + + success = await service.revoke_api_key(key_id) + + if not success: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail="Erreur lors de la révocation", + ) + + logger.info(f" Clé API révoquée par {user.email}: {api_key_obj.name}") + + return { + "success": True, + "message": f"Clé API '{api_key_obj.name}' révoquée avec succès", + } + + +@router.post("/verify", status_code=status.HTTP_200_OK) +async def verify_api_key_endpoint( + api_key: str = Query(..., description="Clé API à vérifier"), + session: AsyncSession = Depends(get_session), +): + service = ApiKeyService(session) + + api_key_obj = await service.verify_api_key(api_key) + + if not api_key_obj: + return { + "valid": False, + "message": "Clé API invalide, expirée ou révoquée", + } + + return { + "valid": True, + "message": "Clé API valide", + "key_name": api_key_obj.name, + "rate_limit": api_key_obj.rate_limit_per_minute, + "expires_at": api_key_obj.expires_at, + } diff --git a/routes/auth.py b/routes/auth.py index d6e6761..d401d86 100644 --- a/routes/auth.py +++ b/routes/auth.py @@ -510,7 +510,7 @@ async def logout( token_record.revoked_at = datetime.now() await session.commit() - logger.info(f"👋 Déconnexion: {user.email}") + logger.info(f" Déconnexion: {user.email}") return {"success": True, "message": "Déconnexion réussie"} diff --git a/routes/enterprise.py b/routes/enterprise.py index 2ed18d1..2de1e5f 100644 --- a/routes/enterprise.py +++ b/routes/enterprise.py @@ -22,7 +22,6 @@ async def rechercher_entreprise( try: logger.info(f" Recherche entreprise: '{q}'") - # Appel API api_response = await rechercher_entreprise_api(q, per_page) resultats_api = api_response.get("results", []) diff --git a/routes/universign.py b/routes/universign.py index b8a16c3..bada5aa 100644 --- a/routes/universign.py +++ b/routes/universign.py @@ -35,7 +35,6 @@ logger = logging.getLogger(__name__) router = APIRouter( prefix="/universign", tags=["Universign"], - # dependencies=[Depends(get_current_user)] ) sync_service = UniversignSyncService( @@ -512,7 +511,6 @@ async def webhook_universign( transaction_id = None if payload.get("type", "").startswith("transaction.") and "payload" in payload: - # Le transaction_id est dans payload.object.id nested_object = payload.get("payload", {}).get("object", {}) if nested_object.get("object") == "transaction": transaction_id = nested_object.get("id") diff --git a/sage_client.py b/sage_client.py index 0137512..9ad7b50 100644 --- a/sage_client.py +++ b/sage_client.py @@ -1,4 +1,3 @@ -# sage_client.py import requests from typing import Dict, List, Optional from config.config import settings @@ -468,7 +467,6 @@ class SageGatewayClient: "tva_encaissement": tva_encaissement, } - # Champs optionnels if date_reglement: payload["date_reglement"] = date_reglement if code_journal: diff --git a/schemas/api_key.py b/schemas/api_key.py new file mode 100644 index 0000000..4ec49b6 --- /dev/null +++ b/schemas/api_key.py @@ -0,0 +1,77 @@ +from pydantic import BaseModel, Field +from typing import Optional, List +from datetime import datetime + + +class ApiKeyCreate(BaseModel): + """Schema pour créer une clé API""" + + name: str = Field(..., min_length=3, max_length=255, description="Nom de la clé") + description: Optional[str] = Field(None, description="Description de l'usage") + expires_in_days: Optional[int] = Field( + None, ge=1, le=3650, description="Expiration en jours (max 10 ans)" + ) + rate_limit_per_minute: int = Field( + 60, ge=1, le=1000, description="Limite de requêtes par minute" + ) + allowed_endpoints: Optional[List[str]] = Field( + None, description="Endpoints autorisés ([] = tous, ['/clients*'] = wildcard)" + ) + + +class ApiKeyResponse(BaseModel): + """Schema de réponse pour une clé API""" + + id: str + name: str + description: Optional[str] + key_prefix: str + is_active: bool + is_expired: bool + rate_limit_per_minute: int + allowed_endpoints: Optional[List[str]] + total_requests: int + last_used_at: Optional[datetime] + created_at: datetime + expires_at: Optional[datetime] + revoked_at: Optional[datetime] + created_by: str + + +class ApiKeyCreatedResponse(ApiKeyResponse): + """Schema de réponse après création (inclut la clé en clair)""" + + api_key: str = Field( + ..., description=" Clé API en clair - à sauvegarder immédiatement" + ) + + +class ApiKeyList(BaseModel): + """Liste de clés API""" + + total: int + items: List[ApiKeyResponse] + + +class SwaggerUserCreate(BaseModel): + """Schema pour créer un utilisateur Swagger""" + + username: str = Field(..., min_length=3, max_length=100) + password: str = Field(..., min_length=8) + full_name: Optional[str] = None + email: Optional[str] = None + + +class SwaggerUserResponse(BaseModel): + """Schema de réponse pour un utilisateur Swagger""" + + id: str + username: str + full_name: Optional[str] + email: Optional[str] + is_active: bool + created_at: datetime + last_login: Optional[datetime] + + class Config: + from_attributes = True diff --git a/schemas/articles/articles.py b/schemas/articles/articles.py index 79b2d62..26996a7 100644 --- a/schemas/articles/articles.py +++ b/schemas/articles/articles.py @@ -76,7 +76,6 @@ class Article(BaseModel): ) nb_emplacements: int = Field(0, description="Nombre d'emplacements") - # Champs énumérés normalisés suivi_stock: Optional[int] = Field( None, description="Type de suivi de stock (AR_SuiviStock): 0=Aucun, 1=CMUP, 2=FIFO/LIFO, 3=Sérialisé", diff --git a/schemas/documents/reglements.py b/schemas/documents/reglements.py index bf6d178..5cc5e2c 100644 --- a/schemas/documents/reglements.py +++ b/schemas/documents/reglements.py @@ -10,12 +10,10 @@ logger = logging.getLogger(__name__) class ReglementFactureCreate(BaseModel): """Requête de règlement d'une facture côté VPS""" - # Montant et devise montant: Decimal = Field(..., gt=0, description="Montant à régler") devise_code: Optional[int] = Field(0, description="Code devise (0=EUR par défaut)") cours_devise: Optional[Decimal] = Field(1.0, description="Cours de la devise") - # Mode et journal mode_reglement: int = Field( ..., ge=0, description="Code mode règlement depuis /reglements/modes" ) @@ -23,13 +21,11 @@ class ReglementFactureCreate(BaseModel): ..., min_length=1, description="Code journal depuis /journaux/tresorerie" ) - # Dates date_reglement: Optional[date] = Field( None, description="Date du règlement (défaut: aujourd'hui)" ) date_echeance: Optional[date] = Field(None, description="Date d'échéance") - # Références reference: Optional[str] = Field( "", max_length=17, description="Référence pièce règlement" ) @@ -37,7 +33,6 @@ class ReglementFactureCreate(BaseModel): "", max_length=35, description="Libellé du règlement" ) - # TVA sur encaissement tva_encaissement: Optional[bool] = Field( False, description="Appliquer TVA sur encaissement" ) @@ -81,7 +76,6 @@ class ReglementMultipleCreate(BaseModel): libelle: Optional[str] = Field("") tva_encaissement: Optional[bool] = Field(False) - # Factures spécifiques (optionnel) numeros_factures: Optional[List[str]] = Field( None, description="Si vide, règle les plus anciennes en premier" ) diff --git a/schemas/sage/sage_gateway.py b/schemas/sage/sage_gateway.py index e503641..9501129 100644 --- a/schemas/sage/sage_gateway.py +++ b/schemas/sage/sage_gateway.py @@ -10,7 +10,6 @@ class GatewayHealthStatus(str, Enum): UNKNOWN = "unknown" -# === CREATE === class SageGatewayCreate(BaseModel): name: str = Field( @@ -71,7 +70,6 @@ class SageGatewayUpdate(BaseModel): return v.rstrip("/") if v else v -# === RESPONSE === class SageGatewayResponse(BaseModel): id: str diff --git a/schemas/tiers/commercial.py b/schemas/tiers/commercial.py index 5a4685b..de74165 100644 --- a/schemas/tiers/commercial.py +++ b/schemas/tiers/commercial.py @@ -9,7 +9,6 @@ class CollaborateurBase(BaseModel): prenom: Optional[str] = Field(None, max_length=50) fonction: Optional[str] = Field(None, max_length=50) - # Adresse adresse: Optional[str] = Field(None, max_length=100) complement: Optional[str] = Field(None, max_length=100) code_postal: Optional[str] = Field(None, max_length=10) @@ -17,7 +16,6 @@ class CollaborateurBase(BaseModel): code_region: Optional[str] = Field(None, max_length=50) pays: Optional[str] = Field(None, max_length=50) - # Services service: Optional[str] = Field(None, max_length=50) vendeur: bool = Field(default=False) caissier: bool = Field(default=False) @@ -25,18 +23,15 @@ class CollaborateurBase(BaseModel): chef_ventes: bool = Field(default=False) numero_chef_ventes: Optional[int] = None - # Contact telephone: Optional[str] = Field(None, max_length=20) telecopie: Optional[str] = Field(None, max_length=20) email: Optional[EmailStr] = None tel_portable: Optional[str] = Field(None, max_length=20) - # Réseaux sociaux facebook: Optional[str] = Field(None, max_length=100) linkedin: Optional[str] = Field(None, max_length=100) skype: Optional[str] = Field(None, max_length=100) - # Autres matricule: Optional[str] = Field(None, max_length=20) sommeil: bool = Field(default=False) diff --git a/schemas/tiers/tiers.py b/schemas/tiers/tiers.py index 58166a1..7a46ef5 100644 --- a/schemas/tiers/tiers.py +++ b/schemas/tiers/tiers.py @@ -14,7 +14,6 @@ class TypeTiersInt(IntEnum): class TiersDetails(BaseModel): - # IDENTIFICATION numero: Optional[str] = Field(None, description="Code tiers (CT_Num)") intitule: Optional[str] = Field( None, description="Raison sociale ou Nom complet (CT_Intitule)" @@ -37,7 +36,6 @@ class TiersDetails(BaseModel): ) code_naf: Optional[str] = Field(None, description="Code NAF/APE (CT_Ape)") - # ADRESSE contact: Optional[str] = Field( None, description="Nom du contact principal (CT_Contact)" ) @@ -50,7 +48,6 @@ class TiersDetails(BaseModel): region: Optional[str] = Field(None, description="Région/État (CT_CodeRegion)") pays: Optional[str] = Field(None, description="Pays (CT_Pays)") - # TELECOM telephone: Optional[str] = Field(None, description="Téléphone fixe (CT_Telephone)") telecopie: Optional[str] = Field(None, description="Fax (CT_Telecopie)") email: Optional[str] = Field(None, description="Email principal (CT_EMail)") @@ -58,13 +55,11 @@ class TiersDetails(BaseModel): facebook: Optional[str] = Field(None, description="Profil Facebook (CT_Facebook)") linkedin: Optional[str] = Field(None, description="Profil LinkedIn (CT_LinkedIn)") - # TAUX taux01: Optional[float] = Field(None, description="Taux personnalisé 1 (CT_Taux01)") taux02: Optional[float] = Field(None, description="Taux personnalisé 2 (CT_Taux02)") taux03: Optional[float] = Field(None, description="Taux personnalisé 3 (CT_Taux03)") taux04: Optional[float] = Field(None, description="Taux personnalisé 4 (CT_Taux04)") - # STATISTIQUES statistique01: Optional[str] = Field( None, description="Statistique 1 (CT_Statistique01)" ) @@ -96,7 +91,6 @@ class TiersDetails(BaseModel): None, description="Statistique 10 (CT_Statistique10)" ) - # COMMERCIAL encours_autorise: Optional[float] = Field( None, description="Encours maximum autorisé (CT_Encours)" ) @@ -113,7 +107,6 @@ class TiersDetails(BaseModel): None, description="Détails du commercial/collaborateur" ) - # FACTURATION lettrage_auto: Optional[bool] = Field( None, description="Lettrage automatique (CT_Lettrage)" ) @@ -146,7 +139,6 @@ class TiersDetails(BaseModel): None, description="Bon à payer obligatoire (CT_BonAPayer)" ) - # LOGISTIQUE priorite_livraison: Optional[int] = Field( None, description="Priorité livraison (CT_PrioriteLivr)" ) @@ -160,17 +152,14 @@ class TiersDetails(BaseModel): None, description="Délai appro jours (CT_DelaiAppro)" ) - # COMMENTAIRE commentaire: Optional[str] = Field( None, description="Commentaire libre (CT_Commentaire)" ) - # ANALYTIQUE section_analytique: Optional[str] = Field( None, description="Section analytique (CA_Num)" ) - # ORGANISATION / SURVEILLANCE mode_reglement_code: Optional[int] = Field( None, description="Code mode règlement (MR_No)" ) @@ -200,7 +189,6 @@ class TiersDetails(BaseModel): None, description="Résultat financier (CT_SvResultat)" ) - # COMPTE GENERAL ET CATEGORIES compte_general: Optional[str] = Field( None, description="Compte général principal (CG_NumPrinc)" ) @@ -211,7 +199,6 @@ class TiersDetails(BaseModel): None, description="Catégorie comptable (N_CatCompta)" ) - # CONTACTS contacts: Optional[List[Contact]] = Field( default_factory=list, description="Liste des contacts du tiers" ) diff --git a/scripts/manage_security.py b/scripts/manage_security.py new file mode 100644 index 0000000..1e5cab9 --- /dev/null +++ b/scripts/manage_security.py @@ -0,0 +1,353 @@ +import sys +import os +from pathlib import Path + +_current_file = Path(__file__).resolve() +_script_dir = _current_file.parent +_app_dir = _script_dir.parent + +print(f"DEBUG: Script path: {_current_file}") +print(f"DEBUG: App dir: {_app_dir}") +print(f"DEBUG: Current working dir: {os.getcwd()}") + +if str(_app_dir) in sys.path: + sys.path.remove(str(_app_dir)) +sys.path.insert(0, str(_app_dir)) + +os.chdir(str(_app_dir)) + +print(f"DEBUG: sys.path[0]: {sys.path[0]}") +print(f"DEBUG: New working dir: {os.getcwd()}") + +_test_imports = [ + "database", + "database.db_config", + "database.models", + "services", + "security", +] + +print("\nDEBUG: Vérification des imports...") +for module in _test_imports: + try: + __import__(module) + print(f" {module}") + except ImportError as e: + print(f" {module}: {e}") + +import asyncio +import argparse +import logging +from datetime import datetime + +from sqlalchemy import select + +try: + from database.db_config import async_session_factory + from database.models.user import User + from database.models.api_key import SwaggerUser, ApiKey + from services.api_key import ApiKeyService + from security.auth import hash_password +except ImportError as e: + print(f"\n ERREUR D'IMPORT: {e}") + print(f" Vérifiez que vous êtes dans /app") + print(f" Commande correcte: cd /app && python scripts/manage_security.py ...") + sys.exit(1) + +logging.basicConfig(level=logging.INFO, format="%(levelname)s - %(message)s") +logger = logging.getLogger(__name__) + + +async def add_swagger_user(username: str, password: str, full_name: str = None): + """Ajouter un utilisateur Swagger""" + async with async_session_factory() as session: + result = await session.execute( + select(SwaggerUser).where(SwaggerUser.username == username) + ) + existing = result.scalar_one_or_none() + + if existing: + logger.error(f" L'utilisateur '{username}' existe déjà") + return + + swagger_user = SwaggerUser( + username=username, + hashed_password=hash_password(password), + full_name=full_name or username, + is_active=True, + ) + + session.add(swagger_user) + await session.commit() + + logger.info(f" Utilisateur Swagger créé: {username}") + logger.info(f" Nom complet: {swagger_user.full_name}") + + +async def list_swagger_users(): + """Lister tous les utilisateurs Swagger""" + async with async_session_factory() as session: + result = await session.execute(select(SwaggerUser)) + users = result.scalars().all() + + if not users: + logger.info("🔭 Aucun utilisateur Swagger") + return + + logger.info(f"👥 {len(users)} utilisateur(s) Swagger:\n") + for user in users: + status = "" if user.is_active else "" + logger.info(f" {status} {user.username}") + logger.info(f" Nom: {user.full_name}") + logger.info(f" Créé: {user.created_at}") + logger.info(f" Dernière connexion: {user.last_login or 'Jamais'}\n") + + +async def delete_swagger_user(username: str): + """Supprimer un utilisateur Swagger""" + async with async_session_factory() as session: + result = await session.execute( + select(SwaggerUser).where(SwaggerUser.username == username) + ) + user = result.scalar_one_or_none() + + if not user: + logger.error(f" Utilisateur '{username}' introuvable") + return + + await session.delete(user) + await session.commit() + logger.info(f"🗑️ Utilisateur Swagger supprimé: {username}") + + +async def create_api_key( + name: str, + description: str = None, + expires_in_days: int = 365, + rate_limit: int = 60, + endpoints: list = None, +): + """Créer une clé API""" + async with async_session_factory() as session: + service = ApiKeyService(session) + + api_key_obj, api_key_plain = await service.create_api_key( + name=name, + description=description, + created_by="cli", + expires_in_days=expires_in_days, + rate_limit_per_minute=rate_limit, + allowed_endpoints=endpoints, + ) + + logger.info("=" * 70) + logger.info("🔑 Clé API créée avec succès") + logger.info("=" * 70) + logger.info(f" ID: {api_key_obj.id}") + logger.info(f" Nom: {api_key_obj.name}") + logger.info(f" Clé: {api_key_plain}") + logger.info(f" Préfixe: {api_key_obj.key_prefix}") + logger.info(f" Rate limit: {api_key_obj.rate_limit_per_minute} req/min") + logger.info(f" Expire le: {api_key_obj.expires_at}") + + if api_key_obj.allowed_endpoints: + import json + + try: + endpoints_list = json.loads(api_key_obj.allowed_endpoints) + logger.info(f" Endpoints: {', '.join(endpoints_list)}") + except: + logger.info(f" Endpoints: {api_key_obj.allowed_endpoints}") + else: + logger.info(" Endpoints: Tous (aucune restriction)") + + logger.info("=" * 70) + logger.info(" SAUVEGARDEZ CETTE CLÉ - Elle ne sera plus affichée !") + logger.info("=" * 70) + + +async def list_api_keys(): + """Lister toutes les clés API""" + async with async_session_factory() as session: + service = ApiKeyService(session) + keys = await service.list_api_keys() + + if not keys: + logger.info("🔭 Aucune clé API") + return + + logger.info(f"🔑 {len(keys)} clé(s) API:\n") + + for key in keys: + is_valid = key.is_active and ( + not key.expires_at or key.expires_at > datetime.now() + ) + status = "" if is_valid else "" + + logger.info(f" {status} {key.name:<30} ({key.key_prefix}...)") + logger.info(f" ID: {key.id}") + logger.info(f" Rate limit: {key.rate_limit_per_minute} req/min") + logger.info(f" Requêtes: {key.total_requests}") + logger.info(f" Expire: {key.expires_at or 'Jamais'}") + logger.info(f" Dernière utilisation: {key.last_used_at or 'Jamais'}") + + if key.allowed_endpoints: + import json + + try: + endpoints = json.loads(key.allowed_endpoints) + display = ", ".join(endpoints[:4]) + if len(endpoints) > 4: + display += f"... (+{len(endpoints) - 4})" + logger.info(f" Endpoints: {display}") + except: + pass + else: + logger.info(" Endpoints: Tous") + logger.info("") + + +async def revoke_api_key(key_id: str): + """Révoquer une clé API""" + async with async_session_factory() as session: + result = await session.execute(select(ApiKey).where(ApiKey.id == key_id)) + key = result.scalar_one_or_none() + + if not key: + logger.error(f" Clé API '{key_id}' introuvable") + return + + key.is_active = False + key.revoked_at = datetime.now() + await session.commit() + + logger.info(f"🗑️ Clé API révoquée: {key.name}") + logger.info(f" ID: {key.id}") + + +async def verify_api_key(api_key: str): + """Vérifier une clé API""" + async with async_session_factory() as session: + service = ApiKeyService(session) + key = await service.verify_api_key(api_key) + + if not key: + logger.error(" Clé API invalide ou expirée") + return + + logger.info("=" * 60) + logger.info(" Clé API valide") + logger.info("=" * 60) + logger.info(f" Nom: {key.name}") + logger.info(f" ID: {key.id}") + logger.info(f" Rate limit: {key.rate_limit_per_minute} req/min") + logger.info(f" Requêtes totales: {key.total_requests}") + logger.info(f" Expire: {key.expires_at or 'Jamais'}") + + if key.allowed_endpoints: + import json + + try: + endpoints = json.loads(key.allowed_endpoints) + logger.info(f" Endpoints autorisés: {endpoints}") + except: + pass + else: + logger.info(" Endpoints autorisés: Tous") + logger.info("=" * 60) + + +async def main(): + parser = argparse.ArgumentParser( + description="Gestion des utilisateurs Swagger et clés API", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Exemples: + python scripts/manage_security.py swagger add admin MyP@ssw0rd + python scripts/manage_security.py swagger list + python scripts/manage_security.py apikey create "Mon App" --days 365 --rate-limit 100 + python scripts/manage_security.py apikey create "SDK-ReadOnly" --endpoints "/clients" "/clients/*" "/devis" "/devis/*" + python scripts/manage_security.py apikey list + python scripts/manage_security.py apikey verify sdk_live_xxxxx + """, + ) + subparsers = parser.add_subparsers(dest="command", help="Commandes") + + swagger_parser = subparsers.add_parser("swagger", help="Gestion Swagger") + swagger_sub = swagger_parser.add_subparsers(dest="swagger_command") + + add_p = swagger_sub.add_parser("add", help="Ajouter utilisateur") + add_p.add_argument("username", help="Nom d'utilisateur") + add_p.add_argument("password", help="Mot de passe") + add_p.add_argument("--full-name", help="Nom complet") + + swagger_sub.add_parser("list", help="Lister utilisateurs") + + del_p = swagger_sub.add_parser("delete", help="Supprimer utilisateur") + del_p.add_argument("username", help="Nom d'utilisateur") + + apikey_parser = subparsers.add_parser("apikey", help="Gestion clés API") + apikey_sub = apikey_parser.add_subparsers(dest="apikey_command") + + create_p = apikey_sub.add_parser("create", help="Créer clé API") + create_p.add_argument("name", help="Nom de la clé") + create_p.add_argument("--description", help="Description") + create_p.add_argument("--days", type=int, default=365, help="Expiration (jours)") + create_p.add_argument("--rate-limit", type=int, default=60, help="Req/min") + create_p.add_argument("--endpoints", nargs="+", help="Endpoints autorisés") + + apikey_sub.add_parser("list", help="Lister clés") + + rev_p = apikey_sub.add_parser("revoke", help="Révoquer clé") + rev_p.add_argument("key_id", help="ID de la clé") + + ver_p = apikey_sub.add_parser("verify", help="Vérifier clé") + ver_p.add_argument("api_key", help="Clé API complète") + + args = parser.parse_args() + + if not args.command: + parser.print_help() + return + + if args.command == "swagger": + if args.swagger_command == "add": + await add_swagger_user(args.username, args.password, args.full_name) + elif args.swagger_command == "list": + await list_swagger_users() + elif args.swagger_command == "delete": + await delete_swagger_user(args.username) + else: + swagger_parser.print_help() + + elif args.command == "apikey": + if args.apikey_command == "create": + await create_api_key( + name=args.name, + description=args.description, + expires_in_days=args.days, + rate_limit=args.rate_limit, + endpoints=args.endpoints, + ) + elif args.apikey_command == "list": + await list_api_keys() + elif args.apikey_command == "revoke": + await revoke_api_key(args.key_id) + elif args.apikey_command == "verify": + await verify_api_key(args.api_key) + else: + apikey_parser.print_help() + + +if __name__ == "__main__": + try: + asyncio.run(main()) + except KeyboardInterrupt: + print("\nℹ️ Interrupted") + sys.exit(0) + except Exception as e: + logger.error(f" Erreur: {e}") + import traceback + + traceback.print_exc() + sys.exit(1) diff --git a/scripts/test_security.py b/scripts/test_security.py new file mode 100644 index 0000000..497870e --- /dev/null +++ b/scripts/test_security.py @@ -0,0 +1,354 @@ +import requests +import argparse +import sys +from typing import Tuple + + +class SecurityTester: + def __init__(self, base_url: str): + self.base_url = base_url.rstrip("/") + self.results = {"passed": 0, "failed": 0, "tests": []} + + def log_test(self, name: str, passed: bool, details: str = ""): + """Enregistrer le résultat d'un test""" + status = " PASS" if passed else " FAIL" + print(f"{status} - {name}") + if details: + print(f" {details}") + + self.results["tests"].append( + {"name": name, "passed": passed, "details": details} + ) + + if passed: + self.results["passed"] += 1 + else: + self.results["failed"] += 1 + + def test_swagger_without_auth(self) -> bool: + """Test 1: Swagger UI devrait demander une authentification""" + print("\n Test 1: Protection Swagger UI") + + try: + response = requests.get(f"{self.base_url}/docs", timeout=5) + + if response.status_code == 401: + self.log_test( + "Swagger protégé", + True, + "Code 401 retourné sans authentification", + ) + return True + else: + self.log_test( + "Swagger protégé", + False, + f"Code {response.status_code} au lieu de 401", + ) + return False + + except Exception as e: + self.log_test("Swagger protégé", False, f"Erreur: {str(e)}") + return False + + def test_swagger_with_auth(self, username: str, password: str) -> bool: + """Test 2: Swagger UI accessible avec credentials valides""" + print("\n Test 2: Accès Swagger avec authentification") + + try: + response = requests.get( + f"{self.base_url}/docs", auth=(username, password), timeout=5 + ) + + if response.status_code == 200: + self.log_test( + "Accès Swagger avec auth", + True, + f"Authentifié comme {username}", + ) + return True + else: + self.log_test( + "Accès Swagger avec auth", + False, + f"Code {response.status_code}, credentials invalides?", + ) + return False + + except Exception as e: + self.log_test("Accès Swagger avec auth", False, f"Erreur: {str(e)}") + return False + + def test_api_without_auth(self) -> bool: + """Test 3: Endpoints API devraient demander une authentification""" + print("\n Test 3: Protection des endpoints API") + + test_endpoints = ["/api/v1/clients", "/api/v1/documents"] + + all_protected = True + for endpoint in test_endpoints: + try: + response = requests.get(f"{self.base_url}{endpoint}", timeout=5) + + if response.status_code == 401: + print(f" {endpoint} protégé (401)") + else: + print( + f" {endpoint} accessible sans auth (code {response.status_code})" + ) + all_protected = False + + except Exception as e: + print(f" {endpoint} erreur: {str(e)}") + all_protected = False + + self.log_test("Endpoints API protégés", all_protected) + return all_protected + + def test_health_endpoint_public(self) -> bool: + """Test 4: Endpoint /health devrait être accessible sans auth""" + print("\n Test 4: Endpoint /health public") + + try: + response = requests.get(f"{self.base_url}/health", timeout=5) + + if response.status_code == 200: + self.log_test("/health accessible", True, "Endpoint public fonctionne") + return True + else: + self.log_test( + "/health accessible", + False, + f"Code {response.status_code} inattendu", + ) + return False + + except Exception as e: + self.log_test("/health accessible", False, f"Erreur: {str(e)}") + return False + + def test_api_key_creation(self, username: str, password: str) -> Tuple[bool, str]: + """Test 5: Créer une clé API via l'endpoint""" + print("\n Test 5: Création d'une clé API") + + try: + login_response = requests.post( + f"{self.base_url}/api/v1/auth/login", + json={"email": username, "password": password}, + timeout=5, + ) + + if login_response.status_code != 200: + self.log_test( + "Création clé API", + False, + "Impossible de se connecter pour obtenir un JWT", + ) + return False, "" + + jwt_token = login_response.json().get("access_token") + + create_response = requests.post( + f"{self.base_url}/api/v1/api-keys", + headers={"Authorization": f"Bearer {jwt_token}"}, + json={ + "name": "Test API Key", + "description": "Clé de test automatisé", + "rate_limit_per_minute": 60, + "expires_in_days": 30, + }, + timeout=5, + ) + + if create_response.status_code == 201: + api_key = create_response.json().get("api_key") + self.log_test("Création clé API", True, f"Clé créée: {api_key[:20]}...") + return True, api_key + else: + self.log_test( + "Création clé API", + False, + f"Code {create_response.status_code}", + ) + return False, "" + + except Exception as e: + self.log_test("Création clé API", False, f"Erreur: {str(e)}") + return False, "" + + def test_api_key_usage(self, api_key: str) -> bool: + """Test 6: Utiliser une clé API pour accéder à un endpoint""" + print("\n Test 6: Utilisation d'une clé API") + + if not api_key: + self.log_test("Utilisation clé API", False, "Pas de clé disponible") + return False + + try: + response = requests.get( + f"{self.base_url}/api/v1/clients", + headers={"X-API-Key": api_key}, + timeout=5, + ) + + if response.status_code == 200: + self.log_test("Utilisation clé API", True, "Clé acceptée") + return True + else: + self.log_test( + "Utilisation clé API", + False, + f"Code {response.status_code}, clé refusée?", + ) + return False + + except Exception as e: + self.log_test("Utilisation clé API", False, f"Erreur: {str(e)}") + return False + + def test_invalid_api_key(self) -> bool: + """Test 7: Une clé invalide devrait être refusée""" + print("\n Test 7: Rejet de clé API invalide") + + invalid_key = "sdk_live_invalid_key_12345" + + try: + response = requests.get( + f"{self.base_url}/api/v1/clients", + headers={"X-API-Key": invalid_key}, + timeout=5, + ) + + if response.status_code == 401: + self.log_test("Clé invalide rejetée", True, "Code 401 comme attendu") + return True + else: + self.log_test( + "Clé invalide rejetée", + False, + f"Code {response.status_code} au lieu de 401", + ) + return False + + except Exception as e: + self.log_test("Clé invalide rejetée", False, f"Erreur: {str(e)}") + return False + + def test_rate_limiting(self, api_key: str) -> bool: + """Test 8: Rate limiting (optionnel, peut prendre du temps)""" + print("\n Test 8: Rate limiting (test simple)") + + if not api_key: + self.log_test("Rate limiting", False, "Pas de clé disponible") + return False + + print(" Envoi de 70 requêtes rapides...") + + rate_limited = False + for i in range(70): + try: + response = requests.get( + f"{self.base_url}/health", + headers={"X-API-Key": api_key}, + timeout=1, + ) + + if response.status_code == 429: + rate_limited = True + print(f" Rate limit atteint à la requête {i + 1}") + break + + except Exception: + pass + + if rate_limited: + self.log_test("Rate limiting", True, "Rate limit détecté") + return True + else: + self.log_test( + "Rate limiting", + True, + "Aucun rate limit détecté (peut être normal si pas implémenté)", + ) + return True + + def print_summary(self): + """Afficher le résumé des tests""" + print("\n" + "=" * 60) + print(" RÉSUMÉ DES TESTS") + print("=" * 60) + + total = self.results["passed"] + self.results["failed"] + success_rate = (self.results["passed"] / total * 100) if total > 0 else 0 + + print(f"\nTotal: {total} tests") + print(f" Réussis: {self.results['passed']}") + print(f" Échoués: {self.results['failed']}") + print(f"Taux de réussite: {success_rate:.1f}%\n") + + if self.results["failed"] == 0: + print("🎉 Tous les tests sont passés ! Sécurité OK.") + return 0 + else: + print(" Certains tests ont échoué. Vérifiez la configuration.") + return 1 + + +def main(): + parser = argparse.ArgumentParser( + description="Test automatisé de la sécurité de l'API" + ) + + parser.add_argument( + "--url", + required=True, + help="URL de base de l'API (ex: http://localhost:8000)", + ) + + parser.add_argument( + "--swagger-user", required=True, help="Utilisateur Swagger pour les tests" + ) + + parser.add_argument( + "--swagger-pass", required=True, help="Mot de passe Swagger pour les tests" + ) + + parser.add_argument( + "--skip-rate-limit", + action="store_true", + help="Sauter le test de rate limiting (long)", + ) + + args = parser.parse_args() + + print(" Démarrage des tests de sécurité") + print(f" URL cible: {args.url}\n") + + tester = SecurityTester(args.url) + + tester.test_swagger_without_auth() + tester.test_swagger_with_auth(args.swagger_user, args.swagger_pass) + tester.test_api_without_auth() + tester.test_health_endpoint_public() + + success, api_key = tester.test_api_key_creation( + args.swagger_user, args.swagger_pass + ) + + if success and api_key: + tester.test_api_key_usage(api_key) + tester.test_invalid_api_key() + + if not args.skip_rate_limit: + tester.test_rate_limiting(api_key) + else: + print("\n Test de rate limiting sauté") + else: + print("\n Tests avec clé API sautés (création échouée)") + + exit_code = tester.print_summary() + sys.exit(exit_code) + + +if __name__ == "__main__": + main() diff --git a/security/auth.py b/security/auth.py index 970a90f..e05b6a0 100644 --- a/security/auth.py +++ b/security/auth.py @@ -5,10 +5,12 @@ import jwt import secrets import hashlib -SECRET_KEY = "VOTRE_SECRET_KEY_A_METTRE_EN_.ENV" -ALGORITHM = "HS256" -ACCESS_TOKEN_EXPIRE_MINUTES = 10080 -REFRESH_TOKEN_EXPIRE_DAYS = 7 +from config.config import settings + +SECRET_KEY = settings.jwt_secret +ALGORITHM = settings.jwt_algorithm +ACCESS_TOKEN_EXPIRE_MINUTES = settings.access_token_expire_minutes +REFRESH_TOKEN_EXPIRE_DAYS = settings.refresh_token_expire_days pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") @@ -67,9 +69,13 @@ def decode_token(token: str) -> Optional[Dict]: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) return payload except jwt.ExpiredSignatureError: - return None - except jwt.JWTError: - return None + raise jwt.InvalidTokenError("Token expiré") + except jwt.DecodeError: + raise jwt.InvalidTokenError("Token invalide (format incorrect)") + except jwt.InvalidTokenError as e: + raise jwt.InvalidTokenError(f"Token invalide: {str(e)}") + except Exception as e: + raise jwt.InvalidTokenError(f"Erreur lors du décodage du token: {str(e)}") def validate_password_strength(password: str) -> tuple[bool, str]: diff --git a/services/api_key.py b/services/api_key.py new file mode 100644 index 0000000..04e271e --- /dev/null +++ b/services/api_key.py @@ -0,0 +1,223 @@ +import secrets +import hashlib +import json +from datetime import datetime, timedelta +from typing import Optional, List, Dict +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import select, and_, or_ +import logging + +from database.models.api_key import ApiKey + +logger = logging.getLogger(__name__) + + +class ApiKeyService: + """Service de gestion des clés API""" + + def __init__(self, session: AsyncSession): + self.session = session + + @staticmethod + def generate_api_key() -> str: + """Génère une clé API unique et sécurisée""" + random_part = secrets.token_urlsafe(32) + return f"sdk_live_{random_part}" + + @staticmethod + def hash_api_key(api_key: str) -> str: + """Hash la clé API pour stockage sécurisé""" + return hashlib.sha256(api_key.encode()).hexdigest() + + @staticmethod + def get_key_prefix(api_key: str) -> str: + """Extrait le préfixe de la clé pour identification""" + return api_key[:12] if len(api_key) >= 12 else api_key + + async def create_api_key( + self, + name: str, + description: Optional[str] = None, + created_by: str = "system", + user_id: Optional[str] = None, + expires_in_days: Optional[int] = None, + rate_limit_per_minute: int = 60, + allowed_endpoints: Optional[List[str]] = None, + ) -> tuple[ApiKey, str]: + api_key_plain = self.generate_api_key() + key_hash = self.hash_api_key(api_key_plain) + key_prefix = self.get_key_prefix(api_key_plain) + + expires_at = None + if expires_in_days: + expires_at = datetime.now() + timedelta(days=expires_in_days) + + api_key_obj = ApiKey( + key_hash=key_hash, + key_prefix=key_prefix, + name=name, + description=description, + created_by=created_by, + user_id=user_id, + expires_at=expires_at, + rate_limit_per_minute=rate_limit_per_minute, + allowed_endpoints=json.dumps(allowed_endpoints) + if allowed_endpoints + else None, + ) + + self.session.add(api_key_obj) + await self.session.commit() + await self.session.refresh(api_key_obj) + + logger.info(f" Clé API créée: {name} (prefix: {key_prefix})") + + return api_key_obj, api_key_plain + + async def verify_api_key(self, api_key_plain: str) -> Optional[ApiKey]: + key_hash = self.hash_api_key(api_key_plain) + + result = await self.session.execute( + select(ApiKey).where( + and_( + ApiKey.key_hash == key_hash, + ApiKey.is_active, + ApiKey.revoked_at.is_(None), + or_( + ApiKey.expires_at.is_(None), ApiKey.expires_at > datetime.now() + ), + ) + ) + ) + + api_key_obj = result.scalar_one_or_none() + + if api_key_obj: + api_key_obj.total_requests += 1 + api_key_obj.last_used_at = datetime.now() + await self.session.commit() + + logger.debug(f" Clé API validée: {api_key_obj.name}") + else: + logger.warning(" Clé API invalide ou expirée") + + return api_key_obj + + async def list_api_keys( + self, + include_revoked: bool = False, + user_id: Optional[str] = None, + ) -> List[ApiKey]: + """Liste les clés API""" + query = select(ApiKey) + + if not include_revoked: + query = query.where(ApiKey.revoked_at.is_(None)) + + if user_id: + query = query.where(ApiKey.user_id == user_id) + + query = query.order_by(ApiKey.created_at.desc()) + + result = await self.session.execute(query) + return list(result.scalars().all()) + + async def revoke_api_key(self, key_id: str) -> bool: + """Révoque une clé API""" + result = await self.session.execute(select(ApiKey).where(ApiKey.id == key_id)) + api_key_obj = result.scalar_one_or_none() + + if not api_key_obj: + return False + + api_key_obj.is_active = False + api_key_obj.revoked_at = datetime.now() + await self.session.commit() + + logger.info(f"🗑️ Clé API révoquée: {api_key_obj.name}") + return True + + async def get_by_id(self, key_id: str) -> Optional[ApiKey]: + """Récupère une clé API par son ID""" + result = await self.session.execute(select(ApiKey).where(ApiKey.id == key_id)) + return result.scalar_one_or_none() + + async def check_rate_limit(self, api_key_obj: ApiKey) -> tuple[bool, Dict]: + return True, { + "allowed": True, + "limit": api_key_obj.rate_limit_per_minute, + "remaining": api_key_obj.rate_limit_per_minute, + } + + async def check_endpoint_access(self, api_key_obj: ApiKey, endpoint: str) -> bool: + if not api_key_obj.allowed_endpoints: + logger.debug( + f"🔓 API Key {api_key_obj.name}: Aucune restriction d'endpoint" + ) + return True + + try: + allowed = json.loads(api_key_obj.allowed_endpoints) + + if "*" in allowed or "/*" in allowed: + logger.debug(f"🔓 API Key {api_key_obj.name}: Accès global autorisé") + return True + + for pattern in allowed: + if pattern == endpoint: + logger.debug(f" Match exact: {pattern} == {endpoint}") + return True + + if pattern.endswith("/*"): + base = pattern[:-2] # "/clients/*" → "/clients" + if endpoint == base or endpoint.startswith(base + "/"): + logger.debug(f" Match wildcard: {pattern} ↔ {endpoint}") + return True + + elif pattern.endswith("*"): + base = pattern[:-1] # "/clients*" → "/clients" + if endpoint.startswith(base): + logger.debug(f" Match prefix: {pattern} ↔ {endpoint}") + return True + + logger.warning( + f" API Key {api_key_obj.name}: Accès refusé à {endpoint}\n" + f" Endpoints autorisés: {allowed}" + ) + return False + + except json.JSONDecodeError: + logger.error(f" Erreur parsing allowed_endpoints pour {api_key_obj.id}") + return False + + +def api_key_to_response(api_key_obj: ApiKey, show_key: bool = False) -> Dict: + """Convertit un objet ApiKey en réponse API""" + + allowed_endpoints = None + if api_key_obj.allowed_endpoints: + try: + allowed_endpoints = json.loads(api_key_obj.allowed_endpoints) + except json.JSONDecodeError: + pass + + is_expired = False + if api_key_obj.expires_at: + is_expired = api_key_obj.expires_at < datetime.now() + + return { + "id": api_key_obj.id, + "name": api_key_obj.name, + "description": api_key_obj.description, + "key_prefix": api_key_obj.key_prefix, + "is_active": api_key_obj.is_active, + "is_expired": is_expired, + "rate_limit_per_minute": api_key_obj.rate_limit_per_minute, + "allowed_endpoints": allowed_endpoints, + "total_requests": api_key_obj.total_requests, + "last_used_at": api_key_obj.last_used_at, + "created_at": api_key_obj.created_at, + "expires_at": api_key_obj.expires_at, + "revoked_at": api_key_obj.revoked_at, + "created_by": api_key_obj.created_by, + } diff --git a/services/universign_document.py b/services/universign_document.py index fa899a9..394c3ce 100644 --- a/services/universign_document.py +++ b/services/universign_document.py @@ -23,7 +23,7 @@ class UniversignDocumentService: def fetch_transaction_documents(self, transaction_id: str) -> Optional[List[Dict]]: try: - logger.info(f"📋 Récupération documents pour transaction: {transaction_id}") + logger.info(f" Récupération documents pour transaction: {transaction_id}") response = requests.get( f"{self.api_url}/transactions/{transaction_id}", @@ -38,7 +38,6 @@ class UniversignDocumentService: logger.info(f"{len(documents)} document(s) trouvé(s)") - # Log détaillé de chaque document for idx, doc in enumerate(documents): logger.debug( f" Document {idx}: id={doc.get('id')}, " @@ -64,7 +63,7 @@ class UniversignDocumentService: logger.error(f"⏱️ Timeout récupération transaction {transaction_id}") return None except Exception as e: - logger.error(f"❌ Erreur fetch documents: {e}", exc_info=True) + logger.error(f" Erreur fetch documents: {e}", exc_info=True) return None def download_signed_document( @@ -94,7 +93,6 @@ class UniversignDocumentService: f"Content-Type={content_type}, Size={content_length}" ) - # Vérification du type de contenu if ( "pdf" not in content_type.lower() and "octet-stream" not in content_type.lower() @@ -104,31 +102,30 @@ class UniversignDocumentService: f"Tentative de lecture quand même..." ) - # Lecture du contenu content = response.content if len(content) < 1024: - logger.error(f"❌ Document trop petit: {len(content)} octets") + logger.error(f" Document trop petit: {len(content)} octets") return None return content elif response.status_code == 404: logger.error( - f"❌ Document {document_id} introuvable pour transaction {transaction_id}" + f" Document {document_id} introuvable pour transaction {transaction_id}" ) return None elif response.status_code == 403: logger.error( - f"❌ Accès refusé au document {document_id}. " + f" Accès refusé au document {document_id}. " f"Vérifiez que la transaction est bien signée." ) return None else: logger.error( - f"❌ Erreur HTTP {response.status_code}: {response.text[:500]}" + f" Erreur HTTP {response.status_code}: {response.text[:500]}" ) return None @@ -136,13 +133,12 @@ class UniversignDocumentService: logger.error(f"⏱️ Timeout téléchargement document {document_id}") return None except Exception as e: - logger.error(f"❌ Erreur téléchargement: {e}", exc_info=True) + logger.error(f" Erreur téléchargement: {e}", exc_info=True) return None async def download_and_store_signed_document( self, session: AsyncSession, transaction, force: bool = False ) -> Tuple[bool, Optional[str]]: - # Vérification si déjà téléchargé if not force and transaction.signed_document_path: if os.path.exists(transaction.signed_document_path): logger.debug( @@ -153,7 +149,6 @@ class UniversignDocumentService: transaction.download_attempts += 1 try: - # ÉTAPE 1: Récupérer les documents de la transaction logger.info( f"Récupération document signé pour: {transaction.transaction_id}" ) @@ -167,13 +162,11 @@ class UniversignDocumentService: await session.commit() return False, error - # ÉTAPE 2: Récupérer le premier document (ou chercher celui qui est signé) document_id = None for doc in documents: doc_id = doc.get("id") doc_status = doc.get("status", "").lower() - # Priorité aux documents marqués comme signés/complétés if doc_status in ["signed", "completed", "closed"]: document_id = doc_id logger.info( @@ -181,34 +174,30 @@ class UniversignDocumentService: ) break - # Fallback sur le premier document si aucun n'est explicitement signé if document_id is None: document_id = doc_id if not document_id: error = "Impossible de déterminer l'ID du document à télécharger" - logger.error(f"❌ {error}") + logger.error(f" {error}") transaction.download_error = error await session.commit() return False, error - # Stocker le document_id pour référence future if hasattr(transaction, "universign_document_id"): transaction.universign_document_id = document_id - # ÉTAPE 3: Télécharger le document signé pdf_content = self.download_signed_document( transaction_id=transaction.transaction_id, document_id=document_id ) if not pdf_content: error = f"Échec téléchargement document {document_id}" - logger.error(f"❌ {error}") + logger.error(f" {error}") transaction.download_error = error await session.commit() return False, error - # ÉTAPE 4: Stocker le fichier localement filename = self._generate_filename(transaction) file_path = SIGNED_DOCS_DIR / filename @@ -217,13 +206,11 @@ class UniversignDocumentService: file_size = os.path.getsize(file_path) - # Mise à jour de la transaction transaction.signed_document_path = str(file_path) transaction.signed_document_downloaded_at = datetime.now() transaction.signed_document_size_bytes = file_size transaction.download_error = None - # Stocker aussi l'URL de téléchargement pour référence transaction.document_url = ( f"{self.api_url}/transactions/{transaction.transaction_id}" f"/documents/{document_id}/download" @@ -239,14 +226,14 @@ class UniversignDocumentService: except OSError as e: error = f"Erreur filesystem: {str(e)}" - logger.error(f"❌ {error}") + logger.error(f" {error}") transaction.download_error = error await session.commit() return False, error except Exception as e: error = f"Erreur inattendue: {str(e)}" - logger.error(f"❌ {error}", exc_info=True) + logger.error(f" {error}", exc_info=True) transaction.download_error = error await session.commit() return False, error @@ -294,7 +281,6 @@ class UniversignDocumentService: return deleted, int(size_freed_mb) - # === MÉTHODES DE DIAGNOSTIC === def diagnose_transaction(self, transaction_id: str) -> Dict: """ @@ -308,7 +294,6 @@ class UniversignDocumentService: } try: - # Test 1: Récupération de la transaction logger.info(f"Diagnostic transaction: {transaction_id}") response = requests.get( @@ -334,7 +319,6 @@ class UniversignDocumentService: "participants_count": len(data.get("participants", [])), } - # Test 2: Documents disponibles documents = data.get("documents", []) result["checks"]["documents"] = [] @@ -345,7 +329,6 @@ class UniversignDocumentService: "status": doc.get("status"), } - # Test téléchargement if doc.get("id"): download_url = ( f"{self.api_url}/transactions/{transaction_id}" diff --git a/services/universign_sync.py b/services/universign_sync.py index 28e633c..da634f2 100644 --- a/services/universign_sync.py +++ b/services/universign_sync.py @@ -159,7 +159,6 @@ class UniversignSyncService: return stats - # CORRECTION 1 : process_webhook dans universign_sync.py async def process_webhook( self, session: AsyncSession, payload: Dict, transaction_id: str = None ) -> Tuple[bool, Optional[str]]: @@ -167,9 +166,7 @@ class UniversignSyncService: Traite un webhook Universign - CORRECTION : meilleure gestion des payloads """ try: - # Si transaction_id n'est pas fourni, essayer de l'extraire if not transaction_id: - # Même logique que dans universign.py if ( payload.get("type", "").startswith("transaction.") and "payload" in payload @@ -195,7 +192,6 @@ class UniversignSyncService: f"📨 Traitement webhook: transaction={transaction_id}, event={event_type}" ) - # Récupérer la transaction locale query = ( select(UniversignTransaction) .options(selectinload(UniversignTransaction.signers)) @@ -208,25 +204,20 @@ class UniversignSyncService: logger.warning(f"Transaction {transaction_id} inconnue localement") return False, "Transaction inconnue" - # Marquer comme webhook reçu transaction.webhook_received = True - # Stocker l'ancien statut pour comparaison old_status = transaction.local_status.value - # Force la synchronisation complète success, error = await self.sync_transaction( session, transaction, force=True ) - # Log du changement de statut if success and transaction.local_status.value != old_status: logger.info( f"Webhook traité: {transaction_id} | " f"{old_status} → {transaction.local_status.value}" ) - # Enregistrer le log du webhook await self._log_sync_attempt( session=session, transaction=transaction, @@ -248,7 +239,6 @@ class UniversignSyncService: logger.error(f"💥 Erreur traitement webhook: {e}", exc_info=True) return False, str(e) - # CORRECTION 2 : _sync_signers - Ne pas écraser les signers existants async def _sync_signers( self, session: AsyncSession, @@ -271,7 +261,6 @@ class UniversignSyncService: logger.warning(f"Signataire sans email à l'index {idx}, ignoré") continue - # PROTECTION : gérer les statuts inconnus raw_status = signer_data.get("status") or signer_data.get( "state", "waiting" ) @@ -302,7 +291,6 @@ class UniversignSyncService: if signer_data.get("name") and not signer.name: signer.name = signer_data.get("name") else: - # Nouveau signer avec gestion d'erreur intégrée try: signer = UniversignSigner( id=f"{transaction.id}_signer_{idx}_{int(datetime.now().timestamp())}", @@ -330,7 +318,6 @@ class UniversignSyncService: ): import json - # Si statut final et pas de force, skip if is_final_status(transaction.local_status.value) and not force: logger.debug( f"⏭️ Skip {transaction.transaction_id}: statut final " @@ -340,14 +327,13 @@ class UniversignSyncService: await session.commit() return True, None - # Récupération du statut distant logger.info(f"Synchronisation: {transaction.transaction_id}") result = self.fetch_transaction_status(transaction.transaction_id) if not result: error = "Échec récupération données Universign" - logger.error(f"❌ {error}: {transaction.transaction_id}") + logger.error(f" {error}: {transaction.transaction_id}") transaction.sync_attempts += 1 transaction.sync_error = error await self._log_sync_attempt(session, transaction, "polling", False, error) @@ -358,9 +344,8 @@ class UniversignSyncService: universign_data = result["transaction"] universign_status_raw = universign_data.get("state", "draft") - logger.info(f"📊 Statut Universign brut: {universign_status_raw}") + logger.info(f" Statut Universign brut: {universign_status_raw}") - # Convertir le statut new_local_status = map_universign_to_local(universign_status_raw) previous_local_status = transaction.local_status.value @@ -369,7 +354,6 @@ class UniversignSyncService: f"{new_local_status} (Local) | Actuel: {previous_local_status}" ) - # Vérifier la transition if not is_transition_allowed(previous_local_status, new_local_status): logger.warning( f"Transition refusée: {previous_local_status} → {new_local_status}" @@ -383,10 +367,9 @@ class UniversignSyncService: if status_changed: logger.info( - f"🔔 CHANGEMENT DÉTECTÉ: {previous_local_status} → {new_local_status}" + f"CHANGEMENT DÉTECTÉ: {previous_local_status} → {new_local_status}" ) - # Mise à jour du statut Universign brut try: transaction.universign_status = UniversignTransactionStatus( universign_status_raw @@ -404,14 +387,12 @@ class UniversignSyncService: else: transaction.universign_status = UniversignTransactionStatus.STARTED - # Mise à jour du statut local transaction.local_status = LocalDocumentStatus(new_local_status) transaction.universign_status_updated_at = datetime.now() - # Mise à jour des dates if new_local_status == "EN_COURS" and not transaction.sent_at: transaction.sent_at = datetime.now() - logger.info("📅 Date d'envoi mise à jour") + logger.info("Date d'envoi mise à jour") if new_local_status == "SIGNE" and not transaction.signed_at: transaction.signed_at = datetime.now() @@ -419,15 +400,11 @@ class UniversignSyncService: if new_local_status == "REFUSE" and not transaction.refused_at: transaction.refused_at = datetime.now() - logger.info("❌ Date de refus mise à jour") + logger.info(" Date de refus mise à jour") if new_local_status == "EXPIRE" and not transaction.expired_at: transaction.expired_at = datetime.now() - logger.info("⏰ Date d'expiration mise à jour") - - # === SECTION CORRIGÉE: Gestion des documents === - # Ne plus chercher document_url dans la réponse (elle n'existe pas!) - # Le téléchargement se fait via le service document qui utilise le bon endpoint + logger.info("Date d'expiration mise à jour") documents = universign_data.get("documents", []) if documents: @@ -437,7 +414,6 @@ class UniversignSyncService: f"status={first_doc.get('status')}" ) - # Téléchargement automatique du document signé if new_local_status == "SIGNE" and not transaction.signed_document_path: logger.info("Déclenchement téléchargement document signé...") @@ -455,21 +431,15 @@ class UniversignSyncService: logger.warning(f"Échec téléchargement: {download_error}") except Exception as e: - logger.error( - f"❌ Erreur téléchargement document: {e}", exc_info=True - ) - # === FIN SECTION CORRIGÉE === + logger.error(f" Erreur téléchargement document: {e}", exc_info=True) - # Synchroniser les signataires await self._sync_signers(session, transaction, universign_data) - # Mise à jour des métadonnées de sync transaction.last_synced_at = datetime.now() transaction.sync_attempts += 1 transaction.needs_sync = not is_final_status(new_local_status) transaction.sync_error = None - # Log de la tentative await self._log_sync_attempt( session=session, transaction=transaction, @@ -491,7 +461,6 @@ class UniversignSyncService: await session.commit() - # Exécuter les actions post-changement if status_changed: logger.info(f"🎬 Exécution actions pour statut: {new_local_status}") await self._execute_status_actions( @@ -507,7 +476,7 @@ class UniversignSyncService: except Exception as e: error_msg = f"Erreur lors de la synchronisation: {str(e)}" - logger.error(f"❌ {error_msg}", exc_info=True) + logger.error(f" {error_msg}", exc_info=True) transaction.sync_error = error_msg[:1000] transaction.sync_attempts += 1 @@ -519,20 +488,16 @@ class UniversignSyncService: return False, error_msg - # CORRECTION 3 : Amélioration du logging dans sync_transaction async def _sync_transaction_documents_corrected( self, session, transaction, universign_data: dict, new_local_status: str ): - # Récupérer et stocker les infos documents documents = universign_data.get("documents", []) if documents: - # Stocker le premier document_id pour référence first_doc = documents[0] first_doc_id = first_doc.get("id") if first_doc_id: - # Stocker l'ID du document (si le champ existe dans le modèle) if hasattr(transaction, "universign_document_id"): transaction.universign_document_id = first_doc_id @@ -543,7 +508,6 @@ class UniversignSyncService: else: logger.debug("Aucun document dans la réponse Universign") - # Téléchargement automatique si signé if new_local_status == "SIGNE": if not transaction.signed_document_path: logger.info("Déclenchement téléchargement document signé...") @@ -562,9 +526,7 @@ class UniversignSyncService: logger.warning(f"Échec téléchargement: {download_error}") except Exception as e: - logger.error( - f"❌ Erreur téléchargement document: {e}", exc_info=True - ) + logger.error(f" Erreur téléchargement document: {e}", exc_info=True) else: logger.debug( f"Document déjà téléchargé: {transaction.signed_document_path}" diff --git a/tools/extract_pydantic_models.py b/tools/extract_pydantic_models.py index 595e15f..5718790 100644 --- a/tools/extract_pydantic_models.py +++ b/tools/extract_pydantic_models.py @@ -24,7 +24,6 @@ for node in tree.body: continue other_nodes.append(node) -# --- Extraction des classes --- imports = """ from pydantic import BaseModel, Field from typing import Optional, List @@ -44,7 +43,6 @@ for cls in pydantic_classes: print(f"✅ Modèle extrait : {class_name} → {file_path}") -# --- Réécriture du fichier source sans les modèles --- new_tree = ast.Module(body=other_nodes, type_ignores=[]) new_source = ast.unparse(new_tree) diff --git a/utils/generic_functions.py b/utils/generic_functions.py index 41b734b..29a361e 100644 --- a/utils/generic_functions.py +++ b/utils/generic_functions.py @@ -290,15 +290,11 @@ def _preparer_lignes_document(lignes: List) -> List[Dict]: UNIVERSIGN_TO_LOCAL: Dict[str, str] = { - # États initiaux "draft": "EN_ATTENTE", "ready": "EN_ATTENTE", - # En cours "started": "EN_COURS", - # États finaux (succès) "completed": "SIGNE", "closed": "SIGNE", - # États finaux (échec) "refused": "REFUSE", "expired": "EXPIRE", "canceled": "REFUSE", @@ -429,7 +425,7 @@ STATUS_MESSAGES: Dict[str, Dict[str, str]] = { "REFUSE": { "fr": "Signature refusée", "en": "Signature refused", - "icon": "❌", + "icon": "", "color": "red", }, "EXPIRE": { diff --git a/utils/universign_status_mapping.py b/utils/universign_status_mapping.py index 50e29cc..8391698 100644 --- a/utils/universign_status_mapping.py +++ b/utils/universign_status_mapping.py @@ -96,7 +96,7 @@ STATUS_MESSAGES: Dict[str, Dict[str, str]] = { "REFUSE": { "fr": "Signature refusée", "en": "Signature refused", - "icon": "❌", + "icon": "", "color": "red", }, "EXPIRE": {