227 lines
8.4 KiB
Python
227 lines
8.4 KiB
Python
from datetime import datetime, timezone
|
|
from typing import Any, Dict, Optional
|
|
|
|
from fastapi import Depends, HTTPException, Request, Response, status
|
|
from fastapi.security import OAuth2PasswordBearer
|
|
|
|
from .config import settings
|
|
from .db import get_user_by_username, set_user_auth_provider, upsert_user_activity
|
|
from .network_security import request_trusts_forwarded_headers
|
|
from .security import TokenError, safe_decode_token, verify_password
|
|
|
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login", auto_error=False)
|
|
|
|
|
|
def _is_expired(expires_at: str | None) -> bool:
|
|
if not isinstance(expires_at, str) or not expires_at.strip():
|
|
return False
|
|
candidate = expires_at.strip()
|
|
if candidate.endswith("Z"):
|
|
candidate = candidate[:-1] + "+00:00"
|
|
try:
|
|
parsed = datetime.fromisoformat(candidate)
|
|
except ValueError:
|
|
return False
|
|
if parsed.tzinfo is None:
|
|
parsed = parsed.replace(tzinfo=timezone.utc)
|
|
return parsed <= datetime.now(timezone.utc)
|
|
|
|
|
|
def _extract_client_ip(request: Request) -> str:
|
|
direct_host = request.client.host if request.client else None
|
|
if request_trusts_forwarded_headers(direct_host):
|
|
forwarded = request.headers.get("x-forwarded-for")
|
|
if forwarded:
|
|
parts = [part.strip() for part in forwarded.split(",") if part.strip()]
|
|
if parts:
|
|
return parts[0]
|
|
real_ip = request.headers.get("x-real-ip")
|
|
if real_ip:
|
|
return real_ip.strip()
|
|
if direct_host:
|
|
return direct_host
|
|
return "unknown"
|
|
|
|
|
|
def _cookie_settings() -> dict[str, Any]:
|
|
samesite = str(settings.auth_cookie_samesite or "lax").strip().lower()
|
|
if samesite not in {"lax", "strict", "none"}:
|
|
samesite = "lax"
|
|
return {
|
|
"secure": bool(settings.auth_cookie_secure),
|
|
"httponly": True,
|
|
"samesite": samesite,
|
|
"domain": settings.auth_cookie_domain or None,
|
|
"path": "/",
|
|
}
|
|
|
|
|
|
def _state_cookie_settings() -> dict[str, Any]:
|
|
cookie = _cookie_settings()
|
|
cookie["httponly"] = False
|
|
return cookie
|
|
|
|
|
|
def set_auth_cookies(response: Response, token: str) -> None:
|
|
max_age = max(60, int(settings.jwt_exp_minutes or 720) * 60)
|
|
response.set_cookie(
|
|
settings.auth_cookie_name,
|
|
token,
|
|
max_age=max_age,
|
|
**_cookie_settings(),
|
|
)
|
|
response.set_cookie(
|
|
settings.auth_state_cookie_name,
|
|
"1",
|
|
max_age=max_age,
|
|
**_state_cookie_settings(),
|
|
)
|
|
|
|
|
|
def clear_auth_cookies(response: Response) -> None:
|
|
response.delete_cookie(settings.auth_cookie_name, path="/", domain=settings.auth_cookie_domain or None)
|
|
response.delete_cookie(
|
|
settings.auth_state_cookie_name,
|
|
path="/",
|
|
domain=settings.auth_cookie_domain or None,
|
|
)
|
|
|
|
|
|
def _extract_access_token(request: Request, oauth_token: Optional[str]) -> Optional[str]:
|
|
auth_header = request.headers.get("authorization", "")
|
|
if auth_header.lower().startswith("bearer "):
|
|
return auth_header.split(" ", 1)[1].strip()
|
|
if oauth_token:
|
|
return oauth_token
|
|
cookie_token = request.cookies.get(settings.auth_cookie_name)
|
|
if isinstance(cookie_token, str) and cookie_token.strip():
|
|
return cookie_token.strip()
|
|
return None
|
|
|
|
|
|
def resolve_user_auth_provider(user: Optional[Dict[str, Any]]) -> str:
|
|
if not isinstance(user, dict):
|
|
return "local"
|
|
provider = str(user.get("auth_provider") or "local").strip().lower() or "local"
|
|
if provider != "local":
|
|
return provider
|
|
password_hash = user.get("password_hash")
|
|
if isinstance(password_hash, str) and password_hash:
|
|
if verify_password("jellyfin-user", password_hash):
|
|
return "jellyfin"
|
|
if verify_password("jellyseerr-user", password_hash):
|
|
return "jellyseerr"
|
|
return provider
|
|
|
|
|
|
def normalize_user_auth_provider(user: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
|
if not isinstance(user, dict):
|
|
return {}
|
|
resolved_provider = resolve_user_auth_provider(user)
|
|
stored_provider = str(user.get("auth_provider") or "local").strip().lower() or "local"
|
|
if resolved_provider != stored_provider:
|
|
username = str(user.get("username") or "").strip()
|
|
if username:
|
|
set_user_auth_provider(username, resolved_provider)
|
|
refreshed_user = get_user_by_username(username)
|
|
if refreshed_user:
|
|
user = refreshed_user
|
|
normalized = dict(user)
|
|
normalized["auth_provider"] = resolved_provider
|
|
normalized["password_change_supported"] = resolved_provider in {"local", "jellyfin"}
|
|
normalized["password_provider"] = (
|
|
resolved_provider if resolved_provider in {"local", "jellyfin"} else None
|
|
)
|
|
return normalized
|
|
|
|
|
|
def _load_current_user_from_token(
|
|
token: str,
|
|
request: Optional[Request] = None,
|
|
allowed_token_types: Optional[set[str]] = None,
|
|
) -> Dict[str, Any]:
|
|
try:
|
|
payload = safe_decode_token(token)
|
|
except TokenError as exc:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token") from exc
|
|
token_type = str(payload.get("typ") or "access").strip().lower()
|
|
if allowed_token_types and token_type not in allowed_token_types:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token type")
|
|
|
|
username = payload.get("sub")
|
|
if not username:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid token subject")
|
|
|
|
user = get_user_by_username(username)
|
|
if not user:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="User not found")
|
|
if user.get("is_blocked"):
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User is blocked")
|
|
if _is_expired(user.get("expires_at")):
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User access has expired")
|
|
|
|
user = normalize_user_auth_provider(user)
|
|
|
|
if request is not None:
|
|
ip = _extract_client_ip(request)
|
|
user_agent = request.headers.get("user-agent", "unknown")
|
|
upsert_user_activity(user["username"], ip, user_agent)
|
|
|
|
return {
|
|
"username": user["username"],
|
|
"email": user.get("email"),
|
|
"role": user["role"],
|
|
"auth_provider": user.get("auth_provider", "local"),
|
|
"jellyseerr_user_id": user.get("jellyseerr_user_id"),
|
|
"auto_search_enabled": bool(user.get("auto_search_enabled", True)),
|
|
"invite_management_enabled": bool(user.get("invite_management_enabled", False)),
|
|
"profile_id": user.get("profile_id"),
|
|
"expires_at": user.get("expires_at"),
|
|
"is_expired": bool(user.get("is_expired", False)),
|
|
"password_change_supported": bool(user.get("password_change_supported", False)),
|
|
"password_provider": user.get("password_provider"),
|
|
}
|
|
|
|
|
|
def get_current_user(
|
|
request: Request,
|
|
token: Optional[str] = Depends(oauth2_scheme),
|
|
) -> Dict[str, Any]:
|
|
resolved_token = _extract_access_token(request, token)
|
|
if not resolved_token:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing token")
|
|
return _load_current_user_from_token(resolved_token, request)
|
|
|
|
|
|
def get_current_user_event_stream(
|
|
request: Request,
|
|
token: Optional[str] = Depends(oauth2_scheme),
|
|
) -> Dict[str, Any]:
|
|
"""EventSource cannot send Authorization headers, so allow a short-lived stream token via query."""
|
|
resolved_token = _extract_access_token(request, token)
|
|
stream_query_token = request.query_params.get("stream_token")
|
|
if resolved_token:
|
|
# Allow standard bearer tokens for non-browser EventSource clients.
|
|
return _load_current_user_from_token(resolved_token, None)
|
|
if not stream_query_token:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Missing token")
|
|
return _load_current_user_from_token(
|
|
str(stream_query_token),
|
|
None,
|
|
allowed_token_types={"sse"},
|
|
)
|
|
|
|
|
|
def require_admin(user: Dict[str, Any] = Depends(get_current_user)) -> Dict[str, Any]:
|
|
if user.get("role") != "admin":
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin access required")
|
|
return user
|
|
|
|
|
|
def require_admin_event_stream(
|
|
user: Dict[str, Any] = Depends(get_current_user_event_stream),
|
|
) -> Dict[str, Any]:
|
|
if user.get("role") != "admin":
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin access required")
|
|
return user
|