358 lines
16 KiB
Python
358 lines
16 KiB
Python
from datetime import datetime, timedelta, timezone
|
|
from typing import Optional
|
|
from fastapi import APIRouter, HTTPException, status, Depends, Request
|
|
from fastapi.security import OAuth2PasswordRequestForm
|
|
import secrets
|
|
|
|
from ..db import (
|
|
verify_user_password,
|
|
create_user_if_missing,
|
|
create_user,
|
|
set_last_login,
|
|
get_user_by_username,
|
|
get_user_by_email,
|
|
set_user_password,
|
|
get_invite_by_code,
|
|
get_invite_profile,
|
|
increment_invite_use,
|
|
list_invites_by_creator,
|
|
create_invite,
|
|
upsert_user_contact,
|
|
get_user_contact,
|
|
set_user_expiry,
|
|
create_password_reset,
|
|
get_password_reset,
|
|
mark_password_reset_used,
|
|
)
|
|
from ..runtime import get_runtime_settings
|
|
from ..clients.jellyfin import JellyfinClient
|
|
from ..clients.jellyseerr import JellyseerrClient
|
|
from ..security import create_access_token
|
|
from ..auth import get_current_user
|
|
from ..services.captcha import verify_captcha
|
|
from ..services.notifications import send_notification
|
|
|
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
|
|
|
def _validate_password(password: str, rules: dict | None = None) -> Optional[str]:
|
|
runtime = get_runtime_settings()
|
|
rules = rules or {}
|
|
min_length = int(rules.get("min_length") or runtime.password_min_length or 8)
|
|
require_upper = bool(rules.get("require_upper", runtime.password_require_upper))
|
|
require_lower = bool(rules.get("require_lower", runtime.password_require_lower))
|
|
require_number = bool(rules.get("require_number", runtime.password_require_number))
|
|
require_symbol = bool(rules.get("require_symbol", runtime.password_require_symbol))
|
|
|
|
if len(password) < min_length:
|
|
return f"Password must be at least {min_length} characters."
|
|
if require_upper and password.lower() == password:
|
|
return "Password must include an uppercase letter."
|
|
if require_lower and password.upper() == password:
|
|
return "Password must include a lowercase letter."
|
|
if require_number and not any(char.isdigit() for char in password):
|
|
return "Password must include a number."
|
|
if require_symbol and password.isalnum():
|
|
return "Password must include a symbol."
|
|
return None
|
|
|
|
|
|
@router.post("/login")
|
|
async def login(form_data: OAuth2PasswordRequestForm = Depends()) -> dict:
|
|
user = verify_user_password(form_data.username, form_data.password)
|
|
if not user:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid credentials")
|
|
if user.get("is_blocked"):
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User is blocked")
|
|
token = create_access_token(user["username"], user["role"])
|
|
set_last_login(user["username"])
|
|
return {
|
|
"access_token": token,
|
|
"token_type": "bearer",
|
|
"user": {"username": user["username"], "role": user["role"]},
|
|
}
|
|
|
|
|
|
@router.post("/jellyfin/login")
|
|
async def jellyfin_login(form_data: OAuth2PasswordRequestForm = Depends()) -> dict:
|
|
runtime = get_runtime_settings()
|
|
client = JellyfinClient(runtime.jellyfin_base_url, runtime.jellyfin_api_key)
|
|
if not client.configured():
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Jellyfin not configured")
|
|
try:
|
|
response = await client.authenticate_by_name(form_data.username, form_data.password)
|
|
except Exception as exc:
|
|
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc
|
|
if not isinstance(response, dict) or not response.get("User"):
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Jellyfin credentials")
|
|
create_user_if_missing(form_data.username, "jellyfin-user", role="user", auth_provider="jellyfin")
|
|
user = get_user_by_username(form_data.username)
|
|
if user and user.get("is_blocked"):
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User is blocked")
|
|
try:
|
|
users = await client.get_users()
|
|
if isinstance(users, list):
|
|
for user in users:
|
|
if not isinstance(user, dict):
|
|
continue
|
|
name = user.get("Name")
|
|
if isinstance(name, str) and name:
|
|
create_user_if_missing(name, "jellyfin-user", role="user", auth_provider="jellyfin")
|
|
except Exception:
|
|
pass
|
|
token = create_access_token(form_data.username, "user")
|
|
set_last_login(form_data.username)
|
|
return {"access_token": token, "token_type": "bearer", "user": {"username": form_data.username, "role": "user"}}
|
|
|
|
|
|
@router.post("/jellyseerr/login")
|
|
async def jellyseerr_login(form_data: OAuth2PasswordRequestForm = Depends()) -> dict:
|
|
runtime = get_runtime_settings()
|
|
client = JellyseerrClient(runtime.jellyseerr_base_url, runtime.jellyseerr_api_key)
|
|
if not client.configured():
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Jellyseerr not configured")
|
|
payload = {"email": form_data.username, "password": form_data.password}
|
|
try:
|
|
response = await client.post("/api/v1/auth/login", payload=payload)
|
|
except Exception as exc:
|
|
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail=str(exc)) from exc
|
|
if not isinstance(response, dict):
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid Jellyseerr credentials")
|
|
create_user_if_missing(form_data.username, "jellyseerr-user", role="user", auth_provider="jellyseerr")
|
|
user = get_user_by_username(form_data.username)
|
|
if user and user.get("is_blocked"):
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="User is blocked")
|
|
token = create_access_token(form_data.username, "user")
|
|
set_last_login(form_data.username)
|
|
return {"access_token": token, "token_type": "bearer", "user": {"username": form_data.username, "role": "user"}}
|
|
|
|
|
|
@router.get("/me")
|
|
async def me(current_user: dict = Depends(get_current_user)) -> dict:
|
|
return current_user
|
|
|
|
|
|
@router.post("/password")
|
|
async def change_password(payload: dict, current_user: dict = Depends(get_current_user)) -> dict:
|
|
if current_user.get("auth_provider") != "local":
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Password changes are only available for local users.",
|
|
)
|
|
current_password = payload.get("current_password") if isinstance(payload, dict) else None
|
|
new_password = payload.get("new_password") if isinstance(payload, dict) else None
|
|
if not isinstance(current_password, str) or not isinstance(new_password, str):
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid payload")
|
|
error = _validate_password(new_password.strip())
|
|
if error:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=error)
|
|
user = verify_user_password(current_user["username"], current_password)
|
|
if not user:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Current password is incorrect")
|
|
set_user_password(current_user["username"], new_password.strip())
|
|
return {"status": "ok"}
|
|
|
|
|
|
@router.get("/contact")
|
|
async def get_contact(current_user: dict = Depends(get_current_user)) -> dict:
|
|
contact = get_user_contact(current_user["username"])
|
|
return {"contact": contact or {}}
|
|
|
|
|
|
@router.post("/contact")
|
|
async def update_contact(payload: dict, current_user: dict = Depends(get_current_user)) -> dict:
|
|
if not isinstance(payload, dict):
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid payload")
|
|
upsert_user_contact(
|
|
current_user["username"],
|
|
email=str(payload.get("email") or "").strip() or None,
|
|
discord=str(payload.get("discord") or "").strip() or None,
|
|
telegram=str(payload.get("telegram") or "").strip() or None,
|
|
matrix=str(payload.get("matrix") or "").strip() or None,
|
|
)
|
|
return {"status": "ok"}
|
|
|
|
|
|
@router.post("/register")
|
|
async def register(payload: dict, request: Request) -> dict:
|
|
runtime = get_runtime_settings()
|
|
if not runtime.invites_enabled:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invites are disabled")
|
|
invite_code = str(payload.get("invite_code") or "").strip()
|
|
username = str(payload.get("username") or "").strip()
|
|
password = str(payload.get("password") or "").strip()
|
|
contact = payload.get("contact") if isinstance(payload, dict) else None
|
|
captcha_token = str(payload.get("captcha_token") or "").strip()
|
|
if not invite_code or not username or not password:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invite, username, and password required")
|
|
if get_user_by_username(username):
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Username already exists")
|
|
invite = get_invite_by_code(invite_code)
|
|
if not invite or invite.get("disabled"):
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invite not found or disabled")
|
|
profile = None
|
|
if invite.get("profile_id"):
|
|
profile = get_invite_profile(int(invite["profile_id"]))
|
|
max_uses = invite.get("max_uses")
|
|
if max_uses is not None and invite.get("uses_count", 0) >= max_uses:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invite has been fully used")
|
|
expires_at = invite.get("expires_at")
|
|
if expires_at:
|
|
try:
|
|
if datetime.fromisoformat(expires_at) <= datetime.now(timezone.utc):
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invite has expired")
|
|
except ValueError:
|
|
pass
|
|
require_captcha = (
|
|
bool(invite.get("require_captcha"))
|
|
or (bool(profile.get("require_captcha")) if profile else False)
|
|
or runtime.invites_require_captcha
|
|
)
|
|
if require_captcha:
|
|
ok = await verify_captcha(captcha_token, request.client.host if request.client else None)
|
|
if not ok:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Captcha failed")
|
|
rules = invite.get("password_rules") or (profile.get("password_rules") if profile else None)
|
|
error = _validate_password(password, rules)
|
|
if error:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=error)
|
|
try:
|
|
create_user(username, password, role="user", auth_provider="local")
|
|
except Exception as exc:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
|
|
if isinstance(contact, dict):
|
|
upsert_user_contact(
|
|
username,
|
|
email=str(contact.get("email") or "").strip() or None,
|
|
discord=str(contact.get("discord") or "").strip() or None,
|
|
telegram=str(contact.get("telegram") or "").strip() or None,
|
|
matrix=str(contact.get("matrix") or "").strip() or None,
|
|
)
|
|
expiry_days = (
|
|
invite.get("user_expiry_days")
|
|
or (profile.get("user_expiry_days") if profile else None)
|
|
or runtime.expiry_default_days
|
|
)
|
|
expiry_action = (
|
|
invite.get("user_expiry_action")
|
|
or (profile.get("user_expiry_action") if profile else None)
|
|
or runtime.expiry_default_action
|
|
)
|
|
if expiry_days and expiry_action:
|
|
try:
|
|
expiry_days_float = float(expiry_days)
|
|
except (TypeError, ValueError):
|
|
expiry_days_float = 0
|
|
if expiry_days_float > 0:
|
|
expires_at = (
|
|
datetime.now(timezone.utc) + timedelta(days=expiry_days_float)
|
|
).isoformat()
|
|
set_user_expiry(username, expires_at, str(expiry_action))
|
|
increment_invite_use(invite_code)
|
|
token = create_access_token(username, "user")
|
|
return {"access_token": token, "token_type": "bearer", "user": {"username": username, "role": "user"}}
|
|
|
|
|
|
@router.post("/password/reset")
|
|
async def request_password_reset(payload: dict) -> dict:
|
|
runtime = get_runtime_settings()
|
|
if not runtime.password_reset_enabled:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Password reset disabled")
|
|
identifier = str(payload.get("identifier") or "").strip()
|
|
if not identifier:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Username or email required")
|
|
user = get_user_by_username(identifier)
|
|
if not user:
|
|
user = get_user_by_email(identifier)
|
|
if not user:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="User not found")
|
|
if user.get("auth_provider") != "local":
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Password reset for local users only")
|
|
token = secrets.token_urlsafe(32)
|
|
expires_at = (datetime.now(timezone.utc) + timedelta(hours=2)).isoformat()
|
|
create_password_reset(token, user["username"], expires_at)
|
|
contact = get_user_contact(user["username"])
|
|
email = contact.get("email") if isinstance(contact, dict) else None
|
|
if not runtime.notify_email_enabled or not email:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Email notifications are not configured for password resets.",
|
|
)
|
|
await send_notification(
|
|
"Password reset request",
|
|
f"Your reset token is: {token}",
|
|
channels=["email"],
|
|
email=email,
|
|
)
|
|
return {"status": "ok"}
|
|
|
|
|
|
@router.post("/password/reset/confirm")
|
|
async def confirm_password_reset(payload: dict) -> dict:
|
|
token = str(payload.get("token") or "").strip()
|
|
new_password = str(payload.get("new_password") or "").strip()
|
|
if not token or not new_password:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Token and new password required")
|
|
reset = get_password_reset(token)
|
|
if not reset:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Reset token not found")
|
|
if reset.get("used_at"):
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Reset token already used")
|
|
try:
|
|
expires_at = datetime.fromisoformat(reset["expires_at"])
|
|
if expires_at <= datetime.now(timezone.utc):
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Reset token expired")
|
|
except ValueError:
|
|
pass
|
|
error = _validate_password(new_password)
|
|
if error:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=error)
|
|
set_user_password(reset["username"], new_password)
|
|
mark_password_reset_used(token)
|
|
return {"status": "ok"}
|
|
|
|
|
|
@router.get("/signup/config")
|
|
async def signup_config() -> dict:
|
|
runtime = get_runtime_settings()
|
|
return {
|
|
"invites_enabled": runtime.invites_enabled,
|
|
"captcha_provider": runtime.captcha_provider,
|
|
"hcaptcha_site_key": runtime.hcaptcha_site_key,
|
|
"recaptcha_site_key": runtime.recaptcha_site_key,
|
|
"turnstile_site_key": runtime.turnstile_site_key,
|
|
"password_min_length": runtime.password_min_length,
|
|
"password_require_upper": runtime.password_require_upper,
|
|
"password_require_lower": runtime.password_require_lower,
|
|
"password_require_number": runtime.password_require_number,
|
|
"password_require_symbol": runtime.password_require_symbol,
|
|
}
|
|
|
|
|
|
@router.get("/referrals")
|
|
async def list_referrals(current_user: dict = Depends(get_current_user)) -> dict:
|
|
invites = list_invites_by_creator(current_user["username"], is_referral=True)
|
|
return {"invites": invites}
|
|
|
|
|
|
@router.post("/referrals")
|
|
async def create_referral(current_user: dict = Depends(get_current_user)) -> dict:
|
|
runtime = get_runtime_settings()
|
|
if not runtime.signup_allow_referrals:
|
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Referrals are disabled")
|
|
code = secrets.token_urlsafe(8)
|
|
create_invite(
|
|
code=code,
|
|
created_by=current_user["username"],
|
|
profile_id=runtime.invite_default_profile_id,
|
|
expires_at=None,
|
|
max_uses=int(runtime.referral_default_uses or 1),
|
|
require_captcha=runtime.invites_require_captcha,
|
|
password_rules=None,
|
|
allow_referrals=False,
|
|
referral_uses=None,
|
|
user_expiry_days=None,
|
|
user_expiry_action=None,
|
|
is_referral=True,
|
|
)
|
|
return {"status": "ok", "code": code}
|