import os import tempfile import unittest from unittest.mock import AsyncMock, patch from fastapi import HTTPException from starlette.requests import Request from backend.app import db from backend.app.config import settings from backend.app.network_security import request_trusts_forwarded_headers, validate_notification_target_url from backend.app.routers import auth as auth_router from backend.app.routers import portal as portal_router from backend.app.security import PASSWORD_POLICY_MESSAGE, validate_password_policy from backend.app.services import password_reset def _build_request(ip: str = "127.0.0.1", user_agent: str = "backend-test") -> Request: scope = { "type": "http", "http_version": "1.1", "method": "POST", "scheme": "http", "path": "/auth/password/forgot", "raw_path": b"/auth/password/forgot", "query_string": b"", "headers": [(b"user-agent", user_agent.encode("utf-8"))], "client": (ip, 12345), "server": ("testserver", 8000), } async def receive() -> dict: return {"type": "http.request", "body": b"", "more_body": False} return Request(scope, receive) class TempDatabaseMixin: def setUp(self) -> None: super_method = getattr(super(), "setUp", None) if callable(super_method): super_method() self._tempdir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True) self._original_sqlite_path = settings.sqlite_path self._original_journal_mode = getattr(settings, "sqlite_journal_mode", "DELETE") settings.sqlite_path = os.path.join(self._tempdir.name, "test.db") settings.sqlite_journal_mode = "DELETE" auth_router._LOGIN_ATTEMPTS_BY_IP.clear() auth_router._LOGIN_ATTEMPTS_BY_USER.clear() auth_router._RESET_ATTEMPTS_BY_IP.clear() auth_router._RESET_ATTEMPTS_BY_IDENTIFIER.clear() db.init_db() def tearDown(self) -> None: settings.sqlite_path = self._original_sqlite_path settings.sqlite_journal_mode = self._original_journal_mode auth_router._LOGIN_ATTEMPTS_BY_IP.clear() auth_router._LOGIN_ATTEMPTS_BY_USER.clear() auth_router._RESET_ATTEMPTS_BY_IP.clear() auth_router._RESET_ATTEMPTS_BY_IDENTIFIER.clear() self._tempdir.cleanup() super_method = getattr(super(), "tearDown", None) if callable(super_method): super_method() class PasswordPolicyTests(unittest.TestCase): def test_validate_password_policy_rejects_short_passwords(self) -> None: with self.assertRaisesRegex(ValueError, PASSWORD_POLICY_MESSAGE): validate_password_policy("short") def test_validate_password_policy_trims_whitespace(self) -> None: self.assertEqual(validate_password_policy(" password123 "), "password123") class NetworkSecurityTests(unittest.TestCase): def test_notification_targets_reject_loopback(self) -> None: with self.assertRaisesRegex(ValueError, "Private or local notification targets are not allowed."): validate_notification_target_url("http://127.0.0.1:8080/webhook") def test_forwarded_headers_require_trusted_proxy(self) -> None: original_enabled = settings.magent_proxy_enabled original_trust = settings.magent_proxy_trust_forwarded_headers original_proxies = settings.magent_proxy_trusted_proxies settings.magent_proxy_enabled = True settings.magent_proxy_trust_forwarded_headers = True settings.magent_proxy_trusted_proxies = "127.0.0.1,::1" try: self.assertTrue(request_trusts_forwarded_headers("127.0.0.1")) self.assertFalse(request_trusts_forwarded_headers("203.0.113.10")) finally: settings.magent_proxy_enabled = original_enabled settings.magent_proxy_trust_forwarded_headers = original_trust settings.magent_proxy_trusted_proxies = original_proxies class DatabaseEmailTests(TempDatabaseMixin, unittest.TestCase): def test_set_user_email_is_case_insensitive(self) -> None: created = db.create_user_if_missing( "MixedCaseUser", "password123", email=None, auth_provider="local", ) self.assertTrue(created) updated = db.set_user_email("mixedcaseuser", "mixed@example.com") self.assertTrue(updated) stored = db.get_user_by_username("MIXEDCASEUSER") self.assertIsNotNone(stored) self.assertEqual(stored.get("email"), "mixed@example.com") class AuthFlowTests(TempDatabaseMixin, unittest.IsolatedAsyncioTestCase): async def test_forgot_password_is_rate_limited(self) -> None: request = _build_request(ip="10.1.2.3") payload = {"identifier": "resetuser@example.com"} with patch.object(auth_router, "smtp_email_config_ready", return_value=(True, "")), patch.object( auth_router, "request_password_reset", new=AsyncMock(return_value={"status": "ok", "issued": False}), ): for _ in range(3): result = await auth_router.forgot_password(payload, request) self.assertEqual(result["status"], "ok") with self.assertRaises(HTTPException) as context: await auth_router.forgot_password(payload, request) self.assertEqual(context.exception.status_code, 429) self.assertEqual( context.exception.detail, "Too many password reset attempts. Try again shortly.", ) async def test_request_password_reset_prefers_local_user_email(self) -> None: db.create_user_if_missing( "ResetUser", "password123", email="local@example.com", auth_provider="local", ) with patch.object( password_reset, "send_password_reset_email", new=AsyncMock(return_value={"status": "ok"}), ) as send_email: result = await password_reset.request_password_reset("ResetUser") self.assertTrue(result["issued"]) self.assertEqual(result["recipient_email"], "local@example.com") send_email.assert_awaited_once() self.assertEqual(send_email.await_args.kwargs["recipient_email"], "local@example.com") async def test_profile_invite_requires_recipient_email(self) -> None: current_user = { "username": "invite-owner", "role": "user", "invite_management_enabled": True, "profile_id": None, } with self.assertRaises(HTTPException) as context: await auth_router.create_profile_invite({"label": "Missing email"}, current_user) self.assertEqual(context.exception.status_code, 400) self.assertEqual( context.exception.detail, "recipient_email is required and must be a valid email address.", ) class PortalWorkflowTests(TempDatabaseMixin, unittest.TestCase): def test_legacy_request_status_maps_to_workflow(self) -> None: item = {"kind": "request", "status": "in_progress"} serialized = portal_router._serialize_item(item, {"username": "tester", "role": "user"}) workflow = serialized.get("workflow") or {} self.assertEqual(workflow.get("request_status"), "approved") self.assertEqual(workflow.get("media_status"), "processing") def test_invalid_pipeline_transition_is_rejected(self) -> None: with self.assertRaises(HTTPException) as context: portal_router._validate_pipeline_transition( "approved", "processing", "pending", "pending", ) self.assertEqual(context.exception.status_code, 400) def test_portal_workflow_filters(self) -> None: db.create_portal_item( kind="request", title="Request A", description="A", created_by_username="alpha", created_by_id=None, status="processing", workflow_request_status="approved", workflow_media_status="processing", ) db.create_portal_item( kind="request", title="Request B", description="B", created_by_username="bravo", created_by_id=None, status="pending", workflow_request_status="pending", workflow_media_status="pending", ) processing = db.list_portal_items( kind="request", workflow_request_status="approved", workflow_media_status="processing", limit=10, offset=0, ) pending_count = db.count_portal_items( kind="request", workflow_request_status="pending", workflow_media_status="pending", ) self.assertEqual(len(processing), 1) self.assertEqual(pending_count, 1)