import os import tempfile import unittest from unittest.mock import AsyncMock, patch from fastapi import HTTPException from starlette.requests import Request from backend.app import db from backend.app.config import settings from backend.app.routers import auth as auth_router from backend.app.security import PASSWORD_POLICY_MESSAGE, validate_password_policy from backend.app.services import password_reset def _build_request(ip: str = "127.0.0.1", user_agent: str = "backend-test") -> Request: scope = { "type": "http", "http_version": "1.1", "method": "POST", "scheme": "http", "path": "/auth/password/forgot", "raw_path": b"/auth/password/forgot", "query_string": b"", "headers": [(b"user-agent", user_agent.encode("utf-8"))], "client": (ip, 12345), "server": ("testserver", 8000), } async def receive() -> dict: return {"type": "http.request", "body": b"", "more_body": False} return Request(scope, receive) class TempDatabaseMixin: def setUp(self) -> None: super_method = getattr(super(), "setUp", None) if callable(super_method): super_method() self._tempdir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True) self._original_sqlite_path = settings.sqlite_path self._original_journal_mode = getattr(settings, "sqlite_journal_mode", "DELETE") settings.sqlite_path = os.path.join(self._tempdir.name, "test.db") settings.sqlite_journal_mode = "DELETE" auth_router._LOGIN_ATTEMPTS_BY_IP.clear() auth_router._LOGIN_ATTEMPTS_BY_USER.clear() auth_router._RESET_ATTEMPTS_BY_IP.clear() auth_router._RESET_ATTEMPTS_BY_IDENTIFIER.clear() db.init_db() def tearDown(self) -> None: settings.sqlite_path = self._original_sqlite_path settings.sqlite_journal_mode = self._original_journal_mode auth_router._LOGIN_ATTEMPTS_BY_IP.clear() auth_router._LOGIN_ATTEMPTS_BY_USER.clear() auth_router._RESET_ATTEMPTS_BY_IP.clear() auth_router._RESET_ATTEMPTS_BY_IDENTIFIER.clear() self._tempdir.cleanup() super_method = getattr(super(), "tearDown", None) if callable(super_method): super_method() class PasswordPolicyTests(unittest.TestCase): def test_validate_password_policy_rejects_short_passwords(self) -> None: with self.assertRaisesRegex(ValueError, PASSWORD_POLICY_MESSAGE): validate_password_policy("short") def test_validate_password_policy_trims_whitespace(self) -> None: self.assertEqual(validate_password_policy(" password123 "), "password123") class DatabaseEmailTests(TempDatabaseMixin, unittest.TestCase): def test_set_user_email_is_case_insensitive(self) -> None: created = db.create_user_if_missing( "MixedCaseUser", "password123", email=None, auth_provider="local", ) self.assertTrue(created) updated = db.set_user_email("mixedcaseuser", "mixed@example.com") self.assertTrue(updated) stored = db.get_user_by_username("MIXEDCASEUSER") self.assertIsNotNone(stored) self.assertEqual(stored.get("email"), "mixed@example.com") class AuthFlowTests(TempDatabaseMixin, unittest.IsolatedAsyncioTestCase): async def test_forgot_password_is_rate_limited(self) -> None: request = _build_request(ip="10.1.2.3") payload = {"identifier": "resetuser@example.com"} with patch.object(auth_router, "smtp_email_config_ready", return_value=(True, "")), patch.object( auth_router, "request_password_reset", new=AsyncMock(return_value={"status": "ok", "issued": False}), ): for _ in range(3): result = await auth_router.forgot_password(payload, request) self.assertEqual(result["status"], "ok") with self.assertRaises(HTTPException) as context: await auth_router.forgot_password(payload, request) self.assertEqual(context.exception.status_code, 429) self.assertEqual( context.exception.detail, "Too many password reset attempts. Try again shortly.", ) async def test_request_password_reset_prefers_local_user_email(self) -> None: db.create_user_if_missing( "ResetUser", "password123", email="local@example.com", auth_provider="local", ) with patch.object( password_reset, "send_password_reset_email", new=AsyncMock(return_value={"status": "ok"}), ) as send_email: result = await password_reset.request_password_reset("ResetUser") self.assertTrue(result["issued"]) self.assertEqual(result["recipient_email"], "local@example.com") send_email.assert_awaited_once() self.assertEqual(send_email.await_args.kwargs["recipient_email"], "local@example.com") async def test_profile_invite_requires_recipient_email(self) -> None: current_user = { "username": "invite-owner", "role": "user", "invite_management_enabled": True, "profile_id": None, } with self.assertRaises(HTTPException) as context: await auth_router.create_profile_invite({"label": "Missing email"}, current_user) self.assertEqual(context.exception.status_code, 400) self.assertEqual( context.exception.detail, "recipient_email is required and must be a valid email address.", )