from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy import false, select, and_, or_, delete, true from datetime import datetime, timedelta from typing import Optional, Tuple, Dict, Any import uuid import logging import time from config.config import settings from database import RefreshToken, User from services.redis_service import redis_service from security.auth import ( create_access_token, create_refresh_token, create_csrf_token, decode_token, hash_token, generate_session_id, ) logger = logging.getLogger(__name__) class TokenService: @classmethod async def create_token_pair( cls, session: AsyncSession, user: User, fingerprint_hash: str, device_info: str, ip_address: str, ) -> Tuple[str, str, str, str]: session_id = generate_session_id() access_token = create_access_token( data={ "sub": user.id, "email": user.email, "role": user.role, "sid": session_id, }, fingerprint_hash=fingerprint_hash, ) refresh_token_jwt, token_id = create_refresh_token( user_id=user.id, fingerprint_hash=fingerprint_hash ) csrf_token = create_csrf_token(session_id) token_record = RefreshToken( id=str(uuid.uuid4()), user_id=user.id, token_hash=hash_token(refresh_token_jwt), token_id=token_id, fingerprint_hash=fingerprint_hash, device_info=device_info[:500] if device_info else None, ip_address=ip_address, expires_at=datetime.now() + timedelta(days=settings.refresh_token_expire_days), created_at=datetime.now(), ) session.add(token_record) await session.flush() await redis_service.store_refresh_token_metadata( token_id=token_id, user_id=user.id, fingerprint_hash=fingerprint_hash, ttl_seconds=settings.refresh_token_expire_days * 24 * 60 * 60, ) logger.info(f"Token pair cree pour utilisateur {user.email}") return access_token, refresh_token_jwt, csrf_token, session_id @classmethod async def refresh_tokens( cls, session: AsyncSession, refresh_token: str, fingerprint_hash: str, device_info: str, ip_address: str, ) -> Optional[Tuple[str, str, str, str]]: payload = decode_token(refresh_token, expected_type="refresh") if not payload: logger.warning("Refresh token invalide ou expire") return None user_id = payload.get("sub") token_id = payload.get("jti") stored_fingerprint = payload.get("fph") if not user_id or not token_id: logger.warning("Refresh token malformed") return None if await redis_service.is_token_blacklisted(token_id): logger.warning(f"Refresh token {token_id[:8]}... est blackliste") return None token_hash = hash_token(refresh_token) result = await session.execute( select(RefreshToken).where( and_( RefreshToken.token_hash == token_hash, RefreshToken.user_id == user_id, RefreshToken.is_revoked.is_(false()), RefreshToken.expires_at > datetime.now(), ) ) ) token_record = result.scalar_one_or_none() if not token_record: logger.warning(f"Refresh token non trouve en DB pour user {user_id}") await cls._handle_potential_token_theft(session, user_id, token_id) return None if settings.refresh_token_rotation_enabled and token_record.is_used: used_at = token_record.used_at if used_at: time_since_use = (datetime.now() - used_at).total_seconds() if time_since_use > settings.refresh_token_reuse_window_seconds: logger.warning( f"Reutilisation de refresh token detectee pour user {user_id}" ) await cls._handle_potential_token_theft(session, user_id, token_id) return None if stored_fingerprint and fingerprint_hash: if stored_fingerprint != fingerprint_hash: logger.warning(f"Fingerprint mismatch pour user {user_id}") return None result = await session.execute( select(User).where(and_(User.id == user_id, User.is_active.is_(true()))) ) user = result.scalar_one_or_none() if not user: logger.warning(f"Utilisateur {user_id} introuvable ou inactif") return None session_id = generate_session_id() new_access_token = create_access_token( data={ "sub": user.id, "email": user.email, "role": user.role, "sid": session_id, }, fingerprint_hash=fingerprint_hash, ) new_csrf_token = create_csrf_token(session_id) if settings.refresh_token_rotation_enabled: token_record.is_used = True token_record.used_at = datetime.now() new_refresh_jwt, new_token_id = create_refresh_token( user_id=user.id, fingerprint_hash=fingerprint_hash ) new_token_record = RefreshToken( id=str(uuid.uuid4()), user_id=user.id, token_hash=hash_token(new_refresh_jwt), token_id=new_token_id, fingerprint_hash=fingerprint_hash, device_info=device_info[:500] if device_info else None, ip_address=ip_address, expires_at=datetime.now() + timedelta(days=settings.refresh_token_expire_days), created_at=datetime.now(), ) token_record.replaced_by = new_token_record.id session.add(new_token_record) await redis_service.mark_refresh_token_used(token_id) await redis_service.store_refresh_token_metadata( token_id=new_token_id, user_id=user.id, fingerprint_hash=fingerprint_hash, ttl_seconds=settings.refresh_token_expire_days * 24 * 60 * 60, ) logger.info(f"Refresh token rotation pour user {user.email}") return new_access_token, new_refresh_jwt, new_csrf_token, session_id else: token_record.last_used_at = datetime.now() return new_access_token, refresh_token, new_csrf_token, session_id @classmethod async def revoke_token( cls, session: AsyncSession, refresh_token: str, reason: str = "user_logout" ) -> bool: payload = decode_token(refresh_token, expected_type="refresh") if not payload: return False token_id = payload.get("jti") user_id = payload.get("sub") exp = payload.get("exp", 0) ttl_seconds = max(0, exp - int(time.time())) await redis_service.blacklist_token(token_id, ttl_seconds) token_hash = hash_token(refresh_token) result = await session.execute( select(RefreshToken).where(RefreshToken.token_hash == token_hash) ) token_record = result.scalar_one_or_none() if token_record: token_record.is_revoked = True token_record.revoked_at = datetime.now() token_record.revoked_reason = reason await redis_service.delete_refresh_token(token_id) logger.info(f"Token revoque pour user {user_id}: {reason}") return True @classmethod async def revoke_all_user_tokens( cls, session: AsyncSession, user_id: str, reason: str = "security_action" ) -> int: result = await session.execute( select(RefreshToken).where( and_( RefreshToken.user_id == user_id, RefreshToken.is_revoked.is_(false()), ) ) ) tokens = result.scalars().all() count = 0 for token in tokens: token.is_revoked = True token.revoked_at = datetime.now() token.revoked_reason = reason await redis_service.blacklist_token( token.token_id, settings.refresh_token_expire_days * 24 * 60 * 60 ) await redis_service.delete_refresh_token(token.token_id) count += 1 await redis_service.blacklist_user_tokens( user_id, settings.refresh_token_expire_days * 24 * 60 * 60 ) logger.info(f"{count} tokens revoques pour user {user_id}: {reason}") return count @classmethod async def _handle_potential_token_theft( cls, session: AsyncSession, user_id: str, token_id: str ) -> None: logger.warning( f"Potentiel vol de token detecte pour user {user_id}, token {token_id[:8]}..." ) await cls.revoke_all_user_tokens( session, user_id, reason="potential_token_theft" ) @classmethod async def validate_access_token( cls, token: str, fingerprint_hash: Optional[str] = None ) -> Optional[Dict[str, Any]]: payload = decode_token(token, expected_type="access") if not payload: return None token_id = payload.get("jti") if token_id and await redis_service.is_token_blacklisted(token_id): logger.debug(f"Access token {token_id[:8]}... est blackliste") return None user_id = payload.get("sub") if user_id: invalidation_time = await redis_service.get_user_token_invalidation_time( user_id ) if invalidation_time: token_iat = payload.get("iat", 0) if token_iat < invalidation_time: logger.debug("Access token emis avant invalidation globale") return None if fingerprint_hash: stored_fingerprint = payload.get("fph") if stored_fingerprint and stored_fingerprint != fingerprint_hash: logger.warning("Fingerprint mismatch sur access token") return None return payload @classmethod async def cleanup_expired_tokens(cls, session: AsyncSession) -> int: result = await session.execute( delete(RefreshToken).where( or_( RefreshToken.expires_at < datetime.now(), and_( RefreshToken.is_revoked.is_(true()), RefreshToken.revoked_at < datetime.now() - timedelta(days=7), ), ) ) ) count = result.rowcount logger.info(f"{count} tokens expires nettoyes") return count @classmethod async def get_user_active_sessions( cls, session: AsyncSession, user_id: str ) -> list: result = await session.execute( select(RefreshToken) .where( and_( RefreshToken.user_id == user_id, RefreshToken.is_revoked.is_(false()), RefreshToken.expires_at > datetime.now(), ) ) .order_by(RefreshToken.created_at.desc()) ) tokens = result.scalars().all() return [ { "id": t.id, "device_info": t.device_info, "ip_address": t.ip_address, "created_at": t.created_at.isoformat(), "last_used_at": t.last_used_at.isoformat() if t.last_used_at else None, } for t in tokens ]