334 lines
12 KiB
Python
334 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
import secrets
|
|
from datetime import datetime, timedelta, timezone
|
|
from typing import Any, Dict, Optional
|
|
|
|
from ..auth import normalize_user_auth_provider, resolve_user_auth_provider
|
|
from ..clients.jellyfin import JellyfinClient
|
|
from ..clients.jellyseerr import JellyseerrClient
|
|
from ..db import (
|
|
create_password_reset_token,
|
|
delete_expired_password_reset_tokens,
|
|
get_password_reset_token,
|
|
get_user_by_jellyseerr_id,
|
|
get_user_by_username,
|
|
get_users_by_username_ci,
|
|
mark_password_reset_token_used,
|
|
set_user_auth_provider,
|
|
set_user_password,
|
|
sync_jellyfin_password_state,
|
|
)
|
|
from ..runtime import get_runtime_settings
|
|
from .invite_email import send_password_reset_email
|
|
from .user_cache import get_cached_jellyseerr_users, save_jellyseerr_users_cache
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
PASSWORD_RESET_TOKEN_TTL_MINUTES = 30
|
|
|
|
|
|
class PasswordResetUnavailableError(RuntimeError):
|
|
pass
|
|
|
|
|
|
def _normalize_handles(value: object) -> list[str]:
|
|
if not isinstance(value, str):
|
|
return []
|
|
normalized = value.strip().lower()
|
|
if not normalized:
|
|
return []
|
|
handles = [normalized]
|
|
if "@" in normalized:
|
|
handles.append(normalized.split("@", 1)[0])
|
|
return list(dict.fromkeys(handles))
|
|
|
|
|
|
def _pick_preferred_user(users: list[dict], requested_identifier: str) -> dict | None:
|
|
if not users:
|
|
return None
|
|
requested = str(requested_identifier or "").strip().lower()
|
|
|
|
def _rank(user: dict) -> tuple[int, int, int, int]:
|
|
provider = str(user.get("auth_provider") or "local").strip().lower()
|
|
role = str(user.get("role") or "user").strip().lower()
|
|
username = str(user.get("username") or "").strip().lower()
|
|
return (
|
|
0 if role == "admin" else 1,
|
|
0 if isinstance(user.get("jellyseerr_user_id"), int) else 1,
|
|
0 if provider == "jellyfin" else (1 if provider == "local" else 2),
|
|
0 if username == requested else 1,
|
|
)
|
|
|
|
return sorted(users, key=_rank)[0]
|
|
|
|
|
|
def _find_matching_seerr_user(identifier: str, users: list[dict]) -> dict | None:
|
|
target_handles = set(_normalize_handles(identifier))
|
|
if not target_handles:
|
|
return None
|
|
for user in users:
|
|
if not isinstance(user, dict):
|
|
continue
|
|
for key in ("username", "email"):
|
|
value = user.get(key)
|
|
if target_handles.intersection(_normalize_handles(value)):
|
|
return user
|
|
return None
|
|
|
|
|
|
async def _fetch_all_seerr_users() -> list[dict]:
|
|
cached = get_cached_jellyseerr_users()
|
|
if cached is not None:
|
|
return cached
|
|
runtime = get_runtime_settings()
|
|
client = JellyseerrClient(runtime.jellyseerr_base_url, runtime.jellyseerr_api_key)
|
|
if not client.configured():
|
|
return []
|
|
users: list[dict] = []
|
|
take = 100
|
|
skip = 0
|
|
while True:
|
|
payload = await client.get_users(take=take, skip=skip)
|
|
if not payload:
|
|
break
|
|
if isinstance(payload, list):
|
|
batch = payload
|
|
elif isinstance(payload, dict):
|
|
batch = payload.get("results") or payload.get("users") or payload.get("data") or payload.get("items")
|
|
else:
|
|
batch = None
|
|
if not isinstance(batch, list) or not batch:
|
|
break
|
|
users.extend([user for user in batch if isinstance(user, dict)])
|
|
if len(batch) < take:
|
|
break
|
|
skip += take
|
|
if users:
|
|
return save_jellyseerr_users_cache(users)
|
|
return users
|
|
|
|
|
|
def _resolve_seerr_user_email(seerr_user: Optional[dict], local_user: Optional[dict]) -> Optional[str]:
|
|
if isinstance(local_user, dict):
|
|
stored_email = str(local_user.get("email") or "").strip()
|
|
if "@" in stored_email:
|
|
return stored_email
|
|
username = str(local_user.get("username") or "").strip()
|
|
if "@" in username:
|
|
return username
|
|
if isinstance(seerr_user, dict):
|
|
email = str(seerr_user.get("email") or "").strip()
|
|
if "@" in email:
|
|
return email
|
|
return None
|
|
|
|
|
|
async def _resolve_reset_target(identifier: str) -> Optional[Dict[str, Any]]:
|
|
normalized_identifier = str(identifier or "").strip()
|
|
if not normalized_identifier:
|
|
return None
|
|
|
|
local_user = normalize_user_auth_provider(
|
|
_pick_preferred_user(get_users_by_username_ci(normalized_identifier), normalized_identifier)
|
|
)
|
|
seerr_users: list[dict] | None = None
|
|
seerr_user: dict | None = None
|
|
|
|
if isinstance(local_user, dict) and isinstance(local_user.get("jellyseerr_user_id"), int):
|
|
seerr_users = await _fetch_all_seerr_users()
|
|
seerr_user = next(
|
|
(
|
|
user
|
|
for user in seerr_users
|
|
if isinstance(user, dict) and int(user.get("id") or user.get("userId") or 0) == int(local_user["jellyseerr_user_id"])
|
|
),
|
|
None,
|
|
)
|
|
|
|
if not local_user:
|
|
seerr_users = seerr_users if seerr_users is not None else await _fetch_all_seerr_users()
|
|
seerr_user = _find_matching_seerr_user(normalized_identifier, seerr_users)
|
|
if seerr_user:
|
|
seerr_user_id = seerr_user.get("id") or seerr_user.get("userId") or seerr_user.get("Id")
|
|
try:
|
|
seerr_user_id = int(seerr_user_id) if seerr_user_id is not None else None
|
|
except (TypeError, ValueError):
|
|
seerr_user_id = None
|
|
if seerr_user_id is not None:
|
|
local_user = normalize_user_auth_provider(get_user_by_jellyseerr_id(seerr_user_id))
|
|
if not local_user:
|
|
for candidate in (seerr_user.get("email"), seerr_user.get("username")):
|
|
if not isinstance(candidate, str) or not candidate.strip():
|
|
continue
|
|
local_user = normalize_user_auth_provider(
|
|
_pick_preferred_user(get_users_by_username_ci(candidate), candidate)
|
|
)
|
|
if local_user:
|
|
break
|
|
|
|
if not local_user:
|
|
return None
|
|
|
|
auth_provider = resolve_user_auth_provider(local_user)
|
|
username = str(local_user.get("username") or "").strip()
|
|
recipient_email = _resolve_seerr_user_email(seerr_user, local_user)
|
|
if not recipient_email:
|
|
seerr_users = seerr_users if seerr_users is not None else await _fetch_all_seerr_users()
|
|
if isinstance(local_user.get("jellyseerr_user_id"), int):
|
|
seerr_user = next(
|
|
(
|
|
user
|
|
for user in seerr_users
|
|
if isinstance(user, dict) and int(user.get("id") or user.get("userId") or 0) == int(local_user["jellyseerr_user_id"])
|
|
),
|
|
None,
|
|
)
|
|
if not seerr_user:
|
|
seerr_user = _find_matching_seerr_user(username, seerr_users)
|
|
recipient_email = _resolve_seerr_user_email(seerr_user, local_user)
|
|
if not recipient_email:
|
|
return None
|
|
|
|
if auth_provider == "jellyseerr":
|
|
runtime = get_runtime_settings()
|
|
jellyfin_client = JellyfinClient(runtime.jellyfin_base_url, runtime.jellyfin_api_key)
|
|
if jellyfin_client.configured():
|
|
try:
|
|
jellyfin_user = await jellyfin_client.find_user_by_name(username)
|
|
except Exception:
|
|
jellyfin_user = None
|
|
if isinstance(jellyfin_user, dict):
|
|
auth_provider = "jellyfin"
|
|
|
|
if auth_provider not in {"local", "jellyfin"}:
|
|
return None
|
|
|
|
return {
|
|
"username": username,
|
|
"recipient_email": recipient_email,
|
|
"auth_provider": auth_provider,
|
|
}
|
|
|
|
|
|
def _token_record_is_usable(record: Optional[dict]) -> bool:
|
|
if not isinstance(record, dict):
|
|
return False
|
|
if record.get("is_used"):
|
|
return False
|
|
if record.get("is_expired"):
|
|
return False
|
|
return True
|
|
|
|
|
|
def _mask_email(email: str) -> str:
|
|
candidate = str(email or "").strip()
|
|
if "@" not in candidate:
|
|
return "valid reset link"
|
|
local_part, domain = candidate.split("@", 1)
|
|
if not local_part:
|
|
return f"***@{domain}"
|
|
if len(local_part) == 1:
|
|
return f"{local_part}***@{domain}"
|
|
return f"{local_part[0]}***{local_part[-1]}@{domain}"
|
|
|
|
|
|
async def request_password_reset(
|
|
identifier: str,
|
|
*,
|
|
requested_by_ip: Optional[str] = None,
|
|
requested_user_agent: Optional[str] = None,
|
|
) -> Dict[str, Any]:
|
|
delete_expired_password_reset_tokens()
|
|
target = await _resolve_reset_target(identifier)
|
|
if not target:
|
|
logger.info("password reset requested with no eligible match identifier=%s", identifier.strip().lower()[:256])
|
|
return {"status": "ok", "issued": False}
|
|
|
|
token = secrets.token_urlsafe(32)
|
|
expires_at = (datetime.now(timezone.utc) + timedelta(minutes=PASSWORD_RESET_TOKEN_TTL_MINUTES)).isoformat()
|
|
create_password_reset_token(
|
|
token,
|
|
target["username"],
|
|
target["recipient_email"],
|
|
target["auth_provider"],
|
|
expires_at,
|
|
requested_by_ip=requested_by_ip,
|
|
requested_user_agent=requested_user_agent,
|
|
)
|
|
await send_password_reset_email(
|
|
recipient_email=target["recipient_email"],
|
|
username=target["username"],
|
|
token=token,
|
|
expires_at=expires_at,
|
|
auth_provider=target["auth_provider"],
|
|
)
|
|
return {
|
|
"status": "ok",
|
|
"issued": True,
|
|
"username": target["username"],
|
|
"recipient_email": target["recipient_email"],
|
|
"auth_provider": target["auth_provider"],
|
|
"expires_at": expires_at,
|
|
}
|
|
|
|
|
|
def verify_password_reset_token(token: str) -> Dict[str, Any]:
|
|
delete_expired_password_reset_tokens()
|
|
record = get_password_reset_token(token)
|
|
if not _token_record_is_usable(record):
|
|
raise ValueError("Password reset link is invalid or has expired.")
|
|
return {
|
|
"status": "ok",
|
|
"recipient_hint": _mask_email(str(record.get("recipient_email") or "")),
|
|
"auth_provider": record.get("auth_provider"),
|
|
"expires_at": record.get("expires_at"),
|
|
}
|
|
|
|
|
|
async def apply_password_reset(token: str, new_password: str) -> Dict[str, Any]:
|
|
delete_expired_password_reset_tokens()
|
|
record = get_password_reset_token(token)
|
|
if not _token_record_is_usable(record):
|
|
raise ValueError("Password reset link is invalid or has expired.")
|
|
|
|
username = str(record.get("username") or "").strip()
|
|
if not username:
|
|
raise ValueError("Password reset link is invalid or has expired.")
|
|
|
|
stored_user = normalize_user_auth_provider(get_user_by_username(username))
|
|
if not stored_user:
|
|
raise ValueError("Password reset link is invalid or has expired.")
|
|
|
|
auth_provider = resolve_user_auth_provider(stored_user)
|
|
if auth_provider == "jellyseerr":
|
|
auth_provider = "jellyfin"
|
|
|
|
if auth_provider == "local":
|
|
set_user_password(username, new_password)
|
|
if str(stored_user.get("auth_provider") or "").strip().lower() != "local":
|
|
set_user_auth_provider(username, "local")
|
|
mark_password_reset_token_used(token)
|
|
logger.info("password reset applied username=%s provider=local", username)
|
|
return {"status": "ok", "provider": "local", "username": username}
|
|
|
|
if auth_provider == "jellyfin":
|
|
runtime = get_runtime_settings()
|
|
client = JellyfinClient(runtime.jellyfin_base_url, runtime.jellyfin_api_key)
|
|
if not client.configured():
|
|
raise PasswordResetUnavailableError("Jellyfin is not configured for password reset.")
|
|
jellyfin_user = await client.find_user_by_name(username)
|
|
user_id = client._extract_user_id(jellyfin_user)
|
|
if not user_id:
|
|
raise ValueError("Password reset link is invalid or has expired.")
|
|
await client.set_user_password(user_id, new_password)
|
|
sync_jellyfin_password_state(username, new_password)
|
|
if str(stored_user.get("auth_provider") or "").strip().lower() != "jellyfin":
|
|
set_user_auth_provider(username, "jellyfin")
|
|
mark_password_reset_token_used(token)
|
|
logger.info("password reset applied username=%s provider=jellyfin", username)
|
|
return {"status": "ok", "provider": "jellyfin", "username": username}
|
|
|
|
raise ValueError("Password reset is not available for this sign-in provider.")
|