Files
Magent/backend/app/auth.py
2026-03-02 19:54:14 +13:00

161 lines
6.4 KiB
Python

from datetime import datetime, timezone
from typing import Dict, Any, Optional
from fastapi import Depends, HTTPException, status, Request
from fastapi.security import OAuth2PasswordBearer
from .db import get_user_by_username, set_user_auth_provider, upsert_user_activity
from .security import safe_decode_token, TokenError, verify_password
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login")
def _is_expired(expires_at: str | None) -> bool:
if not isinstance(expires_at, str) or not expires_at.strip():
return False
candidate = expires_at.strip()
if candidate.endswith("Z"):
candidate = candidate[:-1] + "+00:00"
try:
parsed = datetime.fromisoformat(candidate)
except ValueError:
return False
if parsed.tzinfo is None:
parsed = parsed.replace(tzinfo=timezone.utc)
return parsed <= datetime.now(timezone.utc)
def _extract_client_ip(request: Request) -> str:
forwarded = request.headers.get("x-forwarded-for")
if forwarded:
parts = [part.strip() for part in forwarded.split(",") if part.strip()]
if parts:
return parts[0]
real_ip = request.headers.get("x-real-ip")
if real_ip:
return real_ip.strip()
if request.client and request.client.host:
return request.client.host
return "unknown"
def resolve_user_auth_provider(user: Optional[Dict[str, Any]]) -> str:
if not isinstance(user, dict):
return "local"
provider = str(user.get("auth_provider") or "local").strip().lower() or "local"
if provider != "local":
return provider
password_hash = user.get("password_hash")
if isinstance(password_hash, str) and password_hash:
if verify_password("jellyfin-user", password_hash):
return "jellyfin"
if verify_password("jellyseerr-user", password_hash):
return "jellyseerr"
return provider
def normalize_user_auth_provider(user: Optional[Dict[str, Any]]) -> Dict[str, Any]:
if not isinstance(user, dict):
return {}
resolved_provider = resolve_user_auth_provider(user)
stored_provider = str(user.get("auth_provider") or "local").strip().lower() or "local"
if resolved_provider != stored_provider:
username = str(user.get("username") or "").strip()
if username:
set_user_auth_provider(username, resolved_provider)
refreshed_user = get_user_by_username(username)
if refreshed_user:
user = refreshed_user
normalized = dict(user)
normalized["auth_provider"] = resolved_provider
normalized["password_change_supported"] = resolved_provider in {"local", "jellyfin"}
normalized["password_provider"] = (
resolved_provider if resolved_provider in {"local", "jellyfin"} else None
)
return normalized
def _load_current_user_from_token(
token: str,
request: Optional[Request] = None,
allowed_token_types: Optional[set[str]] = None,
) -> Dict[str, Any]:
try:
payload = safe_decode_token(token)
except TokenError as exc:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") from exc
token_type = str(payload.get("typ") or "access").strip().lower()
if allowed_token_types and token_type not in allowed_token_types:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token type")
username = payload.get("sub")
if not username:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token subject")
user = get_user_by_username(username)
if not user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found")
if user.get("is_blocked"):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User is blocked")
if _is_expired(user.get("expires_at")):
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User access has expired")
user = normalize_user_auth_provider(user)
if request is not None:
ip = _extract_client_ip(request)
user_agent = request.headers.get("user-agent", "unknown")
upsert_user_activity(user["username"], ip, user_agent)
return {
"username": user["username"],
"role": user["role"],
"auth_provider": user.get("auth_provider", "local"),
"jellyseerr_user_id": user.get("jellyseerr_user_id"),
"auto_search_enabled": bool(user.get("auto_search_enabled", True)),
"invite_management_enabled": bool(user.get("invite_management_enabled", False)),
"profile_id": user.get("profile_id"),
"expires_at": user.get("expires_at"),
"is_expired": bool(user.get("is_expired", False)),
"password_change_supported": bool(user.get("password_change_supported", False)),
"password_provider": user.get("password_provider"),
}
def get_current_user(token: str = Depends(oauth2_scheme), request: Request = None) -> Dict[str, Any]:
return _load_current_user_from_token(token, request)
def get_current_user_event_stream(request: Request) -> Dict[str, Any]:
"""EventSource cannot send Authorization headers, so allow a short-lived stream token via query."""
token = None
stream_query_token = None
auth_header = request.headers.get("authorization", "")
if auth_header.lower().startswith("bearer "):
token = auth_header.split(" ", 1)[1].strip()
if not token:
stream_query_token = request.query_params.get("stream_token")
if not token and not stream_query_token:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing token")
if token:
# Allow standard bearer tokens in Authorization for non-browser EventSource clients.
return _load_current_user_from_token(token, None)
return _load_current_user_from_token(
str(stream_query_token),
None,
allowed_token_types={"sse"},
)
def require_admin(user: Dict[str, Any] = Depends(get_current_user)) -> Dict[str, Any]:
if user.get("role") != "admin":
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin access required")
return user
def require_admin_event_stream(
user: Dict[str, Any] = Depends(get_current_user_event_stream),
) -> Dict[str, Any]:
if user.get("role") != "admin":
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin access required")
return user