357 lines
12 KiB
Python
357 lines
12 KiB
Python
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlalchemy import false, select, and_, or_, delete, true
|
|
from datetime import datetime, timedelta
|
|
from typing import Optional, Tuple, Dict, Any
|
|
import uuid
|
|
import logging
|
|
import time
|
|
|
|
from config.config import settings
|
|
from database import RefreshToken, User
|
|
from services.redis_service import redis_service
|
|
from security.auth import (
|
|
create_access_token,
|
|
create_refresh_token,
|
|
create_csrf_token,
|
|
decode_token,
|
|
hash_token,
|
|
generate_session_id,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class TokenService:
|
|
@classmethod
|
|
async def create_token_pair(
|
|
cls,
|
|
session: AsyncSession,
|
|
user: User,
|
|
fingerprint_hash: str,
|
|
device_info: str,
|
|
ip_address: str,
|
|
) -> Tuple[str, str, str, str]:
|
|
session_id = generate_session_id()
|
|
|
|
access_token = create_access_token(
|
|
data={
|
|
"sub": user.id,
|
|
"email": user.email,
|
|
"role": user.role,
|
|
"sid": session_id,
|
|
},
|
|
fingerprint_hash=fingerprint_hash,
|
|
)
|
|
|
|
refresh_token_jwt, token_id = create_refresh_token(
|
|
user_id=user.id, fingerprint_hash=fingerprint_hash
|
|
)
|
|
|
|
csrf_token = create_csrf_token(session_id)
|
|
|
|
token_record = RefreshToken(
|
|
id=str(uuid.uuid4()),
|
|
user_id=user.id,
|
|
token_hash=hash_token(refresh_token_jwt),
|
|
token_id=token_id,
|
|
fingerprint_hash=fingerprint_hash,
|
|
device_info=device_info[:500] if device_info else None,
|
|
ip_address=ip_address,
|
|
expires_at=datetime.now()
|
|
+ timedelta(days=settings.refresh_token_expire_days),
|
|
created_at=datetime.now(),
|
|
)
|
|
|
|
session.add(token_record)
|
|
await session.flush()
|
|
|
|
await redis_service.store_refresh_token_metadata(
|
|
token_id=token_id,
|
|
user_id=user.id,
|
|
fingerprint_hash=fingerprint_hash,
|
|
ttl_seconds=settings.refresh_token_expire_days * 24 * 60 * 60,
|
|
)
|
|
|
|
logger.info(f"Token pair cree pour utilisateur {user.email}")
|
|
|
|
return access_token, refresh_token_jwt, csrf_token, session_id
|
|
|
|
@classmethod
|
|
async def refresh_tokens(
|
|
cls,
|
|
session: AsyncSession,
|
|
refresh_token: str,
|
|
fingerprint_hash: str,
|
|
device_info: str,
|
|
ip_address: str,
|
|
) -> Optional[Tuple[str, str, str, str]]:
|
|
payload = decode_token(refresh_token, expected_type="refresh")
|
|
if not payload:
|
|
logger.warning("Refresh token invalide ou expire")
|
|
return None
|
|
|
|
user_id = payload.get("sub")
|
|
token_id = payload.get("jti")
|
|
stored_fingerprint = payload.get("fph")
|
|
|
|
if not user_id or not token_id:
|
|
logger.warning("Refresh token malformed")
|
|
return None
|
|
|
|
if await redis_service.is_token_blacklisted(token_id):
|
|
logger.warning(f"Refresh token {token_id[:8]}... est blackliste")
|
|
return None
|
|
|
|
token_hash = hash_token(refresh_token)
|
|
result = await session.execute(
|
|
select(RefreshToken).where(
|
|
and_(
|
|
RefreshToken.token_hash == token_hash,
|
|
RefreshToken.user_id == user_id,
|
|
RefreshToken.is_revoked.is_(false()),
|
|
RefreshToken.expires_at > datetime.now(),
|
|
)
|
|
)
|
|
)
|
|
token_record = result.scalar_one_or_none()
|
|
|
|
if not token_record:
|
|
logger.warning(f"Refresh token non trouve en DB pour user {user_id}")
|
|
await cls._handle_potential_token_theft(session, user_id, token_id)
|
|
return None
|
|
|
|
if settings.refresh_token_rotation_enabled and token_record.is_used:
|
|
used_at = token_record.used_at
|
|
if used_at:
|
|
time_since_use = (datetime.now() - used_at).total_seconds()
|
|
if time_since_use > settings.refresh_token_reuse_window_seconds:
|
|
logger.warning(
|
|
f"Reutilisation de refresh token detectee pour user {user_id}"
|
|
)
|
|
await cls._handle_potential_token_theft(session, user_id, token_id)
|
|
return None
|
|
|
|
if stored_fingerprint and fingerprint_hash:
|
|
if stored_fingerprint != fingerprint_hash:
|
|
logger.warning(f"Fingerprint mismatch pour user {user_id}")
|
|
return None
|
|
|
|
result = await session.execute(
|
|
select(User).where(and_(User.id == user_id, User.is_active.is_(true())))
|
|
)
|
|
user = result.scalar_one_or_none()
|
|
|
|
if not user:
|
|
logger.warning(f"Utilisateur {user_id} introuvable ou inactif")
|
|
return None
|
|
|
|
session_id = generate_session_id()
|
|
|
|
new_access_token = create_access_token(
|
|
data={
|
|
"sub": user.id,
|
|
"email": user.email,
|
|
"role": user.role,
|
|
"sid": session_id,
|
|
},
|
|
fingerprint_hash=fingerprint_hash,
|
|
)
|
|
|
|
new_csrf_token = create_csrf_token(session_id)
|
|
|
|
if settings.refresh_token_rotation_enabled:
|
|
token_record.is_used = True
|
|
token_record.used_at = datetime.now()
|
|
|
|
new_refresh_jwt, new_token_id = create_refresh_token(
|
|
user_id=user.id, fingerprint_hash=fingerprint_hash
|
|
)
|
|
|
|
new_token_record = RefreshToken(
|
|
id=str(uuid.uuid4()),
|
|
user_id=user.id,
|
|
token_hash=hash_token(new_refresh_jwt),
|
|
token_id=new_token_id,
|
|
fingerprint_hash=fingerprint_hash,
|
|
device_info=device_info[:500] if device_info else None,
|
|
ip_address=ip_address,
|
|
expires_at=datetime.now()
|
|
+ timedelta(days=settings.refresh_token_expire_days),
|
|
created_at=datetime.now(),
|
|
)
|
|
|
|
token_record.replaced_by = new_token_record.id
|
|
|
|
session.add(new_token_record)
|
|
|
|
await redis_service.mark_refresh_token_used(token_id)
|
|
await redis_service.store_refresh_token_metadata(
|
|
token_id=new_token_id,
|
|
user_id=user.id,
|
|
fingerprint_hash=fingerprint_hash,
|
|
ttl_seconds=settings.refresh_token_expire_days * 24 * 60 * 60,
|
|
)
|
|
|
|
logger.info(f"Refresh token rotation pour user {user.email}")
|
|
|
|
return new_access_token, new_refresh_jwt, new_csrf_token, session_id
|
|
else:
|
|
token_record.last_used_at = datetime.now()
|
|
return new_access_token, refresh_token, new_csrf_token, session_id
|
|
|
|
@classmethod
|
|
async def revoke_token(
|
|
cls, session: AsyncSession, refresh_token: str, reason: str = "user_logout"
|
|
) -> bool:
|
|
payload = decode_token(refresh_token, expected_type="refresh")
|
|
if not payload:
|
|
return False
|
|
|
|
token_id = payload.get("jti")
|
|
user_id = payload.get("sub")
|
|
exp = payload.get("exp", 0)
|
|
|
|
ttl_seconds = max(0, exp - int(time.time()))
|
|
await redis_service.blacklist_token(token_id, ttl_seconds)
|
|
|
|
token_hash = hash_token(refresh_token)
|
|
result = await session.execute(
|
|
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
|
|
)
|
|
token_record = result.scalar_one_or_none()
|
|
|
|
if token_record:
|
|
token_record.is_revoked = True
|
|
token_record.revoked_at = datetime.now()
|
|
token_record.revoked_reason = reason
|
|
|
|
await redis_service.delete_refresh_token(token_id)
|
|
|
|
logger.info(f"Token revoque pour user {user_id}: {reason}")
|
|
|
|
return True
|
|
|
|
@classmethod
|
|
async def revoke_all_user_tokens(
|
|
cls, session: AsyncSession, user_id: str, reason: str = "security_action"
|
|
) -> int:
|
|
result = await session.execute(
|
|
select(RefreshToken).where(
|
|
and_(
|
|
RefreshToken.user_id == user_id,
|
|
RefreshToken.is_revoked.is_(false()),
|
|
)
|
|
)
|
|
)
|
|
tokens = result.scalars().all()
|
|
|
|
count = 0
|
|
for token in tokens:
|
|
token.is_revoked = True
|
|
token.revoked_at = datetime.now()
|
|
token.revoked_reason = reason
|
|
|
|
await redis_service.blacklist_token(
|
|
token.token_id, settings.refresh_token_expire_days * 24 * 60 * 60
|
|
)
|
|
await redis_service.delete_refresh_token(token.token_id)
|
|
count += 1
|
|
|
|
await redis_service.blacklist_user_tokens(
|
|
user_id, settings.refresh_token_expire_days * 24 * 60 * 60
|
|
)
|
|
|
|
logger.info(f"{count} tokens revoques pour user {user_id}: {reason}")
|
|
|
|
return count
|
|
|
|
@classmethod
|
|
async def _handle_potential_token_theft(
|
|
cls, session: AsyncSession, user_id: str, token_id: str
|
|
) -> None:
|
|
logger.warning(
|
|
f"Potentiel vol de token detecte pour user {user_id}, token {token_id[:8]}..."
|
|
)
|
|
|
|
await cls.revoke_all_user_tokens(
|
|
session, user_id, reason="potential_token_theft"
|
|
)
|
|
|
|
@classmethod
|
|
async def validate_access_token(
|
|
cls, token: str, fingerprint_hash: Optional[str] = None
|
|
) -> Optional[Dict[str, Any]]:
|
|
payload = decode_token(token, expected_type="access")
|
|
if not payload:
|
|
return None
|
|
|
|
token_id = payload.get("jti")
|
|
if token_id and await redis_service.is_token_blacklisted(token_id):
|
|
logger.debug(f"Access token {token_id[:8]}... est blackliste")
|
|
return None
|
|
|
|
user_id = payload.get("sub")
|
|
if user_id:
|
|
invalidation_time = await redis_service.get_user_token_invalidation_time(
|
|
user_id
|
|
)
|
|
if invalidation_time:
|
|
token_iat = payload.get("iat", 0)
|
|
if token_iat < invalidation_time:
|
|
logger.debug("Access token emis avant invalidation globale")
|
|
return None
|
|
|
|
if fingerprint_hash:
|
|
stored_fingerprint = payload.get("fph")
|
|
if stored_fingerprint and stored_fingerprint != fingerprint_hash:
|
|
logger.warning("Fingerprint mismatch sur access token")
|
|
return None
|
|
|
|
return payload
|
|
|
|
@classmethod
|
|
async def cleanup_expired_tokens(cls, session: AsyncSession) -> int:
|
|
result = await session.execute(
|
|
delete(RefreshToken).where(
|
|
or_(
|
|
RefreshToken.expires_at < datetime.now(),
|
|
and_(
|
|
RefreshToken.is_revoked.is_(true()),
|
|
RefreshToken.revoked_at < datetime.now() - timedelta(days=7),
|
|
),
|
|
)
|
|
)
|
|
)
|
|
|
|
count = result.rowcount
|
|
logger.info(f"{count} tokens expires nettoyes")
|
|
|
|
return count
|
|
|
|
@classmethod
|
|
async def get_user_active_sessions(
|
|
cls, session: AsyncSession, user_id: str
|
|
) -> list:
|
|
result = await session.execute(
|
|
select(RefreshToken)
|
|
.where(
|
|
and_(
|
|
RefreshToken.user_id == user_id,
|
|
RefreshToken.is_revoked.is_(false()),
|
|
RefreshToken.expires_at > datetime.now(),
|
|
)
|
|
)
|
|
.order_by(RefreshToken.created_at.desc())
|
|
)
|
|
tokens = result.scalars().all()
|
|
|
|
return [
|
|
{
|
|
"id": t.id,
|
|
"device_info": t.device_info,
|
|
"ip_address": t.ip_address,
|
|
"created_at": t.created_at.isoformat(),
|
|
"last_used_at": t.last_used_at.isoformat() if t.last_used_at else None,
|
|
}
|
|
for t in tokens
|
|
]
|