feat(auth): implement comprehensive security enhancements

This commit is contained in:
Fanilo-Nantenaina 2026-01-02 17:56:28 +03:00
parent 81843dfaee
commit e97ff73e16
23 changed files with 3085 additions and 673 deletions

View file

@ -1,32 +1,97 @@
# ============================================ # === Environment ===
# Configuration Linux VPS - API Principale ENVIRONMENT=development
# ============================================ # Options: development, staging, production
# === Sage Gateway Windows === # === JWT & Authentication ===
SAGE_GATEWAY_URL=http://192.168.1.50:8100 # IMPORTANT: Generer des secrets uniques et forts en production
SAGE_GATEWAY_TOKEN=4e8f9c2a7b1d5e3f9a0c8b7d6e5f4a3b2c1d0e9f8a7b6c5d4e3f2a1b0c9d8e7f # python -c "import secrets; print(secrets.token_urlsafe(64))"
JWT_SECRET=CHANGE_ME_IN_PRODUCTION_USE_STRONG_SECRET_64_CHARS_MIN
JWT_ALGORITHM=HS256
ACCESS_TOKEN_EXPIRE_MINUTES=15
REFRESH_TOKEN_EXPIRE_DAYS=7
CSRF_TOKEN_EXPIRE_MINUTES=60
# === Base de données === # === Cookie Settings ===
COOKIE_DOMAIN=
# Laisser vide pour localhost, sinon ".example.com" pour sous-domaines
COOKIE_SECURE=false
# Mettre true en production avec HTTPS
COOKIE_SAMESITE=strict
# Options: strict, lax, none
COOKIE_HTTPONLY=true
COOKIE_ACCESS_TOKEN_NAME=access_token
COOKIE_REFRESH_TOKEN_NAME=refresh_token
COOKIE_CSRF_TOKEN_NAME=csrf_token
# === Redis (Token Blacklist & Rate Limiting) ===
REDIS_URL=redis://localhost:6379/0
REDIS_PASSWORD=
REDIS_SSL=false
TOKEN_BLACKLIST_PREFIX=blacklist:
RATE_LIMIT_PREFIX=ratelimit:
# === Rate Limiting ===
RATE_LIMIT_LOGIN_ATTEMPTS=5
RATE_LIMIT_LOGIN_WINDOW_MINUTES=15
RATE_LIMIT_API_REQUESTS=100
RATE_LIMIT_API_WINDOW_SECONDS=60
# === Password Security ===
PASSWORD_MIN_LENGTH=8
PASSWORD_REQUIRE_UPPERCASE=true
PASSWORD_REQUIRE_LOWERCASE=true
PASSWORD_REQUIRE_DIGIT=true
PASSWORD_REQUIRE_SPECIAL=true
ACCOUNT_LOCKOUT_THRESHOLD=5
ACCOUNT_LOCKOUT_DURATION_MINUTES=30
# === Device Fingerprint ===
FINGERPRINT_SECRET=
# Si vide, utilise JWT_SECRET
FINGERPRINT_COMPONENTS=user_agent,accept_language,accept_encoding
# === Refresh Token Rotation ===
REFRESH_TOKEN_ROTATION_ENABLED=true
REFRESH_TOKEN_REUSE_WINDOW_SECONDS=10
# === Database ===
DATABASE_URL=sqlite+aiosqlite:///./data/sage_dataven.db DATABASE_URL=sqlite+aiosqlite:///./data/sage_dataven.db
# PostgreSQL: postgresql+asyncpg://user:password@localhost:5432/dbname
# === SMTP === # === Sage Gateway (Windows) ===
SMTP_HOST=smtp.office365.com SAGE_GATEWAY_URL=http://windows-server:5000
SAGE_GATEWAY_TOKEN=your_gateway_token
# === Frontend ===
FRONTEND_URL=http://localhost:3000
# === SMTP (Email) ===
SMTP_HOST=smtp.example.com
SMTP_PORT=587 SMTP_PORT=587
SMTP_USER=commercial@monentreprise.fr SMTP_USER=noreply@example.com
SMTP_PASSWORD=MonMotDePasseEmail123! SMTP_PASSWORD=your_smtp_password
SMTP_FROM=commercial@monentreprise.fr SMTP_FROM=noreply@example.com
SMTP_USE_TLS=true
# === Universign === # === Universign (Signature electronique) ===
UNIVERSIGN_API_KEY=your_real_universign_key_here UNIVERSIGN_API_KEY=your_universign_api_key
UNIVERSIGN_API_URL=https://api.universign.com/v1 UNIVERSIGN_API_URL=https://api.universign.com/v1
# === API === # === API Server ===
API_HOST=0.0.0.0 API_HOST=0.0.0.0
API_PORT=8002 API_PORT=8000
API_RELOAD=False API_RELOAD=true
# Mettre false en production
# === Email Queue === # === CORS ===
MAX_EMAIL_WORKERS=3 # Liste separee par virgules des origines autorisees
CORS_ORIGINS=["*"]
# === Logs === # === Sage Document Types ===
LOG_LEVEL=INFO SAGE_TYPE_DEVIS=0
SAGE_TYPE_BON_COMMANDE=10
SAGE_TYPE_PREPARATION=20
SAGE_TYPE_BON_LIVRAISON=30
SAGE_TYPE_BON_RETOUR=40
SAGE_TYPE_BON_AVOIR=50
SAGE_TYPE_FACTURE=60

116
api.py
View file

@ -1,6 +1,6 @@
from fastapi import FastAPI, HTTPException, Path, Query, Depends, status, Body, Request from fastapi import FastAPI, HTTPException, Path, Query, Depends, status, Body, Request
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel, Field, EmailStr from pydantic import BaseModel, Field, EmailStr
from typing import List, Optional from typing import List, Optional
@ -20,17 +20,17 @@ from routes.auth import router as auth_router
from config.config import settings from config.config import settings
from database import ( from database import (
init_db, init_db,
close_db,
async_session_factory, async_session_factory,
get_session, get_session,
EmailLog, EmailLog,
StatutEmail as StatutEmailEnum, StatutEmail as StatutEmailEnum,
WorkflowLog, WorkflowLog,
SignatureLog, SignatureLog,
StatutSignature as StatutSignatureEnum, StatutSignature,
) )
from services.email_queue import email_queue from services.email_queue import email_queue
from sage.sage_client import sage_client, SageGatewayClient from sage.sage_client import sage_client, SageGatewayClient
from schemas import ( from schemas import (
TiersDetails, TiersDetails,
BaremeRemiseResponse, BaremeRemiseResponse,
@ -58,7 +58,6 @@ from schemas import (
LivraisonCreateRequest, LivraisonCreateRequest,
LivraisonUpdateRequest, LivraisonUpdateRequest,
SignatureRequest, SignatureRequest,
StatutSignature,
ArticleCreateRequest, ArticleCreateRequest,
ArticleResponse, ArticleResponse,
ArticleUpdateRequest, ArticleUpdateRequest,
@ -72,13 +71,20 @@ from schemas import (
ContactUpdate, ContactUpdate,
) )
from utils.normalization import normaliser_type_tiers from utils.normalization import normaliser_type_tiers
from routes.sage_gateway import router as sage_gateway_router from routes.sage_gateway import router as sage_gateway_router
from services.redis_service import redis_service
from core.sage_context import ( from core.sage_context import (
get_sage_client_for_user, get_sage_client_for_user,
get_gateway_context_for_user, get_gateway_context_for_user,
GatewayContext, GatewayContext,
) )
from middleware.security import (
setup_security_middleware,
init_security_services,
shutdown_security_services,
RateLimitMiddleware,
)
if os.path.exists("/app"): if os.path.exists("/app"):
LOGS_DIR = FilePath("/app/logs") LOGS_DIR = FilePath("/app/logs")
@ -112,33 +118,61 @@ async def lifespan(app: FastAPI):
email_queue.start(num_workers=settings.max_email_workers) email_queue.start(num_workers=settings.max_email_workers)
logger.info("Email queue démarrée") logger.info("Email queue démarrée")
try:
await init_security_services()
logger.info("Services de securite initialises")
except Exception as e:
logger.warning(f"Redis non disponible, mode degrade active: {e}")
yield yield
await shutdown_security_services()
await close_db()
email_queue.stop() email_queue.stop()
logger.info("Services arrêtés") logger.info("Services arrêtés")
app = FastAPI( app = FastAPI(
title="Sage Gateways", title="Sage API Securisee",
version="3.0.0", version="3.0.0",
description="Configuration multi-tenant des connexions Sage Gateway", description="API avec authentification securisee par cookies HttpOnly",
lifespan=lifespan, lifespan=lifespan,
openapi_tags=TAGS_METADATA, openapi_tags=TAGS_METADATA,
docs_url="/docs" if settings.is_development else None,
redoc_url="/redoc" if settings.is_development else None,
) )
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=settings.cors_origins, allow_origins=settings.cors_origins,
allow_methods=["GET", "POST", "PUT", "DELETE"], allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
allow_headers=["*"], allow_headers=["*"],
allow_credentials=True, allow_credentials=True,
expose_headers=["X-RateLimit-Limit", "X-RateLimit-Remaining", "X-RateLimit-Reset"],
) )
setup_security_middleware(app)
if settings.is_production:
app.add_middleware(RateLimitMiddleware)
app.include_router(auth_router) app.include_router(auth_router)
app.include_router(sage_gateway_router) app.include_router(sage_gateway_router)
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
"""Gestionnaire global d'exceptions."""
logger.error(f"Erreur non geree: {exc}", exc_info=True)
return JSONResponse(
status_code=500,
content={"detail": "Erreur interne du serveur", "type": "internal_error"},
)
async def universign_envoyer( async def universign_envoyer(
doc_id: str, doc_id: str,
pdf_bytes: bytes, pdf_bytes: bytes,
@ -1135,7 +1169,7 @@ async def envoyer_signature_optimise(
signer_url=resultat["signer_url"], signer_url=resultat["signer_url"],
email_signataire=demande.email_signataire, email_signataire=demande.email_signataire,
nom_signataire=demande.nom_signataire, nom_signataire=demande.nom_signataire,
statut=StatutSignatureEnum.ENVOYE, statut=StatutSignature.ENVOYE,
date_envoi=datetime.now(), date_envoi=datetime.now(),
) )
@ -1191,7 +1225,7 @@ async def webhook_universign(
return {"status": "not_found"} return {"status": "not_found"}
if event_type == "transaction.completed": if event_type == "transaction.completed":
signature_log.statut = StatutSignatureEnum.SIGNE signature_log.statut = StatutSignature.SIGNE
signature_log.date_signature = datetime.now() signature_log.date_signature = datetime.now()
logger.info(f"Signature confirmée: {signature_log.document_id}") logger.info(f"Signature confirmée: {signature_log.document_id}")
@ -1242,11 +1276,11 @@ async def webhook_universign(
) )
elif event_type == "transaction.refused": elif event_type == "transaction.refused":
signature_log.statut = StatutSignatureEnum.REFUSE signature_log.statut = StatutSignature.REFUSE
logger.warning(f"Signature refusée: {signature_log.document_id}") logger.warning(f"Signature refusée: {signature_log.document_id}")
elif event_type == "transaction.expired": elif event_type == "transaction.expired":
signature_log.statut = StatutSignatureEnum.EXPIRE signature_log.statut = StatutSignature.EXPIRE
logger.warning(f"⏰ Transaction expirée: {signature_log.document_id}") logger.warning(f"⏰ Transaction expirée: {signature_log.document_id}")
await session.commit() await session.commit()
@ -1271,7 +1305,7 @@ async def relancer_signatures_automatique(session: AsyncSession = Depends(get_se
query = select(SignatureLog).where( query = select(SignatureLog).where(
SignatureLog.statut.in_( SignatureLog.statut.in_(
[StatutSignatureEnum.EN_ATTENTE, StatutSignatureEnum.ENVOYE] [StatutSignature.EN_ATTENTE, StatutSignature.ENVOYE]
), ),
SignatureLog.date_envoi < date_limite, SignatureLog.date_envoi < date_limite,
SignatureLog.nb_relances < 3, # Max 3 relances SignatureLog.nb_relances < 3, # Max 3 relances
@ -1288,7 +1322,7 @@ async def relancer_signatures_automatique(session: AsyncSession = Depends(get_se
jours_restants = 30 - nb_jours # Lien expire après 30 jours jours_restants = 30 - nb_jours # Lien expire après 30 jours
if jours_restants <= 0: if jours_restants <= 0:
signature.statut = StatutSignatureEnum.EXPIRE signature.statut = StatutSignature.EXPIRE
continue continue
template = templates_signature_email["relance_signature"] template = templates_signature_email["relance_signature"]
@ -1394,7 +1428,7 @@ async def lister_signatures(
query = select(SignatureLog).order_by(SignatureLog.date_envoi.desc()) query = select(SignatureLog).order_by(SignatureLog.date_envoi.desc())
if statut: if statut:
statut_db = StatutSignatureEnum[statut.value] statut_db = StatutSignature[statut.value]
query = query.where(SignatureLog.statut == statut_db) query = query.where(SignatureLog.statut == statut_db)
query = query.limit(limit) query = query.limit(limit)
@ -1437,15 +1471,15 @@ async def statut_signature_detail(
if statut_universign.get("statut") != "ERREUR": if statut_universign.get("statut") != "ERREUR":
statut_map = { statut_map = {
"EN_ATTENTE": StatutSignatureEnum.EN_ATTENTE, "EN_ATTENTE": StatutSignature.EN_ATTENTE,
"ENVOYE": StatutSignatureEnum.ENVOYE, "ENVOYE": StatutSignature.ENVOYE,
"SIGNE": StatutSignatureEnum.SIGNE, "SIGNE": StatutSignature.SIGNE,
"REFUSE": StatutSignatureEnum.REFUSE, "REFUSE": StatutSignature.REFUSE,
"EXPIRE": StatutSignatureEnum.EXPIRE, "EXPIRE": StatutSignature.EXPIRE,
} }
nouveau_statut = statut_map.get( nouveau_statut = statut_map.get(
statut_universign["statut"], StatutSignatureEnum.EN_ATTENTE statut_universign["statut"], StatutSignature.EN_ATTENTE
) )
signature_log.statut = nouveau_statut signature_log.statut = nouveau_statut
@ -1477,9 +1511,7 @@ async def statut_signature_detail(
@app.post("/signatures/refresh-all", tags=["Signatures"]) @app.post("/signatures/refresh-all", tags=["Signatures"])
async def rafraichir_statuts_signatures(session: AsyncSession = Depends(get_session)): async def rafraichir_statuts_signatures(session: AsyncSession = Depends(get_session)):
query = select(SignatureLog).where( query = select(SignatureLog).where(
SignatureLog.statut.in_( SignatureLog.statut.in_([StatutSignature.EN_ATTENTE, StatutSignature.ENVOYE])
[StatutSignatureEnum.EN_ATTENTE, StatutSignatureEnum.ENVOYE]
)
) )
result = await session.execute(query) result = await session.execute(query)
@ -1492,9 +1524,9 @@ async def rafraichir_statuts_signatures(session: AsyncSession = Depends(get_sess
if statut_universign.get("statut") != "ERREUR": if statut_universign.get("statut") != "ERREUR":
statut_map = { statut_map = {
"SIGNE": StatutSignatureEnum.SIGNE, "SIGNE": StatutSignature.SIGNE,
"REFUSE": StatutSignatureEnum.REFUSE, "REFUSE": StatutSignature.REFUSE,
"EXPIRE": StatutSignatureEnum.EXPIRE, "EXPIRE": StatutSignature.EXPIRE,
} }
nouveau = statut_map.get(statut_universign["statut"]) nouveau = statut_map.get(statut_universign["statut"])
@ -1548,7 +1580,7 @@ async def envoyer_devis_signature(
signer_url=resultat["signer_url"], signer_url=resultat["signer_url"],
email_signataire=request.email_signataire, email_signataire=request.email_signataire,
nom_signataire=request.nom_signataire, nom_signataire=request.nom_signataire,
statut=StatutSignatureEnum.ENVOYE, statut=StatutSignature.ENVOYE,
date_envoi=datetime.now(), date_envoi=datetime.now(),
) )
@ -1694,7 +1726,7 @@ async def relancer_devis_signature(
signer_url=resultat["signer_url"], signer_url=resultat["signer_url"],
email_signataire=contact["email"], email_signataire=contact["email"],
nom_signataire=contact["nom"] or contact["client_intitule"], nom_signataire=contact["nom"] or contact["client_intitule"],
statut=StatutSignatureEnum.ENVOYE, statut=StatutSignature.ENVOYE,
date_envoi=datetime.now(), date_envoi=datetime.now(),
est_relance=True, est_relance=True,
nb_relances=1, nb_relances=1,
@ -3158,17 +3190,26 @@ async def health_check(
sage: SageGatewayClient = Depends(get_sage_client_for_user), sage: SageGatewayClient = Depends(get_sage_client_for_user),
): ):
gateway_health = sage.health() gateway_health = sage.health()
redis_status = "connected"
try:
if not await redis_service.is_connected():
redis_status = "disconnected"
except Exception:
redis_status = "error"
return { return {
"status": "healthy", "status": "healthy",
"sage_gateway": gateway_health, "sage_gateway": gateway_health,
"using_gateway_id": sage.gateway_id, "using_gateway_id": sage.gateway_id,
"timestamp": datetime.now().isoformat(),
"environment": settings.environment.value,
"services": {"redis": redis_status},
"email_queue": { "email_queue": {
"running": email_queue.running, "running": email_queue.running,
"workers": len(email_queue.workers), "workers": len(email_queue.workers),
"queue_size": email_queue.queue.qsize(), "queue_size": email_queue.queue.qsize(),
}, },
"timestamp": datetime.now().isoformat(),
} }
@ -3177,22 +3218,13 @@ async def root():
return { return {
"api": "Sage 100c Dataven - VPS Linux", "api": "Sage 100c Dataven - VPS Linux",
"version": "2.0.0", "version": "2.0.0",
"documentation": "/docs", "documentation": "/docs"
if settings.is_development
else "Disabled in production",
"health": "/health", "health": "/health",
} }
@app.get("/admin/cache/info", tags=["Admin"])
async def info_cache():
try:
cache_info = sage_client.get_cache_info()
return cache_info
except Exception as e:
logger.error(f"Erreur info cache: {e}")
raise HTTPException(500, str(e))
@app.get("/admin/queue/status", tags=["Admin"]) @app.get("/admin/queue/status", tags=["Admin"])
async def statut_queue(): async def statut_queue():
return { return {

View file

@ -1,5 +1,12 @@
from pydantic_settings import BaseSettings, SettingsConfigDict from pydantic_settings import BaseSettings, SettingsConfigDict
from typing import List from typing import List, Optional
from enum import Enum
class Environment(str, Enum):
DEVELOPMENT = "development"
STAGING = "staging"
PRODUCTION = "production"
class Settings(BaseSettings): class Settings(BaseSettings):
@ -7,12 +14,60 @@ class Settings(BaseSettings):
env_file=".env", env_file_encoding="utf-8", case_sensitive=False, extra="ignore" env_file=".env", env_file_encoding="utf-8", case_sensitive=False, extra="ignore"
) )
# === Environment ===
environment: Environment = Environment.DEVELOPMENT
# === JWT & Auth === # === JWT & Auth ===
jwt_secret: str jwt_secret: str
jwt_algorithm: str jwt_algorithm: str = "HS256"
access_token_expire_minutes: int access_token_expire_minutes: int = 15
refresh_token_expire_days: int refresh_token_expire_days: int = 7
csrf_token_expire_minutes: int = 60
# === Cookie Settings ===
cookie_domain: Optional[str] = None
cookie_secure: bool = True
cookie_samesite: str = "strict"
cookie_httponly: bool = True
cookie_access_token_name: str = "access_token"
cookie_refresh_token_name: str = "refresh_token"
cookie_csrf_token_name: str = "csrf_token"
# === Redis (Token Blacklist & Rate Limiting) ===
redis_url: str = "redis://localhost:6379/0"
redis_password: Optional[str] = None
redis_ssl: bool = False
token_blacklist_prefix: str = "blacklist:"
rate_limit_prefix: str = "ratelimit:"
# === Rate Limiting ===
rate_limit_login_attempts: int = 5
rate_limit_login_window_minutes: int = 15
rate_limit_api_requests: int = 100
rate_limit_api_window_seconds: int = 60
# === Security ===
password_min_length: int = 8
password_require_uppercase: bool = True
password_require_lowercase: bool = True
password_require_digit: bool = True
password_require_special: bool = True
account_lockout_threshold: int = 5
account_lockout_duration_minutes: int = 30
# === Fingerprint ===
fingerprint_secret: str = ""
fingerprint_components: List[str] = [
"user_agent",
"accept_language",
"accept_encoding",
]
# === Refresh Token Rotation ===
refresh_token_rotation_enabled: bool = True
refresh_token_reuse_window_seconds: int = 10
# === Sage Types ===
SAGE_TYPE_DEVIS: int = 0 SAGE_TYPE_DEVIS: int = 0
SAGE_TYPE_BON_COMMANDE: int = 10 SAGE_TYPE_BON_COMMANDE: int = 10
SAGE_TYPE_PREPARATION: int = 20 SAGE_TYPE_PREPARATION: int = 20
@ -21,12 +76,12 @@ class Settings(BaseSettings):
SAGE_TYPE_BON_AVOIR: int = 50 SAGE_TYPE_BON_AVOIR: int = 50
SAGE_TYPE_FACTURE: int = 60 SAGE_TYPE_FACTURE: int = 60
# === Sage Gateway (Windows) === # === Sage Gateway ===
sage_gateway_url: str sage_gateway_url: str
sage_gateway_token: str sage_gateway_token: str
frontend_url: str frontend_url: str
# === Base de données === # === Database ===
database_url: str = "sqlite+aiosqlite:///./data/sage_dataven.db" database_url: str = "sqlite+aiosqlite:///./data/sage_dataven.db"
# === SMTP === # === SMTP ===
@ -42,9 +97,9 @@ class Settings(BaseSettings):
universign_api_url: str universign_api_url: str
# === API === # === API ===
api_host: str api_host: str = "0.0.0.0"
api_port: int api_port: int = 8000
api_reload: bool = False api_reload: bool = True
# === Email Queue === # === Email Queue ===
max_email_workers: int = 3 max_email_workers: int = 3
@ -54,5 +109,13 @@ class Settings(BaseSettings):
# === CORS === # === CORS ===
cors_origins: List[str] = ["*"] cors_origins: List[str] = ["*"]
@property
def is_production(self) -> bool:
return self.environment == Environment.PRODUCTION
@property
def is_development(self) -> bool:
return self.environment == Environment.DEVELOPMENT
settings = Settings() settings = Settings()

View file

@ -1,41 +1,49 @@
from fastapi import Depends, HTTPException, status from fastapi import Depends, HTTPException, status, Request
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select from sqlalchemy import select
from database import get_session, User from typing import Optional, Tuple
from security.auth import decode_token
from typing import Optional
from datetime import datetime from datetime import datetime
import logging
security = HTTPBearer() from database import get_session
from database import User, AuditEventType
from services.token_service import TokenService
from services.audit_service import AuditService
from security.cookies import CookieManager
from security.fingerprint import DeviceFingerprint, get_client_ip
from security.csrf import CSRFProtection
logger = logging.getLogger(__name__)
async def get_current_user( async def get_current_user(
credentials: HTTPAuthorizationCredentials = Depends(security), request: Request, session: AsyncSession = Depends(get_session)
session: AsyncSession = Depends(get_session),
) -> User: ) -> User:
token = credentials.credentials token = CookieManager.get_access_token(request)
if not token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentification requise",
headers={"WWW-Authenticate": "Bearer"},
)
fingerprint_hash = DeviceFingerprint.generate_hash(request)
payload = await TokenService.validate_access_token(token, fingerprint_hash)
payload = decode_token(token)
if not payload: if not payload:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token invalide ou expiré", detail="Token invalide ou expire",
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
if payload.get("type") != "access": user_id = payload.get("sub")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Type de token incorrect",
headers={"WWW-Authenticate": "Bearer"},
)
user_id: str = payload.get("sub")
if not user_id: if not user_id:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token malformé", detail="Token malformed",
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
@ -51,33 +59,31 @@ async def get_current_user(
if not user.is_active: if not user.is_active:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Compte désactivé" status_code=status.HTTP_403_FORBIDDEN, detail="Compte desactive"
) )
if not user.is_verified: if not user.is_verified:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN, detail="Email non verifie"
detail="Email non vérifié. Consultez votre boîte de réception.",
) )
if user.locked_until and user.locked_until > datetime.now(): if user.locked_until and user.locked_until > datetime.now():
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="Compte temporairement verrouillé suite à trop de tentatives échouées", detail="Compte temporairement verrouille",
) )
request.state.user = user
request.state.session_id = payload.get("sid")
return user return user
async def get_current_user_optional( async def get_current_user_optional(
credentials: Optional[HTTPAuthorizationCredentials] = Depends(security), request: Request, session: AsyncSession = Depends(get_session)
session: AsyncSession = Depends(get_session),
) -> Optional[User]: ) -> Optional[User]:
if not credentials:
return None
try: try:
return await get_current_user(credentials, session) return await get_current_user(request, session)
except HTTPException: except HTTPException:
return None return None
@ -87,8 +93,99 @@ def require_role(*allowed_roles: str):
if user.role not in allowed_roles: if user.role not in allowed_roles:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail=f"Accès refusé. Rôles requis: {', '.join(allowed_roles)}", detail=f"Acces refuse. Roles requis: {', '.join(allowed_roles)}",
) )
return user return user
return role_checker return role_checker
def require_verified_email():
async def email_checker(user: User = Depends(get_current_user)) -> User:
if not user.is_verified:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Verification email requise",
)
return user
return email_checker
async def verify_csrf_token(
request: Request, user: User = Depends(get_current_user)
) -> None:
if CSRFProtection.is_exempt(request):
return
session_id = getattr(request.state, "session_id", None)
if not CSRFProtection.validate_request(request, session_id):
logger.warning(
f"CSRF validation echouee pour user {user.id} "
f"sur {request.method} {request.url.path}"
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Verification CSRF echouee"
)
async def get_auth_context(
request: Request, session: AsyncSession = Depends(get_session)
) -> Tuple[Optional[User], str, str]:
ip_address = get_client_ip(request)
fingerprint_hash = DeviceFingerprint.generate_hash(request)
try:
user = await get_current_user(request, session)
except HTTPException:
user = None
return user, ip_address, fingerprint_hash
class AuthenticatedRoute:
def __init__(
self,
require_csrf: bool = True,
allowed_roles: Optional[Tuple[str, ...]] = None,
audit_event: Optional[AuditEventType] = None,
):
self.require_csrf = require_csrf
self.allowed_roles = allowed_roles
self.audit_event = audit_event
async def __call__(
self, request: Request, session: AsyncSession = Depends(get_session)
) -> User:
user = await get_current_user(request, session)
if self.allowed_roles and user.role not in self.allowed_roles:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Acces refuse pour ce role",
)
if self.require_csrf and not CSRFProtection.is_exempt(request):
session_id = getattr(request.state, "session_id", None)
if not CSRFProtection.validate_request(request, session_id):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Verification CSRF echouee",
)
if self.audit_event:
await AuditService.log_event(
session=session,
event_type=self.audit_event,
request=request,
user_id=user.id,
success=True,
)
return user
require_admin = require_role("admin")
require_manager = require_role("admin", "manager")
require_user = require_role("admin", "manager", "user")

View file

@ -5,13 +5,18 @@ from database.db_config import (
get_session, get_session,
close_db, close_db,
) )
from database.models.generic_model import (
CacheMetadata, from database.models.generic_model import Base
AuditLog,
from database.models.auth_models import (
User,
RefreshToken, RefreshToken,
AuditLog,
AuditEventType,
LoginAttempt, LoginAttempt,
UserSession,
) )
from database.models.user import User
from database.models.email import EmailLog from database.models.email import EmailLog
from database.models.signature import SignatureLog from database.models.signature import SignatureLog
from database.models.sage_config import SageGatewayConfig from database.models.sage_config import SageGatewayConfig
@ -28,15 +33,16 @@ __all__ = [
"get_session", "get_session",
"close_db", "close_db",
"Base", "Base",
"User",
"RefreshToken",
"AuditLog",
"AuditEventType",
"LoginAttempt",
"UserSession",
"EmailLog", "EmailLog",
"SignatureLog", "SignatureLog",
"WorkflowLog", "WorkflowLog",
"CacheMetadata",
"AuditLog",
"StatutEmail", "StatutEmail",
"StatutSignature", "StatutSignature",
"User",
"RefreshToken",
"LoginAttempt",
"SageGatewayConfig", "SageGatewayConfig",
] ]

View file

@ -0,0 +1,214 @@
from sqlalchemy import (
Column,
Integer,
String,
DateTime,
Boolean,
Text,
ForeignKey,
Index,
Enum as SQLEnum,
)
from sqlalchemy.orm import relationship
from datetime import datetime
from enum import Enum
from database.models.generic_model import Base
class User(Base):
__tablename__ = "users"
id = Column(String(36), primary_key=True)
email = Column(String(255), unique=True, nullable=False, index=True)
hashed_password = Column(String(255), nullable=False)
nom = Column(String(100), nullable=False)
prenom = Column(String(100), nullable=False)
role = Column(String(50), default="user")
is_verified = Column(Boolean, default=False, index=True)
verification_token = Column(String(255), nullable=True, unique=True, index=True)
verification_token_expires = Column(DateTime, nullable=True)
is_active = Column(Boolean, default=True, index=True)
failed_login_attempts = Column(Integer, default=0)
locked_until = Column(DateTime, nullable=True)
reset_token = Column(String(255), nullable=True, unique=True, index=True)
reset_token_expires = Column(DateTime, nullable=True)
password_changed_at = Column(DateTime, nullable=True)
must_change_password = Column(Boolean, default=False)
created_at = Column(DateTime, default=datetime.now, nullable=False)
updated_at = Column(DateTime, default=datetime.now, onupdate=datetime.now)
last_login = Column(DateTime, nullable=True)
last_login_ip = Column(String(45), nullable=True)
refresh_tokens = relationship(
"RefreshToken", back_populates="user", cascade="all, delete-orphan"
)
audit_logs = relationship(
"AuditLog", back_populates="user", cascade="all, delete-orphan"
)
def __repr__(self):
return f"<User {self.email} verified={self.is_verified}>"
class RefreshToken(Base):
__tablename__ = "refresh_tokens"
id = Column(String(36), primary_key=True)
user_id = Column(
String(36),
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
token_hash = Column(String(64), unique=True, nullable=False, index=True)
token_id = Column(String(32), unique=True, nullable=False, index=True)
fingerprint_hash = Column(String(64), nullable=True)
device_info = Column(String(500), nullable=True)
ip_address = Column(String(45), nullable=True)
is_revoked = Column(Boolean, default=False, index=True)
revoked_at = Column(DateTime, nullable=True)
revoked_reason = Column(String(100), nullable=True)
is_used = Column(Boolean, default=False)
used_at = Column(DateTime, nullable=True)
replaced_by = Column(String(36), nullable=True)
expires_at = Column(DateTime, nullable=False, index=True)
created_at = Column(DateTime, default=datetime.now, nullable=False)
last_used_at = Column(DateTime, nullable=True)
user = relationship("User", back_populates="refresh_tokens")
__table_args__ = (
Index("ix_refresh_tokens_user_valid", "user_id", "is_revoked", "expires_at"),
)
def __repr__(self):
return f"<RefreshToken {self.token_id[:8]}... user={self.user_id}>"
class AuditEventType(str, Enum):
LOGIN_SUCCESS = "login_success"
LOGIN_FAILED = "login_failed"
LOGOUT = "logout"
PASSWORD_CHANGE = "password_change"
PASSWORD_RESET_REQUEST = "password_reset_request"
PASSWORD_RESET_COMPLETE = "password_reset_complete"
EMAIL_VERIFICATION = "email_verification"
ACCOUNT_LOCKED = "account_locked"
ACCOUNT_UNLOCKED = "account_unlocked"
TOKEN_REFRESH = "token_refresh"
TOKEN_REVOKED = "token_revoked"
SUSPICIOUS_ACTIVITY = "suspicious_activity"
SESSION_CREATED = "session_created"
SESSION_TERMINATED = "session_terminated"
class AuditLog(Base):
__tablename__ = "audit_logs"
id = Column(String(36), primary_key=True)
user_id = Column(
String(36),
ForeignKey("users.id", ondelete="SET NULL"),
nullable=True,
index=True,
)
event_type = Column(SQLEnum(AuditEventType), nullable=False, index=True)
event_description = Column(Text, nullable=True)
ip_address = Column(String(45), nullable=True, index=True)
user_agent = Column(String(500), nullable=True)
fingerprint_hash = Column(String(64), nullable=True)
resource_type = Column(String(50), nullable=True)
resource_id = Column(String(100), nullable=True)
request_method = Column(String(10), nullable=True)
request_path = Column(String(500), nullable=True)
meta = Column("metadata", Text, nullable=True)
success = Column(Boolean, default=True)
failure_reason = Column(String(255), nullable=True)
created_at = Column(DateTime, default=datetime.now, nullable=False, index=True)
user = relationship("User", back_populates="audit_logs")
__table_args__ = (
Index("ix_audit_logs_user_event", "user_id", "event_type", "created_at"),
Index("ix_audit_logs_ip_event", "ip_address", "event_type", "created_at"),
)
def __repr__(self):
return f"<AuditLog {self.event_type.value} user={self.user_id}>"
class LoginAttempt(Base):
__tablename__ = "login_attempts"
id = Column(Integer, primary_key=True, autoincrement=True)
email = Column(String(255), nullable=False, index=True)
ip_address = Column(String(45), nullable=True, index=True)
user_agent = Column(String(500), nullable=True)
fingerprint_hash = Column(String(64), nullable=True)
success = Column(Boolean, default=False, index=True)
failure_reason = Column(String(255), nullable=True)
timestamp = Column(DateTime, default=datetime.now, nullable=False, index=True)
__table_args__ = (
Index("ix_login_attempts_email_time", "email", "timestamp"),
Index("ix_login_attempts_ip_time", "ip_address", "timestamp"),
)
def __repr__(self):
return f"<LoginAttempt {self.email} success={self.success}>"
class UserSession(Base):
__tablename__ = "user_sessions"
id = Column(String(36), primary_key=True)
user_id = Column(
String(36),
ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
index=True,
)
session_token_hash = Column(String(64), unique=True, nullable=False, index=True)
refresh_token_id = Column(String(36), nullable=True)
device_info = Column(String(500), nullable=True)
ip_address = Column(String(45), nullable=True)
fingerprint_hash = Column(String(64), nullable=True)
location = Column(String(255), nullable=True)
is_active = Column(Boolean, default=True, index=True)
terminated_at = Column(DateTime, nullable=True)
termination_reason = Column(String(100), nullable=True)
created_at = Column(DateTime, default=datetime.now, nullable=False)
last_activity = Column(DateTime, default=datetime.now, nullable=False)
expires_at = Column(DateTime, nullable=False)
__table_args__ = (Index("ix_user_sessions_user_active", "user_id", "is_active"),)
def __repr__(self):
return f"<UserSession {self.id[:8]}... user={self.user_id}>"

View file

@ -5,7 +5,6 @@ from sqlalchemy import (
DateTime, DateTime,
Float, Float,
Text, Text,
Boolean,
) )
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
from datetime import datetime from datetime import datetime
@ -29,63 +28,3 @@ class CacheMetadata(Base):
def __repr__(self): def __repr__(self):
return f"<CacheMetadata type={self.cache_type} items={self.item_count}>" return f"<CacheMetadata type={self.cache_type} items={self.item_count}>"
class AuditLog(Base):
__tablename__ = "audit_logs"
id = Column(Integer, primary_key=True, autoincrement=True)
action = Column(String(100), nullable=False, index=True)
ressource_type = Column(String(50), nullable=True)
ressource_id = Column(String(100), nullable=True, index=True)
utilisateur = Column(String(100), nullable=True)
ip_address = Column(String(45), nullable=True)
succes = Column(Boolean, default=True)
details = Column(Text, nullable=True)
erreur = Column(Text, nullable=True)
date_action = Column(DateTime, default=datetime.now, nullable=False, index=True)
def __repr__(self):
return f"<AuditLog {self.action} on {self.ressource_type}/{self.ressource_id}>"
class RefreshToken(Base):
__tablename__ = "refresh_tokens"
id = Column(String(36), primary_key=True)
user_id = Column(String(36), nullable=False, index=True)
token_hash = Column(String(255), nullable=False, unique=True, index=True)
device_info = Column(String(500), nullable=True)
ip_address = Column(String(45), nullable=True)
expires_at = Column(DateTime, nullable=False)
created_at = Column(DateTime, default=datetime.now, nullable=False)
is_revoked = Column(Boolean, default=False)
revoked_at = Column(DateTime, nullable=True)
def __repr__(self):
return f"<RefreshToken user={self.user_id} revoked={self.is_revoked}>"
class LoginAttempt(Base):
__tablename__ = "login_attempts"
id = Column(Integer, primary_key=True, autoincrement=True)
email = Column(String(255), nullable=False, index=True)
ip_address = Column(String(45), nullable=False, index=True)
user_agent = Column(String(500), nullable=True)
success = Column(Boolean, default=False)
failure_reason = Column(String(255), nullable=True)
timestamp = Column(DateTime, default=datetime.now, nullable=False, index=True)
def __repr__(self):
return f"<LoginAttempt {self.email} success={self.success}>"

View file

@ -22,9 +22,6 @@ class SageGatewayConfig(Base):
gateway_url = Column(String(500), nullable=False) gateway_url = Column(String(500), nullable=False)
gateway_token = Column(String(255), nullable=False) gateway_token = Column(String(255), nullable=False)
sage_database = Column(String(255), nullable=True)
sage_company = Column(String(255), nullable=True)
is_active = Column(Boolean, default=False, index=True) is_active = Column(Boolean, default=False, index=True)
is_default = Column(Boolean, default=False) is_default = Column(Boolean, default=False)
priority = Column(Integer, default=0) priority = Column(Integer, default=0)

181
middleware/security.py Normal file
View file

@ -0,0 +1,181 @@
from fastapi import Request, Response, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware, RequestResponseEndpoint
from typing import Set
import logging
import time
from config.config import settings
from security.csrf import CSRFProtection
from security.fingerprint import get_client_ip
from services.redis_service import redis_service
logger = logging.getLogger(__name__)
RATE_LIMIT_EXEMPT_PATHS: Set[str] = {
"/health",
"/docs",
"/redoc",
"/openapi.json",
}
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
response = await call_next(request)
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
if settings.is_production:
response.headers["Strict-Transport-Security"] = (
"max-age=31536000; includeSubDomains"
)
response.headers["Cache-Control"] = "no-store, no-cache, must-revalidate"
response.headers["Pragma"] = "no-cache"
response.headers["Permissions-Policy"] = (
"geolocation=(), microphone=(), camera=()"
)
return response
class CSRFMiddleware(BaseHTTPMiddleware):
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
if CSRFProtection.is_exempt(request):
return await call_next(request)
if not CSRFProtection.validate_double_submit(request):
logger.warning(
f"CSRF validation echouee: {request.method} {request.url.path} "
f"depuis {get_client_ip(request)}"
)
return JSONResponse(
status_code=status.HTTP_403_FORBIDDEN,
content={"detail": "Verification CSRF echouee"},
)
return await call_next(request)
class RateLimitMiddleware(BaseHTTPMiddleware):
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
path = request.url.path.rstrip("/")
if path in RATE_LIMIT_EXEMPT_PATHS:
return await call_next(request)
ip = get_client_ip(request)
key = f"api:{ip}"
window_seconds = settings.rate_limit_api_window_seconds
max_requests = settings.rate_limit_api_requests
try:
count = await redis_service.increment_rate_limit(key, window_seconds)
remaining = max(0, max_requests - count)
response = await call_next(request)
response.headers["X-RateLimit-Limit"] = str(max_requests)
response.headers["X-RateLimit-Remaining"] = str(remaining)
response.headers["X-RateLimit-Reset"] = str(window_seconds)
if count > max_requests:
logger.warning(
f"Rate limit depasse pour IP {ip}: {count}/{max_requests}"
)
return JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content={"detail": "Limite de requetes atteinte"},
headers={
"X-RateLimit-Limit": str(max_requests),
"X-RateLimit-Remaining": "0",
"X-RateLimit-Reset": str(window_seconds),
"Retry-After": str(window_seconds),
},
)
return response
except Exception as e:
logger.error(f"Erreur rate limiting: {e}")
return await call_next(request)
class RequestLoggingMiddleware(BaseHTTPMiddleware):
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
start_time = time.time()
ip = get_client_ip(request)
method = request.method
path = request.url.path
response = await call_next(request)
duration_ms = (time.time() - start_time) * 1000
log_level = logging.INFO
if response.status_code >= 500:
log_level = logging.ERROR
elif response.status_code >= 400:
log_level = logging.WARNING
logger.log(
log_level,
f"{method} {path} - {response.status_code} - {duration_ms:.2f}ms - {ip}",
)
return response
class FingerprintValidationMiddleware(BaseHTTPMiddleware):
VALIDATION_PATHS: Set[str] = {
"/auth/refresh",
"/auth/logout",
}
async def dispatch(
self, request: Request, call_next: RequestResponseEndpoint
) -> Response:
path = request.url.path.rstrip("/")
if path not in self.VALIDATION_PATHS:
return await call_next(request)
return await call_next(request)
def setup_security_middleware(app) -> None:
app.add_middleware(RequestLoggingMiddleware)
app.add_middleware(SecurityHeadersMiddleware)
app.add_middleware(FingerprintValidationMiddleware)
async def init_security_services() -> None:
try:
await redis_service.connect()
logger.info("Services de securite initialises")
except Exception as e:
logger.warning(f"Redis non disponible, fonctionnement en mode degrade: {e}")
async def shutdown_security_services() -> None:
try:
await redis_service.disconnect()
logger.info("Services de securite arretes")
except Exception as e:
logger.error(f"Erreur arret services securite: {e}")

View file

@ -1,10 +1,14 @@
fastapi fastapi
uvicorn[standard] uvicorn[standard]
starlette
structlog
pydantic pydantic
pydantic-settings pydantic-settings
reportlab reportlab
requests requests
msal msal
aiosmtplib
python-multipart python-multipart
email-validator email-validator
@ -13,9 +17,13 @@ python-dotenv
python-jose[cryptography] python-jose[cryptography]
passlib[bcrypt] passlib[bcrypt]
bcrypt==4.2.0 bcrypt==4.2.0
PyJWT
sqlalchemy sqlalchemy[asyncio]
aiosqlite aiosqlite
tenacity tenacity
asyncpg
httpx httpx
redis[hiredis]

View file

@ -1,27 +1,29 @@
from fastapi import APIRouter, Depends, HTTPException, status, Request from fastapi import APIRouter, Depends, HTTPException, status, Request, Response
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select from sqlalchemy import false, select
from pydantic import BaseModel, EmailStr, Field from pydantic import BaseModel, EmailStr, Field
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Optional from typing import Optional, List
import uuid import uuid
import logging
from database import get_session, User, RefreshToken, LoginAttempt from config.config import settings
from database import get_session
from database import User, RefreshToken, AuditEventType
from security.auth import ( from security.auth import (
hash_password, hash_password,
verify_password, verify_password,
validate_password_strength, validate_password_strength,
create_access_token,
create_refresh_token,
decode_token,
generate_verification_token, generate_verification_token,
generate_reset_token, generate_reset_token,
hash_token,
) )
from security.cookies import CookieManager, set_auth_cookies
from security.fingerprint import DeviceFingerprint, get_client_ip
from security.rate_limiter import RateLimiter
from services.token_service import TokenService
from services.audit_service import AuditService
from services.email_service import AuthEmailService from services.email_service import AuthEmailService
from core.dependencies import get_current_user from core.dependencies import get_current_user
from config.config import settings
import logging
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter(prefix="/auth", tags=["Authentication"]) router = APIRouter(prefix="/auth", tags=["Authentication"])
@ -29,25 +31,20 @@ router = APIRouter(prefix="/auth", tags=["Authentication"])
class RegisterRequest(BaseModel): class RegisterRequest(BaseModel):
email: EmailStr email: EmailStr
password: str = Field(..., min_length=8) password: str = Field(..., min_length=8, max_length=128)
nom: str = Field(..., min_length=2, max_length=100) nom: str = Field(..., min_length=2, max_length=100)
prenom: str = Field(..., min_length=2, max_length=100) prenom: str = Field(..., min_length=2, max_length=100)
class LoginRequest(BaseModel): class LoginRequest(BaseModel):
email: EmailStr email: EmailStr
password: str password: str = Field(..., min_length=1, max_length=128)
class TokenResponse(BaseModel): class TokenResponse(BaseModel):
access_token: str message: str
refresh_token: str user: dict
token_type: str = "bearer" expires_in: int
expires_in: int = 86400
class RefreshTokenRequest(BaseModel):
refresh_token: str
class ForgotPasswordRequest(BaseModel): class ForgotPasswordRequest(BaseModel):
@ -56,7 +53,7 @@ class ForgotPasswordRequest(BaseModel):
class ResetPasswordRequest(BaseModel): class ResetPasswordRequest(BaseModel):
token: str token: str
new_password: str = Field(..., min_length=8) new_password: str = Field(..., min_length=8, max_length=128)
class VerifyEmailRequest(BaseModel): class VerifyEmailRequest(BaseModel):
@ -67,44 +64,17 @@ class ResendVerificationRequest(BaseModel):
email: EmailStr email: EmailStr
async def log_login_attempt( class ChangePasswordRequest(BaseModel):
session: AsyncSession, current_password: str
email: str, new_password: str = Field(..., min_length=8, max_length=128)
ip: str,
user_agent: str,
success: bool,
failure_reason: Optional[str] = None,
):
attempt = LoginAttempt(
email=email,
ip_address=ip,
user_agent=user_agent,
success=success,
failure_reason=failure_reason,
timestamp=datetime.now(),
)
session.add(attempt)
await session.commit()
async def check_rate_limit( class SessionResponse(BaseModel):
session: AsyncSession, email: str, ip: str id: str
) -> tuple[bool, str]: device_info: Optional[str]
time_window = datetime.now() - timedelta(minutes=15) ip_address: Optional[str]
created_at: str
result = await session.execute( last_used_at: Optional[str]
select(LoginAttempt).where(
LoginAttempt.email == email,
LoginAttempt.success,
LoginAttempt.timestamp >= time_window,
)
)
failed_attempts = result.scalars().all()
if len(failed_attempts) >= 5:
return False, "Trop de tentatives échouées. Réessayez dans 15 minutes."
return True, ""
@router.post("/register", status_code=status.HTTP_201_CREATED) @router.post("/register", status_code=status.HTTP_201_CREATED)
@ -113,12 +83,18 @@ async def register(
request: Request, request: Request,
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
): ):
result = await session.execute(select(User).where(User.email == data.email)) ip = get_client_ip(request)
existing_user = result.scalar_one_or_none()
if existing_user: allowed, error_msg = await RateLimiter.check_registration_rate_limit(ip)
if not allowed:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Cet email est déjà utilisé" status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail=error_msg
)
result = await session.execute(select(User).where(User.email == data.email.lower()))
if result.scalar_one_or_none():
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Cet email est deja utilise"
) )
is_valid, error_msg = validate_password_strength(data.password) is_valid, error_msg = validate_password_strength(data.password)
@ -143,23 +119,233 @@ async def register(
await session.commit() await session.commit()
base_url = str(request.base_url).rstrip("/") base_url = str(request.base_url).rstrip("/")
email_sent = AuthEmailService.send_verification_email( AuthEmailService.send_verification_email(data.email, verification_token, base_url)
data.email, verification_token, base_url
)
if not email_sent: logger.info(f"Nouvel utilisateur inscrit: {data.email}")
logger.warning(f"Échec envoi email vérification pour {data.email}")
logger.info(f" Nouvel utilisateur inscrit: {data.email} (ID: {new_user.id})")
return { return {
"success": True, "success": True,
"message": "Inscription réussie ! Consultez votre email pour vérifier votre compte.", "message": "Inscription reussie. Consultez votre email pour verifier votre compte.",
"user_id": new_user.id, "user_id": new_user.id,
"email": data.email,
} }
@router.post("/login")
async def login(
data: LoginRequest,
request: Request,
response: Response,
session: AsyncSession = Depends(get_session),
):
ip = get_client_ip(request)
user_agent = request.headers.get("User-Agent", "")
fingerprint_hash = DeviceFingerprint.generate_hash(request)
allowed, error_msg, _ = await RateLimiter.check_login_rate_limit(
data.email.lower(), ip
)
if not allowed:
await AuditService.log_login_failed(session, request, data.email, "rate_limit")
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail=error_msg
)
result = await session.execute(select(User).where(User.email == data.email.lower()))
user = result.scalar_one_or_none()
if not user or not verify_password(data.password, user.hashed_password):
await RateLimiter.record_login_attempt(data.email.lower(), ip, success=False)
await AuditService.record_login_attempt(
session, request, data.email, False, "invalid_credentials"
)
if user:
user.failed_login_attempts = (user.failed_login_attempts or 0) + 1
if user.failed_login_attempts >= settings.account_lockout_threshold:
user.locked_until = datetime.now() + timedelta(
minutes=settings.account_lockout_duration_minutes
)
await AuditService.log_account_locked(
session, request, user.id, "too_many_failed_attempts"
)
await session.commit()
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Compte verrouille. Reessayez dans {settings.account_lockout_duration_minutes} minutes.",
)
await session.commit()
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Email ou mot de passe incorrect",
)
if not user.is_active:
await AuditService.log_login_failed(
session, request, data.email, "account_disabled", user.id
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Compte desactive"
)
if not user.is_verified:
await AuditService.log_login_failed(
session, request, data.email, "email_not_verified", user.id
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Email non verifie. Consultez votre boite de reception.",
)
if user.locked_until and user.locked_until > datetime.now():
await AuditService.log_login_failed(
session, request, data.email, "account_locked", user.id
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Compte temporairement verrouille",
)
user.failed_login_attempts = 0
user.locked_until = None
user.last_login = datetime.now()
user.last_login_ip = ip
(
access_token,
refresh_token,
csrf_token,
session_id,
) = await TokenService.create_token_pair(
session=session,
user=user,
fingerprint_hash=fingerprint_hash,
device_info=user_agent,
ip_address=ip,
)
await session.commit()
await RateLimiter.record_login_attempt(data.email.lower(), ip, success=True)
await AuditService.log_login_success(session, request, user.id, user.email)
set_auth_cookies(response, access_token, refresh_token, csrf_token)
logger.info(f"Connexion reussie: {user.email} depuis {ip}")
return TokenResponse(
message="Connexion reussie",
user={
"id": user.id,
"email": user.email,
"nom": user.nom,
"prenom": user.prenom,
"role": user.role,
},
expires_in=settings.access_token_expire_minutes * 60,
)
@router.post("/refresh")
async def refresh_tokens(
request: Request, response: Response, session: AsyncSession = Depends(get_session)
):
refresh_token = CookieManager.get_refresh_token(request)
if not refresh_token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Refresh token manquant"
)
ip = get_client_ip(request)
user_agent = request.headers.get("User-Agent", "")
fingerprint_hash = DeviceFingerprint.generate_hash(request)
result = await TokenService.refresh_tokens(
session=session,
refresh_token=refresh_token,
fingerprint_hash=fingerprint_hash,
device_info=user_agent,
ip_address=ip,
)
if not result:
CookieManager.clear_all_auth_cookies(response)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Refresh token invalide ou expire",
)
new_access, new_refresh, new_csrf, session_id = result
await session.commit()
set_auth_cookies(response, new_access, new_refresh, new_csrf)
logger.debug("Tokens rafraichis avec succes")
return {
"message": "Tokens rafraichis",
"expires_in": settings.access_token_expire_minutes * 60,
}
@router.post("/logout")
async def logout(
request: Request,
response: Response,
session: AsyncSession = Depends(get_session),
user: User = Depends(get_current_user),
):
refresh_token = CookieManager.get_refresh_token(request)
if refresh_token:
await TokenService.revoke_token(
session=session, refresh_token=refresh_token, reason="user_logout"
)
await AuditService.log_logout(session, request, user.id)
await session.commit()
CookieManager.clear_all_auth_cookies(response)
logger.info(f"Deconnexion: {user.email}")
return {"success": True, "message": "Deconnexion reussie"}
@router.post("/logout-all")
async def logout_all_sessions(
request: Request,
response: Response,
session: AsyncSession = Depends(get_session),
user: User = Depends(get_current_user),
):
count = await TokenService.revoke_all_user_tokens(
session=session, user_id=user.id, reason="user_logout_all"
)
await AuditService.log_event(
session=session,
event_type=AuditEventType.SESSION_TERMINATED,
request=request,
user_id=user.id,
description=f"Toutes les sessions terminees ({count} tokens revoques)",
)
await session.commit()
CookieManager.clear_all_auth_cookies(response)
logger.info(f"Toutes les sessions terminees pour {user.email}: {count} tokens")
return {"success": True, "message": f"{count} session(s) terminee(s)"}
@router.get("/verify-email") @router.get("/verify-email")
async def verify_email_get(token: str, session: AsyncSession = Depends(get_session)): async def verify_email_get(token: str, session: AsyncSession = Depends(get_session)):
result = await session.execute(select(User).where(User.verification_token == token)) result = await session.execute(select(User).where(User.verification_token == token))
@ -168,13 +354,16 @@ async def verify_email_get(token: str, session: AsyncSession = Depends(get_sessi
if not user: if not user:
return { return {
"success": False, "success": False,
"message": "Token de vérification invalide ou déjà utilisé.", "message": "Token de verification invalide ou deja utilise.",
} }
if user.verification_token_expires < datetime.now(): if (
user.verification_token_expires
and user.verification_token_expires < datetime.now()
):
return { return {
"success": False, "success": False,
"message": "Token expiré. Veuillez demander un nouvel email de vérification.", "message": "Token expire. Demandez un nouveau lien de verification.",
"expired": True, "expired": True,
} }
@ -183,18 +372,19 @@ async def verify_email_get(token: str, session: AsyncSession = Depends(get_sessi
user.verification_token_expires = None user.verification_token_expires = None
await session.commit() await session.commit()
logger.info(f" Email vérifié: {user.email}") logger.info(f"Email verifie: {user.email}")
return { return {
"success": True, "success": True,
"message": " Email vérifié avec succès ! Vous pouvez maintenant vous connecter.", "message": "Email verifie avec succes. Vous pouvez maintenant vous connecter.",
"email": user.email,
} }
@router.post("/verify-email") @router.post("/verify-email")
async def verify_email_post( async def verify_email_post(
data: VerifyEmailRequest, session: AsyncSession = Depends(get_session) data: VerifyEmailRequest,
request: Request,
session: AsyncSession = Depends(get_session),
): ):
result = await session.execute( result = await session.execute(
select(User).where(User.verification_token == data.token) select(User).where(User.verification_token == data.token)
@ -204,26 +394,35 @@ async def verify_email_post(
if not user: if not user:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="Token de vérification invalide", detail="Token de verification invalide",
) )
if user.verification_token_expires < datetime.now(): if (
user.verification_token_expires
and user.verification_token_expires < datetime.now()
):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="Token expiré. Demandez un nouvel email de vérification.", detail="Token expire. Demandez un nouveau lien de verification.",
) )
user.is_verified = True user.is_verified = True
user.verification_token = None user.verification_token = None
user.verification_token_expires = None user.verification_token_expires = None
await AuditService.log_event(
session=session,
event_type=AuditEventType.EMAIL_VERIFICATION,
request=request,
user_id=user.id,
description="Email verifie avec succes",
)
await session.commit() await session.commit()
logger.info(f" Email vérifié: {user.email}") logger.info(f"Email verifie: {user.email}")
return { return {"success": True, "message": "Email verifie avec succes."}
"success": True,
"message": "Email vérifié avec succès ! Vous pouvez maintenant vous connecter.",
}
@router.post("/resend-verification") @router.post("/resend-verification")
@ -238,12 +437,12 @@ async def resend_verification(
if not user: if not user:
return { return {
"success": True, "success": True,
"message": "Si cet email existe, un nouveau lien de vérification a été envoyé.", "message": "Si cet email existe, un nouveau lien a ete envoye.",
} }
if user.is_verified: if user.is_verified:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="Ce compte est déjà vérifié" status_code=status.HTTP_400_BAD_REQUEST, detail="Ce compte est deja verifie"
) )
verification_token = generate_verification_token() verification_token = generate_verification_token()
@ -254,165 +453,7 @@ async def resend_verification(
base_url = str(request.base_url).rstrip("/") base_url = str(request.base_url).rstrip("/")
AuthEmailService.send_verification_email(user.email, verification_token, base_url) AuthEmailService.send_verification_email(user.email, verification_token, base_url)
return {"success": True, "message": "Un nouveau lien de vérification a été envoyé."} return {"success": True, "message": "Un nouveau lien de verification a ete envoye."}
@router.post("/login", response_model=TokenResponse)
async def login(
data: LoginRequest, request: Request, session: AsyncSession = Depends(get_session)
):
ip = request.client.host if request.client else "unknown"
user_agent = request.headers.get("user-agent", "unknown")
is_allowed, error_msg = await check_rate_limit(session, data.email.lower(), ip)
if not is_allowed:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail=error_msg
)
result = await session.execute(select(User).where(User.email == data.email.lower()))
user = result.scalar_one_or_none()
if not user or not verify_password(data.password, user.hashed_password):
await log_login_attempt(
session,
data.email.lower(),
ip,
user_agent,
False,
"Identifiants incorrects",
)
if user:
user.failed_login_attempts += 1
if user.failed_login_attempts >= 5:
user.locked_until = datetime.now() + timedelta(minutes=15)
await session.commit()
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Compte verrouillé suite à trop de tentatives. Réessayez dans 15 minutes.",
)
await session.commit()
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Email ou mot de passe incorrect",
)
if not user.is_active:
await log_login_attempt(
session, data.email.lower(), ip, user_agent, False, "Compte désactivé"
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Compte désactivé"
)
if not user.is_verified:
await log_login_attempt(
session, data.email.lower(), ip, user_agent, False, "Email non vérifié"
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Email non vérifié. Consultez votre boîte de réception.",
)
if user.locked_until and user.locked_until > datetime.now():
await log_login_attempt(
session, data.email.lower(), ip, user_agent, False, "Compte verrouillé"
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Compte temporairement verrouillé",
)
user.failed_login_attempts = 0
user.locked_until = None
user.last_login = datetime.now()
access_token = create_access_token(
{"sub": user.id, "email": user.email, "role": user.role}
)
refresh_token_jwt = create_refresh_token(user.id)
refresh_token_record = RefreshToken(
id=str(uuid.uuid4()),
user_id=user.id,
token_hash=hash_token(refresh_token_jwt),
device_info=user_agent[:500],
ip_address=ip,
expires_at=datetime.now() + timedelta(days=7),
created_at=datetime.now(),
)
session.add(refresh_token_record)
await session.commit()
await log_login_attempt(session, data.email.lower(), ip, user_agent, True)
logger.info(f" Connexion réussie: {user.email}")
return TokenResponse(
access_token=access_token,
refresh_token=refresh_token_jwt,
expires_in=86400,
)
@router.post("/refresh", response_model=TokenResponse)
async def refresh_access_token(
data: RefreshTokenRequest, session: AsyncSession = Depends(get_session)
):
payload = decode_token(data.refresh_token)
if not payload or payload.get("type") != "refresh":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Refresh token invalide"
)
user_id = payload.get("sub")
token_hash = hash_token(data.refresh_token)
result = await session.execute(
select(RefreshToken).where(
RefreshToken.user_id == user_id,
RefreshToken.token_hash == token_hash,
not RefreshToken.is_revoked,
)
)
token_record = result.scalar_one_or_none()
if not token_record:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Refresh token révoqué ou introuvable",
)
if token_record.expires_at < datetime.now():
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, detail="Refresh token expiré"
)
result = await session.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if not user or not user.is_active:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Utilisateur introuvable ou désactivé",
)
new_access_token = create_access_token(
{"sub": user.id, "email": user.email, "role": user.role}
)
logger.info(f" Token rafraîchi: {user.email}")
return TokenResponse(
access_token=new_access_token,
refresh_token=data.refresh_token,
expires_in=86400,
)
@router.post("/forgot-password") @router.post("/forgot-password")
@ -421,13 +462,27 @@ async def forgot_password(
request: Request, request: Request,
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
): ):
ip = get_client_ip(request)
allowed, error_msg = await RateLimiter.check_password_reset_rate_limit(
data.email.lower(), ip
)
if not allowed:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail=error_msg
)
result = await session.execute(select(User).where(User.email == data.email.lower())) result = await session.execute(select(User).where(User.email == data.email.lower()))
user = result.scalar_one_or_none() user = result.scalar_one_or_none()
await AuditService.log_password_reset_request(
session, request, data.email, user.id if user else None
)
if not user: if not user:
return { return {
"success": True, "success": True,
"message": "Si cet email existe, un lien de réinitialisation a été envoyé.", "message": "Si cet email existe, un lien de reinitialisation a ete envoye.",
} }
reset_token = generate_reset_token() reset_token = generate_reset_token()
@ -435,24 +490,23 @@ async def forgot_password(
user.reset_token_expires = datetime.now() + timedelta(hours=1) user.reset_token_expires = datetime.now() + timedelta(hours=1)
await session.commit() await session.commit()
frontend_url = ( frontend_url = settings.frontend_url or str(request.base_url).rstrip("/")
settings.frontend_url
if hasattr(settings, "frontend_url")
else str(request.base_url).rstrip("/")
)
AuthEmailService.send_password_reset_email(user.email, reset_token, frontend_url) AuthEmailService.send_password_reset_email(user.email, reset_token, frontend_url)
logger.info(f" Reset password demandé: {user.email}") logger.info(f"Reset password demande: {user.email}")
return { return {
"success": True, "success": True,
"message": "Si cet email existe, un lien de réinitialisation a été envoyé.", "message": "Si cet email existe, un lien de reinitialisation a ete envoye.",
} }
@router.post("/reset-password") @router.post("/reset-password")
async def reset_password( async def reset_password(
data: ResetPasswordRequest, session: AsyncSession = Depends(get_session) data: ResetPasswordRequest,
request: Request,
response: Response,
session: AsyncSession = Depends(get_session),
): ):
result = await session.execute(select(User).where(User.reset_token == data.token)) result = await session.execute(select(User).where(User.reset_token == data.token))
user = result.scalar_one_or_none() user = result.scalar_one_or_none()
@ -460,13 +514,13 @@ async def reset_password(
if not user: if not user:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="Token de réinitialisation invalide", detail="Token de reinitialisation invalide",
) )
if user.reset_token_expires < datetime.now(): if user.reset_token_expires and user.reset_token_expires < datetime.now():
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
detail="Token expiré. Demandez un nouveau lien de réinitialisation.", detail="Token expire. Demandez un nouveau lien.",
) )
is_valid, error_msg = validate_password_strength(data.new_password) is_valid, error_msg = validate_password_strength(data.new_password)
@ -478,41 +532,67 @@ async def reset_password(
user.reset_token_expires = None user.reset_token_expires = None
user.failed_login_attempts = 0 user.failed_login_attempts = 0
user.locked_until = None user.locked_until = None
user.password_changed_at = datetime.now()
await TokenService.revoke_all_user_tokens(
session=session, user_id=user.id, reason="password_reset"
)
await AuditService.log_password_change(session, request, user.id, "reset")
await session.commit() await session.commit()
CookieManager.clear_all_auth_cookies(response)
AuthEmailService.send_password_changed_notification(user.email) AuthEmailService.send_password_changed_notification(user.email)
logger.info(f" Mot de passe réinitialisé: {user.email}") logger.info(f"Mot de passe reinitialise: {user.email}")
return { return {
"success": True, "success": True,
"message": "Mot de passe réinitialisé avec succès. Vous pouvez maintenant vous connecter.", "message": "Mot de passe reinitialise. Vous pouvez maintenant vous connecter.",
} }
@router.post("/logout") @router.post("/change-password")
async def logout( async def change_password(
data: RefreshTokenRequest, data: ChangePasswordRequest,
request: Request,
response: Response,
session: AsyncSession = Depends(get_session), session: AsyncSession = Depends(get_session),
user: User = Depends(get_current_user), user: User = Depends(get_current_user),
): ):
token_hash = hash_token(data.refresh_token) if not verify_password(data.current_password, user.hashed_password):
raise HTTPException(
result = await session.execute( status_code=status.HTTP_400_BAD_REQUEST,
select(RefreshToken).where( detail="Mot de passe actuel incorrect",
RefreshToken.user_id == user.id, RefreshToken.token_hash == token_hash
) )
)
token_record = result.scalar_one_or_none()
if token_record: is_valid, error_msg = validate_password_strength(data.new_password)
token_record.is_revoked = True if not is_valid:
token_record.revoked_at = datetime.now() raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=error_msg)
user.hashed_password = hash_password(data.new_password)
user.password_changed_at = datetime.now()
await TokenService.revoke_all_user_tokens(
session=session, user_id=user.id, reason="password_change"
)
await AuditService.log_password_change(session, request, user.id, "user_initiated")
await session.commit() await session.commit()
logger.info(f"👋 Déconnexion: {user.email}") CookieManager.clear_all_auth_cookies(response)
return {"success": True, "message": "Déconnexion réussie"} AuthEmailService.send_password_changed_notification(user.email)
logger.info(f"Mot de passe change: {user.email}")
return {
"success": True,
"message": "Mot de passe modifie. Veuillez vous reconnecter.",
}
@router.get("/me") @router.get("/me")
@ -524,6 +604,69 @@ async def get_current_user_info(user: User = Depends(get_current_user)):
"prenom": user.prenom, "prenom": user.prenom,
"role": user.role, "role": user.role,
"is_verified": user.is_verified, "is_verified": user.is_verified,
"created_at": user.created_at.isoformat(), "created_at": user.created_at.isoformat() if user.created_at else None,
"last_login": user.last_login.isoformat() if user.last_login else None, "last_login": user.last_login.isoformat() if user.last_login else None,
} }
@router.get("/sessions", response_model=List[SessionResponse])
async def get_active_sessions(
session: AsyncSession = Depends(get_session), user: User = Depends(get_current_user)
):
sessions = await TokenService.get_user_active_sessions(session, user.id)
return [SessionResponse(**s) for s in sessions]
@router.delete("/sessions/{session_id}")
async def revoke_session(
session_id: str,
request: Request,
session: AsyncSession = Depends(get_session),
user: User = Depends(get_current_user),
):
result = await session.execute(
select(RefreshToken).where(
RefreshToken.id == session_id,
RefreshToken.user_id == user.id,
RefreshToken.is_revoked.is_(false()),
)
)
token_record = result.scalar_one_or_none()
if not token_record:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Session introuvable"
)
token_record.is_revoked = True
token_record.revoked_at = datetime.now()
token_record.revoked_reason = "user_revoked"
await AuditService.log_event(
session=session,
event_type=AuditEventType.SESSION_TERMINATED,
request=request,
user_id=user.id,
description=f"Session {session_id[:8]}... revoquee",
)
await session.commit()
return {"success": True, "message": "Session revoquee"}
@router.get("/csrf-token")
async def get_csrf_token(
request: Request, response: Response, user: User = Depends(get_current_user)
):
from security.auth import generate_session_id, create_csrf_token
session_id = getattr(request.state, "session_id", None)
if not session_id:
session_id = generate_session_id()
csrf_token = create_csrf_token(session_id)
CookieManager.set_csrf_token(response, csrf_token)
return {"csrf_token": csrf_token}

View file

@ -12,7 +12,6 @@ class GatewayHealthStatus(str, Enum):
# === CREATE === # === CREATE ===
class SageGatewayCreate(BaseModel): class SageGatewayCreate(BaseModel):
name: str = Field( name: str = Field(
..., min_length=2, max_length=100, description="Nom de la gateway" ..., min_length=2, max_length=100, description="Nom de la gateway"
) )
@ -24,8 +23,6 @@ class SageGatewayCreate(BaseModel):
gateway_token: str = Field( gateway_token: str = Field(
..., min_length=10, description="Token d'authentification" ..., min_length=10, description="Token d'authentification"
) )
sage_database: Optional[str] = Field(None, max_length=255)
sage_company: Optional[str] = Field(None, max_length=255) sage_company: Optional[str] = Field(None, max_length=255)
is_active: bool = Field(False, description="Activer immédiatement cette gateway") is_active: bool = Field(False, description="Activer immédiatement cette gateway")
@ -54,9 +51,6 @@ class SageGatewayUpdate(BaseModel):
gateway_url: Optional[str] = None gateway_url: Optional[str] = None
gateway_token: Optional[str] = Field(None, min_length=10) gateway_token: Optional[str] = Field(None, min_length=10)
sage_database: Optional[str] = None
sage_company: Optional[str] = None
is_default: Optional[bool] = None is_default: Optional[bool] = None
priority: Optional[int] = Field(None, ge=0, le=100) priority: Optional[int] = Field(None, ge=0, le=100)
@ -73,7 +67,6 @@ class SageGatewayUpdate(BaseModel):
# === RESPONSE === # === RESPONSE ===
class SageGatewayResponse(BaseModel): class SageGatewayResponse(BaseModel):
id: str id: str
user_id: str user_id: str
@ -83,9 +76,6 @@ class SageGatewayResponse(BaseModel):
gateway_url: str gateway_url: str
token_preview: str token_preview: str
sage_database: Optional[str] = None
sage_company: Optional[str] = None
is_active: bool is_active: bool
is_default: bool is_default: bool
priority: int priority: int
@ -111,7 +101,6 @@ class SageGatewayResponse(BaseModel):
class SageGatewayListResponse(BaseModel): class SageGatewayListResponse(BaseModel):
items: List[SageGatewayResponse] items: List[SageGatewayResponse]
total: int total: int
active_gateway: Optional[SageGatewayResponse] = None active_gateway: Optional[SageGatewayResponse] = None

55
security/__init__.py Normal file
View file

@ -0,0 +1,55 @@
from security.auth import (
hash_password,
verify_password,
validate_password_strength,
generate_verification_token,
generate_reset_token,
generate_csrf_token,
generate_secure_token,
hash_token,
constant_time_compare,
create_access_token,
create_refresh_token,
decode_token,
generate_session_id,
)
from security.cookies import CookieManager, set_auth_cookies
from security.fingerprint import (
DeviceFingerprint,
get_fingerprint_hash,
validate_fingerprint,
get_client_ip,
)
from security.csrf import CSRFProtection, verify_csrf, generate_csrf_for_session
from security.rate_limiter import RateLimiter, check_rate_limit_dependency
__all__ = [
"hash_password",
"verify_password",
"validate_password_strength",
"generate_verification_token",
"generate_reset_token",
"generate_csrf_token",
"generate_secure_token",
"hash_token",
"constant_time_compare",
"create_access_token",
"create_refresh_token",
"decode_token",
"generate_session_id",
"CookieManager",
"set_auth_cookies",
"DeviceFingerprint",
"get_fingerprint_hash",
"validate_fingerprint",
"get_client_ip",
"CSRFProtection",
"verify_csrf",
"generate_csrf_for_session",
"RateLimiter",
"check_rate_limit_dependency",
]

View file

@ -1,16 +1,17 @@
from passlib.context import CryptContext from passlib.context import CryptContext
from datetime import datetime, timedelta from datetime import datetime, timedelta, timezone
from typing import Optional, Dict from typing import Optional, Dict, Any, Tuple
import jwt import jwt
import secrets import secrets
import hashlib import hashlib
import hmac
import logging
SECRET_KEY = "VOTRE_SECRET_KEY_A_METTRE_EN_.ENV" from config.config import settings
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 30
REFRESH_TOKEN_EXPIRE_DAYS = 7
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") logger = logging.getLogger(__name__)
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto", bcrypt__rounds=12)
def hash_password(password: str) -> str: def hash_password(password: str) -> str:
@ -18,75 +19,192 @@ def hash_password(password: str) -> str:
def verify_password(plain_password: str, hashed_password: str) -> bool: def verify_password(plain_password: str, hashed_password: str) -> bool:
try:
return pwd_context.verify(plain_password, hashed_password) return pwd_context.verify(plain_password, hashed_password)
except Exception as e:
logger.warning(f"Erreur verification mot de passe: {e}")
return False
def generate_secure_token(length: int = 32) -> str:
return secrets.token_urlsafe(length)
def generate_verification_token() -> str: def generate_verification_token() -> str:
return secrets.token_urlsafe(32) return generate_secure_token(32)
def generate_reset_token() -> str: def generate_reset_token() -> str:
return secrets.token_urlsafe(32) return generate_secure_token(32)
def generate_csrf_token() -> str:
return generate_secure_token(32)
def generate_refresh_token_id() -> str:
return generate_secure_token(16)
def hash_token(token: str) -> str: def hash_token(token: str) -> str:
return hashlib.sha256(token.encode()).hexdigest() return hashlib.sha256(token.encode()).hexdigest()
def create_access_token(data: Dict, expires_delta: Optional[timedelta] = None) -> str: def constant_time_compare(val1: str, val2: str) -> bool:
return hmac.compare_digest(val1.encode(), val2.encode())
def create_access_token(
data: Dict[str, Any],
expires_delta: Optional[timedelta] = None,
fingerprint_hash: Optional[str] = None,
) -> str:
to_encode = data.copy() to_encode = data.copy()
now = datetime.now(timezone.utc)
if expires_delta: if expires_delta:
expire = datetime.utcnow() + expires_delta expire = now + expires_delta
else: else:
expire = datetime.utcnow() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) expire = now + timedelta(minutes=settings.access_token_expire_minutes)
to_encode.update({"exp": expire, "iat": datetime.utcnow(), "type": "access"}) to_encode.update(
{
"exp": expire,
"iat": now,
"nbf": now,
"type": "access",
"jti": generate_secure_token(8),
}
)
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) if fingerprint_hash:
return encoded_jwt to_encode["fph"] = fingerprint_hash
return jwt.encode(to_encode, settings.jwt_secret, algorithm=settings.jwt_algorithm)
def create_refresh_token(user_id: str) -> str: def create_refresh_token(
expire = datetime.utcnow() + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS) user_id: str,
token_id: Optional[str] = None,
fingerprint_hash: Optional[str] = None,
expires_delta: Optional[timedelta] = None,
) -> Tuple[str, str]:
now = datetime.now(timezone.utc)
if expires_delta:
expire = now + expires_delta
else:
expire = now + timedelta(days=settings.refresh_token_expire_days)
if not token_id:
token_id = generate_refresh_token_id()
to_encode = { to_encode = {
"sub": user_id, "sub": user_id,
"exp": expire, "exp": expire,
"iat": datetime.utcnow(), "iat": now,
"nbf": now,
"type": "refresh", "type": "refresh",
"jti": secrets.token_urlsafe(16), # Unique ID "jti": token_id,
} }
encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) if fingerprint_hash:
return encoded_jwt to_encode["fph"] = fingerprint_hash
token = jwt.encode(to_encode, settings.jwt_secret, algorithm=settings.jwt_algorithm)
return token, token_id
def decode_token(token: str) -> Optional[Dict]: def create_csrf_token(session_id: str) -> str:
now = datetime.now(timezone.utc)
expire = now + timedelta(minutes=settings.csrf_token_expire_minutes)
to_encode = {
"sid": session_id,
"exp": expire,
"iat": now,
"type": "csrf",
"jti": generate_secure_token(8),
}
return jwt.encode(to_encode, settings.jwt_secret, algorithm=settings.jwt_algorithm)
def decode_token(
token: str, expected_type: Optional[str] = None
) -> Optional[Dict[str, Any]]:
try: try:
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) payload = jwt.decode(
token,
settings.jwt_secret,
algorithms=[settings.jwt_algorithm],
options={
"require": ["exp", "iat", "type"],
"verify_exp": True,
"verify_iat": True,
"verify_nbf": True,
},
)
if expected_type and payload.get("type") != expected_type:
logger.warning(
f"Type de token incorrect: attendu={expected_type}, recu={payload.get('type')}"
)
return None
return payload return payload
except jwt.ExpiredSignatureError: except jwt.ExpiredSignatureError:
logger.debug("Token expire")
return None return None
except jwt.JWTError: except jwt.InvalidTokenError as e:
logger.warning(f"Token invalide: {e}")
return None
except Exception as e:
logger.error(f"Erreur decodage token: {e}")
return None return None
def validate_password_strength(password: str) -> tuple[bool, str]: def validate_password_strength(password: str) -> Tuple[bool, str]:
if len(password) < 8: if len(password) < settings.password_min_length:
return False, "Le mot de passe doit contenir au moins 8 caractères" return (
False,
f"Le mot de passe doit contenir au moins {settings.password_min_length} caracteres",
)
if not any(c.isupper() for c in password): if settings.password_require_uppercase and not any(c.isupper() for c in password):
return False, "Le mot de passe doit contenir au moins une majuscule" return False, "Le mot de passe doit contenir au moins une majuscule"
if not any(c.islower() for c in password): if settings.password_require_lowercase and not any(c.islower() for c in password):
return False, "Le mot de passe doit contenir au moins une minuscule" return False, "Le mot de passe doit contenir au moins une minuscule"
if not any(c.isdigit() for c in password): if settings.password_require_digit and not any(c.isdigit() for c in password):
return False, "Le mot de passe doit contenir au moins un chiffre" return False, "Le mot de passe doit contenir au moins un chiffre"
special_chars = "!@#$%^&*()_+-=[]{}|;:,.<>?" if settings.password_require_special:
special_chars = "!@#$%^&*()_+-=[]{}|;:,.<>?/~`"
if not any(c in special_chars for c in password): if not any(c in special_chars for c in password):
return False, "Le mot de passe doit contenir au moins un caractère spécial" return False, "Le mot de passe doit contenir au moins un caractere special"
common_passwords = [
"password",
"123456",
"qwerty",
"admin",
"letmein",
"welcome",
"monkey",
"dragon",
"master",
"login",
]
if password.lower() in common_passwords:
return False, "Ce mot de passe est trop courant"
return True, "" return True, ""
def generate_session_id() -> str:
"""Genere un identifiant de session unique."""
return generate_secure_token(24)

157
security/cookies.py Normal file
View file

@ -0,0 +1,157 @@
from fastapi import Response, Request
from typing import Optional
import logging
from config.config import settings
logger = logging.getLogger(__name__)
class CookieManager:
@staticmethod
def _get_samesite_value() -> str:
value = settings.cookie_samesite.lower()
if value in ("strict", "lax", "none"):
return value
return "strict"
@staticmethod
def _should_be_secure() -> bool:
if settings.is_development and not settings.cookie_secure:
return False
return True
@classmethod
def set_access_token(
cls, response: Response, token: str, max_age: Optional[int] = None
) -> None:
if max_age is None:
max_age = settings.access_token_expire_minutes * 60
response.set_cookie(
key=settings.cookie_access_token_name,
value=token,
max_age=max_age,
expires=max_age,
path="/",
domain=settings.cookie_domain,
secure=cls._should_be_secure(),
httponly=settings.cookie_httponly,
samesite=cls._get_samesite_value(),
)
logger.debug("Cookie access_token defini")
@classmethod
def set_refresh_token(
cls, response: Response, token: str, max_age: Optional[int] = None
) -> None:
if max_age is None:
max_age = settings.refresh_token_expire_days * 24 * 60 * 60
response.set_cookie(
key=settings.cookie_refresh_token_name,
value=token,
max_age=max_age,
expires=max_age,
path="/auth",
domain=settings.cookie_domain,
secure=cls._should_be_secure(),
httponly=settings.cookie_httponly,
samesite=cls._get_samesite_value(),
)
logger.debug("Cookie refresh_token defini")
@classmethod
def set_csrf_token(
cls, response: Response, token: str, max_age: Optional[int] = None
) -> None:
if max_age is None:
max_age = settings.csrf_token_expire_minutes * 60
response.set_cookie(
key=settings.cookie_csrf_token_name,
value=token,
max_age=max_age,
expires=max_age,
path="/",
domain=settings.cookie_domain,
secure=cls._should_be_secure(),
httponly=False,
samesite=cls._get_samesite_value(),
)
logger.debug("Cookie csrf_token defini")
@classmethod
def clear_access_token(cls, response: Response) -> None:
response.delete_cookie(
key=settings.cookie_access_token_name,
path="/",
domain=settings.cookie_domain,
secure=cls._should_be_secure(),
httponly=settings.cookie_httponly,
samesite=cls._get_samesite_value(),
)
logger.debug("Cookie access_token supprime")
@classmethod
def clear_refresh_token(cls, response: Response) -> None:
response.delete_cookie(
key=settings.cookie_refresh_token_name,
path="/auth",
domain=settings.cookie_domain,
secure=cls._should_be_secure(),
httponly=settings.cookie_httponly,
samesite=cls._get_samesite_value(),
)
logger.debug("Cookie refresh_token supprime")
@classmethod
def clear_csrf_token(cls, response: Response) -> None:
response.delete_cookie(
key=settings.cookie_csrf_token_name,
path="/",
domain=settings.cookie_domain,
secure=cls._should_be_secure(),
httponly=False,
samesite=cls._get_samesite_value(),
)
logger.debug("Cookie csrf_token supprime")
@classmethod
def clear_all_auth_cookies(cls, response: Response) -> None:
cls.clear_access_token(response)
cls.clear_refresh_token(response)
cls.clear_csrf_token(response)
logger.debug("Tous les cookies auth supprimes")
@classmethod
def get_access_token(cls, request: Request) -> Optional[str]:
token = request.cookies.get(settings.cookie_access_token_name)
if token:
return token
auth_header = request.headers.get("Authorization")
if auth_header and auth_header.startswith("Bearer "):
return auth_header[7:]
return None
@classmethod
def get_refresh_token(cls, request: Request) -> Optional[str]:
return request.cookies.get(settings.cookie_refresh_token_name)
@classmethod
def get_csrf_token(cls, request: Request) -> Optional[str]:
csrf_header = request.headers.get("X-CSRF-Token")
if csrf_header:
return csrf_header
return request.cookies.get(settings.cookie_csrf_token_name)
def set_auth_cookies(
response: Response, access_token: str, refresh_token: str, csrf_token: str
) -> None:
CookieManager.set_access_token(response, access_token)
CookieManager.set_refresh_token(response, refresh_token)
CookieManager.set_csrf_token(response, csrf_token)

117
security/csrf.py Normal file
View file

@ -0,0 +1,117 @@
"""
security/csrf.py - Protection contre les attaques Cross-Site Request Forgery
"""
from fastapi import Request, HTTPException, status
from typing import Optional, Set
import logging
from config.config import settings
from security.auth import decode_token, create_csrf_token, constant_time_compare
logger = logging.getLogger(__name__)
SAFE_METHODS: Set[str] = {"GET", "HEAD", "OPTIONS", "TRACE"}
CSRF_EXEMPT_PATHS: Set[str] = {
"/auth/login",
"/auth/register",
"/auth/forgot-password",
"/auth/verify-email",
"/auth/resend-verification",
"/health",
"/docs",
"/redoc",
"/openapi.json",
"/webhooks/universign",
}
class CSRFProtection:
@classmethod
def is_exempt(cls, request: Request) -> bool:
if request.method in SAFE_METHODS:
return True
path = request.url.path.rstrip("/")
if path in CSRF_EXEMPT_PATHS:
return True
for exempt_path in CSRF_EXEMPT_PATHS:
if path.startswith(exempt_path):
return True
return False
@classmethod
def generate_token(cls, session_id: str) -> str:
return create_csrf_token(session_id)
@classmethod
def validate_token(cls, request: Request, session_id: Optional[str] = None) -> bool:
csrf_header = request.headers.get("X-CSRF-Token")
if not csrf_header:
logger.warning("Token CSRF manquant dans le header")
return False
payload = decode_token(csrf_header, expected_type="csrf")
if not payload:
logger.warning("Token CSRF invalide ou expire")
return False
if session_id and payload.get("sid") != session_id:
logger.warning("Token CSRF ne correspond pas a la session")
return False
return True
@classmethod
def validate_double_submit(cls, request: Request) -> bool:
header_token = request.headers.get("X-CSRF-Token")
cookie_token = request.cookies.get(settings.cookie_csrf_token_name)
if not header_token or not cookie_token:
logger.warning("Token CSRF manquant (header ou cookie)")
return False
if not constant_time_compare(header_token, cookie_token):
logger.warning("Tokens CSRF ne correspondent pas")
return False
return True
@classmethod
def validate_request(
cls,
request: Request,
session_id: Optional[str] = None,
use_double_submit: bool = True,
) -> bool:
if cls.is_exempt(request):
return True
if use_double_submit:
if not cls.validate_double_submit(request):
return False
return cls.validate_token(request, session_id)
async def verify_csrf(request: Request, session_id: Optional[str] = None) -> None:
if CSRFProtection.is_exempt(request):
return
if not CSRFProtection.validate_request(request, session_id):
logger.warning(
f"Verification CSRF echouee pour {request.method} {request.url.path}"
)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, detail="Verification CSRF echouee"
)
def generate_csrf_for_session(session_id: str) -> str:
return CSRFProtection.generate_token(session_id)

122
security/fingerprint.py Normal file
View file

@ -0,0 +1,122 @@
from fastapi import Request
from typing import Dict
import hashlib
import hmac
import logging
from config.config import settings
logger = logging.getLogger(__name__)
class DeviceFingerprint:
COMPONENT_EXTRACTORS = {
"user_agent": lambda r: r.headers.get("User-Agent", ""),
"accept_language": lambda r: r.headers.get("Accept-Language", ""),
"accept_encoding": lambda r: r.headers.get("Accept-Encoding", ""),
"accept": lambda r: r.headers.get("Accept", ""),
"connection": lambda r: r.headers.get("Connection", ""),
"cache_control": lambda r: r.headers.get("Cache-Control", ""),
"client_ip": lambda r: DeviceFingerprint._get_client_ip(r),
"sec_ch_ua": lambda r: r.headers.get("Sec-CH-UA", ""),
"sec_ch_ua_platform": lambda r: r.headers.get("Sec-CH-UA-Platform", ""),
"sec_ch_ua_mobile": lambda r: r.headers.get("Sec-CH-UA-Mobile", ""),
}
@staticmethod
def _get_client_ip(request: Request) -> str:
forwarded = request.headers.get("X-Forwarded-For")
if forwarded:
return forwarded.split(",")[0].strip()
real_ip = request.headers.get("X-Real-IP")
if real_ip:
return real_ip
if request.client:
return request.client.host
return ""
@classmethod
def extract_components(cls, request: Request) -> Dict[str, str]:
components = {}
for component_name in settings.fingerprint_components:
extractor = cls.COMPONENT_EXTRACTORS.get(component_name)
if extractor:
try:
value = extractor(request)
components[component_name] = value if value else ""
except Exception as e:
logger.warning(f"Erreur extraction composant {component_name}: {e}")
components[component_name] = ""
else:
logger.warning(f"Extracteur inconnu pour composant: {component_name}")
return components
@classmethod
def generate_hash(cls, request: Request, include_ip: bool = False) -> str:
components = cls.extract_components(request)
if not include_ip and "client_ip" in components:
del components["client_ip"]
sorted_keys = sorted(components.keys())
fingerprint_data = "|".join(f"{k}:{components[k]}" for k in sorted_keys)
secret = settings.fingerprint_secret or settings.jwt_secret
signature = hmac.new(
secret.encode(), fingerprint_data.encode(), hashlib.sha256
).hexdigest()
return signature
@classmethod
def generate_from_components(cls, components: Dict[str, str]) -> str:
sorted_keys = sorted(components.keys())
fingerprint_data = "|".join(f"{k}:{components.get(k, '')}" for k in sorted_keys)
secret = settings.fingerprint_secret or settings.jwt_secret
signature = hmac.new(
secret.encode(), fingerprint_data.encode(), hashlib.sha256
).hexdigest()
return signature
@classmethod
def validate(
cls, request: Request, stored_hash: str, include_ip: bool = False
) -> bool:
if not stored_hash:
return True
current_hash = cls.generate_hash(request, include_ip=include_ip)
return hmac.compare_digest(current_hash, stored_hash)
@classmethod
def get_device_info(cls, request: Request) -> Dict[str, str]:
user_agent = request.headers.get("User-Agent", "")
return {
"user_agent": user_agent[:500] if user_agent else "",
"ip_address": cls._get_client_ip(request),
"accept_language": request.headers.get("Accept-Language", "")[:100],
"fingerprint_hash": cls.generate_hash(request),
}
def get_fingerprint_hash(request: Request) -> str:
return DeviceFingerprint.generate_hash(request)
def validate_fingerprint(request: Request, stored_hash: str) -> bool:
return DeviceFingerprint.validate(request, stored_hash)
def get_client_ip(request: Request) -> str:
return DeviceFingerprint._get_client_ip(request)

147
security/rate_limiter.py Normal file
View file

@ -0,0 +1,147 @@
from fastapi import Request, HTTPException, status
from typing import Optional, Tuple
import logging
from config.config import settings
from services.redis_service import redis_service
from security.fingerprint import get_client_ip
logger = logging.getLogger(__name__)
class RateLimiter:
@staticmethod
def _make_key(identifier: str, action: str) -> str:
return f"{action}:{identifier}"
@classmethod
async def check_login_rate_limit(
cls, email: str, ip_address: str
) -> Tuple[bool, Optional[str], int]:
window_seconds = settings.rate_limit_login_window_minutes * 60
max_attempts = settings.rate_limit_login_attempts
email_key = cls._make_key(email.lower(), "login_email")
email_count = await redis_service.get_rate_limit_count(email_key)
if email_count >= max_attempts:
return (
False,
f"Trop de tentatives pour cet email. Reessayez dans {settings.rate_limit_login_window_minutes} minutes.",
window_seconds,
)
ip_key = cls._make_key(ip_address, "login_ip")
ip_count = await redis_service.get_rate_limit_count(ip_key)
ip_limit = max_attempts * 3
if ip_count >= ip_limit:
return (
False,
window_seconds,
)
return (True, None, 0)
@classmethod
async def record_login_attempt(
cls, email: str, ip_address: str, success: bool
) -> None:
window_seconds = settings.rate_limit_login_window_minutes * 60
if success:
email_key = cls._make_key(email.lower(), "login_email")
await redis_service.reset_rate_limit(email_key)
logger.debug(f"Rate limit reinitialise pour {email}")
else:
email_key = cls._make_key(email.lower(), "login_email")
await redis_service.increment_rate_limit(email_key, window_seconds)
ip_key = cls._make_key(ip_address, "login_ip")
await redis_service.increment_rate_limit(ip_key, window_seconds)
logger.debug(
f"Tentative echouee enregistree pour {email} depuis {ip_address}"
)
@classmethod
async def check_api_rate_limit(
cls, identifier: str, endpoint: Optional[str] = None
) -> Tuple[bool, int, int]:
window_seconds = settings.rate_limit_api_window_seconds
max_requests = settings.rate_limit_api_requests
if endpoint:
key = cls._make_key(f"{identifier}:{endpoint}", "api")
else:
key = cls._make_key(identifier, "api")
count = await redis_service.increment_rate_limit(key, window_seconds)
remaining = max(0, max_requests - count)
if count > max_requests:
return (False, remaining, window_seconds)
return (True, remaining, window_seconds)
@classmethod
async def check_password_reset_rate_limit(
cls, email: str, ip_address: str
) -> Tuple[bool, Optional[str]]:
window_seconds = 3600
max_attempts_email = 3
max_attempts_ip = 10
email_key = cls._make_key(email.lower(), "reset_email")
email_count = await redis_service.get_rate_limit_count(email_key)
if email_count >= max_attempts_email:
return (False, "Trop de demandes de reinitialisation pour cet email.")
ip_key = cls._make_key(ip_address, "reset_ip")
ip_count = await redis_service.get_rate_limit_count(ip_key)
if ip_count >= max_attempts_ip:
return (False, "Trop de demandes depuis cette adresse IP.")
await redis_service.increment_rate_limit(email_key, window_seconds)
await redis_service.increment_rate_limit(ip_key, window_seconds)
return (True, None)
@classmethod
async def check_registration_rate_limit(
cls, ip_address: str
) -> Tuple[bool, Optional[str]]:
window_seconds = 3600
max_registrations = 5
key = cls._make_key(ip_address, "register_ip")
count = await redis_service.get_rate_limit_count(key)
if count >= max_registrations:
return (False, "Trop d'inscriptions depuis cette adresse IP.")
await redis_service.increment_rate_limit(key, window_seconds)
return (True, None)
async def check_rate_limit_dependency(request: Request) -> None:
ip = get_client_ip(request)
allowed, remaining, reset_seconds = await RateLimiter.check_api_rate_limit(ip)
request.state.rate_limit_remaining = remaining
request.state.rate_limit_reset = reset_seconds
if not allowed:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail="Limite de requetes atteinte. Reessayez plus tard.",
headers={
"X-RateLimit-Remaining": "0",
"X-RateLimit-Reset": str(reset_seconds),
"Retry-After": str(reset_seconds),
},
)

318
services/audit_service.py Normal file
View file

@ -0,0 +1,318 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import false, select, and_
from datetime import datetime, timedelta
from typing import Optional, Dict, Any, List
from fastapi import Request
import uuid
import json
import logging
from database import AuditLog, AuditEventType, LoginAttempt
from security.fingerprint import DeviceFingerprint, get_client_ip
logger = logging.getLogger(__name__)
class AuditService:
@classmethod
async def log_event(
cls,
session: AsyncSession,
event_type: AuditEventType,
request: Optional[Request] = None,
user_id: Optional[str] = None,
description: Optional[str] = None,
success: bool = True,
failure_reason: Optional[str] = None,
resource_type: Optional[str] = None,
resource_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> AuditLog:
ip_address = None
user_agent = None
fingerprint_hash = None
request_method = None
request_path = None
if request:
ip_address = get_client_ip(request)
user_agent = request.headers.get("User-Agent", "")[:500]
fingerprint_hash = DeviceFingerprint.generate_hash(request)
request_method = request.method
request_path = str(request.url.path)[:500]
metadata_json = None
if metadata:
try:
metadata_json = json.dumps(metadata, default=str)
except Exception as e:
logger.warning(f"Erreur serialisation metadata audit: {e}")
audit_log = AuditLog(
id=str(uuid.uuid4()),
user_id=user_id,
event_type=event_type,
event_description=description,
ip_address=ip_address,
user_agent=user_agent,
fingerprint_hash=fingerprint_hash,
resource_type=resource_type,
resource_id=resource_id,
request_method=request_method,
request_path=request_path,
metadata=metadata_json,
success=success,
failure_reason=failure_reason,
created_at=datetime.now(),
)
session.add(audit_log)
await session.flush()
log_level = logging.INFO if success else logging.WARNING
logger.log(
log_level,
f"Audit: {event_type.value} user={user_id} success={success} ip={ip_address}",
)
return audit_log
@classmethod
async def log_login_success(
cls, session: AsyncSession, request: Request, user_id: str, email: str
) -> AuditLog:
return await cls.log_event(
session=session,
event_type=AuditEventType.LOGIN_SUCCESS,
request=request,
user_id=user_id,
description=f"Connexion reussie pour {email}",
success=True,
metadata={"email": email},
)
@classmethod
async def log_login_failed(
cls,
session: AsyncSession,
request: Request,
email: str,
reason: str,
user_id: Optional[str] = None,
) -> AuditLog:
return await cls.log_event(
session=session,
event_type=AuditEventType.LOGIN_FAILED,
request=request,
user_id=user_id,
description=f"Echec connexion pour {email}: {reason}",
success=False,
failure_reason=reason,
metadata={"email": email},
)
@classmethod
async def log_logout(
cls, session: AsyncSession, request: Request, user_id: str
) -> AuditLog:
return await cls.log_event(
session=session,
event_type=AuditEventType.LOGOUT,
request=request,
user_id=user_id,
description="Deconnexion utilisateur",
success=True,
)
@classmethod
async def log_password_change(
cls,
session: AsyncSession,
request: Request,
user_id: str,
method: str = "user_initiated",
) -> AuditLog:
return await cls.log_event(
session=session,
event_type=AuditEventType.PASSWORD_CHANGE,
request=request,
user_id=user_id,
description=f"Mot de passe modifie ({method})",
success=True,
metadata={"method": method},
)
@classmethod
async def log_password_reset_request(
cls,
session: AsyncSession,
request: Request,
email: str,
user_id: Optional[str] = None,
) -> AuditLog:
return await cls.log_event(
session=session,
event_type=AuditEventType.PASSWORD_RESET_REQUEST,
request=request,
user_id=user_id,
description=f"Demande reset mot de passe pour {email}",
success=True,
metadata={"email": email},
)
@classmethod
async def log_account_locked(
cls, session: AsyncSession, request: Request, user_id: str, reason: str
) -> AuditLog:
return await cls.log_event(
session=session,
event_type=AuditEventType.ACCOUNT_LOCKED,
request=request,
user_id=user_id,
description=f"Compte verrouille: {reason}",
success=True,
metadata={"reason": reason},
)
@classmethod
async def log_token_refresh(
cls, session: AsyncSession, request: Request, user_id: str
) -> AuditLog:
return await cls.log_event(
session=session,
event_type=AuditEventType.TOKEN_REFRESH,
request=request,
user_id=user_id,
description="Token rafraichi",
success=True,
)
@classmethod
async def log_suspicious_activity(
cls,
session: AsyncSession,
request: Request,
user_id: Optional[str],
activity_type: str,
details: str,
) -> AuditLog:
return await cls.log_event(
session=session,
event_type=AuditEventType.SUSPICIOUS_ACTIVITY,
request=request,
user_id=user_id,
description=f"Activite suspecte: {activity_type} - {details}",
success=False,
failure_reason=activity_type,
metadata={"activity_type": activity_type, "details": details},
)
@classmethod
async def record_login_attempt(
cls,
session: AsyncSession,
request: Request,
email: str,
success: bool,
failure_reason: Optional[str] = None,
) -> LoginAttempt:
attempt = LoginAttempt(
email=email.lower(),
ip_address=get_client_ip(request),
user_agent=request.headers.get("User-Agent", "")[:500],
fingerprint_hash=DeviceFingerprint.generate_hash(request),
success=success,
failure_reason=failure_reason,
timestamp=datetime.now(),
)
session.add(attempt)
await session.flush()
return attempt
@classmethod
async def get_recent_failed_attempts(
cls, session: AsyncSession, email: str, window_minutes: int = 15
) -> int:
time_threshold = datetime.now() - timedelta(minutes=window_minutes)
result = await session.execute(
select(LoginAttempt).where(
and_(
LoginAttempt.email == email.lower(),
LoginAttempt.success.is_(false()),
LoginAttempt.timestamp >= time_threshold,
)
)
)
return len(result.scalars().all())
@classmethod
async def get_user_audit_history(
cls,
session: AsyncSession,
user_id: str,
limit: int = 50,
event_types: Optional[List[AuditEventType]] = None,
) -> List[AuditLog]:
query = select(AuditLog).where(AuditLog.user_id == user_id)
if event_types:
query = query.where(AuditLog.event_type.in_(event_types))
query = query.order_by(AuditLog.created_at.desc()).limit(limit)
result = await session.execute(query)
return list(result.scalars().all())
@classmethod
async def detect_suspicious_patterns(
cls, session: AsyncSession, user_id: str
) -> Dict[str, Any]:
one_hour_ago = datetime.now() - timedelta(hours=1)
one_day_ago = datetime.now() - timedelta(days=1)
result = await session.execute(
select(AuditLog).where(
and_(
AuditLog.user_id == user_id,
AuditLog.event_type == AuditEventType.LOGIN_FAILED,
AuditLog.created_at >= one_hour_ago,
)
)
)
failed_logins_hour = len(result.scalars().all())
result = await session.execute(
select(AuditLog).where(
and_(
AuditLog.user_id == user_id,
AuditLog.event_type == AuditEventType.LOGIN_SUCCESS,
AuditLog.created_at >= one_day_ago,
)
)
)
login_logs = result.scalars().all()
unique_ips = set(log.ip_address for log in login_logs if log.ip_address)
result = await session.execute(
select(AuditLog).where(
and_(
AuditLog.user_id == user_id,
AuditLog.event_type == AuditEventType.PASSWORD_RESET_REQUEST,
AuditLog.created_at >= one_day_ago,
)
)
)
password_resets = len(result.scalars().all())
return {
"failed_logins_last_hour": failed_logins_hour,
"unique_ips_last_day": len(unique_ips),
"password_reset_requests_last_day": password_resets,
"is_suspicious": (
failed_logins_hour >= 5 or len(unique_ips) >= 5 or password_resets >= 3
),
}

View file

@ -1,22 +1,39 @@
import smtplib import smtplib
from email.mime.multipart import MIMEMultipart from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText from email.mime.text import MIMEText
from config.config import settings from typing import Optional, List
import logging import logging
from config.config import settings
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class AuthEmailService: class AuthEmailService:
@staticmethod @staticmethod
def _send_email(to: str, subject: str, html_body: str) -> bool: def _send_email(
to: str,
subject: str,
html_body: str,
cc: Optional[List[str]] = None,
bcc: Optional[List[str]] = None,
) -> bool:
try: try:
msg = MIMEMultipart() msg = MIMEMultipart("alternative")
msg["From"] = settings.smtp_from msg["From"] = settings.smtp_from
msg["To"] = to msg["To"] = to
msg["Subject"] = subject msg["Subject"] = subject
msg.attach(MIMEText(html_body, "html")) if cc:
msg["Cc"] = ", ".join(cc)
msg.attach(MIMEText(html_body, "html", "utf-8"))
recipients = [to]
if cc:
recipients.extend(cc)
if bcc:
recipients.extend(bcc)
with smtplib.SMTP( with smtplib.SMTP(
settings.smtp_host, settings.smtp_port, timeout=30 settings.smtp_host, settings.smtp_port, timeout=30
@ -27,176 +44,263 @@ class AuthEmailService:
if settings.smtp_user and settings.smtp_password: if settings.smtp_user and settings.smtp_password:
server.login(settings.smtp_user, settings.smtp_password) server.login(settings.smtp_user, settings.smtp_password)
server.send_message(msg) server.sendmail(settings.smtp_from, recipients, msg.as_string())
logger.info(f" Email envoyé: {subject} {to}") logger.info(f"Email envoye: {subject} vers {to}")
return True return True
except smtplib.SMTPException as e:
logger.error(f"Erreur SMTP envoi email: {e}")
return False
except Exception as e: except Exception as e:
logger.error(f"Erreur envoi email: {e}") logger.error(f"Erreur envoi email: {e}")
return False return False
@staticmethod @classmethod
def send_verification_email(email: str, token: str, base_url: str) -> bool: def send_verification_email(cls, email: str, token: str, base_url: str) -> bool:
verification_link = f"{base_url}/auth/verify-email?token={token}" verification_link = f"{base_url}/auth/verify-email?token={token}"
html_body = f""" html_body = f"""
<!DOCTYPE html> <!DOCTYPE html>
<html> <html lang="fr">
<head> <head>
<style> <meta charset="UTF-8">
body {{ font-family: Arial, sans-serif; line-height: 1.6; color: #333; }} <meta name="viewport" content="width=device-width, initial-scale=1.0">
.container {{ max-width: 600px; margin: 0 auto; padding: 20px; }} <title>Verification de votre email</title>
.header {{ background: #4F46E5; color: white; padding: 20px; text-align: center; border-radius: 8px 8px 0 0; }}
.content {{ background: #f9fafb; padding: 30px; border-radius: 0 0 8px 8px; }}
.button {{
display: inline-block;
background: #4F46E5;
color: white;
padding: 12px 30px;
text-decoration: none;
border-radius: 6px;
margin: 20px 0;
}}
.footer {{ text-align: center; margin-top: 20px; font-size: 12px; color: #6b7280; }}
</style>
</head> </head>
<body> <body style="margin: 0; padding: 0; font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif; background-color: #f5f5f5;">
<div class="container"> <table role="presentation" width="100%" cellspacing="0" cellpadding="0" style="background-color: #f5f5f5; padding: 40px 20px;">
<div class="header"> <tr>
<h1>🎉 Bienvenue sur Sage Dataven</h1> <td align="center">
</div> <table role="presentation" width="600" cellspacing="0" cellpadding="0" style="background-color: #ffffff; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);">
<div class="content"> <tr>
<h2>Vérifiez votre adresse email</h2> <td style="background-color: #4F46E5; padding: 32px; text-align: center; border-radius: 8px 8px 0 0;">
<p>Merci de vous être inscrit ! Pour activer votre compte, veuillez cliquer sur le bouton ci-dessous :</p> <h1 style="color: #ffffff; margin: 0; font-size: 24px; font-weight: 600;">Verification de votre email</h1>
</td>
<div style="text-align: center;"> </tr>
<a href="{verification_link}" class="button">Vérifier mon email</a> <tr>
</div> <td style="padding: 40px 32px;">
<p style="color: #374151; font-size: 16px; line-height: 1.6; margin: 0 0 24px;">
<p style="margin-top: 30px;">Ou copiez ce lien dans votre navigateur :</p> Bienvenue sur Sage Dataven. Pour activer votre compte, veuillez verifier votre adresse email en cliquant sur le bouton ci-dessous.
<p style="word-break: break-all; background: #e5e7eb; padding: 10px; border-radius: 4px;"> </p>
<table role="presentation" width="100%" cellspacing="0" cellpadding="0">
<tr>
<td align="center" style="padding: 24px 0;">
<a href="{verification_link}" style="display: inline-block; background-color: #4F46E5; color: #ffffff; text-decoration: none; padding: 14px 32px; border-radius: 6px; font-size: 16px; font-weight: 500;">Verifier mon email</a>
</td>
</tr>
</table>
<p style="color: #6B7280; font-size: 14px; line-height: 1.6; margin: 24px 0 0;">
Si le bouton ne fonctionne pas, copiez ce lien dans votre navigateur :
</p>
<p style="color: #4F46E5; font-size: 14px; word-break: break-all; background-color: #F3F4F6; padding: 12px; border-radius: 4px; margin: 12px 0 24px;">
{verification_link} {verification_link}
</p> </p>
<p style="color: #EF4444; font-size: 14px; margin: 0;">
<p style="margin-top: 30px; color: #ef4444;"> Ce lien expire dans 24 heures.
Ce lien expire dans <strong>24 heures</strong>
</p> </p>
</td>
<p style="margin-top: 30px; font-size: 14px; color: #6b7280;"> </tr>
Si vous n'avez pas créé de compte, ignorez cet email. <tr>
<td style="background-color: #F9FAFB; padding: 24px 32px; border-radius: 0 0 8px 8px; border-top: 1px solid #E5E7EB;">
<p style="color: #9CA3AF; font-size: 12px; margin: 0; text-align: center;">
Si vous n'avez pas cree de compte, ignorez cet email.
</p> </p>
</div> </td>
<div class="footer"> </tr>
<p>© 2024 Sage Dataven - API de gestion commerciale</p> </table>
</div> </td>
</div> </tr>
</table>
</body> </body>
</html> </html>
""" """
return AuthEmailService._send_email( return cls._send_email(
email, "rifiez votre adresse email - Sage Dataven", html_body email, "Verifiez votre adresse email - Sage Dataven", html_body
) )
@staticmethod @classmethod
def send_password_reset_email(email: str, token: str, base_url: str) -> bool: def send_password_reset_email(
reset_link = f"{base_url}/reset?token={token}" cls, email: str, token: str, frontend_url: str
) -> bool:
reset_link = f"{frontend_url}/reset-password?token={token}"
html_body = f""" html_body = f"""
<!DOCTYPE html> <!DOCTYPE html>
<html> <html lang="fr">
<head> <head>
<style> <meta charset="UTF-8">
body {{ font-family: Arial, sans-serif; line-height: 1.6; color: #333; }} <meta name="viewport" content="width=device-width, initial-scale=1.0">
.container {{ max-width: 600px; margin: 0 auto; padding: 20px; }} <title>Reinitialisation de mot de passe</title>
.header {{ background: #EF4444; color: white; padding: 20px; text-align: center; border-radius: 8px 8px 0 0; }}
.content {{ background: #f9fafb; padding: 30px; border-radius: 0 0 8px 8px; }}
.button {{
display: inline-block;
background: #EF4444;
color: white;
padding: 12px 30px;
text-decoration: none;
border-radius: 6px;
margin: 20px 0;
}}
.footer {{ text-align: center; margin-top: 20px; font-size: 12px; color: #6b7280; }}
</style>
</head> </head>
<body> <body style="margin: 0; padding: 0; font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif; background-color: #f5f5f5;">
<div class="container"> <table role="presentation" width="100%" cellspacing="0" cellpadding="0" style="background-color: #f5f5f5; padding: 40px 20px;">
<div class="header"> <tr>
<h1> Réinitialisation de mot de passe</h1> <td align="center">
</div> <table role="presentation" width="600" cellspacing="0" cellpadding="0" style="background-color: #ffffff; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);">
<div class="content"> <tr>
<h2>Demande de réinitialisation</h2> <td style="background-color: #DC2626; padding: 32px; text-align: center; border-radius: 8px 8px 0 0;">
<p>Vous avez demandé à réinitialiser votre mot de passe. Cliquez sur le bouton ci-dessous pour créer un nouveau mot de passe :</p> <h1 style="color: #ffffff; margin: 0; font-size: 24px; font-weight: 600;">Reinitialisation du mot de passe</h1>
</td>
<div style="text-align: center;"> </tr>
<a href="{reset_link}" class="button">Réinitialiser mon mot de passe</a> <tr>
</div> <td style="padding: 40px 32px;">
<p style="color: #374151; font-size: 16px; line-height: 1.6; margin: 0 0 24px;">
<p style="margin-top: 30px;">Ou copiez ce lien dans votre navigateur :</p> Vous avez demande la reinitialisation de votre mot de passe. Cliquez sur le bouton ci-dessous pour creer un nouveau mot de passe.
<p style="word-break: break-all; background: #e5e7eb; padding: 10px; border-radius: 4px;"> </p>
<table role="presentation" width="100%" cellspacing="0" cellpadding="0">
<tr>
<td align="center" style="padding: 24px 0;">
<a href="{reset_link}" style="display: inline-block; background-color: #DC2626; color: #ffffff; text-decoration: none; padding: 14px 32px; border-radius: 6px; font-size: 16px; font-weight: 500;">Reinitialiser mon mot de passe</a>
</td>
</tr>
</table>
<p style="color: #6B7280; font-size: 14px; line-height: 1.6; margin: 24px 0 0;">
Si le bouton ne fonctionne pas, copiez ce lien :
</p>
<p style="color: #DC2626; font-size: 14px; word-break: break-all; background-color: #FEF2F2; padding: 12px; border-radius: 4px; margin: 12px 0 24px;">
{reset_link} {reset_link}
</p> </p>
<p style="color: #EF4444; font-size: 14px; margin: 0;">
<p style="margin-top: 30px; color: #ef4444;"> Ce lien expire dans 1 heure.
Ce lien expire dans <strong>1 heure</strong>
</p> </p>
</td>
<p style="margin-top: 30px; font-size: 14px; color: #6b7280;"> </tr>
Si vous n'avez pas demandé cette réinitialisation, ignorez cet email. Votre mot de passe actuel reste inchangé. <tr>
<td style="background-color: #FEF2F2; padding: 24px 32px; border-radius: 0 0 8px 8px; border-top: 1px solid #FECACA;">
<p style="color: #991B1B; font-size: 12px; margin: 0; text-align: center;">
Si vous n'avez pas demande cette reinitialisation, ignorez cet email. Votre mot de passe restera inchange.
</p> </p>
</div> </td>
<div class="footer"> </tr>
<p>© 2024 Sage Dataven - API de gestion commerciale</p> </table>
</div> </td>
</div> </tr>
</table>
</body> </body>
</html> </html>
""" """
return AuthEmailService._send_email( return cls._send_email(
email, "initialisation de votre mot de passe - Sage Dataven", html_body email, "Reinitialisation de votre mot de passe - Sage Dataven", html_body
) )
@staticmethod @classmethod
def send_password_changed_notification(email: str) -> bool: def send_password_changed_notification(cls, email: str) -> bool:
html_body = """ html_body = """
<!DOCTYPE html> <!DOCTYPE html>
<html> <html lang="fr">
<head> <head>
<style> <meta charset="UTF-8">
body { font-family: Arial, sans-serif; line-height: 1.6; color: #333; } <meta name="viewport" content="width=device-width, initial-scale=1.0">
.container { max-width: 600px; margin: 0 auto; padding: 20px; } <title>Mot de passe modifie</title>
.header { background: #10B981; color: white; padding: 20px; text-align: center; border-radius: 8px 8px 0 0; }
.content { background: #f9fafb; padding: 30px; border-radius: 0 0 8px 8px; }
.footer { text-align: center; margin-top: 20px; font-size: 12px; color: #6b7280; }
</style>
</head> </head>
<body> <body style="margin: 0; padding: 0; font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif; background-color: #f5f5f5;">
<div class="container"> <table role="presentation" width="100%" cellspacing="0" cellpadding="0" style="background-color: #f5f5f5; padding: 40px 20px;">
<div class="header"> <tr>
<h1> Mot de passe modifié</h1> <td align="center">
</div> <table role="presentation" width="600" cellspacing="0" cellpadding="0" style="background-color: #ffffff; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);">
<div class="content"> <tr>
<h2>Votre mot de passe a été changé avec succès</h2> <td style="background-color: #059669; padding: 32px; text-align: center; border-radius: 8px 8px 0 0;">
<p>Ce message confirme que le mot de passe de votre compte Sage Dataven a été modifié.</p> <h1 style="color: #ffffff; margin: 0; font-size: 24px; font-weight: 600;">Mot de passe modifie</h1>
</td>
<p style="margin-top: 30px; padding: 15px; background: #FEF3C7; border-left: 4px solid #F59E0B; border-radius: 4px;"> </tr>
Si vous n'êtes pas à l'origine de ce changement, contactez immédiatement notre support. <tr>
<td style="padding: 40px 32px;">
<p style="color: #374151; font-size: 16px; line-height: 1.6; margin: 0 0 24px;">
Votre mot de passe a ete modifie avec succes.
</p> </p>
</div> <p style="color: #374151; font-size: 16px; line-height: 1.6; margin: 0 0 24px;">
<div class="footer"> Si vous n'etes pas a l'origine de ce changement, contactez immediatement notre support.
<p>© 2024 Sage Dataven - API de gestion commerciale</p> </p>
</div> <table role="presentation" width="100%" cellspacing="0" cellpadding="0" style="background-color: #FEF3C7; border-left: 4px solid #F59E0B; border-radius: 4px;">
</div> <tr>
<td style="padding: 16px;">
<p style="color: #92400E; font-size: 14px; margin: 0;">
<strong>Securite :</strong> Toutes vos sessions actives ont ete deconnectees. Vous devrez vous reconnecter sur tous vos appareils.
</p>
</td>
</tr>
</table>
</td>
</tr>
<tr>
<td style="background-color: #F9FAFB; padding: 24px 32px; border-radius: 0 0 8px 8px; border-top: 1px solid #E5E7EB;">
<p style="color: #9CA3AF; font-size: 12px; margin: 0; text-align: center;">
Sage Dataven - Notification de securite
</p>
</td>
</tr>
</table>
</td>
</tr>
</table>
</body> </body>
</html> </html>
""" """
return AuthEmailService._send_email( return cls._send_email(
email, " Votre mot de passe a été modifié - Sage Dataven", html_body email, "Votre mot de passe a ete modifie - Sage Dataven", html_body
)
@classmethod
def send_security_alert(
cls, email: str, alert_type: str, details: str, ip_address: Optional[str] = None
) -> bool:
ip_info = (
f"<p style='color: #6B7280; font-size: 14px;'>Adresse IP : {ip_address}</p>"
if ip_address
else ""
)
html_body = f"""
<!DOCTYPE html>
<html lang="fr">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>Alerte de securite</title>
</head>
<body style="margin: 0; padding: 0; font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, Helvetica, Arial, sans-serif; background-color: #f5f5f5;">
<table role="presentation" width="100%" cellspacing="0" cellpadding="0" style="background-color: #f5f5f5; padding: 40px 20px;">
<tr>
<td align="center">
<table role="presentation" width="600" cellspacing="0" cellpadding="0" style="background-color: #ffffff; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);">
<tr>
<td style="background-color: #B91C1C; padding: 32px; text-align: center; border-radius: 8px 8px 0 0;">
<h1 style="color: #ffffff; margin: 0; font-size: 24px; font-weight: 600;">Alerte de securite</h1>
</td>
</tr>
<tr>
<td style="padding: 40px 32px;">
<p style="color: #374151; font-size: 16px; line-height: 1.6; margin: 0 0 16px;">
<strong>{alert_type}</strong>
</p>
<p style="color: #374151; font-size: 16px; line-height: 1.6; margin: 0 0 24px;">
{details}
</p>
{ip_info}
<p style="color: #6B7280; font-size: 14px; margin: 24px 0 0;">
Si vous reconnaissez cette activite, vous pouvez ignorer ce message. Sinon, nous vous recommandons de changer votre mot de passe immediatement.
</p>
</td>
</tr>
<tr>
<td style="background-color: #FEF2F2; padding: 24px 32px; border-radius: 0 0 8px 8px; border-top: 1px solid #FECACA;">
<p style="color: #991B1B; font-size: 12px; margin: 0; text-align: center;">
Sage Dataven - Alerte de securite automatique
</p>
</td>
</tr>
</table>
</td>
</tr>
</table>
</body>
</html>
"""
return cls._send_email(
email, f"Alerte de securite : {alert_type} - Sage Dataven", html_body
) )

200
services/redis_service.py Normal file
View file

@ -0,0 +1,200 @@
import redis.asyncio as redis
from typing import Optional
import logging
import json
from config.config import settings
logger = logging.getLogger(__name__)
class RedisService:
_instance: Optional["RedisService"] = None
_client: Optional[redis.Redis] = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
async def connect(self) -> None:
if self._client is not None:
return
try:
self._client = redis.from_url(
settings.redis_url,
password=settings.redis_password,
encoding="utf-8",
decode_responses=True,
socket_timeout=5.0,
socket_connect_timeout=5.0,
)
await self._client.ping()
logger.info("Connexion Redis etablie")
except Exception as e:
logger.error(f"Erreur connexion Redis: {e}")
self._client = None
raise
async def disconnect(self) -> None:
if self._client:
await self._client.close()
self._client = None
logger.info("Connexion Redis fermee")
async def is_connected(self) -> bool:
if not self._client:
return False
try:
await self._client.ping()
return True
except Exception:
return False
@property
def client(self) -> redis.Redis:
if not self._client:
raise RuntimeError("Redis non connecte. Appelez connect() d'abord.")
return self._client
async def blacklist_token(self, token_id: str, ttl_seconds: int) -> bool:
try:
key = f"{settings.token_blacklist_prefix}{token_id}"
await self.client.setex(key, ttl_seconds, "1")
logger.debug(f"Token {token_id[:8]}... ajoute a la blacklist")
return True
except Exception as e:
logger.error(f"Erreur blacklist token: {e}")
return False
async def is_token_blacklisted(self, token_id: str) -> bool:
try:
key = f"{settings.token_blacklist_prefix}{token_id}"
result = await self.client.exists(key)
return result > 0
except Exception as e:
logger.error(f"Erreur verification blacklist: {e}")
return False
async def blacklist_user_tokens(
self, user_id: str, ttl_seconds: int = 86400
) -> bool:
try:
key = f"{settings.token_blacklist_prefix}user:{user_id}"
import time
await self.client.setex(key, ttl_seconds, str(int(time.time())))
logger.info(f"Tokens utilisateur {user_id} invalides")
return True
except Exception as e:
logger.error(f"Erreur invalidation tokens utilisateur: {e}")
return False
async def get_user_token_invalidation_time(self, user_id: str) -> Optional[int]:
try:
key = f"{settings.token_blacklist_prefix}user:{user_id}"
result = await self.client.get(key)
return int(result) if result else None
except Exception as e:
logger.error(f"Erreur lecture invalidation: {e}")
return None
async def increment_rate_limit(self, key: str, window_seconds: int) -> int:
try:
full_key = f"{settings.rate_limit_prefix}{key}"
pipe = self.client.pipeline()
pipe.incr(full_key)
pipe.expire(full_key, window_seconds)
results = await pipe.execute()
return results[0]
except Exception as e:
logger.error(f"Erreur increment rate limit: {e}")
return 0
async def get_rate_limit_count(self, key: str) -> int:
try:
full_key = f"{settings.rate_limit_prefix}{key}"
result = await self.client.get(full_key)
return int(result) if result else 0
except Exception as e:
logger.error(f"Erreur lecture rate limit: {e}")
return 0
async def reset_rate_limit(self, key: str) -> bool:
try:
full_key = f"{settings.rate_limit_prefix}{key}"
await self.client.delete(full_key)
return True
except Exception as e:
logger.error(f"Erreur reset rate limit: {e}")
return False
async def store_refresh_token_metadata(
self, token_id: str, user_id: str, fingerprint_hash: str, ttl_seconds: int
) -> bool:
try:
key = f"refresh_token:{token_id}"
data = json.dumps(
{
"user_id": user_id,
"fingerprint_hash": fingerprint_hash,
"used": False,
}
)
await self.client.setex(key, ttl_seconds, data)
return True
except Exception as e:
logger.error(f"Erreur stockage metadata refresh token: {e}")
return False
async def get_refresh_token_metadata(self, token_id: str) -> Optional[dict]:
try:
key = f"refresh_token:{token_id}"
data = await self.client.get(key)
return json.loads(data) if data else None
except Exception as e:
logger.error(f"Erreur lecture metadata refresh token: {e}")
return None
async def mark_refresh_token_used(self, token_id: str) -> bool:
try:
key = f"refresh_token:{token_id}"
data = await self.client.get(key)
if not data:
return False
metadata = json.loads(data)
metadata["used"] = True
metadata["used_at"] = int(__import__("time").time())
ttl = await self.client.ttl(key)
if ttl > 0:
await self.client.setex(key, ttl, json.dumps(metadata))
return True
except Exception as e:
logger.error(f"Erreur marquage refresh token: {e}")
return False
async def delete_refresh_token(self, token_id: str) -> bool:
try:
key = f"refresh_token:{token_id}"
result = await self.client.delete(key)
return result > 0
except Exception as e:
logger.error(f"Erreur suppression refresh token: {e}")
return False
redis_service = RedisService()
async def get_redis() -> RedisService:
if not await redis_service.is_connected():
await redis_service.connect()
return redis_service

View file

@ -6,7 +6,7 @@ import httpx
from datetime import datetime from datetime import datetime
from typing import Optional, Tuple, List from typing import Optional, Tuple, List
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select, update, and_ from sqlalchemy import false, select, update, and_
import logging import logging
from config.config import settings from config.config import settings
@ -20,8 +20,6 @@ class SageGatewayService:
self.session = session self.session = session
async def create(self, user_id: str, data: dict) -> SageGatewayConfig: async def create(self, user_id: str, data: dict) -> SageGatewayConfig:
"""Créer une nouvelle configuration gateway"""
if data.get("is_active"): if data.get("is_active"):
await self._deactivate_all_for_user(user_id) await self._deactivate_all_for_user(user_id)
@ -55,7 +53,6 @@ class SageGatewayService:
and_( and_(
SageGatewayConfig.id == gateway_id, SageGatewayConfig.id == gateway_id,
SageGatewayConfig.user_id == user_id, SageGatewayConfig.user_id == user_id,
SageGatewayConfig.is_deleted,
) )
) )
) )
@ -67,7 +64,7 @@ class SageGatewayService:
query = select(SageGatewayConfig).where(SageGatewayConfig.user_id == user_id) query = select(SageGatewayConfig).where(SageGatewayConfig.user_id == user_id)
if not include_deleted: if not include_deleted:
query = query.where(SageGatewayConfig.is_deleted) query = query.where(SageGatewayConfig.is_deleted.is_(false()))
query = query.order_by( query = query.order_by(
SageGatewayConfig.is_active.desc(), SageGatewayConfig.is_active.desc(),
@ -81,8 +78,6 @@ class SageGatewayService:
async def update( async def update(
self, gateway_id: str, user_id: str, data: dict self, gateway_id: str, user_id: str, data: dict
) -> Optional[SageGatewayConfig]: ) -> Optional[SageGatewayConfig]:
"""Mettre à jour une gateway"""
gateway = await self.get_by_id(gateway_id, user_id) gateway = await self.get_by_id(gateway_id, user_id)
if not gateway: if not gateway:
return None return None
@ -131,7 +126,6 @@ class SageGatewayService:
async def activate( async def activate(
self, gateway_id: str, user_id: str self, gateway_id: str, user_id: str
) -> Optional[SageGatewayConfig]: ) -> Optional[SageGatewayConfig]:
"""Activer une gateway (désactive les autres)"""
gateway = await self.get_by_id(gateway_id, user_id) gateway = await self.get_by_id(gateway_id, user_id)
if not gateway: if not gateway:
return None return None
@ -167,7 +161,7 @@ class SageGatewayService:
and_( and_(
SageGatewayConfig.user_id == user_id, SageGatewayConfig.user_id == user_id,
SageGatewayConfig.is_active, SageGatewayConfig.is_active,
SageGatewayConfig.is_deleted, SageGatewayConfig.is_deleted.is_(false()),
) )
) )
) )
@ -277,8 +271,6 @@ class SageGatewayService:
return {"success": False, "status": "error", "error": str(e)} return {"success": False, "status": "error", "error": str(e)}
async def record_request(self, gateway_id: str, success: bool) -> None: async def record_request(self, gateway_id: str, success: bool) -> None:
"""Enregistrer une requête (succès/échec)"""
if not gateway_id: if not gateway_id:
return return
@ -297,7 +289,6 @@ class SageGatewayService:
await self.session.commit() await self.session.commit()
async def get_stats(self, user_id: str) -> dict: async def get_stats(self, user_id: str) -> dict:
"""Statistiques d'utilisation pour un utilisateur"""
gateways = await self.list_for_user(user_id) gateways = await self.list_for_user(user_id)
total_requests = sum(g.total_requests for g in gateways) total_requests = sum(g.total_requests for g in gateways)
@ -323,8 +314,6 @@ class SageGatewayService:
} }
async def _deactivate_all_for_user(self, user_id: str) -> None: async def _deactivate_all_for_user(self, user_id: str) -> None:
"""Désactiver toutes les gateways d'un utilisateur"""
await self.session.execute( await self.session.execute(
update(SageGatewayConfig) update(SageGatewayConfig)
.where(SageGatewayConfig.user_id == user_id) .where(SageGatewayConfig.user_id == user_id)
@ -332,8 +321,6 @@ class SageGatewayService:
) )
async def _unset_default_for_user(self, user_id: str) -> None: async def _unset_default_for_user(self, user_id: str) -> None:
"""Retirer le flag default de toutes les gateways"""
await self.session.execute( await self.session.execute(
update(SageGatewayConfig) update(SageGatewayConfig)
.where(SageGatewayConfig.user_id == user_id) .where(SageGatewayConfig.user_id == user_id)
@ -342,8 +329,6 @@ class SageGatewayService:
def gateway_response_from_model(gateway: SageGatewayConfig) -> dict: def gateway_response_from_model(gateway: SageGatewayConfig) -> dict:
"""Convertir un model en réponse API (masque le token)"""
token_preview = ( token_preview = (
f"****{gateway.gateway_token[-4:]}" if gateway.gateway_token else "****" f"****{gateway.gateway_token[-4:]}" if gateway.gateway_token else "****"
) )
@ -380,8 +365,6 @@ def gateway_response_from_model(gateway: SageGatewayConfig) -> dict:
"description": gateway.description, "description": gateway.description,
"gateway_url": gateway.gateway_url, "gateway_url": gateway.gateway_url,
"token_preview": token_preview, "token_preview": token_preview,
"sage_database": gateway.sage_database,
"sage_company": gateway.sage_company,
"is_active": gateway.is_active, "is_active": gateway.is_active,
"is_default": gateway.is_default, "is_default": gateway.is_default,
"priority": gateway.priority, "priority": gateway.priority,

357
services/token_service.py Normal file
View file

@ -0,0 +1,357 @@
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import false, select, and_, or_, delete, true
from datetime import datetime, timedelta
from typing import Optional, Tuple, Dict, Any
import uuid
import logging
import time
from config.config import settings
from database import RefreshToken, User
from services.redis_service import redis_service
from security.auth import (
create_access_token,
create_refresh_token,
create_csrf_token,
decode_token,
hash_token,
generate_session_id,
)
logger = logging.getLogger(__name__)
class TokenService:
@classmethod
async def create_token_pair(
cls,
session: AsyncSession,
user: User,
fingerprint_hash: str,
device_info: str,
ip_address: str,
) -> Tuple[str, str, str, str]:
session_id = generate_session_id()
access_token = create_access_token(
data={
"sub": user.id,
"email": user.email,
"role": user.role,
"sid": session_id,
},
fingerprint_hash=fingerprint_hash,
)
refresh_token_jwt, token_id = create_refresh_token(
user_id=user.id, fingerprint_hash=fingerprint_hash
)
csrf_token = create_csrf_token(session_id)
token_record = RefreshToken(
id=str(uuid.uuid4()),
user_id=user.id,
token_hash=hash_token(refresh_token_jwt),
token_id=token_id,
fingerprint_hash=fingerprint_hash,
device_info=device_info[:500] if device_info else None,
ip_address=ip_address,
expires_at=datetime.now()
+ timedelta(days=settings.refresh_token_expire_days),
created_at=datetime.now(),
)
session.add(token_record)
await session.flush()
await redis_service.store_refresh_token_metadata(
token_id=token_id,
user_id=user.id,
fingerprint_hash=fingerprint_hash,
ttl_seconds=settings.refresh_token_expire_days * 24 * 60 * 60,
)
logger.info(f"Token pair cree pour utilisateur {user.email}")
return access_token, refresh_token_jwt, csrf_token, session_id
@classmethod
async def refresh_tokens(
cls,
session: AsyncSession,
refresh_token: str,
fingerprint_hash: str,
device_info: str,
ip_address: str,
) -> Optional[Tuple[str, str, str, str]]:
payload = decode_token(refresh_token, expected_type="refresh")
if not payload:
logger.warning("Refresh token invalide ou expire")
return None
user_id = payload.get("sub")
token_id = payload.get("jti")
stored_fingerprint = payload.get("fph")
if not user_id or not token_id:
logger.warning("Refresh token malformed")
return None
if await redis_service.is_token_blacklisted(token_id):
logger.warning(f"Refresh token {token_id[:8]}... est blackliste")
return None
token_hash = hash_token(refresh_token)
result = await session.execute(
select(RefreshToken).where(
and_(
RefreshToken.token_hash == token_hash,
RefreshToken.user_id == user_id,
RefreshToken.is_revoked.is_(false()),
RefreshToken.expires_at > datetime.now(),
)
)
)
token_record = result.scalar_one_or_none()
if not token_record:
logger.warning(f"Refresh token non trouve en DB pour user {user_id}")
await cls._handle_potential_token_theft(session, user_id, token_id)
return None
if settings.refresh_token_rotation_enabled and token_record.is_used:
used_at = token_record.used_at
if used_at:
time_since_use = (datetime.now() - used_at).total_seconds()
if time_since_use > settings.refresh_token_reuse_window_seconds:
logger.warning(
f"Reutilisation de refresh token detectee pour user {user_id}"
)
await cls._handle_potential_token_theft(session, user_id, token_id)
return None
if stored_fingerprint and fingerprint_hash:
if stored_fingerprint != fingerprint_hash:
logger.warning(f"Fingerprint mismatch pour user {user_id}")
return None
result = await session.execute(
select(User).where(and_(User.id == user_id, User.is_active.is_(true())))
)
user = result.scalar_one_or_none()
if not user:
logger.warning(f"Utilisateur {user_id} introuvable ou inactif")
return None
session_id = generate_session_id()
new_access_token = create_access_token(
data={
"sub": user.id,
"email": user.email,
"role": user.role,
"sid": session_id,
},
fingerprint_hash=fingerprint_hash,
)
new_csrf_token = create_csrf_token(session_id)
if settings.refresh_token_rotation_enabled:
token_record.is_used = True
token_record.used_at = datetime.now()
new_refresh_jwt, new_token_id = create_refresh_token(
user_id=user.id, fingerprint_hash=fingerprint_hash
)
new_token_record = RefreshToken(
id=str(uuid.uuid4()),
user_id=user.id,
token_hash=hash_token(new_refresh_jwt),
token_id=new_token_id,
fingerprint_hash=fingerprint_hash,
device_info=device_info[:500] if device_info else None,
ip_address=ip_address,
expires_at=datetime.now()
+ timedelta(days=settings.refresh_token_expire_days),
created_at=datetime.now(),
)
token_record.replaced_by = new_token_record.id
session.add(new_token_record)
await redis_service.mark_refresh_token_used(token_id)
await redis_service.store_refresh_token_metadata(
token_id=new_token_id,
user_id=user.id,
fingerprint_hash=fingerprint_hash,
ttl_seconds=settings.refresh_token_expire_days * 24 * 60 * 60,
)
logger.info(f"Refresh token rotation pour user {user.email}")
return new_access_token, new_refresh_jwt, new_csrf_token, session_id
else:
token_record.last_used_at = datetime.now()
return new_access_token, refresh_token, new_csrf_token, session_id
@classmethod
async def revoke_token(
cls, session: AsyncSession, refresh_token: str, reason: str = "user_logout"
) -> bool:
payload = decode_token(refresh_token, expected_type="refresh")
if not payload:
return False
token_id = payload.get("jti")
user_id = payload.get("sub")
exp = payload.get("exp", 0)
ttl_seconds = max(0, exp - int(time.time()))
await redis_service.blacklist_token(token_id, ttl_seconds)
token_hash = hash_token(refresh_token)
result = await session.execute(
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
)
token_record = result.scalar_one_or_none()
if token_record:
token_record.is_revoked = True
token_record.revoked_at = datetime.now()
token_record.revoked_reason = reason
await redis_service.delete_refresh_token(token_id)
logger.info(f"Token revoque pour user {user_id}: {reason}")
return True
@classmethod
async def revoke_all_user_tokens(
cls, session: AsyncSession, user_id: str, reason: str = "security_action"
) -> int:
result = await session.execute(
select(RefreshToken).where(
and_(
RefreshToken.user_id == user_id,
RefreshToken.is_revoked.is_(false()),
)
)
)
tokens = result.scalars().all()
count = 0
for token in tokens:
token.is_revoked = True
token.revoked_at = datetime.now()
token.revoked_reason = reason
await redis_service.blacklist_token(
token.token_id, settings.refresh_token_expire_days * 24 * 60 * 60
)
await redis_service.delete_refresh_token(token.token_id)
count += 1
await redis_service.blacklist_user_tokens(
user_id, settings.refresh_token_expire_days * 24 * 60 * 60
)
logger.info(f"{count} tokens revoques pour user {user_id}: {reason}")
return count
@classmethod
async def _handle_potential_token_theft(
cls, session: AsyncSession, user_id: str, token_id: str
) -> None:
logger.warning(
f"Potentiel vol de token detecte pour user {user_id}, token {token_id[:8]}..."
)
await cls.revoke_all_user_tokens(
session, user_id, reason="potential_token_theft"
)
@classmethod
async def validate_access_token(
cls, token: str, fingerprint_hash: Optional[str] = None
) -> Optional[Dict[str, Any]]:
payload = decode_token(token, expected_type="access")
if not payload:
return None
token_id = payload.get("jti")
if token_id and await redis_service.is_token_blacklisted(token_id):
logger.debug(f"Access token {token_id[:8]}... est blackliste")
return None
user_id = payload.get("sub")
if user_id:
invalidation_time = await redis_service.get_user_token_invalidation_time(
user_id
)
if invalidation_time:
token_iat = payload.get("iat", 0)
if token_iat < invalidation_time:
logger.debug("Access token emis avant invalidation globale")
return None
if fingerprint_hash:
stored_fingerprint = payload.get("fph")
if stored_fingerprint and stored_fingerprint != fingerprint_hash:
logger.warning("Fingerprint mismatch sur access token")
return None
return payload
@classmethod
async def cleanup_expired_tokens(cls, session: AsyncSession) -> int:
result = await session.execute(
delete(RefreshToken).where(
or_(
RefreshToken.expires_at < datetime.now(),
and_(
RefreshToken.is_revoked.is_(true()),
RefreshToken.revoked_at < datetime.now() - timedelta(days=7),
),
)
)
)
count = result.rowcount
logger.info(f"{count} tokens expires nettoyes")
return count
@classmethod
async def get_user_active_sessions(
cls, session: AsyncSession, user_id: str
) -> list:
result = await session.execute(
select(RefreshToken)
.where(
and_(
RefreshToken.user_id == user_id,
RefreshToken.is_revoked.is_(false()),
RefreshToken.expires_at > datetime.now(),
)
)
.order_by(RefreshToken.created_at.desc())
)
tokens = result.scalars().all()
return [
{
"id": t.id,
"device_info": t.device_info,
"ip_address": t.ip_address,
"created_at": t.created_at.isoformat(),
"last_used_at": t.last_used_at.isoformat() if t.last_used_at else None,
}
for t in tokens
]