Sage100-vps/services/token_service.py
2026-01-02 17:56:28 +03:00

357 lines
12 KiB
Python

from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import false, select, and_, or_, delete, true
from datetime import datetime, timedelta
from typing import Optional, Tuple, Dict, Any
import uuid
import logging
import time
from config.config import settings
from database import RefreshToken, User
from services.redis_service import redis_service
from security.auth import (
create_access_token,
create_refresh_token,
create_csrf_token,
decode_token,
hash_token,
generate_session_id,
)
logger = logging.getLogger(__name__)
class TokenService:
@classmethod
async def create_token_pair(
cls,
session: AsyncSession,
user: User,
fingerprint_hash: str,
device_info: str,
ip_address: str,
) -> Tuple[str, str, str, str]:
session_id = generate_session_id()
access_token = create_access_token(
data={
"sub": user.id,
"email": user.email,
"role": user.role,
"sid": session_id,
},
fingerprint_hash=fingerprint_hash,
)
refresh_token_jwt, token_id = create_refresh_token(
user_id=user.id, fingerprint_hash=fingerprint_hash
)
csrf_token = create_csrf_token(session_id)
token_record = RefreshToken(
id=str(uuid.uuid4()),
user_id=user.id,
token_hash=hash_token(refresh_token_jwt),
token_id=token_id,
fingerprint_hash=fingerprint_hash,
device_info=device_info[:500] if device_info else None,
ip_address=ip_address,
expires_at=datetime.now()
+ timedelta(days=settings.refresh_token_expire_days),
created_at=datetime.now(),
)
session.add(token_record)
await session.flush()
await redis_service.store_refresh_token_metadata(
token_id=token_id,
user_id=user.id,
fingerprint_hash=fingerprint_hash,
ttl_seconds=settings.refresh_token_expire_days * 24 * 60 * 60,
)
logger.info(f"Token pair cree pour utilisateur {user.email}")
return access_token, refresh_token_jwt, csrf_token, session_id
@classmethod
async def refresh_tokens(
cls,
session: AsyncSession,
refresh_token: str,
fingerprint_hash: str,
device_info: str,
ip_address: str,
) -> Optional[Tuple[str, str, str, str]]:
payload = decode_token(refresh_token, expected_type="refresh")
if not payload:
logger.warning("Refresh token invalide ou expire")
return None
user_id = payload.get("sub")
token_id = payload.get("jti")
stored_fingerprint = payload.get("fph")
if not user_id or not token_id:
logger.warning("Refresh token malformed")
return None
if await redis_service.is_token_blacklisted(token_id):
logger.warning(f"Refresh token {token_id[:8]}... est blackliste")
return None
token_hash = hash_token(refresh_token)
result = await session.execute(
select(RefreshToken).where(
and_(
RefreshToken.token_hash == token_hash,
RefreshToken.user_id == user_id,
RefreshToken.is_revoked.is_(false()),
RefreshToken.expires_at > datetime.now(),
)
)
)
token_record = result.scalar_one_or_none()
if not token_record:
logger.warning(f"Refresh token non trouve en DB pour user {user_id}")
await cls._handle_potential_token_theft(session, user_id, token_id)
return None
if settings.refresh_token_rotation_enabled and token_record.is_used:
used_at = token_record.used_at
if used_at:
time_since_use = (datetime.now() - used_at).total_seconds()
if time_since_use > settings.refresh_token_reuse_window_seconds:
logger.warning(
f"Reutilisation de refresh token detectee pour user {user_id}"
)
await cls._handle_potential_token_theft(session, user_id, token_id)
return None
if stored_fingerprint and fingerprint_hash:
if stored_fingerprint != fingerprint_hash:
logger.warning(f"Fingerprint mismatch pour user {user_id}")
return None
result = await session.execute(
select(User).where(and_(User.id == user_id, User.is_active.is_(true())))
)
user = result.scalar_one_or_none()
if not user:
logger.warning(f"Utilisateur {user_id} introuvable ou inactif")
return None
session_id = generate_session_id()
new_access_token = create_access_token(
data={
"sub": user.id,
"email": user.email,
"role": user.role,
"sid": session_id,
},
fingerprint_hash=fingerprint_hash,
)
new_csrf_token = create_csrf_token(session_id)
if settings.refresh_token_rotation_enabled:
token_record.is_used = True
token_record.used_at = datetime.now()
new_refresh_jwt, new_token_id = create_refresh_token(
user_id=user.id, fingerprint_hash=fingerprint_hash
)
new_token_record = RefreshToken(
id=str(uuid.uuid4()),
user_id=user.id,
token_hash=hash_token(new_refresh_jwt),
token_id=new_token_id,
fingerprint_hash=fingerprint_hash,
device_info=device_info[:500] if device_info else None,
ip_address=ip_address,
expires_at=datetime.now()
+ timedelta(days=settings.refresh_token_expire_days),
created_at=datetime.now(),
)
token_record.replaced_by = new_token_record.id
session.add(new_token_record)
await redis_service.mark_refresh_token_used(token_id)
await redis_service.store_refresh_token_metadata(
token_id=new_token_id,
user_id=user.id,
fingerprint_hash=fingerprint_hash,
ttl_seconds=settings.refresh_token_expire_days * 24 * 60 * 60,
)
logger.info(f"Refresh token rotation pour user {user.email}")
return new_access_token, new_refresh_jwt, new_csrf_token, session_id
else:
token_record.last_used_at = datetime.now()
return new_access_token, refresh_token, new_csrf_token, session_id
@classmethod
async def revoke_token(
cls, session: AsyncSession, refresh_token: str, reason: str = "user_logout"
) -> bool:
payload = decode_token(refresh_token, expected_type="refresh")
if not payload:
return False
token_id = payload.get("jti")
user_id = payload.get("sub")
exp = payload.get("exp", 0)
ttl_seconds = max(0, exp - int(time.time()))
await redis_service.blacklist_token(token_id, ttl_seconds)
token_hash = hash_token(refresh_token)
result = await session.execute(
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
)
token_record = result.scalar_one_or_none()
if token_record:
token_record.is_revoked = True
token_record.revoked_at = datetime.now()
token_record.revoked_reason = reason
await redis_service.delete_refresh_token(token_id)
logger.info(f"Token revoque pour user {user_id}: {reason}")
return True
@classmethod
async def revoke_all_user_tokens(
cls, session: AsyncSession, user_id: str, reason: str = "security_action"
) -> int:
result = await session.execute(
select(RefreshToken).where(
and_(
RefreshToken.user_id == user_id,
RefreshToken.is_revoked.is_(false()),
)
)
)
tokens = result.scalars().all()
count = 0
for token in tokens:
token.is_revoked = True
token.revoked_at = datetime.now()
token.revoked_reason = reason
await redis_service.blacklist_token(
token.token_id, settings.refresh_token_expire_days * 24 * 60 * 60
)
await redis_service.delete_refresh_token(token.token_id)
count += 1
await redis_service.blacklist_user_tokens(
user_id, settings.refresh_token_expire_days * 24 * 60 * 60
)
logger.info(f"{count} tokens revoques pour user {user_id}: {reason}")
return count
@classmethod
async def _handle_potential_token_theft(
cls, session: AsyncSession, user_id: str, token_id: str
) -> None:
logger.warning(
f"Potentiel vol de token detecte pour user {user_id}, token {token_id[:8]}..."
)
await cls.revoke_all_user_tokens(
session, user_id, reason="potential_token_theft"
)
@classmethod
async def validate_access_token(
cls, token: str, fingerprint_hash: Optional[str] = None
) -> Optional[Dict[str, Any]]:
payload = decode_token(token, expected_type="access")
if not payload:
return None
token_id = payload.get("jti")
if token_id and await redis_service.is_token_blacklisted(token_id):
logger.debug(f"Access token {token_id[:8]}... est blackliste")
return None
user_id = payload.get("sub")
if user_id:
invalidation_time = await redis_service.get_user_token_invalidation_time(
user_id
)
if invalidation_time:
token_iat = payload.get("iat", 0)
if token_iat < invalidation_time:
logger.debug("Access token emis avant invalidation globale")
return None
if fingerprint_hash:
stored_fingerprint = payload.get("fph")
if stored_fingerprint and stored_fingerprint != fingerprint_hash:
logger.warning("Fingerprint mismatch sur access token")
return None
return payload
@classmethod
async def cleanup_expired_tokens(cls, session: AsyncSession) -> int:
result = await session.execute(
delete(RefreshToken).where(
or_(
RefreshToken.expires_at < datetime.now(),
and_(
RefreshToken.is_revoked.is_(true()),
RefreshToken.revoked_at < datetime.now() - timedelta(days=7),
),
)
)
)
count = result.rowcount
logger.info(f"{count} tokens expires nettoyes")
return count
@classmethod
async def get_user_active_sessions(
cls, session: AsyncSession, user_id: str
) -> list:
result = await session.execute(
select(RefreshToken)
.where(
and_(
RefreshToken.user_id == user_id,
RefreshToken.is_revoked.is_(false()),
RefreshToken.expires_at > datetime.now(),
)
)
.order_by(RefreshToken.created_at.desc())
)
tokens = result.scalars().all()
return [
{
"id": t.id,
"device_info": t.device_info,
"ip_address": t.ip_address,
"created_at": t.created_at.isoformat(),
"last_used_at": t.last_used_at.isoformat() if t.last_used_at else None,
}
for t in tokens
]