From e97ff73e168c828b59308053857aa59364b41886 Mon Sep 17 00:00:00 2001 From: Fanilo-Nantenaina Date: Fri, 2 Jan 2026 17:56:28 +0300 Subject: [PATCH] feat(auth): implement comprehensive security enhancements --- .env.example | 107 ++++- api.py | 116 ++++-- config/config.py | 81 +++- core/dependencies.py | 159 +++++-- database/__init__.py | 24 +- database/models/auth_models.py | 214 ++++++++++ database/models/generic_model.py | 61 --- database/models/sage_config.py | 3 - middleware/security.py | 181 ++++++++ requirements.txt | 12 +- routes/auth.py | 695 +++++++++++++++++++------------ schemas/sage/sage_gateway.py | 11 - security/__init__.py | 55 +++ security/auth.py | 186 +++++++-- security/cookies.py | 157 +++++++ security/csrf.py | 117 ++++++ security/fingerprint.py | 122 ++++++ security/rate_limiter.py | 147 +++++++ services/audit_service.py | 318 ++++++++++++++ services/email_service.py | 412 +++++++++++------- services/redis_service.py | 200 +++++++++ services/sage_gateway.py | 23 +- services/token_service.py | 357 ++++++++++++++++ 23 files changed, 3085 insertions(+), 673 deletions(-) create mode 100644 database/models/auth_models.py create mode 100644 middleware/security.py create mode 100644 security/__init__.py create mode 100644 security/cookies.py create mode 100644 security/csrf.py create mode 100644 security/fingerprint.py create mode 100644 security/rate_limiter.py create mode 100644 services/audit_service.py create mode 100644 services/redis_service.py create mode 100644 services/token_service.py diff --git a/.env.example b/.env.example index 314aa07..05d93de 100644 --- a/.env.example +++ b/.env.example @@ -1,32 +1,97 @@ -# ============================================ -# Configuration Linux VPS - API Principale -# ============================================ +# === Environment === +ENVIRONMENT=development +# Options: development, staging, production -# === Sage Gateway Windows === -SAGE_GATEWAY_URL=http://192.168.1.50:8100 -SAGE_GATEWAY_TOKEN=4e8f9c2a7b1d5e3f9a0c8b7d6e5f4a3b2c1d0e9f8a7b6c5d4e3f2a1b0c9d8e7f +# === JWT & Authentication === +# IMPORTANT: Generer des secrets uniques et forts en production +# python -c "import secrets; print(secrets.token_urlsafe(64))" +JWT_SECRET=CHANGE_ME_IN_PRODUCTION_USE_STRONG_SECRET_64_CHARS_MIN +JWT_ALGORITHM=HS256 +ACCESS_TOKEN_EXPIRE_MINUTES=15 +REFRESH_TOKEN_EXPIRE_DAYS=7 +CSRF_TOKEN_EXPIRE_MINUTES=60 -# === Base de données === +# === Cookie Settings === +COOKIE_DOMAIN= +# Laisser vide pour localhost, sinon ".example.com" pour sous-domaines +COOKIE_SECURE=false +# Mettre true en production avec HTTPS +COOKIE_SAMESITE=strict +# Options: strict, lax, none +COOKIE_HTTPONLY=true +COOKIE_ACCESS_TOKEN_NAME=access_token +COOKIE_REFRESH_TOKEN_NAME=refresh_token +COOKIE_CSRF_TOKEN_NAME=csrf_token + +# === Redis (Token Blacklist & Rate Limiting) === +REDIS_URL=redis://localhost:6379/0 +REDIS_PASSWORD= +REDIS_SSL=false +TOKEN_BLACKLIST_PREFIX=blacklist: +RATE_LIMIT_PREFIX=ratelimit: + +# === Rate Limiting === +RATE_LIMIT_LOGIN_ATTEMPTS=5 +RATE_LIMIT_LOGIN_WINDOW_MINUTES=15 +RATE_LIMIT_API_REQUESTS=100 +RATE_LIMIT_API_WINDOW_SECONDS=60 + +# === Password Security === +PASSWORD_MIN_LENGTH=8 +PASSWORD_REQUIRE_UPPERCASE=true +PASSWORD_REQUIRE_LOWERCASE=true +PASSWORD_REQUIRE_DIGIT=true +PASSWORD_REQUIRE_SPECIAL=true +ACCOUNT_LOCKOUT_THRESHOLD=5 +ACCOUNT_LOCKOUT_DURATION_MINUTES=30 + +# === Device Fingerprint === +FINGERPRINT_SECRET= +# Si vide, utilise JWT_SECRET +FINGERPRINT_COMPONENTS=user_agent,accept_language,accept_encoding + +# === Refresh Token Rotation === +REFRESH_TOKEN_ROTATION_ENABLED=true +REFRESH_TOKEN_REUSE_WINDOW_SECONDS=10 + +# === Database === DATABASE_URL=sqlite+aiosqlite:///./data/sage_dataven.db +# PostgreSQL: postgresql+asyncpg://user:password@localhost:5432/dbname -# === SMTP === -SMTP_HOST=smtp.office365.com +# === Sage Gateway (Windows) === +SAGE_GATEWAY_URL=http://windows-server:5000 +SAGE_GATEWAY_TOKEN=your_gateway_token + +# === Frontend === +FRONTEND_URL=http://localhost:3000 + +# === SMTP (Email) === +SMTP_HOST=smtp.example.com SMTP_PORT=587 -SMTP_USER=commercial@monentreprise.fr -SMTP_PASSWORD=MonMotDePasseEmail123! -SMTP_FROM=commercial@monentreprise.fr +SMTP_USER=noreply@example.com +SMTP_PASSWORD=your_smtp_password +SMTP_FROM=noreply@example.com +SMTP_USE_TLS=true -# === Universign === -UNIVERSIGN_API_KEY=your_real_universign_key_here +# === Universign (Signature electronique) === +UNIVERSIGN_API_KEY=your_universign_api_key UNIVERSIGN_API_URL=https://api.universign.com/v1 -# === API === +# === API Server === API_HOST=0.0.0.0 -API_PORT=8002 -API_RELOAD=False +API_PORT=8000 +API_RELOAD=true +# Mettre false en production -# === Email Queue === -MAX_EMAIL_WORKERS=3 +# === CORS === +# Liste separee par virgules des origines autorisees +CORS_ORIGINS=["*"] -# === Logs === -LOG_LEVEL=INFO \ No newline at end of file +# === Sage Document Types === +SAGE_TYPE_DEVIS=0 +SAGE_TYPE_BON_COMMANDE=10 +SAGE_TYPE_PREPARATION=20 +SAGE_TYPE_BON_LIVRAISON=30 +SAGE_TYPE_BON_RETOUR=40 +SAGE_TYPE_BON_AVOIR=50 +SAGE_TYPE_FACTURE=60 \ No newline at end of file diff --git a/api.py b/api.py index 4a1fa05..111cae8 100644 --- a/api.py +++ b/api.py @@ -1,6 +1,6 @@ from fastapi import FastAPI, HTTPException, Path, Query, Depends, status, Body, Request from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import StreamingResponse +from fastapi.responses import JSONResponse, StreamingResponse from fastapi.encoders import jsonable_encoder from pydantic import BaseModel, Field, EmailStr from typing import List, Optional @@ -20,17 +20,17 @@ from routes.auth import router as auth_router from config.config import settings from database import ( init_db, + close_db, async_session_factory, get_session, EmailLog, StatutEmail as StatutEmailEnum, WorkflowLog, SignatureLog, - StatutSignature as StatutSignatureEnum, + StatutSignature, ) from services.email_queue import email_queue from sage.sage_client import sage_client, SageGatewayClient - from schemas import ( TiersDetails, BaremeRemiseResponse, @@ -58,7 +58,6 @@ from schemas import ( LivraisonCreateRequest, LivraisonUpdateRequest, SignatureRequest, - StatutSignature, ArticleCreateRequest, ArticleResponse, ArticleUpdateRequest, @@ -72,13 +71,20 @@ from schemas import ( ContactUpdate, ) from utils.normalization import normaliser_type_tiers - from routes.sage_gateway import router as sage_gateway_router +from services.redis_service import redis_service from core.sage_context import ( get_sage_client_for_user, get_gateway_context_for_user, GatewayContext, ) +from middleware.security import ( + setup_security_middleware, + init_security_services, + shutdown_security_services, + RateLimitMiddleware, +) + if os.path.exists("/app"): LOGS_DIR = FilePath("/app/logs") @@ -112,33 +118,61 @@ async def lifespan(app: FastAPI): email_queue.start(num_workers=settings.max_email_workers) logger.info("Email queue démarrée") + try: + await init_security_services() + logger.info("Services de securite initialises") + except Exception as e: + logger.warning(f"Redis non disponible, mode degrade active: {e}") + yield + await shutdown_security_services() + await close_db() + email_queue.stop() logger.info("Services arrêtés") app = FastAPI( - title="Sage Gateways", + title="Sage API Securisee", version="3.0.0", - description="Configuration multi-tenant des connexions Sage Gateway", + description="API avec authentification securisee par cookies HttpOnly", lifespan=lifespan, openapi_tags=TAGS_METADATA, + docs_url="/docs" if settings.is_development else None, + redoc_url="/redoc" if settings.is_development else None, ) app.add_middleware( CORSMiddleware, allow_origins=settings.cors_origins, - allow_methods=["GET", "POST", "PUT", "DELETE"], + allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_headers=["*"], allow_credentials=True, + expose_headers=["X-RateLimit-Limit", "X-RateLimit-Remaining", "X-RateLimit-Reset"], ) +setup_security_middleware(app) + +if settings.is_production: + app.add_middleware(RateLimitMiddleware) + app.include_router(auth_router) app.include_router(sage_gateway_router) +@app.exception_handler(Exception) +async def global_exception_handler(request: Request, exc: Exception): + """Gestionnaire global d'exceptions.""" + logger.error(f"Erreur non geree: {exc}", exc_info=True) + + return JSONResponse( + status_code=500, + content={"detail": "Erreur interne du serveur", "type": "internal_error"}, + ) + + async def universign_envoyer( doc_id: str, pdf_bytes: bytes, @@ -1135,7 +1169,7 @@ async def envoyer_signature_optimise( signer_url=resultat["signer_url"], email_signataire=demande.email_signataire, nom_signataire=demande.nom_signataire, - statut=StatutSignatureEnum.ENVOYE, + statut=StatutSignature.ENVOYE, date_envoi=datetime.now(), ) @@ -1191,7 +1225,7 @@ async def webhook_universign( return {"status": "not_found"} if event_type == "transaction.completed": - signature_log.statut = StatutSignatureEnum.SIGNE + signature_log.statut = StatutSignature.SIGNE signature_log.date_signature = datetime.now() logger.info(f"Signature confirmée: {signature_log.document_id}") @@ -1242,11 +1276,11 @@ async def webhook_universign( ) elif event_type == "transaction.refused": - signature_log.statut = StatutSignatureEnum.REFUSE + signature_log.statut = StatutSignature.REFUSE logger.warning(f"Signature refusée: {signature_log.document_id}") elif event_type == "transaction.expired": - signature_log.statut = StatutSignatureEnum.EXPIRE + signature_log.statut = StatutSignature.EXPIRE logger.warning(f"⏰ Transaction expirée: {signature_log.document_id}") await session.commit() @@ -1271,7 +1305,7 @@ async def relancer_signatures_automatique(session: AsyncSession = Depends(get_se query = select(SignatureLog).where( SignatureLog.statut.in_( - [StatutSignatureEnum.EN_ATTENTE, StatutSignatureEnum.ENVOYE] + [StatutSignature.EN_ATTENTE, StatutSignature.ENVOYE] ), SignatureLog.date_envoi < date_limite, SignatureLog.nb_relances < 3, # Max 3 relances @@ -1288,7 +1322,7 @@ async def relancer_signatures_automatique(session: AsyncSession = Depends(get_se jours_restants = 30 - nb_jours # Lien expire après 30 jours if jours_restants <= 0: - signature.statut = StatutSignatureEnum.EXPIRE + signature.statut = StatutSignature.EXPIRE continue template = templates_signature_email["relance_signature"] @@ -1394,7 +1428,7 @@ async def lister_signatures( query = select(SignatureLog).order_by(SignatureLog.date_envoi.desc()) if statut: - statut_db = StatutSignatureEnum[statut.value] + statut_db = StatutSignature[statut.value] query = query.where(SignatureLog.statut == statut_db) query = query.limit(limit) @@ -1437,15 +1471,15 @@ async def statut_signature_detail( if statut_universign.get("statut") != "ERREUR": statut_map = { - "EN_ATTENTE": StatutSignatureEnum.EN_ATTENTE, - "ENVOYE": StatutSignatureEnum.ENVOYE, - "SIGNE": StatutSignatureEnum.SIGNE, - "REFUSE": StatutSignatureEnum.REFUSE, - "EXPIRE": StatutSignatureEnum.EXPIRE, + "EN_ATTENTE": StatutSignature.EN_ATTENTE, + "ENVOYE": StatutSignature.ENVOYE, + "SIGNE": StatutSignature.SIGNE, + "REFUSE": StatutSignature.REFUSE, + "EXPIRE": StatutSignature.EXPIRE, } nouveau_statut = statut_map.get( - statut_universign["statut"], StatutSignatureEnum.EN_ATTENTE + statut_universign["statut"], StatutSignature.EN_ATTENTE ) signature_log.statut = nouveau_statut @@ -1477,9 +1511,7 @@ async def statut_signature_detail( @app.post("/signatures/refresh-all", tags=["Signatures"]) async def rafraichir_statuts_signatures(session: AsyncSession = Depends(get_session)): query = select(SignatureLog).where( - SignatureLog.statut.in_( - [StatutSignatureEnum.EN_ATTENTE, StatutSignatureEnum.ENVOYE] - ) + SignatureLog.statut.in_([StatutSignature.EN_ATTENTE, StatutSignature.ENVOYE]) ) result = await session.execute(query) @@ -1492,9 +1524,9 @@ async def rafraichir_statuts_signatures(session: AsyncSession = Depends(get_sess if statut_universign.get("statut") != "ERREUR": statut_map = { - "SIGNE": StatutSignatureEnum.SIGNE, - "REFUSE": StatutSignatureEnum.REFUSE, - "EXPIRE": StatutSignatureEnum.EXPIRE, + "SIGNE": StatutSignature.SIGNE, + "REFUSE": StatutSignature.REFUSE, + "EXPIRE": StatutSignature.EXPIRE, } nouveau = statut_map.get(statut_universign["statut"]) @@ -1548,7 +1580,7 @@ async def envoyer_devis_signature( signer_url=resultat["signer_url"], email_signataire=request.email_signataire, nom_signataire=request.nom_signataire, - statut=StatutSignatureEnum.ENVOYE, + statut=StatutSignature.ENVOYE, date_envoi=datetime.now(), ) @@ -1694,7 +1726,7 @@ async def relancer_devis_signature( signer_url=resultat["signer_url"], email_signataire=contact["email"], nom_signataire=contact["nom"] or contact["client_intitule"], - statut=StatutSignatureEnum.ENVOYE, + statut=StatutSignature.ENVOYE, date_envoi=datetime.now(), est_relance=True, nb_relances=1, @@ -3158,17 +3190,26 @@ async def health_check( sage: SageGatewayClient = Depends(get_sage_client_for_user), ): gateway_health = sage.health() + redis_status = "connected" + + try: + if not await redis_service.is_connected(): + redis_status = "disconnected" + except Exception: + redis_status = "error" return { "status": "healthy", "sage_gateway": gateway_health, "using_gateway_id": sage.gateway_id, + "timestamp": datetime.now().isoformat(), + "environment": settings.environment.value, + "services": {"redis": redis_status}, "email_queue": { "running": email_queue.running, "workers": len(email_queue.workers), "queue_size": email_queue.queue.qsize(), }, - "timestamp": datetime.now().isoformat(), } @@ -3177,22 +3218,13 @@ async def root(): return { "api": "Sage 100c Dataven - VPS Linux", "version": "2.0.0", - "documentation": "/docs", + "documentation": "/docs" + if settings.is_development + else "Disabled in production", "health": "/health", } -@app.get("/admin/cache/info", tags=["Admin"]) -async def info_cache(): - try: - cache_info = sage_client.get_cache_info() - return cache_info - - except Exception as e: - logger.error(f"Erreur info cache: {e}") - raise HTTPException(500, str(e)) - - @app.get("/admin/queue/status", tags=["Admin"]) async def statut_queue(): return { diff --git a/config/config.py b/config/config.py index 63bf99b..637cc3f 100644 --- a/config/config.py +++ b/config/config.py @@ -1,5 +1,12 @@ from pydantic_settings import BaseSettings, SettingsConfigDict -from typing import List +from typing import List, Optional +from enum import Enum + + +class Environment(str, Enum): + DEVELOPMENT = "development" + STAGING = "staging" + PRODUCTION = "production" class Settings(BaseSettings): @@ -7,12 +14,60 @@ class Settings(BaseSettings): env_file=".env", env_file_encoding="utf-8", case_sensitive=False, extra="ignore" ) + # === Environment === + environment: Environment = Environment.DEVELOPMENT + # === JWT & Auth === jwt_secret: str - jwt_algorithm: str - access_token_expire_minutes: int - refresh_token_expire_days: int + jwt_algorithm: str = "HS256" + access_token_expire_minutes: int = 15 + refresh_token_expire_days: int = 7 + csrf_token_expire_minutes: int = 60 + # === Cookie Settings === + cookie_domain: Optional[str] = None + cookie_secure: bool = True + cookie_samesite: str = "strict" + cookie_httponly: bool = True + cookie_access_token_name: str = "access_token" + cookie_refresh_token_name: str = "refresh_token" + cookie_csrf_token_name: str = "csrf_token" + + # === Redis (Token Blacklist & Rate Limiting) === + redis_url: str = "redis://localhost:6379/0" + redis_password: Optional[str] = None + redis_ssl: bool = False + token_blacklist_prefix: str = "blacklist:" + rate_limit_prefix: str = "ratelimit:" + + # === Rate Limiting === + rate_limit_login_attempts: int = 5 + rate_limit_login_window_minutes: int = 15 + rate_limit_api_requests: int = 100 + rate_limit_api_window_seconds: int = 60 + + # === Security === + password_min_length: int = 8 + password_require_uppercase: bool = True + password_require_lowercase: bool = True + password_require_digit: bool = True + password_require_special: bool = True + account_lockout_threshold: int = 5 + account_lockout_duration_minutes: int = 30 + + # === Fingerprint === + fingerprint_secret: str = "" + fingerprint_components: List[str] = [ + "user_agent", + "accept_language", + "accept_encoding", + ] + + # === Refresh Token Rotation === + refresh_token_rotation_enabled: bool = True + refresh_token_reuse_window_seconds: int = 10 + + # === Sage Types === SAGE_TYPE_DEVIS: int = 0 SAGE_TYPE_BON_COMMANDE: int = 10 SAGE_TYPE_PREPARATION: int = 20 @@ -21,12 +76,12 @@ class Settings(BaseSettings): SAGE_TYPE_BON_AVOIR: int = 50 SAGE_TYPE_FACTURE: int = 60 - # === Sage Gateway (Windows) === + # === Sage Gateway === sage_gateway_url: str sage_gateway_token: str frontend_url: str - # === Base de données === + # === Database === database_url: str = "sqlite+aiosqlite:///./data/sage_dataven.db" # === SMTP === @@ -42,9 +97,9 @@ class Settings(BaseSettings): universign_api_url: str # === API === - api_host: str - api_port: int - api_reload: bool = False + api_host: str = "0.0.0.0" + api_port: int = 8000 + api_reload: bool = True # === Email Queue === max_email_workers: int = 3 @@ -54,5 +109,13 @@ class Settings(BaseSettings): # === CORS === cors_origins: List[str] = ["*"] + @property + def is_production(self) -> bool: + return self.environment == Environment.PRODUCTION + + @property + def is_development(self) -> bool: + return self.environment == Environment.DEVELOPMENT + settings = Settings() diff --git a/core/dependencies.py b/core/dependencies.py index 039081c..f71620f 100644 --- a/core/dependencies.py +++ b/core/dependencies.py @@ -1,41 +1,49 @@ -from fastapi import Depends, HTTPException, status -from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from fastapi import Depends, HTTPException, status, Request from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import select -from database import get_session, User -from security.auth import decode_token -from typing import Optional +from typing import Optional, Tuple from datetime import datetime +import logging -security = HTTPBearer() +from database import get_session +from database import User, AuditEventType +from services.token_service import TokenService +from services.audit_service import AuditService +from security.cookies import CookieManager +from security.fingerprint import DeviceFingerprint, get_client_ip +from security.csrf import CSRFProtection + +logger = logging.getLogger(__name__) async def get_current_user( - credentials: HTTPAuthorizationCredentials = Depends(security), - session: AsyncSession = Depends(get_session), + request: Request, session: AsyncSession = Depends(get_session) ) -> User: - token = credentials.credentials + token = CookieManager.get_access_token(request) + + if not token: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Authentification requise", + headers={"WWW-Authenticate": "Bearer"}, + ) + + fingerprint_hash = DeviceFingerprint.generate_hash(request) + + payload = await TokenService.validate_access_token(token, fingerprint_hash) - payload = decode_token(token) if not payload: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Token invalide ou expiré", + detail="Token invalide ou expire", headers={"WWW-Authenticate": "Bearer"}, ) - if payload.get("type") != "access": - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Type de token incorrect", - headers={"WWW-Authenticate": "Bearer"}, - ) - - user_id: str = payload.get("sub") + user_id = payload.get("sub") if not user_id: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, - detail="Token malformé", + detail="Token malformed", headers={"WWW-Authenticate": "Bearer"}, ) @@ -51,33 +59,31 @@ async def get_current_user( if not user.is_active: raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail="Compte désactivé" + status_code=status.HTTP_403_FORBIDDEN, detail="Compte desactive" ) if not user.is_verified: raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Email non vérifié. Consultez votre boîte de réception.", + status_code=status.HTTP_403_FORBIDDEN, detail="Email non verifie" ) if user.locked_until and user.locked_until > datetime.now(): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="Compte temporairement verrouillé suite à trop de tentatives échouées", + detail="Compte temporairement verrouille", ) + request.state.user = user + request.state.session_id = payload.get("sid") + return user async def get_current_user_optional( - credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), - session: AsyncSession = Depends(get_session), + request: Request, session: AsyncSession = Depends(get_session) ) -> Optional[User]: - if not credentials: - return None - try: - return await get_current_user(credentials, session) + return await get_current_user(request, session) except HTTPException: return None @@ -87,8 +93,99 @@ def require_role(*allowed_roles: str): if user.role not in allowed_roles: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail=f"Accès refusé. Rôles requis: {', '.join(allowed_roles)}", + detail=f"Acces refuse. Roles requis: {', '.join(allowed_roles)}", ) return user return role_checker + + +def require_verified_email(): + async def email_checker(user: User = Depends(get_current_user)) -> User: + if not user.is_verified: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Verification email requise", + ) + return user + + return email_checker + + +async def verify_csrf_token( + request: Request, user: User = Depends(get_current_user) +) -> None: + if CSRFProtection.is_exempt(request): + return + + session_id = getattr(request.state, "session_id", None) + + if not CSRFProtection.validate_request(request, session_id): + logger.warning( + f"CSRF validation echouee pour user {user.id} " + f"sur {request.method} {request.url.path}" + ) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail="Verification CSRF echouee" + ) + + +async def get_auth_context( + request: Request, session: AsyncSession = Depends(get_session) +) -> Tuple[Optional[User], str, str]: + ip_address = get_client_ip(request) + fingerprint_hash = DeviceFingerprint.generate_hash(request) + + try: + user = await get_current_user(request, session) + except HTTPException: + user = None + + return user, ip_address, fingerprint_hash + + +class AuthenticatedRoute: + def __init__( + self, + require_csrf: bool = True, + allowed_roles: Optional[Tuple[str, ...]] = None, + audit_event: Optional[AuditEventType] = None, + ): + self.require_csrf = require_csrf + self.allowed_roles = allowed_roles + self.audit_event = audit_event + + async def __call__( + self, request: Request, session: AsyncSession = Depends(get_session) + ) -> User: + user = await get_current_user(request, session) + + if self.allowed_roles and user.role not in self.allowed_roles: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Acces refuse pour ce role", + ) + + if self.require_csrf and not CSRFProtection.is_exempt(request): + session_id = getattr(request.state, "session_id", None) + if not CSRFProtection.validate_request(request, session_id): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Verification CSRF echouee", + ) + + if self.audit_event: + await AuditService.log_event( + session=session, + event_type=self.audit_event, + request=request, + user_id=user.id, + success=True, + ) + + return user + + +require_admin = require_role("admin") +require_manager = require_role("admin", "manager") +require_user = require_role("admin", "manager", "user") diff --git a/database/__init__.py b/database/__init__.py index 7e41efd..97af367 100644 --- a/database/__init__.py +++ b/database/__init__.py @@ -5,13 +5,18 @@ from database.db_config import ( get_session, close_db, ) -from database.models.generic_model import ( - CacheMetadata, - AuditLog, + +from database.models.generic_model import Base + +from database.models.auth_models import ( + User, RefreshToken, + AuditLog, + AuditEventType, LoginAttempt, + UserSession, ) -from database.models.user import User + from database.models.email import EmailLog from database.models.signature import SignatureLog from database.models.sage_config import SageGatewayConfig @@ -28,15 +33,16 @@ __all__ = [ "get_session", "close_db", "Base", + "User", + "RefreshToken", + "AuditLog", + "AuditEventType", + "LoginAttempt", + "UserSession", "EmailLog", "SignatureLog", "WorkflowLog", - "CacheMetadata", - "AuditLog", "StatutEmail", "StatutSignature", - "User", - "RefreshToken", - "LoginAttempt", "SageGatewayConfig", ] diff --git a/database/models/auth_models.py b/database/models/auth_models.py new file mode 100644 index 0000000..254a344 --- /dev/null +++ b/database/models/auth_models.py @@ -0,0 +1,214 @@ +from sqlalchemy import ( + Column, + Integer, + String, + DateTime, + Boolean, + Text, + ForeignKey, + Index, + Enum as SQLEnum, +) +from sqlalchemy.orm import relationship +from datetime import datetime +from enum import Enum + +from database.models.generic_model import Base + + +class User(Base): + __tablename__ = "users" + + id = Column(String(36), primary_key=True) + email = Column(String(255), unique=True, nullable=False, index=True) + hashed_password = Column(String(255), nullable=False) + + nom = Column(String(100), nullable=False) + prenom = Column(String(100), nullable=False) + role = Column(String(50), default="user") + + is_verified = Column(Boolean, default=False, index=True) + verification_token = Column(String(255), nullable=True, unique=True, index=True) + verification_token_expires = Column(DateTime, nullable=True) + + is_active = Column(Boolean, default=True, index=True) + failed_login_attempts = Column(Integer, default=0) + locked_until = Column(DateTime, nullable=True) + + reset_token = Column(String(255), nullable=True, unique=True, index=True) + reset_token_expires = Column(DateTime, nullable=True) + + password_changed_at = Column(DateTime, nullable=True) + must_change_password = Column(Boolean, default=False) + + created_at = Column(DateTime, default=datetime.now, nullable=False) + updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now) + last_login = Column(DateTime, nullable=True) + last_login_ip = Column(String(45), nullable=True) + + refresh_tokens = relationship( + "RefreshToken", back_populates="user", cascade="all, delete-orphan" + ) + + audit_logs = relationship( + "AuditLog", back_populates="user", cascade="all, delete-orphan" + ) + + def __repr__(self): + return f"" + + +class RefreshToken(Base): + __tablename__ = "refresh_tokens" + + id = Column(String(36), primary_key=True) + user_id = Column( + String(36), + ForeignKey("users.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + + token_hash = Column(String(64), unique=True, nullable=False, index=True) + token_id = Column(String(32), unique=True, nullable=False, index=True) + + fingerprint_hash = Column(String(64), nullable=True) + device_info = Column(String(500), nullable=True) + ip_address = Column(String(45), nullable=True) + + is_revoked = Column(Boolean, default=False, index=True) + revoked_at = Column(DateTime, nullable=True) + revoked_reason = Column(String(100), nullable=True) + + is_used = Column(Boolean, default=False) + used_at = Column(DateTime, nullable=True) + replaced_by = Column(String(36), nullable=True) + + expires_at = Column(DateTime, nullable=False, index=True) + created_at = Column(DateTime, default=datetime.now, nullable=False) + last_used_at = Column(DateTime, nullable=True) + + user = relationship("User", back_populates="refresh_tokens") + + __table_args__ = ( + Index("ix_refresh_tokens_user_valid", "user_id", "is_revoked", "expires_at"), + ) + + def __repr__(self): + return f"" + + +class AuditEventType(str, Enum): + LOGIN_SUCCESS = "login_success" + LOGIN_FAILED = "login_failed" + LOGOUT = "logout" + PASSWORD_CHANGE = "password_change" + PASSWORD_RESET_REQUEST = "password_reset_request" + PASSWORD_RESET_COMPLETE = "password_reset_complete" + EMAIL_VERIFICATION = "email_verification" + ACCOUNT_LOCKED = "account_locked" + ACCOUNT_UNLOCKED = "account_unlocked" + TOKEN_REFRESH = "token_refresh" + TOKEN_REVOKED = "token_revoked" + SUSPICIOUS_ACTIVITY = "suspicious_activity" + SESSION_CREATED = "session_created" + SESSION_TERMINATED = "session_terminated" + + +class AuditLog(Base): + __tablename__ = "audit_logs" + + id = Column(String(36), primary_key=True) + user_id = Column( + String(36), + ForeignKey("users.id", ondelete="SET NULL"), + nullable=True, + index=True, + ) + + event_type = Column(SQLEnum(AuditEventType), nullable=False, index=True) + event_description = Column(Text, nullable=True) + + ip_address = Column(String(45), nullable=True, index=True) + user_agent = Column(String(500), nullable=True) + fingerprint_hash = Column(String(64), nullable=True) + + resource_type = Column(String(50), nullable=True) + resource_id = Column(String(100), nullable=True) + + request_method = Column(String(10), nullable=True) + request_path = Column(String(500), nullable=True) + + meta = Column("metadata", Text, nullable=True) + + success = Column(Boolean, default=True) + failure_reason = Column(String(255), nullable=True) + + created_at = Column(DateTime, default=datetime.now, nullable=False, index=True) + + user = relationship("User", back_populates="audit_logs") + + __table_args__ = ( + Index("ix_audit_logs_user_event", "user_id", "event_type", "created_at"), + Index("ix_audit_logs_ip_event", "ip_address", "event_type", "created_at"), + ) + + def __repr__(self): + return f"" + + +class LoginAttempt(Base): + __tablename__ = "login_attempts" + + id = Column(Integer, primary_key=True, autoincrement=True) + + email = Column(String(255), nullable=False, index=True) + ip_address = Column(String(45), nullable=True, index=True) + user_agent = Column(String(500), nullable=True) + fingerprint_hash = Column(String(64), nullable=True) + + success = Column(Boolean, default=False, index=True) + failure_reason = Column(String(255), nullable=True) + + timestamp = Column(DateTime, default=datetime.now, nullable=False, index=True) + + __table_args__ = ( + Index("ix_login_attempts_email_time", "email", "timestamp"), + Index("ix_login_attempts_ip_time", "ip_address", "timestamp"), + ) + + def __repr__(self): + return f"" + + +class UserSession(Base): + __tablename__ = "user_sessions" + + id = Column(String(36), primary_key=True) + user_id = Column( + String(36), + ForeignKey("users.id", ondelete="CASCADE"), + nullable=False, + index=True, + ) + + session_token_hash = Column(String(64), unique=True, nullable=False, index=True) + refresh_token_id = Column(String(36), nullable=True) + + device_info = Column(String(500), nullable=True) + ip_address = Column(String(45), nullable=True) + fingerprint_hash = Column(String(64), nullable=True) + location = Column(String(255), nullable=True) + + is_active = Column(Boolean, default=True, index=True) + terminated_at = Column(DateTime, nullable=True) + termination_reason = Column(String(100), nullable=True) + + created_at = Column(DateTime, default=datetime.now, nullable=False) + last_activity = Column(DateTime, default=datetime.now, nullable=False) + expires_at = Column(DateTime, nullable=False) + + __table_args__ = (Index("ix_user_sessions_user_active", "user_id", "is_active"),) + + def __repr__(self): + return f"" diff --git a/database/models/generic_model.py b/database/models/generic_model.py index 840b614..2265e62 100644 --- a/database/models/generic_model.py +++ b/database/models/generic_model.py @@ -5,7 +5,6 @@ from sqlalchemy import ( DateTime, Float, Text, - Boolean, ) from sqlalchemy.ext.declarative import declarative_base from datetime import datetime @@ -29,63 +28,3 @@ class CacheMetadata(Base): def __repr__(self): return f"" - - -class AuditLog(Base): - __tablename__ = "audit_logs" - - id = Column(Integer, primary_key=True, autoincrement=True) - - action = Column(String(100), nullable=False, index=True) - ressource_type = Column(String(50), nullable=True) - ressource_id = Column(String(100), nullable=True, index=True) - - utilisateur = Column(String(100), nullable=True) - ip_address = Column(String(45), nullable=True) - - succes = Column(Boolean, default=True) - details = Column(Text, nullable=True) - erreur = Column(Text, nullable=True) - - date_action = Column(DateTime, default=datetime.now, nullable=False, index=True) - - def __repr__(self): - return f"" - - -class RefreshToken(Base): - __tablename__ = "refresh_tokens" - - id = Column(String(36), primary_key=True) - user_id = Column(String(36), nullable=False, index=True) - token_hash = Column(String(255), nullable=False, unique=True, index=True) - - device_info = Column(String(500), nullable=True) - ip_address = Column(String(45), nullable=True) - - expires_at = Column(DateTime, nullable=False) - created_at = Column(DateTime, default=datetime.now, nullable=False) - - is_revoked = Column(Boolean, default=False) - revoked_at = Column(DateTime, nullable=True) - - def __repr__(self): - return f"" - - -class LoginAttempt(Base): - __tablename__ = "login_attempts" - - id = Column(Integer, primary_key=True, autoincrement=True) - - email = Column(String(255), nullable=False, index=True) - ip_address = Column(String(45), nullable=False, index=True) - user_agent = Column(String(500), nullable=True) - - success = Column(Boolean, default=False) - failure_reason = Column(String(255), nullable=True) - - timestamp = Column(DateTime, default=datetime.now, nullable=False, index=True) - - def __repr__(self): - return f"" diff --git a/database/models/sage_config.py b/database/models/sage_config.py index f6ed363..a48dac1 100644 --- a/database/models/sage_config.py +++ b/database/models/sage_config.py @@ -22,9 +22,6 @@ class SageGatewayConfig(Base): gateway_url = Column(String(500), nullable=False) gateway_token = Column(String(255), nullable=False) - sage_database = Column(String(255), nullable=True) - sage_company = Column(String(255), nullable=True) - is_active = Column(Boolean, default=False, index=True) is_default = Column(Boolean, default=False) priority = Column(Integer, default=0) diff --git a/middleware/security.py b/middleware/security.py new file mode 100644 index 0000000..6e9409b --- /dev/null +++ b/middleware/security.py @@ -0,0 +1,181 @@ +from fastapi import Request, Response, status +from fastapi.responses import JSONResponse +from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint +from typing import Set +import logging +import time + +from config.config import settings +from security.csrf import CSRFProtection +from security.fingerprint import get_client_ip +from services.redis_service import redis_service + +logger = logging.getLogger(__name__) + + +RATE_LIMIT_EXEMPT_PATHS: Set[str] = { + "/health", + "/docs", + "/redoc", + "/openapi.json", +} + + +class SecurityHeadersMiddleware(BaseHTTPMiddleware): + async def dispatch( + self, request: Request, call_next: RequestResponseEndpoint + ) -> Response: + response = await call_next(request) + + response.headers["X-Content-Type-Options"] = "nosniff" + response.headers["X-Frame-Options"] = "DENY" + response.headers["X-XSS-Protection"] = "1; mode=block" + response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" + + if settings.is_production: + response.headers["Strict-Transport-Security"] = ( + "max-age=31536000; includeSubDomains" + ) + + response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate" + response.headers["Pragma"] = "no-cache" + + response.headers["Permissions-Policy"] = ( + "geolocation=(), microphone=(), camera=()" + ) + + return response + + +class CSRFMiddleware(BaseHTTPMiddleware): + async def dispatch( + self, request: Request, call_next: RequestResponseEndpoint + ) -> Response: + if CSRFProtection.is_exempt(request): + return await call_next(request) + + if not CSRFProtection.validate_double_submit(request): + logger.warning( + f"CSRF validation echouee: {request.method} {request.url.path} " + f"depuis {get_client_ip(request)}" + ) + return JSONResponse( + status_code=status.HTTP_403_FORBIDDEN, + content={"detail": "Verification CSRF echouee"}, + ) + + return await call_next(request) + + +class RateLimitMiddleware(BaseHTTPMiddleware): + async def dispatch( + self, request: Request, call_next: RequestResponseEndpoint + ) -> Response: + path = request.url.path.rstrip("/") + if path in RATE_LIMIT_EXEMPT_PATHS: + return await call_next(request) + + ip = get_client_ip(request) + key = f"api:{ip}" + window_seconds = settings.rate_limit_api_window_seconds + max_requests = settings.rate_limit_api_requests + + try: + count = await redis_service.increment_rate_limit(key, window_seconds) + remaining = max(0, max_requests - count) + + response = await call_next(request) + + response.headers["X-RateLimit-Limit"] = str(max_requests) + response.headers["X-RateLimit-Remaining"] = str(remaining) + response.headers["X-RateLimit-Reset"] = str(window_seconds) + + if count > max_requests: + logger.warning( + f"Rate limit depasse pour IP {ip}: {count}/{max_requests}" + ) + return JSONResponse( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + content={"detail": "Limite de requetes atteinte"}, + headers={ + "X-RateLimit-Limit": str(max_requests), + "X-RateLimit-Remaining": "0", + "X-RateLimit-Reset": str(window_seconds), + "Retry-After": str(window_seconds), + }, + ) + + return response + + except Exception as e: + logger.error(f"Erreur rate limiting: {e}") + return await call_next(request) + + +class RequestLoggingMiddleware(BaseHTTPMiddleware): + async def dispatch( + self, request: Request, call_next: RequestResponseEndpoint + ) -> Response: + start_time = time.time() + + ip = get_client_ip(request) + method = request.method + path = request.url.path + + response = await call_next(request) + + duration_ms = (time.time() - start_time) * 1000 + + log_level = logging.INFO + if response.status_code >= 500: + log_level = logging.ERROR + elif response.status_code >= 400: + log_level = logging.WARNING + + logger.log( + log_level, + f"{method} {path} - {response.status_code} - {duration_ms:.2f}ms - {ip}", + ) + + return response + + +class FingerprintValidationMiddleware(BaseHTTPMiddleware): + VALIDATION_PATHS: Set[str] = { + "/auth/refresh", + "/auth/logout", + } + + async def dispatch( + self, request: Request, call_next: RequestResponseEndpoint + ) -> Response: + path = request.url.path.rstrip("/") + + if path not in self.VALIDATION_PATHS: + return await call_next(request) + + return await call_next(request) + + +def setup_security_middleware(app) -> None: + app.add_middleware(RequestLoggingMiddleware) + + app.add_middleware(SecurityHeadersMiddleware) + + app.add_middleware(FingerprintValidationMiddleware) + + +async def init_security_services() -> None: + try: + await redis_service.connect() + logger.info("Services de securite initialises") + except Exception as e: + logger.warning(f"Redis non disponible, fonctionnement en mode degrade: {e}") + + +async def shutdown_security_services() -> None: + try: + await redis_service.disconnect() + logger.info("Services de securite arretes") + except Exception as e: + logger.error(f"Erreur arret services securite: {e}") diff --git a/requirements.txt b/requirements.txt index 2ece0f4..e8dbe28 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,14 @@ fastapi uvicorn[standard] +starlette +structlog + pydantic pydantic-settings reportlab requests msal +aiosmtplib python-multipart email-validator @@ -13,9 +17,13 @@ python-dotenv python-jose[cryptography] passlib[bcrypt] bcrypt==4.2.0 +PyJWT -sqlalchemy +sqlalchemy[asyncio] aiosqlite tenacity +asyncpg -httpx \ No newline at end of file +httpx + +redis[hiredis] \ No newline at end of file diff --git a/routes/auth.py b/routes/auth.py index 9b7d377..386df98 100644 --- a/routes/auth.py +++ b/routes/auth.py @@ -1,27 +1,29 @@ -from fastapi import APIRouter, Depends, HTTPException, status, Request +from fastapi import APIRouter, Depends, HTTPException, status, Request, Response from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy import select +from sqlalchemy import false, select from pydantic import BaseModel, EmailStr, Field from datetime import datetime, timedelta -from typing import Optional +from typing import Optional, List import uuid +import logging -from database import get_session, User, RefreshToken, LoginAttempt +from config.config import settings +from database import get_session +from database import User, RefreshToken, AuditEventType from security.auth import ( hash_password, verify_password, validate_password_strength, - create_access_token, - create_refresh_token, - decode_token, generate_verification_token, generate_reset_token, - hash_token, ) +from security.cookies import CookieManager, set_auth_cookies +from security.fingerprint import DeviceFingerprint, get_client_ip +from security.rate_limiter import RateLimiter +from services.token_service import TokenService +from services.audit_service import AuditService from services.email_service import AuthEmailService from core.dependencies import get_current_user -from config.config import settings -import logging logger = logging.getLogger(__name__) router = APIRouter(prefix="/auth", tags=["Authentication"]) @@ -29,25 +31,20 @@ router = APIRouter(prefix="/auth", tags=["Authentication"]) class RegisterRequest(BaseModel): email: EmailStr - password: str = Field(..., min_length=8) + password: str = Field(..., min_length=8, max_length=128) nom: str = Field(..., min_length=2, max_length=100) prenom: str = Field(..., min_length=2, max_length=100) class LoginRequest(BaseModel): email: EmailStr - password: str + password: str = Field(..., min_length=1, max_length=128) class TokenResponse(BaseModel): - access_token: str - refresh_token: str - token_type: str = "bearer" - expires_in: int = 86400 - - -class RefreshTokenRequest(BaseModel): - refresh_token: str + message: str + user: dict + expires_in: int class ForgotPasswordRequest(BaseModel): @@ -56,7 +53,7 @@ class ForgotPasswordRequest(BaseModel): class ResetPasswordRequest(BaseModel): token: str - new_password: str = Field(..., min_length=8) + new_password: str = Field(..., min_length=8, max_length=128) class VerifyEmailRequest(BaseModel): @@ -67,44 +64,17 @@ class ResendVerificationRequest(BaseModel): email: EmailStr -async def log_login_attempt( - session: AsyncSession, - email: str, - ip: str, - user_agent: str, - success: bool, - failure_reason: Optional[str] = None, -): - attempt = LoginAttempt( - email=email, - ip_address=ip, - user_agent=user_agent, - success=success, - failure_reason=failure_reason, - timestamp=datetime.now(), - ) - session.add(attempt) - await session.commit() +class ChangePasswordRequest(BaseModel): + current_password: str + new_password: str = Field(..., min_length=8, max_length=128) -async def check_rate_limit( - session: AsyncSession, email: str, ip: str -) -> tuple[bool, str]: - time_window = datetime.now() - timedelta(minutes=15) - - result = await session.execute( - select(LoginAttempt).where( - LoginAttempt.email == email, - LoginAttempt.success, - LoginAttempt.timestamp >= time_window, - ) - ) - failed_attempts = result.scalars().all() - - if len(failed_attempts) >= 5: - return False, "Trop de tentatives échouées. Réessayez dans 15 minutes." - - return True, "" +class SessionResponse(BaseModel): + id: str + device_info: Optional[str] + ip_address: Optional[str] + created_at: str + last_used_at: Optional[str] @router.post("/register", status_code=status.HTTP_201_CREATED) @@ -113,12 +83,18 @@ async def register( request: Request, session: AsyncSession = Depends(get_session), ): - result = await session.execute(select(User).where(User.email == data.email)) - existing_user = result.scalar_one_or_none() + ip = get_client_ip(request) - if existing_user: + allowed, error_msg = await RateLimiter.check_registration_rate_limit(ip) + if not allowed: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail="Cet email est déjà utilisé" + status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail=error_msg + ) + + result = await session.execute(select(User).where(User.email == data.email.lower())) + if result.scalar_one_or_none(): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, detail="Cet email est deja utilise" ) is_valid, error_msg = validate_password_strength(data.password) @@ -143,23 +119,233 @@ async def register( await session.commit() base_url = str(request.base_url).rstrip("/") - email_sent = AuthEmailService.send_verification_email( - data.email, verification_token, base_url - ) + AuthEmailService.send_verification_email(data.email, verification_token, base_url) - if not email_sent: - logger.warning(f"Échec envoi email vérification pour {data.email}") - - logger.info(f" Nouvel utilisateur inscrit: {data.email} (ID: {new_user.id})") + logger.info(f"Nouvel utilisateur inscrit: {data.email}") return { "success": True, - "message": "Inscription réussie ! Consultez votre email pour vérifier votre compte.", + "message": "Inscription reussie. Consultez votre email pour verifier votre compte.", "user_id": new_user.id, - "email": data.email, } +@router.post("/login") +async def login( + data: LoginRequest, + request: Request, + response: Response, + session: AsyncSession = Depends(get_session), +): + ip = get_client_ip(request) + user_agent = request.headers.get("User-Agent", "") + fingerprint_hash = DeviceFingerprint.generate_hash(request) + + allowed, error_msg, _ = await RateLimiter.check_login_rate_limit( + data.email.lower(), ip + ) + if not allowed: + await AuditService.log_login_failed(session, request, data.email, "rate_limit") + raise HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail=error_msg + ) + + result = await session.execute(select(User).where(User.email == data.email.lower())) + user = result.scalar_one_or_none() + + if not user or not verify_password(data.password, user.hashed_password): + await RateLimiter.record_login_attempt(data.email.lower(), ip, success=False) + await AuditService.record_login_attempt( + session, request, data.email, False, "invalid_credentials" + ) + + if user: + user.failed_login_attempts = (user.failed_login_attempts or 0) + 1 + + if user.failed_login_attempts >= settings.account_lockout_threshold: + user.locked_until = datetime.now() + timedelta( + minutes=settings.account_lockout_duration_minutes + ) + await AuditService.log_account_locked( + session, request, user.id, "too_many_failed_attempts" + ) + await session.commit() + + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Compte verrouille. Reessayez dans {settings.account_lockout_duration_minutes} minutes.", + ) + + await session.commit() + + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Email ou mot de passe incorrect", + ) + + if not user.is_active: + await AuditService.log_login_failed( + session, request, data.email, "account_disabled", user.id + ) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail="Compte desactive" + ) + + if not user.is_verified: + await AuditService.log_login_failed( + session, request, data.email, "email_not_verified", user.id + ) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Email non verifie. Consultez votre boite de reception.", + ) + + if user.locked_until and user.locked_until > datetime.now(): + await AuditService.log_login_failed( + session, request, data.email, "account_locked", user.id + ) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Compte temporairement verrouille", + ) + + user.failed_login_attempts = 0 + user.locked_until = None + user.last_login = datetime.now() + user.last_login_ip = ip + + ( + access_token, + refresh_token, + csrf_token, + session_id, + ) = await TokenService.create_token_pair( + session=session, + user=user, + fingerprint_hash=fingerprint_hash, + device_info=user_agent, + ip_address=ip, + ) + + await session.commit() + + await RateLimiter.record_login_attempt(data.email.lower(), ip, success=True) + await AuditService.log_login_success(session, request, user.id, user.email) + + set_auth_cookies(response, access_token, refresh_token, csrf_token) + + logger.info(f"Connexion reussie: {user.email} depuis {ip}") + + return TokenResponse( + message="Connexion reussie", + user={ + "id": user.id, + "email": user.email, + "nom": user.nom, + "prenom": user.prenom, + "role": user.role, + }, + expires_in=settings.access_token_expire_minutes * 60, + ) + + +@router.post("/refresh") +async def refresh_tokens( + request: Request, response: Response, session: AsyncSession = Depends(get_session) +): + refresh_token = CookieManager.get_refresh_token(request) + + if not refresh_token: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, detail="Refresh token manquant" + ) + + ip = get_client_ip(request) + user_agent = request.headers.get("User-Agent", "") + fingerprint_hash = DeviceFingerprint.generate_hash(request) + + result = await TokenService.refresh_tokens( + session=session, + refresh_token=refresh_token, + fingerprint_hash=fingerprint_hash, + device_info=user_agent, + ip_address=ip, + ) + + if not result: + CookieManager.clear_all_auth_cookies(response) + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Refresh token invalide ou expire", + ) + + new_access, new_refresh, new_csrf, session_id = result + + await session.commit() + + set_auth_cookies(response, new_access, new_refresh, new_csrf) + + logger.debug("Tokens rafraichis avec succes") + + return { + "message": "Tokens rafraichis", + "expires_in": settings.access_token_expire_minutes * 60, + } + + +@router.post("/logout") +async def logout( + request: Request, + response: Response, + session: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), +): + refresh_token = CookieManager.get_refresh_token(request) + + if refresh_token: + await TokenService.revoke_token( + session=session, refresh_token=refresh_token, reason="user_logout" + ) + + await AuditService.log_logout(session, request, user.id) + + await session.commit() + + CookieManager.clear_all_auth_cookies(response) + + logger.info(f"Deconnexion: {user.email}") + + return {"success": True, "message": "Deconnexion reussie"} + + +@router.post("/logout-all") +async def logout_all_sessions( + request: Request, + response: Response, + session: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), +): + count = await TokenService.revoke_all_user_tokens( + session=session, user_id=user.id, reason="user_logout_all" + ) + + await AuditService.log_event( + session=session, + event_type=AuditEventType.SESSION_TERMINATED, + request=request, + user_id=user.id, + description=f"Toutes les sessions terminees ({count} tokens revoques)", + ) + + await session.commit() + + CookieManager.clear_all_auth_cookies(response) + + logger.info(f"Toutes les sessions terminees pour {user.email}: {count} tokens") + + return {"success": True, "message": f"{count} session(s) terminee(s)"} + + @router.get("/verify-email") async def verify_email_get(token: str, session: AsyncSession = Depends(get_session)): result = await session.execute(select(User).where(User.verification_token == token)) @@ -168,13 +354,16 @@ async def verify_email_get(token: str, session: AsyncSession = Depends(get_sessi if not user: return { "success": False, - "message": "Token de vérification invalide ou déjà utilisé.", + "message": "Token de verification invalide ou deja utilise.", } - if user.verification_token_expires < datetime.now(): + if ( + user.verification_token_expires + and user.verification_token_expires < datetime.now() + ): return { "success": False, - "message": "Token expiré. Veuillez demander un nouvel email de vérification.", + "message": "Token expire. Demandez un nouveau lien de verification.", "expired": True, } @@ -183,18 +372,19 @@ async def verify_email_get(token: str, session: AsyncSession = Depends(get_sessi user.verification_token_expires = None await session.commit() - logger.info(f" Email vérifié: {user.email}") + logger.info(f"Email verifie: {user.email}") return { "success": True, - "message": " Email vérifié avec succès ! Vous pouvez maintenant vous connecter.", - "email": user.email, + "message": "Email verifie avec succes. Vous pouvez maintenant vous connecter.", } @router.post("/verify-email") async def verify_email_post( - data: VerifyEmailRequest, session: AsyncSession = Depends(get_session) + data: VerifyEmailRequest, + request: Request, + session: AsyncSession = Depends(get_session), ): result = await session.execute( select(User).where(User.verification_token == data.token) @@ -204,26 +394,35 @@ async def verify_email_post( if not user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Token de vérification invalide", + detail="Token de verification invalide", ) - if user.verification_token_expires < datetime.now(): + if ( + user.verification_token_expires + and user.verification_token_expires < datetime.now() + ): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Token expiré. Demandez un nouvel email de vérification.", + detail="Token expire. Demandez un nouveau lien de verification.", ) user.is_verified = True user.verification_token = None user.verification_token_expires = None + + await AuditService.log_event( + session=session, + event_type=AuditEventType.EMAIL_VERIFICATION, + request=request, + user_id=user.id, + description="Email verifie avec succes", + ) + await session.commit() - logger.info(f" Email vérifié: {user.email}") + logger.info(f"Email verifie: {user.email}") - return { - "success": True, - "message": "Email vérifié avec succès ! Vous pouvez maintenant vous connecter.", - } + return {"success": True, "message": "Email verifie avec succes."} @router.post("/resend-verification") @@ -238,12 +437,12 @@ async def resend_verification( if not user: return { "success": True, - "message": "Si cet email existe, un nouveau lien de vérification a été envoyé.", + "message": "Si cet email existe, un nouveau lien a ete envoye.", } if user.is_verified: raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, detail="Ce compte est déjà vérifié" + status_code=status.HTTP_400_BAD_REQUEST, detail="Ce compte est deja verifie" ) verification_token = generate_verification_token() @@ -254,165 +453,7 @@ async def resend_verification( base_url = str(request.base_url).rstrip("/") AuthEmailService.send_verification_email(user.email, verification_token, base_url) - return {"success": True, "message": "Un nouveau lien de vérification a été envoyé."} - - -@router.post("/login", response_model=TokenResponse) -async def login( - data: LoginRequest, request: Request, session: AsyncSession = Depends(get_session) -): - ip = request.client.host if request.client else "unknown" - user_agent = request.headers.get("user-agent", "unknown") - - is_allowed, error_msg = await check_rate_limit(session, data.email.lower(), ip) - if not is_allowed: - raise HTTPException( - status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail=error_msg - ) - - result = await session.execute(select(User).where(User.email == data.email.lower())) - user = result.scalar_one_or_none() - - if not user or not verify_password(data.password, user.hashed_password): - await log_login_attempt( - session, - data.email.lower(), - ip, - user_agent, - False, - "Identifiants incorrects", - ) - - if user: - user.failed_login_attempts += 1 - - if user.failed_login_attempts >= 5: - user.locked_until = datetime.now() + timedelta(minutes=15) - await session.commit() - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Compte verrouillé suite à trop de tentatives. Réessayez dans 15 minutes.", - ) - - await session.commit() - - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Email ou mot de passe incorrect", - ) - - if not user.is_active: - await log_login_attempt( - session, data.email.lower(), ip, user_agent, False, "Compte désactivé" - ) - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail="Compte désactivé" - ) - - if not user.is_verified: - await log_login_attempt( - session, data.email.lower(), ip, user_agent, False, "Email non vérifié" - ) - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Email non vérifié. Consultez votre boîte de réception.", - ) - - if user.locked_until and user.locked_until > datetime.now(): - await log_login_attempt( - session, data.email.lower(), ip, user_agent, False, "Compte verrouillé" - ) - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="Compte temporairement verrouillé", - ) - - user.failed_login_attempts = 0 - user.locked_until = None - user.last_login = datetime.now() - - access_token = create_access_token( - {"sub": user.id, "email": user.email, "role": user.role} - ) - refresh_token_jwt = create_refresh_token(user.id) - - refresh_token_record = RefreshToken( - id=str(uuid.uuid4()), - user_id=user.id, - token_hash=hash_token(refresh_token_jwt), - device_info=user_agent[:500], - ip_address=ip, - expires_at=datetime.now() + timedelta(days=7), - created_at=datetime.now(), - ) - - session.add(refresh_token_record) - await session.commit() - - await log_login_attempt(session, data.email.lower(), ip, user_agent, True) - - logger.info(f" Connexion réussie: {user.email}") - - return TokenResponse( - access_token=access_token, - refresh_token=refresh_token_jwt, - expires_in=86400, - ) - - -@router.post("/refresh", response_model=TokenResponse) -async def refresh_access_token( - data: RefreshTokenRequest, session: AsyncSession = Depends(get_session) -): - payload = decode_token(data.refresh_token) - if not payload or payload.get("type") != "refresh": - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail="Refresh token invalide" - ) - - user_id = payload.get("sub") - token_hash = hash_token(data.refresh_token) - - result = await session.execute( - select(RefreshToken).where( - RefreshToken.user_id == user_id, - RefreshToken.token_hash == token_hash, - not RefreshToken.is_revoked, - ) - ) - token_record = result.scalar_one_or_none() - - if not token_record: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Refresh token révoqué ou introuvable", - ) - - if token_record.expires_at < datetime.now(): - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail="Refresh token expiré" - ) - - result = await session.execute(select(User).where(User.id == user_id)) - user = result.scalar_one_or_none() - - if not user or not user.is_active: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Utilisateur introuvable ou désactivé", - ) - - new_access_token = create_access_token( - {"sub": user.id, "email": user.email, "role": user.role} - ) - - logger.info(f" Token rafraîchi: {user.email}") - - return TokenResponse( - access_token=new_access_token, - refresh_token=data.refresh_token, - expires_in=86400, - ) + return {"success": True, "message": "Un nouveau lien de verification a ete envoye."} @router.post("/forgot-password") @@ -421,13 +462,27 @@ async def forgot_password( request: Request, session: AsyncSession = Depends(get_session), ): + ip = get_client_ip(request) + + allowed, error_msg = await RateLimiter.check_password_reset_rate_limit( + data.email.lower(), ip + ) + if not allowed: + raise HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail=error_msg + ) + result = await session.execute(select(User).where(User.email == data.email.lower())) user = result.scalar_one_or_none() + await AuditService.log_password_reset_request( + session, request, data.email, user.id if user else None + ) + if not user: return { "success": True, - "message": "Si cet email existe, un lien de réinitialisation a été envoyé.", + "message": "Si cet email existe, un lien de reinitialisation a ete envoye.", } reset_token = generate_reset_token() @@ -435,24 +490,23 @@ async def forgot_password( user.reset_token_expires = datetime.now() + timedelta(hours=1) await session.commit() - frontend_url = ( - settings.frontend_url - if hasattr(settings, "frontend_url") - else str(request.base_url).rstrip("/") - ) + frontend_url = settings.frontend_url or str(request.base_url).rstrip("/") AuthEmailService.send_password_reset_email(user.email, reset_token, frontend_url) - logger.info(f" Reset password demandé: {user.email}") + logger.info(f"Reset password demande: {user.email}") return { "success": True, - "message": "Si cet email existe, un lien de réinitialisation a été envoyé.", + "message": "Si cet email existe, un lien de reinitialisation a ete envoye.", } @router.post("/reset-password") async def reset_password( - data: ResetPasswordRequest, session: AsyncSession = Depends(get_session) + data: ResetPasswordRequest, + request: Request, + response: Response, + session: AsyncSession = Depends(get_session), ): result = await session.execute(select(User).where(User.reset_token == data.token)) user = result.scalar_one_or_none() @@ -460,13 +514,13 @@ async def reset_password( if not user: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Token de réinitialisation invalide", + detail="Token de reinitialisation invalide", ) - if user.reset_token_expires < datetime.now(): + if user.reset_token_expires and user.reset_token_expires < datetime.now(): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, - detail="Token expiré. Demandez un nouveau lien de réinitialisation.", + detail="Token expire. Demandez un nouveau lien.", ) is_valid, error_msg = validate_password_strength(data.new_password) @@ -478,41 +532,67 @@ async def reset_password( user.reset_token_expires = None user.failed_login_attempts = 0 user.locked_until = None + user.password_changed_at = datetime.now() + + await TokenService.revoke_all_user_tokens( + session=session, user_id=user.id, reason="password_reset" + ) + + await AuditService.log_password_change(session, request, user.id, "reset") + await session.commit() + CookieManager.clear_all_auth_cookies(response) + AuthEmailService.send_password_changed_notification(user.email) - logger.info(f" Mot de passe réinitialisé: {user.email}") + logger.info(f"Mot de passe reinitialise: {user.email}") return { "success": True, - "message": "Mot de passe réinitialisé avec succès. Vous pouvez maintenant vous connecter.", + "message": "Mot de passe reinitialise. Vous pouvez maintenant vous connecter.", } -@router.post("/logout") -async def logout( - data: RefreshTokenRequest, +@router.post("/change-password") +async def change_password( + data: ChangePasswordRequest, + request: Request, + response: Response, session: AsyncSession = Depends(get_session), user: User = Depends(get_current_user), ): - token_hash = hash_token(data.refresh_token) - - result = await session.execute( - select(RefreshToken).where( - RefreshToken.user_id == user.id, RefreshToken.token_hash == token_hash + if not verify_password(data.current_password, user.hashed_password): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Mot de passe actuel incorrect", ) + + is_valid, error_msg = validate_password_strength(data.new_password) + if not is_valid: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=error_msg) + + user.hashed_password = hash_password(data.new_password) + user.password_changed_at = datetime.now() + + await TokenService.revoke_all_user_tokens( + session=session, user_id=user.id, reason="password_change" ) - token_record = result.scalar_one_or_none() - if token_record: - token_record.is_revoked = True - token_record.revoked_at = datetime.now() - await session.commit() + await AuditService.log_password_change(session, request, user.id, "user_initiated") - logger.info(f"👋 Déconnexion: {user.email}") + await session.commit() - return {"success": True, "message": "Déconnexion réussie"} + CookieManager.clear_all_auth_cookies(response) + + AuthEmailService.send_password_changed_notification(user.email) + + logger.info(f"Mot de passe change: {user.email}") + + return { + "success": True, + "message": "Mot de passe modifie. Veuillez vous reconnecter.", + } @router.get("/me") @@ -524,6 +604,69 @@ async def get_current_user_info(user: User = Depends(get_current_user)): "prenom": user.prenom, "role": user.role, "is_verified": user.is_verified, - "created_at": user.created_at.isoformat(), + "created_at": user.created_at.isoformat() if user.created_at else None, "last_login": user.last_login.isoformat() if user.last_login else None, } + + +@router.get("/sessions", response_model=List[SessionResponse]) +async def get_active_sessions( + session: AsyncSession = Depends(get_session), user: User = Depends(get_current_user) +): + sessions = await TokenService.get_user_active_sessions(session, user.id) + return [SessionResponse(**s) for s in sessions] + + +@router.delete("/sessions/{session_id}") +async def revoke_session( + session_id: str, + request: Request, + session: AsyncSession = Depends(get_session), + user: User = Depends(get_current_user), +): + result = await session.execute( + select(RefreshToken).where( + RefreshToken.id == session_id, + RefreshToken.user_id == user.id, + RefreshToken.is_revoked.is_(false()), + ) + ) + token_record = result.scalar_one_or_none() + + if not token_record: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Session introuvable" + ) + + token_record.is_revoked = True + token_record.revoked_at = datetime.now() + token_record.revoked_reason = "user_revoked" + + await AuditService.log_event( + session=session, + event_type=AuditEventType.SESSION_TERMINATED, + request=request, + user_id=user.id, + description=f"Session {session_id[:8]}... revoquee", + ) + + await session.commit() + + return {"success": True, "message": "Session revoquee"} + + +@router.get("/csrf-token") +async def get_csrf_token( + request: Request, response: Response, user: User = Depends(get_current_user) +): + from security.auth import generate_session_id, create_csrf_token + + session_id = getattr(request.state, "session_id", None) + if not session_id: + session_id = generate_session_id() + + csrf_token = create_csrf_token(session_id) + + CookieManager.set_csrf_token(response, csrf_token) + + return {"csrf_token": csrf_token} diff --git a/schemas/sage/sage_gateway.py b/schemas/sage/sage_gateway.py index 93b2b4a..1d0f657 100644 --- a/schemas/sage/sage_gateway.py +++ b/schemas/sage/sage_gateway.py @@ -12,7 +12,6 @@ class GatewayHealthStatus(str, Enum): # === CREATE === class SageGatewayCreate(BaseModel): - name: str = Field( ..., min_length=2, max_length=100, description="Nom de la gateway" ) @@ -24,8 +23,6 @@ class SageGatewayCreate(BaseModel): gateway_token: str = Field( ..., min_length=10, description="Token d'authentification" ) - - sage_database: Optional[str] = Field(None, max_length=255) sage_company: Optional[str] = Field(None, max_length=255) is_active: bool = Field(False, description="Activer immédiatement cette gateway") @@ -54,9 +51,6 @@ class SageGatewayUpdate(BaseModel): gateway_url: Optional[str] = None gateway_token: Optional[str] = Field(None, min_length=10) - sage_database: Optional[str] = None - sage_company: Optional[str] = None - is_default: Optional[bool] = None priority: Optional[int] = Field(None, ge=0, le=100) @@ -73,7 +67,6 @@ class SageGatewayUpdate(BaseModel): # === RESPONSE === class SageGatewayResponse(BaseModel): - id: str user_id: str @@ -83,9 +76,6 @@ class SageGatewayResponse(BaseModel): gateway_url: str token_preview: str - sage_database: Optional[str] = None - sage_company: Optional[str] = None - is_active: bool is_default: bool priority: int @@ -111,7 +101,6 @@ class SageGatewayResponse(BaseModel): class SageGatewayListResponse(BaseModel): - items: List[SageGatewayResponse] total: int active_gateway: Optional[SageGatewayResponse] = None diff --git a/security/__init__.py b/security/__init__.py new file mode 100644 index 0000000..bb37ac7 --- /dev/null +++ b/security/__init__.py @@ -0,0 +1,55 @@ +from security.auth import ( + hash_password, + verify_password, + validate_password_strength, + generate_verification_token, + generate_reset_token, + generate_csrf_token, + generate_secure_token, + hash_token, + constant_time_compare, + create_access_token, + create_refresh_token, + decode_token, + generate_session_id, +) + +from security.cookies import CookieManager, set_auth_cookies + +from security.fingerprint import ( + DeviceFingerprint, + get_fingerprint_hash, + validate_fingerprint, + get_client_ip, +) + +from security.csrf import CSRFProtection, verify_csrf, generate_csrf_for_session + +from security.rate_limiter import RateLimiter, check_rate_limit_dependency + +__all__ = [ + "hash_password", + "verify_password", + "validate_password_strength", + "generate_verification_token", + "generate_reset_token", + "generate_csrf_token", + "generate_secure_token", + "hash_token", + "constant_time_compare", + "create_access_token", + "create_refresh_token", + "decode_token", + "generate_session_id", + "CookieManager", + "set_auth_cookies", + "DeviceFingerprint", + "get_fingerprint_hash", + "validate_fingerprint", + "get_client_ip", + "CSRFProtection", + "verify_csrf", + "generate_csrf_for_session", + "RateLimiter", + "check_rate_limit_dependency", +] diff --git a/security/auth.py b/security/auth.py index 7821a52..627d627 100644 --- a/security/auth.py +++ b/security/auth.py @@ -1,16 +1,17 @@ from passlib.context import CryptContext -from datetime import datetime, timedelta -from typing import Optional, Dict +from datetime import datetime, timedelta, timezone +from typing import Optional, Dict, Any, Tuple import jwt import secrets import hashlib +import hmac +import logging -SECRET_KEY = "VOTRE_SECRET_KEY_A_METTRE_EN_.ENV" -ALGORITHM = "HS256" -ACCESS_TOKEN_EXPIRE_MINUTES = 30 -REFRESH_TOKEN_EXPIRE_DAYS = 7 +from config.config import settings -pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") +logger = logging.getLogger(__name__) + +pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto", bcrypt__rounds=12) def hash_password(password: str) -> str: @@ -18,75 +19,192 @@ def hash_password(password: str) -> str: def verify_password(plain_password: str, hashed_password: str) -> bool: - return pwd_context.verify(plain_password, hashed_password) + try: + return pwd_context.verify(plain_password, hashed_password) + except Exception as e: + logger.warning(f"Erreur verification mot de passe: {e}") + return False + + +def generate_secure_token(length: int = 32) -> str: + return secrets.token_urlsafe(length) def generate_verification_token() -> str: - return secrets.token_urlsafe(32) + return generate_secure_token(32) def generate_reset_token() -> str: - return secrets.token_urlsafe(32) + return generate_secure_token(32) + + +def generate_csrf_token() -> str: + return generate_secure_token(32) + + +def generate_refresh_token_id() -> str: + return generate_secure_token(16) def hash_token(token: str) -> str: return hashlib.sha256(token.encode()).hexdigest() -def create_access_token(data: Dict, expires_delta: Optional[timedelta] = None) -> str: +def constant_time_compare(val1: str, val2: str) -> bool: + return hmac.compare_digest(val1.encode(), val2.encode()) + + +def create_access_token( + data: Dict[str, Any], + expires_delta: Optional[timedelta] = None, + fingerprint_hash: Optional[str] = None, +) -> str: to_encode = data.copy() + now = datetime.now(timezone.utc) if expires_delta: - expire = datetime.utcnow() + expires_delta + expire = now + expires_delta else: - expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + expire = now + timedelta(minutes=settings.access_token_expire_minutes) - to_encode.update({"exp": expire, "iat": datetime.utcnow(), "type": "access"}) + to_encode.update( + { + "exp": expire, + "iat": now, + "nbf": now, + "type": "access", + "jti": generate_secure_token(8), + } + ) - encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) - return encoded_jwt + if fingerprint_hash: + to_encode["fph"] = fingerprint_hash + + return jwt.encode(to_encode, settings.jwt_secret, algorithm=settings.jwt_algorithm) -def create_refresh_token(user_id: str) -> str: - expire = datetime.utcnow() + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) +def create_refresh_token( + user_id: str, + token_id: Optional[str] = None, + fingerprint_hash: Optional[str] = None, + expires_delta: Optional[timedelta] = None, +) -> Tuple[str, str]: + now = datetime.now(timezone.utc) + + if expires_delta: + expire = now + expires_delta + else: + expire = now + timedelta(days=settings.refresh_token_expire_days) + + if not token_id: + token_id = generate_refresh_token_id() to_encode = { "sub": user_id, "exp": expire, - "iat": datetime.utcnow(), + "iat": now, + "nbf": now, "type": "refresh", - "jti": secrets.token_urlsafe(16), # Unique ID + "jti": token_id, } - encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) - return encoded_jwt + if fingerprint_hash: + to_encode["fph"] = fingerprint_hash + + token = jwt.encode(to_encode, settings.jwt_secret, algorithm=settings.jwt_algorithm) + + return token, token_id -def decode_token(token: str) -> Optional[Dict]: +def create_csrf_token(session_id: str) -> str: + now = datetime.now(timezone.utc) + expire = now + timedelta(minutes=settings.csrf_token_expire_minutes) + + to_encode = { + "sid": session_id, + "exp": expire, + "iat": now, + "type": "csrf", + "jti": generate_secure_token(8), + } + + return jwt.encode(to_encode, settings.jwt_secret, algorithm=settings.jwt_algorithm) + + +def decode_token( + token: str, expected_type: Optional[str] = None +) -> Optional[Dict[str, Any]]: try: - payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) + payload = jwt.decode( + token, + settings.jwt_secret, + algorithms=[settings.jwt_algorithm], + options={ + "require": ["exp", "iat", "type"], + "verify_exp": True, + "verify_iat": True, + "verify_nbf": True, + }, + ) + + if expected_type and payload.get("type") != expected_type: + logger.warning( + f"Type de token incorrect: attendu={expected_type}, recu={payload.get('type')}" + ) + return None + return payload + except jwt.ExpiredSignatureError: + logger.debug("Token expire") return None - except jwt.JWTError: + except jwt.InvalidTokenError as e: + logger.warning(f"Token invalide: {e}") + return None + except Exception as e: + logger.error(f"Erreur decodage token: {e}") return None -def validate_password_strength(password: str) -> tuple[bool, str]: - if len(password) < 8: - return False, "Le mot de passe doit contenir au moins 8 caractères" +def validate_password_strength(password: str) -> Tuple[bool, str]: + if len(password) < settings.password_min_length: + return ( + False, + f"Le mot de passe doit contenir au moins {settings.password_min_length} caracteres", + ) - if not any(c.isupper() for c in password): + if settings.password_require_uppercase and not any(c.isupper() for c in password): return False, "Le mot de passe doit contenir au moins une majuscule" - if not any(c.islower() for c in password): + if settings.password_require_lowercase and not any(c.islower() for c in password): return False, "Le mot de passe doit contenir au moins une minuscule" - if not any(c.isdigit() for c in password): + if settings.password_require_digit and not any(c.isdigit() for c in password): return False, "Le mot de passe doit contenir au moins un chiffre" - special_chars = "!@#$%^&*()_+-=[]{}|;:,.<>?" - if not any(c in special_chars for c in password): - return False, "Le mot de passe doit contenir au moins un caractère spécial" + if settings.password_require_special: + special_chars = "!@#$%^&*()_+-=[]{}|;:,.<>?/~`" + if not any(c in special_chars for c in password): + return False, "Le mot de passe doit contenir au moins un caractere special" + + common_passwords = [ + "password", + "123456", + "qwerty", + "admin", + "letmein", + "welcome", + "monkey", + "dragon", + "master", + "login", + ] + if password.lower() in common_passwords: + return False, "Ce mot de passe est trop courant" return True, "" + + +def generate_session_id() -> str: + """Genere un identifiant de session unique.""" + return generate_secure_token(24) diff --git a/security/cookies.py b/security/cookies.py new file mode 100644 index 0000000..58e48f6 --- /dev/null +++ b/security/cookies.py @@ -0,0 +1,157 @@ +from fastapi import Response, Request +from typing import Optional +import logging + +from config.config import settings + +logger = logging.getLogger(__name__) + + +class CookieManager: + @staticmethod + def _get_samesite_value() -> str: + value = settings.cookie_samesite.lower() + if value in ("strict", "lax", "none"): + return value + return "strict" + + @staticmethod + def _should_be_secure() -> bool: + if settings.is_development and not settings.cookie_secure: + return False + return True + + @classmethod + def set_access_token( + cls, response: Response, token: str, max_age: Optional[int] = None + ) -> None: + if max_age is None: + max_age = settings.access_token_expire_minutes * 60 + + response.set_cookie( + key=settings.cookie_access_token_name, + value=token, + max_age=max_age, + expires=max_age, + path="/", + domain=settings.cookie_domain, + secure=cls._should_be_secure(), + httponly=settings.cookie_httponly, + samesite=cls._get_samesite_value(), + ) + logger.debug("Cookie access_token defini") + + @classmethod + def set_refresh_token( + cls, response: Response, token: str, max_age: Optional[int] = None + ) -> None: + if max_age is None: + max_age = settings.refresh_token_expire_days * 24 * 60 * 60 + + response.set_cookie( + key=settings.cookie_refresh_token_name, + value=token, + max_age=max_age, + expires=max_age, + path="/auth", + domain=settings.cookie_domain, + secure=cls._should_be_secure(), + httponly=settings.cookie_httponly, + samesite=cls._get_samesite_value(), + ) + logger.debug("Cookie refresh_token defini") + + @classmethod + def set_csrf_token( + cls, response: Response, token: str, max_age: Optional[int] = None + ) -> None: + if max_age is None: + max_age = settings.csrf_token_expire_minutes * 60 + + response.set_cookie( + key=settings.cookie_csrf_token_name, + value=token, + max_age=max_age, + expires=max_age, + path="/", + domain=settings.cookie_domain, + secure=cls._should_be_secure(), + httponly=False, + samesite=cls._get_samesite_value(), + ) + logger.debug("Cookie csrf_token defini") + + @classmethod + def clear_access_token(cls, response: Response) -> None: + response.delete_cookie( + key=settings.cookie_access_token_name, + path="/", + domain=settings.cookie_domain, + secure=cls._should_be_secure(), + httponly=settings.cookie_httponly, + samesite=cls._get_samesite_value(), + ) + logger.debug("Cookie access_token supprime") + + @classmethod + def clear_refresh_token(cls, response: Response) -> None: + response.delete_cookie( + key=settings.cookie_refresh_token_name, + path="/auth", + domain=settings.cookie_domain, + secure=cls._should_be_secure(), + httponly=settings.cookie_httponly, + samesite=cls._get_samesite_value(), + ) + logger.debug("Cookie refresh_token supprime") + + @classmethod + def clear_csrf_token(cls, response: Response) -> None: + response.delete_cookie( + key=settings.cookie_csrf_token_name, + path="/", + domain=settings.cookie_domain, + secure=cls._should_be_secure(), + httponly=False, + samesite=cls._get_samesite_value(), + ) + logger.debug("Cookie csrf_token supprime") + + @classmethod + def clear_all_auth_cookies(cls, response: Response) -> None: + cls.clear_access_token(response) + cls.clear_refresh_token(response) + cls.clear_csrf_token(response) + logger.debug("Tous les cookies auth supprimes") + + @classmethod + def get_access_token(cls, request: Request) -> Optional[str]: + token = request.cookies.get(settings.cookie_access_token_name) + if token: + return token + + auth_header = request.headers.get("Authorization") + if auth_header and auth_header.startswith("Bearer "): + return auth_header[7:] + + return None + + @classmethod + def get_refresh_token(cls, request: Request) -> Optional[str]: + return request.cookies.get(settings.cookie_refresh_token_name) + + @classmethod + def get_csrf_token(cls, request: Request) -> Optional[str]: + csrf_header = request.headers.get("X-CSRF-Token") + if csrf_header: + return csrf_header + + return request.cookies.get(settings.cookie_csrf_token_name) + + +def set_auth_cookies( + response: Response, access_token: str, refresh_token: str, csrf_token: str +) -> None: + CookieManager.set_access_token(response, access_token) + CookieManager.set_refresh_token(response, refresh_token) + CookieManager.set_csrf_token(response, csrf_token) diff --git a/security/csrf.py b/security/csrf.py new file mode 100644 index 0000000..b4d60d0 --- /dev/null +++ b/security/csrf.py @@ -0,0 +1,117 @@ +""" +security/csrf.py - Protection contre les attaques Cross-Site Request Forgery +""" + +from fastapi import Request, HTTPException, status +from typing import Optional, Set +import logging + +from config.config import settings +from security.auth import decode_token, create_csrf_token, constant_time_compare + +logger = logging.getLogger(__name__) + + +SAFE_METHODS: Set[str] = {"GET", "HEAD", "OPTIONS", "TRACE"} + +CSRF_EXEMPT_PATHS: Set[str] = { + "/auth/login", + "/auth/register", + "/auth/forgot-password", + "/auth/verify-email", + "/auth/resend-verification", + "/health", + "/docs", + "/redoc", + "/openapi.json", + "/webhooks/universign", +} + + +class CSRFProtection: + @classmethod + def is_exempt(cls, request: Request) -> bool: + if request.method in SAFE_METHODS: + return True + + path = request.url.path.rstrip("/") + if path in CSRF_EXEMPT_PATHS: + return True + + for exempt_path in CSRF_EXEMPT_PATHS: + if path.startswith(exempt_path): + return True + + return False + + @classmethod + def generate_token(cls, session_id: str) -> str: + return create_csrf_token(session_id) + + @classmethod + def validate_token(cls, request: Request, session_id: Optional[str] = None) -> bool: + csrf_header = request.headers.get("X-CSRF-Token") + + if not csrf_header: + logger.warning("Token CSRF manquant dans le header") + return False + + payload = decode_token(csrf_header, expected_type="csrf") + + if not payload: + logger.warning("Token CSRF invalide ou expire") + return False + + if session_id and payload.get("sid") != session_id: + logger.warning("Token CSRF ne correspond pas a la session") + return False + + return True + + @classmethod + def validate_double_submit(cls, request: Request) -> bool: + header_token = request.headers.get("X-CSRF-Token") + cookie_token = request.cookies.get(settings.cookie_csrf_token_name) + + if not header_token or not cookie_token: + logger.warning("Token CSRF manquant (header ou cookie)") + return False + + if not constant_time_compare(header_token, cookie_token): + logger.warning("Tokens CSRF ne correspondent pas") + return False + + return True + + @classmethod + def validate_request( + cls, + request: Request, + session_id: Optional[str] = None, + use_double_submit: bool = True, + ) -> bool: + if cls.is_exempt(request): + return True + + if use_double_submit: + if not cls.validate_double_submit(request): + return False + + return cls.validate_token(request, session_id) + + +async def verify_csrf(request: Request, session_id: Optional[str] = None) -> None: + if CSRFProtection.is_exempt(request): + return + + if not CSRFProtection.validate_request(request, session_id): + logger.warning( + f"Verification CSRF echouee pour {request.method} {request.url.path}" + ) + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail="Verification CSRF echouee" + ) + + +def generate_csrf_for_session(session_id: str) -> str: + return CSRFProtection.generate_token(session_id) diff --git a/security/fingerprint.py b/security/fingerprint.py new file mode 100644 index 0000000..57258ed --- /dev/null +++ b/security/fingerprint.py @@ -0,0 +1,122 @@ +from fastapi import Request +from typing import Dict +import hashlib +import hmac +import logging + +from config.config import settings + +logger = logging.getLogger(__name__) + + +class DeviceFingerprint: + COMPONENT_EXTRACTORS = { + "user_agent": lambda r: r.headers.get("User-Agent", ""), + "accept_language": lambda r: r.headers.get("Accept-Language", ""), + "accept_encoding": lambda r: r.headers.get("Accept-Encoding", ""), + "accept": lambda r: r.headers.get("Accept", ""), + "connection": lambda r: r.headers.get("Connection", ""), + "cache_control": lambda r: r.headers.get("Cache-Control", ""), + "client_ip": lambda r: DeviceFingerprint._get_client_ip(r), + "sec_ch_ua": lambda r: r.headers.get("Sec-CH-UA", ""), + "sec_ch_ua_platform": lambda r: r.headers.get("Sec-CH-UA-Platform", ""), + "sec_ch_ua_mobile": lambda r: r.headers.get("Sec-CH-UA-Mobile", ""), + } + + @staticmethod + def _get_client_ip(request: Request) -> str: + forwarded = request.headers.get("X-Forwarded-For") + if forwarded: + return forwarded.split(",")[0].strip() + + real_ip = request.headers.get("X-Real-IP") + if real_ip: + return real_ip + + if request.client: + return request.client.host + + return "" + + @classmethod + def extract_components(cls, request: Request) -> Dict[str, str]: + components = {} + + for component_name in settings.fingerprint_components: + extractor = cls.COMPONENT_EXTRACTORS.get(component_name) + if extractor: + try: + value = extractor(request) + components[component_name] = value if value else "" + except Exception as e: + logger.warning(f"Erreur extraction composant {component_name}: {e}") + components[component_name] = "" + else: + logger.warning(f"Extracteur inconnu pour composant: {component_name}") + + return components + + @classmethod + def generate_hash(cls, request: Request, include_ip: bool = False) -> str: + components = cls.extract_components(request) + + if not include_ip and "client_ip" in components: + del components["client_ip"] + + sorted_keys = sorted(components.keys()) + fingerprint_data = "|".join(f"{k}:{components[k]}" for k in sorted_keys) + + secret = settings.fingerprint_secret or settings.jwt_secret + + signature = hmac.new( + secret.encode(), fingerprint_data.encode(), hashlib.sha256 + ).hexdigest() + + return signature + + @classmethod + def generate_from_components(cls, components: Dict[str, str]) -> str: + sorted_keys = sorted(components.keys()) + fingerprint_data = "|".join(f"{k}:{components.get(k, '')}" for k in sorted_keys) + + secret = settings.fingerprint_secret or settings.jwt_secret + + signature = hmac.new( + secret.encode(), fingerprint_data.encode(), hashlib.sha256 + ).hexdigest() + + return signature + + @classmethod + def validate( + cls, request: Request, stored_hash: str, include_ip: bool = False + ) -> bool: + if not stored_hash: + return True + + current_hash = cls.generate_hash(request, include_ip=include_ip) + + return hmac.compare_digest(current_hash, stored_hash) + + @classmethod + def get_device_info(cls, request: Request) -> Dict[str, str]: + user_agent = request.headers.get("User-Agent", "") + + return { + "user_agent": user_agent[:500] if user_agent else "", + "ip_address": cls._get_client_ip(request), + "accept_language": request.headers.get("Accept-Language", "")[:100], + "fingerprint_hash": cls.generate_hash(request), + } + + +def get_fingerprint_hash(request: Request) -> str: + return DeviceFingerprint.generate_hash(request) + + +def validate_fingerprint(request: Request, stored_hash: str) -> bool: + return DeviceFingerprint.validate(request, stored_hash) + + +def get_client_ip(request: Request) -> str: + return DeviceFingerprint._get_client_ip(request) diff --git a/security/rate_limiter.py b/security/rate_limiter.py new file mode 100644 index 0000000..da15844 --- /dev/null +++ b/security/rate_limiter.py @@ -0,0 +1,147 @@ +from fastapi import Request, HTTPException, status +from typing import Optional, Tuple +import logging + +from config.config import settings +from services.redis_service import redis_service +from security.fingerprint import get_client_ip + +logger = logging.getLogger(__name__) + + +class RateLimiter: + @staticmethod + def _make_key(identifier: str, action: str) -> str: + return f"{action}:{identifier}" + + @classmethod + async def check_login_rate_limit( + cls, email: str, ip_address: str + ) -> Tuple[bool, Optional[str], int]: + window_seconds = settings.rate_limit_login_window_minutes * 60 + max_attempts = settings.rate_limit_login_attempts + + email_key = cls._make_key(email.lower(), "login_email") + email_count = await redis_service.get_rate_limit_count(email_key) + + if email_count >= max_attempts: + return ( + False, + f"Trop de tentatives pour cet email. Reessayez dans {settings.rate_limit_login_window_minutes} minutes.", + window_seconds, + ) + + ip_key = cls._make_key(ip_address, "login_ip") + ip_count = await redis_service.get_rate_limit_count(ip_key) + + ip_limit = max_attempts * 3 + if ip_count >= ip_limit: + return ( + False, + window_seconds, + ) + + return (True, None, 0) + + @classmethod + async def record_login_attempt( + cls, email: str, ip_address: str, success: bool + ) -> None: + window_seconds = settings.rate_limit_login_window_minutes * 60 + + if success: + email_key = cls._make_key(email.lower(), "login_email") + await redis_service.reset_rate_limit(email_key) + logger.debug(f"Rate limit reinitialise pour {email}") + else: + email_key = cls._make_key(email.lower(), "login_email") + await redis_service.increment_rate_limit(email_key, window_seconds) + + ip_key = cls._make_key(ip_address, "login_ip") + await redis_service.increment_rate_limit(ip_key, window_seconds) + + logger.debug( + f"Tentative echouee enregistree pour {email} depuis {ip_address}" + ) + + @classmethod + async def check_api_rate_limit( + cls, identifier: str, endpoint: Optional[str] = None + ) -> Tuple[bool, int, int]: + window_seconds = settings.rate_limit_api_window_seconds + max_requests = settings.rate_limit_api_requests + + if endpoint: + key = cls._make_key(f"{identifier}:{endpoint}", "api") + else: + key = cls._make_key(identifier, "api") + + count = await redis_service.increment_rate_limit(key, window_seconds) + remaining = max(0, max_requests - count) + + if count > max_requests: + return (False, remaining, window_seconds) + + return (True, remaining, window_seconds) + + @classmethod + async def check_password_reset_rate_limit( + cls, email: str, ip_address: str + ) -> Tuple[bool, Optional[str]]: + window_seconds = 3600 + max_attempts_email = 3 + max_attempts_ip = 10 + + email_key = cls._make_key(email.lower(), "reset_email") + email_count = await redis_service.get_rate_limit_count(email_key) + + if email_count >= max_attempts_email: + return (False, "Trop de demandes de reinitialisation pour cet email.") + + ip_key = cls._make_key(ip_address, "reset_ip") + ip_count = await redis_service.get_rate_limit_count(ip_key) + + if ip_count >= max_attempts_ip: + return (False, "Trop de demandes depuis cette adresse IP.") + + await redis_service.increment_rate_limit(email_key, window_seconds) + await redis_service.increment_rate_limit(ip_key, window_seconds) + + return (True, None) + + @classmethod + async def check_registration_rate_limit( + cls, ip_address: str + ) -> Tuple[bool, Optional[str]]: + window_seconds = 3600 + max_registrations = 5 + + key = cls._make_key(ip_address, "register_ip") + count = await redis_service.get_rate_limit_count(key) + + if count >= max_registrations: + return (False, "Trop d'inscriptions depuis cette adresse IP.") + + await redis_service.increment_rate_limit(key, window_seconds) + + return (True, None) + + +async def check_rate_limit_dependency(request: Request) -> None: + ip = get_client_ip(request) + + allowed, remaining, reset_seconds = await RateLimiter.check_api_rate_limit(ip) + + request.state.rate_limit_remaining = remaining + request.state.rate_limit_reset = reset_seconds + + if not allowed: + raise HTTPException( + status_code=status.HTTP_429_TOO_MANY_REQUESTS, + detail="Limite de requetes atteinte. Reessayez plus tard.", + headers={ + "X-RateLimit-Remaining": "0", + "X-RateLimit-Reset": str(reset_seconds), + "Retry-After": str(reset_seconds), + }, + ) diff --git a/services/audit_service.py b/services/audit_service.py new file mode 100644 index 0000000..59dcdba --- /dev/null +++ b/services/audit_service.py @@ -0,0 +1,318 @@ +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import false, select, and_ +from datetime import datetime, timedelta +from typing import Optional, Dict, Any, List +from fastapi import Request +import uuid +import json +import logging + +from database import AuditLog, AuditEventType, LoginAttempt +from security.fingerprint import DeviceFingerprint, get_client_ip + +logger = logging.getLogger(__name__) + + +class AuditService: + @classmethod + async def log_event( + cls, + session: AsyncSession, + event_type: AuditEventType, + request: Optional[Request] = None, + user_id: Optional[str] = None, + description: Optional[str] = None, + success: bool = True, + failure_reason: Optional[str] = None, + resource_type: Optional[str] = None, + resource_id: Optional[str] = None, + metadata: Optional[Dict[str, Any]] = None, + ) -> AuditLog: + ip_address = None + user_agent = None + fingerprint_hash = None + request_method = None + request_path = None + + if request: + ip_address = get_client_ip(request) + user_agent = request.headers.get("User-Agent", "")[:500] + fingerprint_hash = DeviceFingerprint.generate_hash(request) + request_method = request.method + request_path = str(request.url.path)[:500] + + metadata_json = None + if metadata: + try: + metadata_json = json.dumps(metadata, default=str) + except Exception as e: + logger.warning(f"Erreur serialisation metadata audit: {e}") + + audit_log = AuditLog( + id=str(uuid.uuid4()), + user_id=user_id, + event_type=event_type, + event_description=description, + ip_address=ip_address, + user_agent=user_agent, + fingerprint_hash=fingerprint_hash, + resource_type=resource_type, + resource_id=resource_id, + request_method=request_method, + request_path=request_path, + metadata=metadata_json, + success=success, + failure_reason=failure_reason, + created_at=datetime.now(), + ) + + session.add(audit_log) + await session.flush() + + log_level = logging.INFO if success else logging.WARNING + logger.log( + log_level, + f"Audit: {event_type.value} user={user_id} success={success} ip={ip_address}", + ) + + return audit_log + + @classmethod + async def log_login_success( + cls, session: AsyncSession, request: Request, user_id: str, email: str + ) -> AuditLog: + return await cls.log_event( + session=session, + event_type=AuditEventType.LOGIN_SUCCESS, + request=request, + user_id=user_id, + description=f"Connexion reussie pour {email}", + success=True, + metadata={"email": email}, + ) + + @classmethod + async def log_login_failed( + cls, + session: AsyncSession, + request: Request, + email: str, + reason: str, + user_id: Optional[str] = None, + ) -> AuditLog: + return await cls.log_event( + session=session, + event_type=AuditEventType.LOGIN_FAILED, + request=request, + user_id=user_id, + description=f"Echec connexion pour {email}: {reason}", + success=False, + failure_reason=reason, + metadata={"email": email}, + ) + + @classmethod + async def log_logout( + cls, session: AsyncSession, request: Request, user_id: str + ) -> AuditLog: + return await cls.log_event( + session=session, + event_type=AuditEventType.LOGOUT, + request=request, + user_id=user_id, + description="Deconnexion utilisateur", + success=True, + ) + + @classmethod + async def log_password_change( + cls, + session: AsyncSession, + request: Request, + user_id: str, + method: str = "user_initiated", + ) -> AuditLog: + return await cls.log_event( + session=session, + event_type=AuditEventType.PASSWORD_CHANGE, + request=request, + user_id=user_id, + description=f"Mot de passe modifie ({method})", + success=True, + metadata={"method": method}, + ) + + @classmethod + async def log_password_reset_request( + cls, + session: AsyncSession, + request: Request, + email: str, + user_id: Optional[str] = None, + ) -> AuditLog: + return await cls.log_event( + session=session, + event_type=AuditEventType.PASSWORD_RESET_REQUEST, + request=request, + user_id=user_id, + description=f"Demande reset mot de passe pour {email}", + success=True, + metadata={"email": email}, + ) + + @classmethod + async def log_account_locked( + cls, session: AsyncSession, request: Request, user_id: str, reason: str + ) -> AuditLog: + return await cls.log_event( + session=session, + event_type=AuditEventType.ACCOUNT_LOCKED, + request=request, + user_id=user_id, + description=f"Compte verrouille: {reason}", + success=True, + metadata={"reason": reason}, + ) + + @classmethod + async def log_token_refresh( + cls, session: AsyncSession, request: Request, user_id: str + ) -> AuditLog: + return await cls.log_event( + session=session, + event_type=AuditEventType.TOKEN_REFRESH, + request=request, + user_id=user_id, + description="Token rafraichi", + success=True, + ) + + @classmethod + async def log_suspicious_activity( + cls, + session: AsyncSession, + request: Request, + user_id: Optional[str], + activity_type: str, + details: str, + ) -> AuditLog: + return await cls.log_event( + session=session, + event_type=AuditEventType.SUSPICIOUS_ACTIVITY, + request=request, + user_id=user_id, + description=f"Activite suspecte: {activity_type} - {details}", + success=False, + failure_reason=activity_type, + metadata={"activity_type": activity_type, "details": details}, + ) + + @classmethod + async def record_login_attempt( + cls, + session: AsyncSession, + request: Request, + email: str, + success: bool, + failure_reason: Optional[str] = None, + ) -> LoginAttempt: + attempt = LoginAttempt( + email=email.lower(), + ip_address=get_client_ip(request), + user_agent=request.headers.get("User-Agent", "")[:500], + fingerprint_hash=DeviceFingerprint.generate_hash(request), + success=success, + failure_reason=failure_reason, + timestamp=datetime.now(), + ) + + session.add(attempt) + await session.flush() + + return attempt + + @classmethod + async def get_recent_failed_attempts( + cls, session: AsyncSession, email: str, window_minutes: int = 15 + ) -> int: + time_threshold = datetime.now() - timedelta(minutes=window_minutes) + + result = await session.execute( + select(LoginAttempt).where( + and_( + LoginAttempt.email == email.lower(), + LoginAttempt.success.is_(false()), + LoginAttempt.timestamp >= time_threshold, + ) + ) + ) + + return len(result.scalars().all()) + + @classmethod + async def get_user_audit_history( + cls, + session: AsyncSession, + user_id: str, + limit: int = 50, + event_types: Optional[List[AuditEventType]] = None, + ) -> List[AuditLog]: + query = select(AuditLog).where(AuditLog.user_id == user_id) + + if event_types: + query = query.where(AuditLog.event_type.in_(event_types)) + + query = query.order_by(AuditLog.created_at.desc()).limit(limit) + + result = await session.execute(query) + return list(result.scalars().all()) + + @classmethod + async def detect_suspicious_patterns( + cls, session: AsyncSession, user_id: str + ) -> Dict[str, Any]: + one_hour_ago = datetime.now() - timedelta(hours=1) + one_day_ago = datetime.now() - timedelta(days=1) + + result = await session.execute( + select(AuditLog).where( + and_( + AuditLog.user_id == user_id, + AuditLog.event_type == AuditEventType.LOGIN_FAILED, + AuditLog.created_at >= one_hour_ago, + ) + ) + ) + failed_logins_hour = len(result.scalars().all()) + + result = await session.execute( + select(AuditLog).where( + and_( + AuditLog.user_id == user_id, + AuditLog.event_type == AuditEventType.LOGIN_SUCCESS, + AuditLog.created_at >= one_day_ago, + ) + ) + ) + login_logs = result.scalars().all() + unique_ips = set(log.ip_address for log in login_logs if log.ip_address) + + result = await session.execute( + select(AuditLog).where( + and_( + AuditLog.user_id == user_id, + AuditLog.event_type == AuditEventType.PASSWORD_RESET_REQUEST, + AuditLog.created_at >= one_day_ago, + ) + ) + ) + password_resets = len(result.scalars().all()) + + return { + "failed_logins_last_hour": failed_logins_hour, + "unique_ips_last_day": len(unique_ips), + "password_reset_requests_last_day": password_resets, + "is_suspicious": ( + failed_logins_hour >= 5 or len(unique_ips) >= 5 or password_resets >= 3 + ), + } diff --git a/services/email_service.py b/services/email_service.py index 2a6e9e3..fcc4d78 100644 --- a/services/email_service.py +++ b/services/email_service.py @@ -1,22 +1,39 @@ import smtplib from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText -from config.config import settings +from typing import Optional, List import logging +from config.config import settings + logger = logging.getLogger(__name__) class AuthEmailService: @staticmethod - def _send_email(to: str, subject: str, html_body: str) -> bool: + def _send_email( + to: str, + subject: str, + html_body: str, + cc: Optional[List[str]] = None, + bcc: Optional[List[str]] = None, + ) -> bool: try: - msg = MIMEMultipart() + msg = MIMEMultipart("alternative") msg["From"] = settings.smtp_from msg["To"] = to msg["Subject"] = subject - msg.attach(MIMEText(html_body, "html")) + if cc: + msg["Cc"] = ", ".join(cc) + + msg.attach(MIMEText(html_body, "html", "utf-8")) + + recipients = [to] + if cc: + recipients.extend(cc) + if bcc: + recipients.extend(bcc) with smtplib.SMTP( settings.smtp_host, settings.smtp_port, timeout=30 @@ -27,176 +44,263 @@ class AuthEmailService: if settings.smtp_user and settings.smtp_password: server.login(settings.smtp_user, settings.smtp_password) - server.send_message(msg) + server.sendmail(settings.smtp_from, recipients, msg.as_string()) - logger.info(f" Email envoyé: {subject} → {to}") + logger.info(f"Email envoye: {subject} vers {to}") return True + except smtplib.SMTPException as e: + logger.error(f"Erreur SMTP envoi email: {e}") + return False except Exception as e: - logger.error(f" Erreur envoi email: {e}") + logger.error(f"Erreur envoi email: {e}") return False - @staticmethod - def send_verification_email(email: str, token: str, base_url: str) -> bool: + @classmethod + def send_verification_email(cls, email: str, token: str, base_url: str) -> bool: verification_link = f"{base_url}/auth/verify-email?token={token}" html_body = f""" - - - - - - -
-
-

🎉 Bienvenue sur Sage Dataven

-
-
-

Vérifiez votre adresse email

-

Merci de vous être inscrit ! Pour activer votre compte, veuillez cliquer sur le bouton ci-dessous :

- - - -

Ou copiez ce lien dans votre navigateur :

-

- {verification_link} -

- -

- Ce lien expire dans 24 heures -

- -

- Si vous n'avez pas créé de compte, ignorez cet email. -

-
- -
- - + + + + + + Verification de votre email + + + + + + +
+ + + + + + + + + + +
+

Verification de votre email

+
+

+ Bienvenue sur Sage Dataven. Pour activer votre compte, veuillez verifier votre adresse email en cliquant sur le bouton ci-dessous. +

+ + + + +
+ Verifier mon email +
+

+ Si le bouton ne fonctionne pas, copiez ce lien dans votre navigateur : +

+

+ {verification_link} +

+

+ Ce lien expire dans 24 heures. +

+
+

+ Si vous n'avez pas cree de compte, ignorez cet email. +

+
+
+ + """ - return AuthEmailService._send_email( - email, " Vérifiez votre adresse email - Sage Dataven", html_body + return cls._send_email( + email, "Verifiez votre adresse email - Sage Dataven", html_body ) - @staticmethod - def send_password_reset_email(email: str, token: str, base_url: str) -> bool: - reset_link = f"{base_url}/reset?token={token}" + @classmethod + def send_password_reset_email( + cls, email: str, token: str, frontend_url: str + ) -> bool: + reset_link = f"{frontend_url}/reset-password?token={token}" html_body = f""" - - - - - - -
-
-

Réinitialisation de mot de passe

-
-
-

Demande de réinitialisation

-

Vous avez demandé à réinitialiser votre mot de passe. Cliquez sur le bouton ci-dessous pour créer un nouveau mot de passe :

- - - -

Ou copiez ce lien dans votre navigateur :

-

- {reset_link} -

- -

- Ce lien expire dans 1 heure -

- -

- Si vous n'avez pas demandé cette réinitialisation, ignorez cet email. Votre mot de passe actuel reste inchangé. -

-
- -
- - + + + + + + Reinitialisation de mot de passe + + + + + + +
+ + + + + + + + + + +
+

Reinitialisation du mot de passe

+
+

+ Vous avez demande la reinitialisation de votre mot de passe. Cliquez sur le bouton ci-dessous pour creer un nouveau mot de passe. +

+ + + + +
+ Reinitialiser mon mot de passe +
+

+ Si le bouton ne fonctionne pas, copiez ce lien : +

+

+ {reset_link} +

+

+ Ce lien expire dans 1 heure. +

+
+

+ Si vous n'avez pas demande cette reinitialisation, ignorez cet email. Votre mot de passe restera inchange. +

+
+
+ + """ - return AuthEmailService._send_email( - email, " Réinitialisation de votre mot de passe - Sage Dataven", html_body + return cls._send_email( + email, "Reinitialisation de votre mot de passe - Sage Dataven", html_body ) - @staticmethod - def send_password_changed_notification(email: str) -> bool: + @classmethod + def send_password_changed_notification(cls, email: str) -> bool: html_body = """ - - - - - - -
-
-

Mot de passe modifié

-
-
-

Votre mot de passe a été changé avec succès

-

Ce message confirme que le mot de passe de votre compte Sage Dataven a été modifié.

- -

- Si vous n'êtes pas à l'origine de ce changement, contactez immédiatement notre support. -

-
- -
- - + + + + + + Mot de passe modifie + + + + + + +
+ + + + + + + + + + +
+

Mot de passe modifie

+
+

+ Votre mot de passe a ete modifie avec succes. +

+

+ Si vous n'etes pas a l'origine de ce changement, contactez immediatement notre support. +

+ + + + +
+

+ Securite : Toutes vos sessions actives ont ete deconnectees. Vous devrez vous reconnecter sur tous vos appareils. +

+
+
+

+ Sage Dataven - Notification de securite +

+
+
+ + """ - return AuthEmailService._send_email( - email, " Votre mot de passe a été modifié - Sage Dataven", html_body + return cls._send_email( + email, "Votre mot de passe a ete modifie - Sage Dataven", html_body + ) + + @classmethod + def send_security_alert( + cls, email: str, alert_type: str, details: str, ip_address: Optional[str] = None + ) -> bool: + ip_info = ( + f"

Adresse IP : {ip_address}

" + if ip_address + else "" + ) + + html_body = f""" + + + + + + Alerte de securite + + + + + + +
+ + + + + + + + + + +
+

Alerte de securite

+
+

+ {alert_type} +

+

+ {details} +

+ {ip_info} +

+ Si vous reconnaissez cette activite, vous pouvez ignorer ce message. Sinon, nous vous recommandons de changer votre mot de passe immediatement. +

+
+

+ Sage Dataven - Alerte de securite automatique +

+
+
+ + + """ + + return cls._send_email( + email, f"Alerte de securite : {alert_type} - Sage Dataven", html_body ) diff --git a/services/redis_service.py b/services/redis_service.py new file mode 100644 index 0000000..29b5de8 --- /dev/null +++ b/services/redis_service.py @@ -0,0 +1,200 @@ +import redis.asyncio as redis +from typing import Optional +import logging +import json + +from config.config import settings + +logger = logging.getLogger(__name__) + + +class RedisService: + _instance: Optional["RedisService"] = None + _client: Optional[redis.Redis] = None + + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + async def connect(self) -> None: + if self._client is not None: + return + + try: + self._client = redis.from_url( + settings.redis_url, + password=settings.redis_password, + encoding="utf-8", + decode_responses=True, + socket_timeout=5.0, + socket_connect_timeout=5.0, + ) + + await self._client.ping() + logger.info("Connexion Redis etablie") + + except Exception as e: + logger.error(f"Erreur connexion Redis: {e}") + self._client = None + raise + + async def disconnect(self) -> None: + if self._client: + await self._client.close() + self._client = None + logger.info("Connexion Redis fermee") + + async def is_connected(self) -> bool: + if not self._client: + return False + try: + await self._client.ping() + return True + except Exception: + return False + + @property + def client(self) -> redis.Redis: + if not self._client: + raise RuntimeError("Redis non connecte. Appelez connect() d'abord.") + return self._client + + async def blacklist_token(self, token_id: str, ttl_seconds: int) -> bool: + try: + key = f"{settings.token_blacklist_prefix}{token_id}" + await self.client.setex(key, ttl_seconds, "1") + logger.debug(f"Token {token_id[:8]}... ajoute a la blacklist") + return True + except Exception as e: + logger.error(f"Erreur blacklist token: {e}") + return False + + async def is_token_blacklisted(self, token_id: str) -> bool: + try: + key = f"{settings.token_blacklist_prefix}{token_id}" + result = await self.client.exists(key) + return result > 0 + except Exception as e: + logger.error(f"Erreur verification blacklist: {e}") + return False + + async def blacklist_user_tokens( + self, user_id: str, ttl_seconds: int = 86400 + ) -> bool: + try: + key = f"{settings.token_blacklist_prefix}user:{user_id}" + import time + + await self.client.setex(key, ttl_seconds, str(int(time.time()))) + logger.info(f"Tokens utilisateur {user_id} invalides") + return True + except Exception as e: + logger.error(f"Erreur invalidation tokens utilisateur: {e}") + return False + + async def get_user_token_invalidation_time(self, user_id: str) -> Optional[int]: + try: + key = f"{settings.token_blacklist_prefix}user:{user_id}" + result = await self.client.get(key) + return int(result) if result else None + except Exception as e: + logger.error(f"Erreur lecture invalidation: {e}") + return None + + async def increment_rate_limit(self, key: str, window_seconds: int) -> int: + try: + full_key = f"{settings.rate_limit_prefix}{key}" + + pipe = self.client.pipeline() + pipe.incr(full_key) + pipe.expire(full_key, window_seconds) + results = await pipe.execute() + + return results[0] + except Exception as e: + logger.error(f"Erreur increment rate limit: {e}") + return 0 + + async def get_rate_limit_count(self, key: str) -> int: + try: + full_key = f"{settings.rate_limit_prefix}{key}" + result = await self.client.get(full_key) + return int(result) if result else 0 + except Exception as e: + logger.error(f"Erreur lecture rate limit: {e}") + return 0 + + async def reset_rate_limit(self, key: str) -> bool: + try: + full_key = f"{settings.rate_limit_prefix}{key}" + await self.client.delete(full_key) + return True + except Exception as e: + logger.error(f"Erreur reset rate limit: {e}") + return False + + async def store_refresh_token_metadata( + self, token_id: str, user_id: str, fingerprint_hash: str, ttl_seconds: int + ) -> bool: + try: + key = f"refresh_token:{token_id}" + data = json.dumps( + { + "user_id": user_id, + "fingerprint_hash": fingerprint_hash, + "used": False, + } + ) + await self.client.setex(key, ttl_seconds, data) + return True + except Exception as e: + logger.error(f"Erreur stockage metadata refresh token: {e}") + return False + + async def get_refresh_token_metadata(self, token_id: str) -> Optional[dict]: + try: + key = f"refresh_token:{token_id}" + data = await self.client.get(key) + return json.loads(data) if data else None + except Exception as e: + logger.error(f"Erreur lecture metadata refresh token: {e}") + return None + + async def mark_refresh_token_used(self, token_id: str) -> bool: + try: + key = f"refresh_token:{token_id}" + data = await self.client.get(key) + if not data: + return False + + metadata = json.loads(data) + metadata["used"] = True + metadata["used_at"] = int(__import__("time").time()) + + ttl = await self.client.ttl(key) + if ttl > 0: + await self.client.setex(key, ttl, json.dumps(metadata)) + + return True + except Exception as e: + logger.error(f"Erreur marquage refresh token: {e}") + return False + + async def delete_refresh_token(self, token_id: str) -> bool: + try: + key = f"refresh_token:{token_id}" + result = await self.client.delete(key) + return result > 0 + except Exception as e: + logger.error(f"Erreur suppression refresh token: {e}") + return False + + +redis_service = RedisService() + + +async def get_redis() -> RedisService: + if not await redis_service.is_connected(): + await redis_service.connect() + return redis_service diff --git a/services/sage_gateway.py b/services/sage_gateway.py index 48276a4..48c580b 100644 --- a/services/sage_gateway.py +++ b/services/sage_gateway.py @@ -6,7 +6,7 @@ import httpx from datetime import datetime from typing import Optional, Tuple, List from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy import select, update, and_ +from sqlalchemy import false, select, update, and_ import logging from config.config import settings @@ -20,8 +20,6 @@ class SageGatewayService: self.session = session async def create(self, user_id: str, data: dict) -> SageGatewayConfig: - """Créer une nouvelle configuration gateway""" - if data.get("is_active"): await self._deactivate_all_for_user(user_id) @@ -55,7 +53,6 @@ class SageGatewayService: and_( SageGatewayConfig.id == gateway_id, SageGatewayConfig.user_id == user_id, - SageGatewayConfig.is_deleted, ) ) ) @@ -67,7 +64,7 @@ class SageGatewayService: query = select(SageGatewayConfig).where(SageGatewayConfig.user_id == user_id) if not include_deleted: - query = query.where(SageGatewayConfig.is_deleted) + query = query.where(SageGatewayConfig.is_deleted.is_(false())) query = query.order_by( SageGatewayConfig.is_active.desc(), @@ -81,8 +78,6 @@ class SageGatewayService: async def update( self, gateway_id: str, user_id: str, data: dict ) -> Optional[SageGatewayConfig]: - """Mettre à jour une gateway""" - gateway = await self.get_by_id(gateway_id, user_id) if not gateway: return None @@ -131,7 +126,6 @@ class SageGatewayService: async def activate( self, gateway_id: str, user_id: str ) -> Optional[SageGatewayConfig]: - """Activer une gateway (désactive les autres)""" gateway = await self.get_by_id(gateway_id, user_id) if not gateway: return None @@ -167,7 +161,7 @@ class SageGatewayService: and_( SageGatewayConfig.user_id == user_id, SageGatewayConfig.is_active, - SageGatewayConfig.is_deleted, + SageGatewayConfig.is_deleted.is_(false()), ) ) ) @@ -277,8 +271,6 @@ class SageGatewayService: return {"success": False, "status": "error", "error": str(e)} async def record_request(self, gateway_id: str, success: bool) -> None: - """Enregistrer une requête (succès/échec)""" - if not gateway_id: return @@ -297,7 +289,6 @@ class SageGatewayService: await self.session.commit() async def get_stats(self, user_id: str) -> dict: - """Statistiques d'utilisation pour un utilisateur""" gateways = await self.list_for_user(user_id) total_requests = sum(g.total_requests for g in gateways) @@ -323,8 +314,6 @@ class SageGatewayService: } async def _deactivate_all_for_user(self, user_id: str) -> None: - """Désactiver toutes les gateways d'un utilisateur""" - await self.session.execute( update(SageGatewayConfig) .where(SageGatewayConfig.user_id == user_id) @@ -332,8 +321,6 @@ class SageGatewayService: ) async def _unset_default_for_user(self, user_id: str) -> None: - """Retirer le flag default de toutes les gateways""" - await self.session.execute( update(SageGatewayConfig) .where(SageGatewayConfig.user_id == user_id) @@ -342,8 +329,6 @@ class SageGatewayService: def gateway_response_from_model(gateway: SageGatewayConfig) -> dict: - """Convertir un model en réponse API (masque le token)""" - token_preview = ( f"****{gateway.gateway_token[-4:]}" if gateway.gateway_token else "****" ) @@ -380,8 +365,6 @@ def gateway_response_from_model(gateway: SageGatewayConfig) -> dict: "description": gateway.description, "gateway_url": gateway.gateway_url, "token_preview": token_preview, - "sage_database": gateway.sage_database, - "sage_company": gateway.sage_company, "is_active": gateway.is_active, "is_default": gateway.is_default, "priority": gateway.priority, diff --git a/services/token_service.py b/services/token_service.py new file mode 100644 index 0000000..247bfa9 --- /dev/null +++ b/services/token_service.py @@ -0,0 +1,357 @@ +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy import false, select, and_, or_, delete, true +from datetime import datetime, timedelta +from typing import Optional, Tuple, Dict, Any +import uuid +import logging +import time + +from config.config import settings +from database import RefreshToken, User +from services.redis_service import redis_service +from security.auth import ( + create_access_token, + create_refresh_token, + create_csrf_token, + decode_token, + hash_token, + generate_session_id, +) + +logger = logging.getLogger(__name__) + + +class TokenService: + @classmethod + async def create_token_pair( + cls, + session: AsyncSession, + user: User, + fingerprint_hash: str, + device_info: str, + ip_address: str, + ) -> Tuple[str, str, str, str]: + session_id = generate_session_id() + + access_token = create_access_token( + data={ + "sub": user.id, + "email": user.email, + "role": user.role, + "sid": session_id, + }, + fingerprint_hash=fingerprint_hash, + ) + + refresh_token_jwt, token_id = create_refresh_token( + user_id=user.id, fingerprint_hash=fingerprint_hash + ) + + csrf_token = create_csrf_token(session_id) + + token_record = RefreshToken( + id=str(uuid.uuid4()), + user_id=user.id, + token_hash=hash_token(refresh_token_jwt), + token_id=token_id, + fingerprint_hash=fingerprint_hash, + device_info=device_info[:500] if device_info else None, + ip_address=ip_address, + expires_at=datetime.now() + + timedelta(days=settings.refresh_token_expire_days), + created_at=datetime.now(), + ) + + session.add(token_record) + await session.flush() + + await redis_service.store_refresh_token_metadata( + token_id=token_id, + user_id=user.id, + fingerprint_hash=fingerprint_hash, + ttl_seconds=settings.refresh_token_expire_days * 24 * 60 * 60, + ) + + logger.info(f"Token pair cree pour utilisateur {user.email}") + + return access_token, refresh_token_jwt, csrf_token, session_id + + @classmethod + async def refresh_tokens( + cls, + session: AsyncSession, + refresh_token: str, + fingerprint_hash: str, + device_info: str, + ip_address: str, + ) -> Optional[Tuple[str, str, str, str]]: + payload = decode_token(refresh_token, expected_type="refresh") + if not payload: + logger.warning("Refresh token invalide ou expire") + return None + + user_id = payload.get("sub") + token_id = payload.get("jti") + stored_fingerprint = payload.get("fph") + + if not user_id or not token_id: + logger.warning("Refresh token malformed") + return None + + if await redis_service.is_token_blacklisted(token_id): + logger.warning(f"Refresh token {token_id[:8]}... est blackliste") + return None + + token_hash = hash_token(refresh_token) + result = await session.execute( + select(RefreshToken).where( + and_( + RefreshToken.token_hash == token_hash, + RefreshToken.user_id == user_id, + RefreshToken.is_revoked.is_(false()), + RefreshToken.expires_at > datetime.now(), + ) + ) + ) + token_record = result.scalar_one_or_none() + + if not token_record: + logger.warning(f"Refresh token non trouve en DB pour user {user_id}") + await cls._handle_potential_token_theft(session, user_id, token_id) + return None + + if settings.refresh_token_rotation_enabled and token_record.is_used: + used_at = token_record.used_at + if used_at: + time_since_use = (datetime.now() - used_at).total_seconds() + if time_since_use > settings.refresh_token_reuse_window_seconds: + logger.warning( + f"Reutilisation de refresh token detectee pour user {user_id}" + ) + await cls._handle_potential_token_theft(session, user_id, token_id) + return None + + if stored_fingerprint and fingerprint_hash: + if stored_fingerprint != fingerprint_hash: + logger.warning(f"Fingerprint mismatch pour user {user_id}") + return None + + result = await session.execute( + select(User).where(and_(User.id == user_id, User.is_active.is_(true()))) + ) + user = result.scalar_one_or_none() + + if not user: + logger.warning(f"Utilisateur {user_id} introuvable ou inactif") + return None + + session_id = generate_session_id() + + new_access_token = create_access_token( + data={ + "sub": user.id, + "email": user.email, + "role": user.role, + "sid": session_id, + }, + fingerprint_hash=fingerprint_hash, + ) + + new_csrf_token = create_csrf_token(session_id) + + if settings.refresh_token_rotation_enabled: + token_record.is_used = True + token_record.used_at = datetime.now() + + new_refresh_jwt, new_token_id = create_refresh_token( + user_id=user.id, fingerprint_hash=fingerprint_hash + ) + + new_token_record = RefreshToken( + id=str(uuid.uuid4()), + user_id=user.id, + token_hash=hash_token(new_refresh_jwt), + token_id=new_token_id, + fingerprint_hash=fingerprint_hash, + device_info=device_info[:500] if device_info else None, + ip_address=ip_address, + expires_at=datetime.now() + + timedelta(days=settings.refresh_token_expire_days), + created_at=datetime.now(), + ) + + token_record.replaced_by = new_token_record.id + + session.add(new_token_record) + + await redis_service.mark_refresh_token_used(token_id) + await redis_service.store_refresh_token_metadata( + token_id=new_token_id, + user_id=user.id, + fingerprint_hash=fingerprint_hash, + ttl_seconds=settings.refresh_token_expire_days * 24 * 60 * 60, + ) + + logger.info(f"Refresh token rotation pour user {user.email}") + + return new_access_token, new_refresh_jwt, new_csrf_token, session_id + else: + token_record.last_used_at = datetime.now() + return new_access_token, refresh_token, new_csrf_token, session_id + + @classmethod + async def revoke_token( + cls, session: AsyncSession, refresh_token: str, reason: str = "user_logout" + ) -> bool: + payload = decode_token(refresh_token, expected_type="refresh") + if not payload: + return False + + token_id = payload.get("jti") + user_id = payload.get("sub") + exp = payload.get("exp", 0) + + ttl_seconds = max(0, exp - int(time.time())) + await redis_service.blacklist_token(token_id, ttl_seconds) + + token_hash = hash_token(refresh_token) + result = await session.execute( + select(RefreshToken).where(RefreshToken.token_hash == token_hash) + ) + token_record = result.scalar_one_or_none() + + if token_record: + token_record.is_revoked = True + token_record.revoked_at = datetime.now() + token_record.revoked_reason = reason + + await redis_service.delete_refresh_token(token_id) + + logger.info(f"Token revoque pour user {user_id}: {reason}") + + return True + + @classmethod + async def revoke_all_user_tokens( + cls, session: AsyncSession, user_id: str, reason: str = "security_action" + ) -> int: + result = await session.execute( + select(RefreshToken).where( + and_( + RefreshToken.user_id == user_id, + RefreshToken.is_revoked.is_(false()), + ) + ) + ) + tokens = result.scalars().all() + + count = 0 + for token in tokens: + token.is_revoked = True + token.revoked_at = datetime.now() + token.revoked_reason = reason + + await redis_service.blacklist_token( + token.token_id, settings.refresh_token_expire_days * 24 * 60 * 60 + ) + await redis_service.delete_refresh_token(token.token_id) + count += 1 + + await redis_service.blacklist_user_tokens( + user_id, settings.refresh_token_expire_days * 24 * 60 * 60 + ) + + logger.info(f"{count} tokens revoques pour user {user_id}: {reason}") + + return count + + @classmethod + async def _handle_potential_token_theft( + cls, session: AsyncSession, user_id: str, token_id: str + ) -> None: + logger.warning( + f"Potentiel vol de token detecte pour user {user_id}, token {token_id[:8]}..." + ) + + await cls.revoke_all_user_tokens( + session, user_id, reason="potential_token_theft" + ) + + @classmethod + async def validate_access_token( + cls, token: str, fingerprint_hash: Optional[str] = None + ) -> Optional[Dict[str, Any]]: + payload = decode_token(token, expected_type="access") + if not payload: + return None + + token_id = payload.get("jti") + if token_id and await redis_service.is_token_blacklisted(token_id): + logger.debug(f"Access token {token_id[:8]}... est blackliste") + return None + + user_id = payload.get("sub") + if user_id: + invalidation_time = await redis_service.get_user_token_invalidation_time( + user_id + ) + if invalidation_time: + token_iat = payload.get("iat", 0) + if token_iat < invalidation_time: + logger.debug("Access token emis avant invalidation globale") + return None + + if fingerprint_hash: + stored_fingerprint = payload.get("fph") + if stored_fingerprint and stored_fingerprint != fingerprint_hash: + logger.warning("Fingerprint mismatch sur access token") + return None + + return payload + + @classmethod + async def cleanup_expired_tokens(cls, session: AsyncSession) -> int: + result = await session.execute( + delete(RefreshToken).where( + or_( + RefreshToken.expires_at < datetime.now(), + and_( + RefreshToken.is_revoked.is_(true()), + RefreshToken.revoked_at < datetime.now() - timedelta(days=7), + ), + ) + ) + ) + + count = result.rowcount + logger.info(f"{count} tokens expires nettoyes") + + return count + + @classmethod + async def get_user_active_sessions( + cls, session: AsyncSession, user_id: str + ) -> list: + result = await session.execute( + select(RefreshToken) + .where( + and_( + RefreshToken.user_id == user_id, + RefreshToken.is_revoked.is_(false()), + RefreshToken.expires_at > datetime.now(), + ) + ) + .order_by(RefreshToken.created_at.desc()) + ) + tokens = result.scalars().all() + + return [ + { + "id": t.id, + "device_info": t.device_info, + "ip_address": t.ip_address, + "created_at": t.created_at.isoformat(), + "last_used_at": t.last_used_at.isoformat() if t.last_used_at else None, + } + for t in tokens + ]