146 lines
5.6 KiB
Python
146 lines
5.6 KiB
Python
import os
|
|
import tempfile
|
|
import unittest
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
from fastapi import HTTPException
|
|
from starlette.requests import Request
|
|
|
|
from backend.app import db
|
|
from backend.app.config import settings
|
|
from backend.app.routers import auth as auth_router
|
|
from backend.app.security import PASSWORD_POLICY_MESSAGE, validate_password_policy
|
|
from backend.app.services import password_reset
|
|
|
|
|
|
def _build_request(ip: str = "127.0.0.1", user_agent: str = "backend-test") -> Request:
|
|
scope = {
|
|
"type": "http",
|
|
"http_version": "1.1",
|
|
"method": "POST",
|
|
"scheme": "http",
|
|
"path": "/auth/password/forgot",
|
|
"raw_path": b"/auth/password/forgot",
|
|
"query_string": b"",
|
|
"headers": [(b"user-agent", user_agent.encode("utf-8"))],
|
|
"client": (ip, 12345),
|
|
"server": ("testserver", 8000),
|
|
}
|
|
|
|
async def receive() -> dict:
|
|
return {"type": "http.request", "body": b"", "more_body": False}
|
|
|
|
return Request(scope, receive)
|
|
|
|
|
|
class TempDatabaseMixin:
|
|
def setUp(self) -> None:
|
|
super_method = getattr(super(), "setUp", None)
|
|
if callable(super_method):
|
|
super_method()
|
|
self._tempdir = tempfile.TemporaryDirectory(ignore_cleanup_errors=True)
|
|
self._original_sqlite_path = settings.sqlite_path
|
|
self._original_journal_mode = getattr(settings, "sqlite_journal_mode", "DELETE")
|
|
settings.sqlite_path = os.path.join(self._tempdir.name, "test.db")
|
|
settings.sqlite_journal_mode = "DELETE"
|
|
auth_router._LOGIN_ATTEMPTS_BY_IP.clear()
|
|
auth_router._LOGIN_ATTEMPTS_BY_USER.clear()
|
|
auth_router._RESET_ATTEMPTS_BY_IP.clear()
|
|
auth_router._RESET_ATTEMPTS_BY_IDENTIFIER.clear()
|
|
db.init_db()
|
|
|
|
def tearDown(self) -> None:
|
|
settings.sqlite_path = self._original_sqlite_path
|
|
settings.sqlite_journal_mode = self._original_journal_mode
|
|
auth_router._LOGIN_ATTEMPTS_BY_IP.clear()
|
|
auth_router._LOGIN_ATTEMPTS_BY_USER.clear()
|
|
auth_router._RESET_ATTEMPTS_BY_IP.clear()
|
|
auth_router._RESET_ATTEMPTS_BY_IDENTIFIER.clear()
|
|
self._tempdir.cleanup()
|
|
super_method = getattr(super(), "tearDown", None)
|
|
if callable(super_method):
|
|
super_method()
|
|
|
|
|
|
class PasswordPolicyTests(unittest.TestCase):
|
|
def test_validate_password_policy_rejects_short_passwords(self) -> None:
|
|
with self.assertRaisesRegex(ValueError, PASSWORD_POLICY_MESSAGE):
|
|
validate_password_policy("short")
|
|
|
|
def test_validate_password_policy_trims_whitespace(self) -> None:
|
|
self.assertEqual(validate_password_policy(" password123 "), "password123")
|
|
|
|
|
|
class DatabaseEmailTests(TempDatabaseMixin, unittest.TestCase):
|
|
def test_set_user_email_is_case_insensitive(self) -> None:
|
|
created = db.create_user_if_missing(
|
|
"MixedCaseUser",
|
|
"password123",
|
|
email=None,
|
|
auth_provider="local",
|
|
)
|
|
self.assertTrue(created)
|
|
updated = db.set_user_email("mixedcaseuser", "mixed@example.com")
|
|
self.assertTrue(updated)
|
|
stored = db.get_user_by_username("MIXEDCASEUSER")
|
|
self.assertIsNotNone(stored)
|
|
self.assertEqual(stored.get("email"), "mixed@example.com")
|
|
|
|
|
|
class AuthFlowTests(TempDatabaseMixin, unittest.IsolatedAsyncioTestCase):
|
|
async def test_forgot_password_is_rate_limited(self) -> None:
|
|
request = _build_request(ip="10.1.2.3")
|
|
payload = {"identifier": "resetuser@example.com"}
|
|
with patch.object(auth_router, "smtp_email_config_ready", return_value=(True, "")), patch.object(
|
|
auth_router,
|
|
"request_password_reset",
|
|
new=AsyncMock(return_value={"status": "ok", "issued": False}),
|
|
):
|
|
for _ in range(3):
|
|
result = await auth_router.forgot_password(payload, request)
|
|
self.assertEqual(result["status"], "ok")
|
|
|
|
with self.assertRaises(HTTPException) as context:
|
|
await auth_router.forgot_password(payload, request)
|
|
|
|
self.assertEqual(context.exception.status_code, 429)
|
|
self.assertEqual(
|
|
context.exception.detail,
|
|
"Too many password reset attempts. Try again shortly.",
|
|
)
|
|
|
|
async def test_request_password_reset_prefers_local_user_email(self) -> None:
|
|
db.create_user_if_missing(
|
|
"ResetUser",
|
|
"password123",
|
|
email="local@example.com",
|
|
auth_provider="local",
|
|
)
|
|
with patch.object(
|
|
password_reset,
|
|
"send_password_reset_email",
|
|
new=AsyncMock(return_value={"status": "ok"}),
|
|
) as send_email:
|
|
result = await password_reset.request_password_reset("ResetUser")
|
|
|
|
self.assertTrue(result["issued"])
|
|
self.assertEqual(result["recipient_email"], "local@example.com")
|
|
send_email.assert_awaited_once()
|
|
self.assertEqual(send_email.await_args.kwargs["recipient_email"], "local@example.com")
|
|
|
|
async def test_profile_invite_requires_recipient_email(self) -> None:
|
|
current_user = {
|
|
"username": "invite-owner",
|
|
"role": "user",
|
|
"invite_management_enabled": True,
|
|
"profile_id": None,
|
|
}
|
|
with self.assertRaises(HTTPException) as context:
|
|
await auth_router.create_profile_invite({"label": "Missing email"}, current_user)
|
|
|
|
self.assertEqual(context.exception.status_code, 400)
|
|
self.assertEqual(
|
|
context.exception.detail,
|
|
"recipient_email is required and must be a valid email address.",
|
|
)
|