191 lines
5.8 KiB
Python
191 lines
5.8 KiB
Python
from fastapi import Depends, HTTPException, status, Request
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import select
|
|
from typing import Optional, Tuple
|
|
from datetime import datetime
|
|
import logging
|
|
|
|
from database import get_session
|
|
from database import User, AuditEventType
|
|
from services.token_service import TokenService
|
|
from services.audit_service import AuditService
|
|
from security.cookies import CookieManager
|
|
from security.fingerprint import DeviceFingerprint, get_client_ip
|
|
from security.csrf import CSRFProtection
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
async def get_current_user(
|
|
request: Request, session: AsyncSession = Depends(get_session)
|
|
) -> User:
|
|
token = CookieManager.get_access_token(request)
|
|
|
|
if not token:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Authentification requise",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
fingerprint_hash = DeviceFingerprint.generate_hash(request)
|
|
|
|
payload = await TokenService.validate_access_token(token, fingerprint_hash)
|
|
|
|
if not payload:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Token invalide ou expire",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
user_id = payload.get("sub")
|
|
if not user_id:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Token malformed",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
result = await session.execute(select(User).where(User.id == user_id))
|
|
user = result.scalar_one_or_none()
|
|
|
|
if not user:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
detail="Utilisateur introuvable",
|
|
headers={"WWW-Authenticate": "Bearer"},
|
|
)
|
|
|
|
if not user.is_active:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN, detail="Compte desactive"
|
|
)
|
|
|
|
if not user.is_verified:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN, detail="Email non verifie"
|
|
)
|
|
|
|
if user.locked_until and user.locked_until > datetime.now():
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Compte temporairement verrouille",
|
|
)
|
|
|
|
request.state.user = user
|
|
request.state.session_id = payload.get("sid")
|
|
|
|
return user
|
|
|
|
|
|
async def get_current_user_optional(
|
|
request: Request, session: AsyncSession = Depends(get_session)
|
|
) -> Optional[User]:
|
|
try:
|
|
return await get_current_user(request, session)
|
|
except HTTPException:
|
|
return None
|
|
|
|
|
|
def require_role(*allowed_roles: str):
|
|
async def role_checker(user: User = Depends(get_current_user)) -> User:
|
|
if user.role not in allowed_roles:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail=f"Acces refuse. Roles requis: {', '.join(allowed_roles)}",
|
|
)
|
|
return user
|
|
|
|
return role_checker
|
|
|
|
|
|
def require_verified_email():
|
|
async def email_checker(user: User = Depends(get_current_user)) -> User:
|
|
if not user.is_verified:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Verification email requise",
|
|
)
|
|
return user
|
|
|
|
return email_checker
|
|
|
|
|
|
async def verify_csrf_token(
|
|
request: Request, user: User = Depends(get_current_user)
|
|
) -> None:
|
|
if CSRFProtection.is_exempt(request):
|
|
return
|
|
|
|
session_id = getattr(request.state, "session_id", None)
|
|
|
|
if not CSRFProtection.validate_request(request, session_id):
|
|
logger.warning(
|
|
f"CSRF validation echouee pour user {user.id} "
|
|
f"sur {request.method} {request.url.path}"
|
|
)
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN, detail="Verification CSRF echouee"
|
|
)
|
|
|
|
|
|
async def get_auth_context(
|
|
request: Request, session: AsyncSession = Depends(get_session)
|
|
) -> Tuple[Optional[User], str, str]:
|
|
ip_address = get_client_ip(request)
|
|
fingerprint_hash = DeviceFingerprint.generate_hash(request)
|
|
|
|
try:
|
|
user = await get_current_user(request, session)
|
|
except HTTPException:
|
|
user = None
|
|
|
|
return user, ip_address, fingerprint_hash
|
|
|
|
|
|
class AuthenticatedRoute:
|
|
def __init__(
|
|
self,
|
|
require_csrf: bool = True,
|
|
allowed_roles: Optional[Tuple[str, ...]] = None,
|
|
audit_event: Optional[AuditEventType] = None,
|
|
):
|
|
self.require_csrf = require_csrf
|
|
self.allowed_roles = allowed_roles
|
|
self.audit_event = audit_event
|
|
|
|
async def __call__(
|
|
self, request: Request, session: AsyncSession = Depends(get_session)
|
|
) -> User:
|
|
user = await get_current_user(request, session)
|
|
|
|
if self.allowed_roles and user.role not in self.allowed_roles:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Acces refuse pour ce role",
|
|
)
|
|
|
|
if self.require_csrf and not CSRFProtection.is_exempt(request):
|
|
session_id = getattr(request.state, "session_id", None)
|
|
if not CSRFProtection.validate_request(request, session_id):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="Verification CSRF echouee",
|
|
)
|
|
|
|
if self.audit_event:
|
|
await AuditService.log_event(
|
|
session=session,
|
|
event_type=self.audit_event,
|
|
request=request,
|
|
user_id=user.id,
|
|
success=True,
|
|
)
|
|
|
|
return user
|
|
|
|
|
|
require_admin = require_role("admin")
|
|
require_manager = require_role("admin", "manager")
|
|
require_user = require_role("admin", "manager", "user")
|