161 lines
6.4 KiB
Python
161 lines
6.4 KiB
Python
from datetime import datetime, timezone
|
|
from typing import Dict, Any, Optional
|
|
|
|
from fastapi import Depends, HTTPException, status, Request
|
|
from fastapi.security import OAuth2PasswordBearer
|
|
|
|
from .db import get_user_by_username, set_user_auth_provider, upsert_user_activity
|
|
from .security import safe_decode_token, TokenError, verify_password
|
|
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login")
|
|
|
|
|
|
def _is_expired(expires_at: str | None) -> bool:
|
|
if not isinstance(expires_at, str) or not expires_at.strip():
|
|
return False
|
|
candidate = expires_at.strip()
|
|
if candidate.endswith("Z"):
|
|
candidate = candidate[:-1] + "+00:00"
|
|
try:
|
|
parsed = datetime.fromisoformat(candidate)
|
|
except ValueError:
|
|
return False
|
|
if parsed.tzinfo is None:
|
|
parsed = parsed.replace(tzinfo=timezone.utc)
|
|
return parsed <= datetime.now(timezone.utc)
|
|
|
|
def _extract_client_ip(request: Request) -> str:
|
|
forwarded = request.headers.get("x-forwarded-for")
|
|
if forwarded:
|
|
parts = [part.strip() for part in forwarded.split(",") if part.strip()]
|
|
if parts:
|
|
return parts[0]
|
|
real_ip = request.headers.get("x-real-ip")
|
|
if real_ip:
|
|
return real_ip.strip()
|
|
if request.client and request.client.host:
|
|
return request.client.host
|
|
return "unknown"
|
|
|
|
|
|
def resolve_user_auth_provider(user: Optional[Dict[str, Any]]) -> str:
|
|
if not isinstance(user, dict):
|
|
return "local"
|
|
provider = str(user.get("auth_provider") or "local").strip().lower() or "local"
|
|
if provider != "local":
|
|
return provider
|
|
password_hash = user.get("password_hash")
|
|
if isinstance(password_hash, str) and password_hash:
|
|
if verify_password("jellyfin-user", password_hash):
|
|
return "jellyfin"
|
|
if verify_password("jellyseerr-user", password_hash):
|
|
return "jellyseerr"
|
|
return provider
|
|
|
|
|
|
def normalize_user_auth_provider(user: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
|
if not isinstance(user, dict):
|
|
return {}
|
|
resolved_provider = resolve_user_auth_provider(user)
|
|
stored_provider = str(user.get("auth_provider") or "local").strip().lower() or "local"
|
|
if resolved_provider != stored_provider:
|
|
username = str(user.get("username") or "").strip()
|
|
if username:
|
|
set_user_auth_provider(username, resolved_provider)
|
|
refreshed_user = get_user_by_username(username)
|
|
if refreshed_user:
|
|
user = refreshed_user
|
|
normalized = dict(user)
|
|
normalized["auth_provider"] = resolved_provider
|
|
normalized["password_change_supported"] = resolved_provider in {"local", "jellyfin"}
|
|
normalized["password_provider"] = (
|
|
resolved_provider if resolved_provider in {"local", "jellyfin"} else None
|
|
)
|
|
return normalized
|
|
|
|
|
|
def _load_current_user_from_token(
|
|
token: str,
|
|
request: Optional[Request] = None,
|
|
allowed_token_types: Optional[set[str]] = None,
|
|
) -> Dict[str, Any]:
|
|
try:
|
|
payload = safe_decode_token(token)
|
|
except TokenError as exc:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") from exc
|
|
token_type = str(payload.get("typ") or "access").strip().lower()
|
|
if allowed_token_types and token_type not in allowed_token_types:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token type")
|
|
|
|
username = payload.get("sub")
|
|
if not username:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token subject")
|
|
|
|
user = get_user_by_username(username)
|
|
if not user:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found")
|
|
if user.get("is_blocked"):
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User is blocked")
|
|
if _is_expired(user.get("expires_at")):
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User access has expired")
|
|
|
|
user = normalize_user_auth_provider(user)
|
|
|
|
if request is not None:
|
|
ip = _extract_client_ip(request)
|
|
user_agent = request.headers.get("user-agent", "unknown")
|
|
upsert_user_activity(user["username"], ip, user_agent)
|
|
|
|
return {
|
|
"username": user["username"],
|
|
"role": user["role"],
|
|
"auth_provider": user.get("auth_provider", "local"),
|
|
"jellyseerr_user_id": user.get("jellyseerr_user_id"),
|
|
"auto_search_enabled": bool(user.get("auto_search_enabled", True)),
|
|
"invite_management_enabled": bool(user.get("invite_management_enabled", False)),
|
|
"profile_id": user.get("profile_id"),
|
|
"expires_at": user.get("expires_at"),
|
|
"is_expired": bool(user.get("is_expired", False)),
|
|
"password_change_supported": bool(user.get("password_change_supported", False)),
|
|
"password_provider": user.get("password_provider"),
|
|
}
|
|
|
|
|
|
def get_current_user(token: str = Depends(oauth2_scheme), request: Request = None) -> Dict[str, Any]:
|
|
return _load_current_user_from_token(token, request)
|
|
|
|
|
|
def get_current_user_event_stream(request: Request) -> Dict[str, Any]:
|
|
"""EventSource cannot send Authorization headers, so allow a short-lived stream token via query."""
|
|
token = None
|
|
stream_query_token = None
|
|
auth_header = request.headers.get("authorization", "")
|
|
if auth_header.lower().startswith("bearer "):
|
|
token = auth_header.split(" ", 1)[1].strip()
|
|
if not token:
|
|
stream_query_token = request.query_params.get("stream_token")
|
|
if not token and not stream_query_token:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing token")
|
|
if token:
|
|
# Allow standard bearer tokens in Authorization for non-browser EventSource clients.
|
|
return _load_current_user_from_token(token, None)
|
|
return _load_current_user_from_token(
|
|
str(stream_query_token),
|
|
None,
|
|
allowed_token_types={"sse"},
|
|
)
|
|
|
|
|
|
def require_admin(user: Dict[str, Any] = Depends(get_current_user)) -> Dict[str, Any]:
|
|
if user.get("role") != "admin":
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin access required")
|
|
return user
|
|
|
|
|
|
def require_admin_event_stream(
|
|
user: Dict[str, Any] = Depends(get_current_user_event_stream),
|
|
) -> Dict[str, Any]:
|
|
if user.get("role") != "admin":
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin access required")
|
|
return user
|