Files
Magent/backend/app/db.py

776 lines
23 KiB
Python

import json
import os
import sqlite3
import logging
from datetime import datetime, timezone, timedelta
from typing import Any, Dict, Optional
from .config import settings
from .models import Snapshot
from .security import hash_password, verify_password
logger = logging.getLogger(__name__)
def _db_path() -> str:
path = settings.sqlite_path or "data/magent.db"
if not os.path.isabs(path):
path = os.path.join(os.getcwd(), path)
os.makedirs(os.path.dirname(path), exist_ok=True)
return path
def _connect() -> sqlite3.Connection:
return sqlite3.connect(_db_path())
def init_db() -> None:
with _connect() as conn:
conn.execute(
"""
CREATE TABLE IF NOT EXISTS snapshots (
id INTEGER PRIMARY KEY AUTOINCREMENT,
request_id TEXT NOT NULL,
state TEXT NOT NULL,
state_reason TEXT,
created_at TEXT NOT NULL,
payload_json TEXT NOT NULL
)
"""
)
conn.execute(
"""
CREATE TABLE IF NOT EXISTS actions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
request_id TEXT NOT NULL,
action_id TEXT NOT NULL,
label TEXT NOT NULL,
status TEXT NOT NULL,
message TEXT,
created_at TEXT NOT NULL
)
"""
)
conn.execute(
"""
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username TEXT NOT NULL UNIQUE,
password_hash TEXT NOT NULL,
role TEXT NOT NULL,
auth_provider TEXT NOT NULL DEFAULT 'local',
created_at TEXT NOT NULL,
last_login_at TEXT,
is_blocked INTEGER NOT NULL DEFAULT 0
)
"""
)
conn.execute(
"""
CREATE TABLE IF NOT EXISTS settings (
key TEXT PRIMARY KEY,
value TEXT,
updated_at TEXT NOT NULL
)
"""
)
conn.execute(
"""
CREATE TABLE IF NOT EXISTS requests_cache (
request_id INTEGER PRIMARY KEY,
media_id INTEGER,
media_type TEXT,
status INTEGER,
title TEXT,
year INTEGER,
requested_by TEXT,
requested_by_norm TEXT,
created_at TEXT,
updated_at TEXT,
payload_json TEXT NOT NULL
)
"""
)
conn.execute(
"""
CREATE INDEX IF NOT EXISTS idx_requests_cache_created_at
ON requests_cache (created_at)
"""
)
conn.execute(
"""
CREATE INDEX IF NOT EXISTS idx_requests_cache_requested_by_norm
ON requests_cache (requested_by_norm)
"""
)
try:
conn.execute("ALTER TABLE users ADD COLUMN last_login_at TEXT")
except sqlite3.OperationalError:
pass
try:
conn.execute("ALTER TABLE users ADD COLUMN is_blocked INTEGER NOT NULL DEFAULT 0")
except sqlite3.OperationalError:
pass
try:
conn.execute("ALTER TABLE users ADD COLUMN auth_provider TEXT NOT NULL DEFAULT 'local'")
except sqlite3.OperationalError:
pass
_backfill_auth_providers()
ensure_admin_user()
def save_snapshot(snapshot: Snapshot) -> None:
payload = json.dumps(snapshot.model_dump(), ensure_ascii=True)
created_at = datetime.now(timezone.utc).isoformat()
with _connect() as conn:
conn.execute(
"""
INSERT INTO snapshots (request_id, state, state_reason, created_at, payload_json)
VALUES (?, ?, ?, ?, ?)
""",
(
snapshot.request_id,
snapshot.state.value,
snapshot.state_reason,
created_at,
payload,
),
)
def save_action(
request_id: str,
action_id: str,
label: str,
status: str,
message: Optional[str] = None,
) -> None:
created_at = datetime.now(timezone.utc).isoformat()
with _connect() as conn:
conn.execute(
"""
INSERT INTO actions (request_id, action_id, label, status, message, created_at)
VALUES (?, ?, ?, ?, ?, ?)
""",
(request_id, action_id, label, status, message, created_at),
)
def get_recent_snapshots(request_id: str, limit: int = 10) -> list[dict[str, Any]]:
with _connect() as conn:
rows = conn.execute(
"""
SELECT request_id, state, state_reason, created_at, payload_json
FROM snapshots
WHERE request_id = ?
ORDER BY id DESC
LIMIT ?
""",
(request_id, limit),
).fetchall()
results = []
for row in rows:
results.append(
{
"request_id": row[0],
"state": row[1],
"state_reason": row[2],
"created_at": row[3],
"payload": json.loads(row[4]),
}
)
return results
def get_recent_actions(request_id: str, limit: int = 10) -> list[dict[str, Any]]:
with _connect() as conn:
rows = conn.execute(
"""
SELECT request_id, action_id, label, status, message, created_at
FROM actions
WHERE request_id = ?
ORDER BY id DESC
LIMIT ?
""",
(request_id, limit),
).fetchall()
results = []
for row in rows:
results.append(
{
"request_id": row[0],
"action_id": row[1],
"label": row[2],
"status": row[3],
"message": row[4],
"created_at": row[5],
}
)
return results
def ensure_admin_user() -> None:
if not settings.admin_username or not settings.admin_password:
return
existing = get_user_by_username(settings.admin_username)
if existing:
return
create_user(settings.admin_username, settings.admin_password, role="admin")
def create_user(username: str, password: str, role: str = "user", auth_provider: str = "local") -> None:
created_at = datetime.now(timezone.utc).isoformat()
password_hash = hash_password(password)
with _connect() as conn:
conn.execute(
"""
INSERT INTO users (username, password_hash, role, auth_provider, created_at)
VALUES (?, ?, ?, ?, ?)
""",
(username, password_hash, role, auth_provider, created_at),
)
def create_user_if_missing(
username: str, password: str, role: str = "user", auth_provider: str = "local"
) -> bool:
created_at = datetime.now(timezone.utc).isoformat()
password_hash = hash_password(password)
with _connect() as conn:
cursor = conn.execute(
"""
INSERT OR IGNORE INTO users (username, password_hash, role, auth_provider, created_at)
VALUES (?, ?, ?, ?, ?)
""",
(username, password_hash, role, auth_provider, created_at),
)
return cursor.rowcount > 0
def get_user_by_username(username: str) -> Optional[Dict[str, Any]]:
with _connect() as conn:
row = conn.execute(
"""
SELECT id, username, password_hash, role, auth_provider, created_at, last_login_at, is_blocked
FROM users
WHERE username = ?
""",
(username,),
).fetchone()
if not row:
return None
return {
"id": row[0],
"username": row[1],
"password_hash": row[2],
"role": row[3],
"auth_provider": row[4],
"created_at": row[5],
"last_login_at": row[6],
"is_blocked": bool(row[7]),
}
def get_all_users() -> list[Dict[str, Any]]:
with _connect() as conn:
rows = conn.execute(
"""
SELECT id, username, role, auth_provider, created_at, last_login_at, is_blocked
FROM users
ORDER BY username COLLATE NOCASE
"""
).fetchall()
results: list[Dict[str, Any]] = []
for row in rows:
results.append(
{
"id": row[0],
"username": row[1],
"role": row[2],
"auth_provider": row[3],
"created_at": row[4],
"last_login_at": row[5],
"is_blocked": bool(row[6]),
}
)
return results
def set_last_login(username: str) -> None:
timestamp = datetime.now(timezone.utc).isoformat()
with _connect() as conn:
conn.execute(
"""
UPDATE users SET last_login_at = ? WHERE username = ?
""",
(timestamp, username),
)
def set_user_blocked(username: str, blocked: bool) -> None:
with _connect() as conn:
conn.execute(
"""
UPDATE users SET is_blocked = ? WHERE username = ?
""",
(1 if blocked else 0, username),
)
def set_user_role(username: str, role: str) -> None:
with _connect() as conn:
conn.execute(
"""
UPDATE users SET role = ? WHERE username = ?
""",
(role, username),
)
def verify_user_password(username: str, password: str) -> Optional[Dict[str, Any]]:
user = get_user_by_username(username)
if not user:
return None
if not verify_password(password, user["password_hash"]):
return None
return user
def set_user_password(username: str, password: str) -> None:
password_hash = hash_password(password)
with _connect() as conn:
conn.execute(
"""
UPDATE users SET password_hash = ? WHERE username = ?
""",
(password_hash, username),
)
def _backfill_auth_providers() -> None:
with _connect() as conn:
rows = conn.execute(
"""
SELECT username, password_hash, auth_provider
FROM users
"""
).fetchall()
updates: list[tuple[str, str]] = []
for row in rows:
username, password_hash, auth_provider = row
provider = auth_provider or "local"
if provider == "local":
if verify_password("jellyfin-user", password_hash):
provider = "jellyfin"
elif verify_password("jellyseerr-user", password_hash):
provider = "jellyseerr"
if provider != auth_provider:
updates.append((provider, username))
if not updates:
return
with _connect() as conn:
conn.executemany(
"""
UPDATE users SET auth_provider = ? WHERE username = ?
""",
updates,
)
def upsert_request_cache(
request_id: int,
media_id: Optional[int],
media_type: Optional[str],
status: Optional[int],
title: Optional[str],
year: Optional[int],
requested_by: Optional[str],
requested_by_norm: Optional[str],
created_at: Optional[str],
updated_at: Optional[str],
payload_json: str,
) -> None:
with _connect() as conn:
conn.execute(
"""
INSERT INTO requests_cache (
request_id,
media_id,
media_type,
status,
title,
year,
requested_by,
requested_by_norm,
created_at,
updated_at,
payload_json
)
VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(request_id) DO UPDATE SET
media_id = excluded.media_id,
media_type = excluded.media_type,
status = excluded.status,
title = excluded.title,
year = excluded.year,
requested_by = excluded.requested_by,
requested_by_norm = excluded.requested_by_norm,
created_at = excluded.created_at,
updated_at = excluded.updated_at,
payload_json = excluded.payload_json
""",
(
request_id,
media_id,
media_type,
status,
title,
year,
requested_by,
requested_by_norm,
created_at,
updated_at,
payload_json,
),
)
logger.debug(
"requests_cache upsert: request_id=%s media_id=%s status=%s updated_at=%s",
request_id,
media_id,
status,
updated_at,
)
def get_request_cache_last_updated() -> Optional[str]:
with _connect() as conn:
row = conn.execute(
"""
SELECT MAX(updated_at) FROM requests_cache
"""
).fetchone()
if not row:
return None
return row[0]
def get_request_cache_by_id(request_id: int) -> Optional[Dict[str, Any]]:
with _connect() as conn:
row = conn.execute(
"""
SELECT request_id, updated_at, title
FROM requests_cache
WHERE request_id = ?
""",
(request_id,),
).fetchone()
if not row:
logger.debug("requests_cache miss: request_id=%s", request_id)
return None
logger.debug("requests_cache hit: request_id=%s updated_at=%s", row[0], row[1])
return {"request_id": row[0], "updated_at": row[1], "title": row[2]}
def get_request_cache_payload(request_id: int) -> Optional[Dict[str, Any]]:
with _connect() as conn:
row = conn.execute(
"""
SELECT payload_json
FROM requests_cache
WHERE request_id = ?
""",
(request_id,),
).fetchone()
if not row or not row[0]:
logger.debug("requests_cache payload miss: request_id=%s", request_id)
return None
try:
payload = json.loads(row[0])
logger.debug("requests_cache payload hit: request_id=%s", request_id)
return payload
except json.JSONDecodeError:
logger.warning("requests_cache payload invalid json: request_id=%s", request_id)
return None
def get_cached_requests(
limit: int,
offset: int,
requested_by_norm: Optional[str] = None,
since_iso: Optional[str] = None,
) -> list[Dict[str, Any]]:
query = """
SELECT request_id, media_id, media_type, status, title, year, requested_by, created_at
FROM requests_cache
"""
params: list[Any] = []
conditions = []
if requested_by_norm:
conditions.append("requested_by_norm = ?")
params.append(requested_by_norm)
if since_iso:
conditions.append("created_at >= ?")
params.append(since_iso)
if conditions:
query += " WHERE " + " AND ".join(conditions)
query += " ORDER BY created_at DESC, request_id DESC LIMIT ? OFFSET ?"
params.extend([limit, offset])
with _connect() as conn:
rows = conn.execute(query, tuple(params)).fetchall()
logger.debug(
"requests_cache list: count=%s requested_by_norm=%s since_iso=%s",
len(rows),
requested_by_norm,
since_iso,
)
results: list[Dict[str, Any]] = []
for row in rows:
results.append(
{
"request_id": row[0],
"media_id": row[1],
"media_type": row[2],
"status": row[3],
"title": row[4],
"year": row[5],
"requested_by": row[6],
"created_at": row[7],
}
)
return results
def get_request_cache_overview(limit: int = 50) -> list[Dict[str, Any]]:
limit = max(1, min(limit, 200))
with _connect() as conn:
rows = conn.execute(
"""
SELECT request_id, media_id, media_type, status, title, year, requested_by, created_at, updated_at, payload_json
FROM requests_cache
ORDER BY updated_at DESC, request_id DESC
LIMIT ?
""",
(limit,),
).fetchall()
results: list[Dict[str, Any]] = []
for row in rows:
title = row[4]
if not title and row[9]:
try:
payload = json.loads(row[9])
if isinstance(payload, dict):
media = payload.get("media") or {}
title = (
(media.get("title") if isinstance(media, dict) else None)
or (media.get("name") if isinstance(media, dict) else None)
or payload.get("title")
or payload.get("name")
)
except json.JSONDecodeError:
title = row[4]
results.append(
{
"request_id": row[0],
"media_id": row[1],
"media_type": row[2],
"status": row[3],
"title": title,
"year": row[5],
"requested_by": row[6],
"created_at": row[7],
"updated_at": row[8],
}
)
return results
def get_request_cache_count() -> int:
with _connect() as conn:
row = conn.execute("SELECT COUNT(*) FROM requests_cache").fetchone()
return int(row[0] or 0)
def prune_duplicate_requests_cache() -> int:
with _connect() as conn:
cursor = conn.execute(
"""
DELETE FROM requests_cache
WHERE media_id IS NOT NULL
AND request_id NOT IN (
SELECT MAX(request_id)
FROM requests_cache
WHERE media_id IS NOT NULL
GROUP BY media_id, COALESCE(requested_by_norm, '')
)
"""
)
return cursor.rowcount
def get_request_cache_payloads(limit: int = 200, offset: int = 0) -> list[Dict[str, Any]]:
limit = max(1, min(limit, 1000))
offset = max(0, offset)
with _connect() as conn:
rows = conn.execute(
"""
SELECT request_id, payload_json
FROM requests_cache
ORDER BY request_id ASC
LIMIT ? OFFSET ?
""",
(limit, offset),
).fetchall()
results: list[Dict[str, Any]] = []
for row in rows:
payload = None
if row[1]:
try:
payload = json.loads(row[1])
except json.JSONDecodeError:
payload = None
results.append({"request_id": row[0], "payload": payload})
return results
def get_cached_requests_since(since_iso: str) -> list[Dict[str, Any]]:
with _connect() as conn:
rows = conn.execute(
"""
SELECT request_id, media_id, media_type, status, title, year, requested_by, requested_by_norm, created_at
FROM requests_cache
WHERE created_at >= ?
ORDER BY created_at DESC, request_id DESC
""",
(since_iso,),
).fetchall()
results: list[Dict[str, Any]] = []
for row in rows:
results.append(
{
"request_id": row[0],
"media_id": row[1],
"media_type": row[2],
"status": row[3],
"title": row[4],
"year": row[5],
"requested_by": row[6],
"requested_by_norm": row[7],
"created_at": row[8],
}
)
return results
def get_cached_request_by_media_id(
media_id: int, requested_by_norm: Optional[str] = None
) -> Optional[Dict[str, Any]]:
query = """
SELECT request_id, status
FROM requests_cache
WHERE media_id = ?
"""
params: list[Any] = [media_id]
if requested_by_norm:
query += " AND requested_by_norm = ?"
params.append(requested_by_norm)
query += " ORDER BY created_at DESC, request_id DESC LIMIT 1"
with _connect() as conn:
row = conn.execute(query, tuple(params)).fetchone()
if not row:
return None
return {"request_id": row[0], "status": row[1]}
def get_setting(key: str) -> Optional[str]:
with _connect() as conn:
row = conn.execute(
"""
SELECT value FROM settings WHERE key = ?
""",
(key,),
).fetchone()
if not row:
return None
return row[0]
def set_setting(key: str, value: Optional[str]) -> None:
updated_at = datetime.now(timezone.utc).isoformat()
with _connect() as conn:
conn.execute(
"""
INSERT INTO settings (key, value, updated_at)
VALUES (?, ?, ?)
ON CONFLICT(key) DO UPDATE SET value = excluded.value, updated_at = excluded.updated_at
""",
(key, value, updated_at),
)
def delete_setting(key: str) -> None:
with _connect() as conn:
conn.execute(
"""
DELETE FROM settings WHERE key = ?
""",
(key,),
)
def get_settings_overrides() -> Dict[str, str]:
with _connect() as conn:
rows = conn.execute(
"""
SELECT key, value FROM settings
"""
).fetchall()
overrides: Dict[str, str] = {}
for row in rows:
key = row[0]
value = row[1]
if key:
overrides[key] = value
return overrides
def run_integrity_check() -> str:
with _connect() as conn:
row = conn.execute("PRAGMA integrity_check").fetchone()
if not row:
return "unknown"
return str(row[0])
def vacuum_db() -> None:
with _connect() as conn:
conn.execute("VACUUM")
def clear_requests_cache() -> int:
with _connect() as conn:
cursor = conn.execute("DELETE FROM requests_cache")
return cursor.rowcount
def clear_history() -> Dict[str, int]:
with _connect() as conn:
actions = conn.execute("DELETE FROM actions").rowcount
snapshots = conn.execute("DELETE FROM snapshots").rowcount
return {"actions": actions, "snapshots": snapshots}
def cleanup_history(days: int) -> Dict[str, int]:
if days <= 0:
return {"actions": 0, "snapshots": 0}
cutoff = (datetime.now(timezone.utc) - timedelta(days=days)).isoformat()
with _connect() as conn:
actions = conn.execute(
"DELETE FROM actions WHERE created_at < ?",
(cutoff,),
).rowcount
snapshots = conn.execute(
"DELETE FROM snapshots WHERE created_at < ?",
(cutoff,),
).rowcount
return {"actions": actions, "snapshots": snapshots}