from datetime import datetime, timezone from typing import Dict, Any, Optional from fastapi import Depends, HTTPException, status, Request from fastapi.security import OAuth2PasswordBearer from .db import get_user_by_username, set_user_auth_provider, upsert_user_activity from .security import safe_decode_token, TokenError, verify_password oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login") def _is_expired(expires_at: str | None) -> bool: if not isinstance(expires_at, str) or not expires_at.strip(): return False candidate = expires_at.strip() if candidate.endswith("Z"): candidate = candidate[:-1] + "+00:00" try: parsed = datetime.fromisoformat(candidate) except ValueError: return False if parsed.tzinfo is None: parsed = parsed.replace(tzinfo=timezone.utc) return parsed <= datetime.now(timezone.utc) def _extract_client_ip(request: Request) -> str: forwarded = request.headers.get("x-forwarded-for") if forwarded: parts = [part.strip() for part in forwarded.split(",") if part.strip()] if parts: return parts[0] real_ip = request.headers.get("x-real-ip") if real_ip: return real_ip.strip() if request.client and request.client.host: return request.client.host return "unknown" def resolve_user_auth_provider(user: Optional[Dict[str, Any]]) -> str: if not isinstance(user, dict): return "local" provider = str(user.get("auth_provider") or "local").strip().lower() or "local" if provider != "local": return provider password_hash = user.get("password_hash") if isinstance(password_hash, str) and password_hash: if verify_password("jellyfin-user", password_hash): return "jellyfin" if verify_password("jellyseerr-user", password_hash): return "jellyseerr" return provider def normalize_user_auth_provider(user: Optional[Dict[str, Any]]) -> Dict[str, Any]: if not isinstance(user, dict): return {} resolved_provider = resolve_user_auth_provider(user) stored_provider = str(user.get("auth_provider") or "local").strip().lower() or "local" if resolved_provider != stored_provider: username = str(user.get("username") or "").strip() if username: set_user_auth_provider(username, resolved_provider) refreshed_user = get_user_by_username(username) if refreshed_user: user = refreshed_user normalized = dict(user) normalized["auth_provider"] = resolved_provider normalized["password_change_supported"] = resolved_provider in {"local", "jellyfin"} normalized["password_provider"] = ( resolved_provider if resolved_provider in {"local", "jellyfin"} else None ) return normalized def _load_current_user_from_token( token: str, request: Optional[Request] = None, allowed_token_types: Optional[set[str]] = None, ) -> Dict[str, Any]: try: payload = safe_decode_token(token) except TokenError as exc: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") from exc token_type = str(payload.get("typ") or "access").strip().lower() if allowed_token_types and token_type not in allowed_token_types: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token type") username = payload.get("sub") if not username: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token subject") user = get_user_by_username(username) if not user: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found") if user.get("is_blocked"): raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User is blocked") if _is_expired(user.get("expires_at")): raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User access has expired") user = normalize_user_auth_provider(user) if request is not None: ip = _extract_client_ip(request) user_agent = request.headers.get("user-agent", "unknown") upsert_user_activity(user["username"], ip, user_agent) return { "username": user["username"], "role": user["role"], "auth_provider": user.get("auth_provider", "local"), "jellyseerr_user_id": user.get("jellyseerr_user_id"), "auto_search_enabled": bool(user.get("auto_search_enabled", True)), "invite_management_enabled": bool(user.get("invite_management_enabled", False)), "profile_id": user.get("profile_id"), "expires_at": user.get("expires_at"), "is_expired": bool(user.get("is_expired", False)), "password_change_supported": bool(user.get("password_change_supported", False)), "password_provider": user.get("password_provider"), } def get_current_user(token: str = Depends(oauth2_scheme), request: Request = None) -> Dict[str, Any]: return _load_current_user_from_token(token, request) def get_current_user_event_stream(request: Request) -> Dict[str, Any]: """EventSource cannot send Authorization headers, so allow a short-lived stream token via query.""" token = None stream_query_token = None auth_header = request.headers.get("authorization", "") if auth_header.lower().startswith("bearer "): token = auth_header.split(" ", 1)[1].strip() if not token: stream_query_token = request.query_params.get("stream_token") if not token and not stream_query_token: raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing token") if token: # Allow standard bearer tokens in Authorization for non-browser EventSource clients. return _load_current_user_from_token(token, None) return _load_current_user_from_token( str(stream_query_token), None, allowed_token_types={"sse"}, ) def require_admin(user: Dict[str, Any] = Depends(get_current_user)) -> Dict[str, Any]: if user.get("role") != "admin": raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin access required") return user def require_admin_event_stream( user: Dict[str, Any] = Depends(get_current_user_event_stream), ) -> Dict[str, Any]: if user.get("role") != "admin": raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin access required") return user