import contextlib
import copy
import logging
import os
import signal
import subprocess
import sys
import time
import uuid
import warnings
from collections import Counter
from collections.abc import Iterator, Sequence
from contextlib import redirect_stderr
from datetime import datetime, timedelta
from functools import partial
from io import StringIO
from typing import Any, cast
from unittest import mock, skipIf

import django
from django import VERSION
from django.core.exceptions import ImproperlyConfigured, SuspiciousOperation
from django.core.management import call_command, execute_from_command_line
from django.db import connection, connections, transaction
from django.db.models import QuerySet
from django.db.utils import IntegrityError, OperationalError
from django.test import SimpleTestCase, TransactionTestCase, override_settings
from django.test.testcases import _deferredSkip  # type:ignore[attr-defined]
from django.utils import timezone
from django_tasks import (
    TaskResultStatus,
    default_task_backend,
    task_backends,
)
from django_tasks.base import Task
from django_tasks.exceptions import InvalidTaskError, TaskResultDoesNotExist
from django_tasks.signals import task_enqueued
from django_tasks.utils import get_random_id

from django_tasks_db import DatabaseBackend, compat
from django_tasks_db.management.commands.prune_db_task_results import (
    logger as prune_db_tasks_logger,
)
from django_tasks_db.models import DBTaskResult
from django_tasks_db.utils import (
    connection_requires_manual_exclusive_transaction,
    exclusive_transaction,
    normalize_uuid,
)
from tests import tasks as test_tasks


def skipIfInMemoryDB() -> Any:  # noqa:N802
    return _deferredSkip(
        lambda: connection.vendor == "sqlite" and connection.is_in_memory_db(),  # type:ignore[attr-defined]
        "Tests cannot run on in-memory DB",
        "skipIfInMemoryDB",
    )


class DatabaseBackendTestCase(TransactionTestCase):
    @contextlib.contextmanager
    def _capture_task_enqueued_signal(
        self,
    ) -> Iterator[list[tuple[type[DatabaseBackend], Any]]]:
        received: list[tuple[type[DatabaseBackend], Any]] = []

        def capture_signal(
            sender: type[DatabaseBackend], task_result: Any, **kwargs: Any
        ) -> None:
            received.append((sender, task_result))

        task_enqueued.connect(
            capture_signal, dispatch_uid="db_enqueue_signal", weak=False
        )
        try:
            yield received
        finally:
            task_enqueued.disconnect(capture_signal, dispatch_uid="db_enqueue_signal")

    def get_task_count_in_new_connection(self) -> int:
        """
        See what other connections see
        """
        new_connection = connections.create_connection("default")
        try:
            with new_connection.cursor() as c:
                c.execute(str(DBTaskResult.objects.values("id").query))
                return len(c.fetchall())
        finally:
            new_connection.close()

    def test_using_correct_backend(self) -> None:
        self.assertEqual(default_task_backend, task_backends["default"])
        self.assertIsInstance(task_backends["default"], DatabaseBackend)
        self.assertEqual(default_task_backend.alias, "default")
        self.assertEqual(default_task_backend.options, {})

    def test_enqueue_task(self) -> None:
        for task in [test_tasks.noop_task, test_tasks.noop_task_async]:
            with self.subTest(task), self.assertNumQueries(1):
                with self._capture_task_enqueued_signal() as received:
                    result = cast(Task, task).enqueue(1, two=3)

                self.assertEqual(uuid.UUID(result.id).version, 4)
                self.assertEqual(result.status, TaskResultStatus.READY)
                self.assertFalse(result.is_finished)
                self.assertIsNone(result.started_at)
                self.assertIsNone(result.last_attempted_at)
                self.assertIsNone(result.finished_at)
                with self.assertRaisesMessage(ValueError, "Task has not finished yet"):
                    result.return_value  # noqa:B018
                self.assertEqual(result.task, task)
                self.assertEqual(result.args, [1])
                self.assertEqual(result.kwargs, {"two": 3})
                self.assertEqual(result.attempts, 0)

                self.assertEqual(received, [(DatabaseBackend, result)])

    async def test_enqueue_task_async(self) -> None:
        for task in [test_tasks.noop_task, test_tasks.noop_task_async]:
            with self.subTest(task):
                with self._capture_task_enqueued_signal() as received:
                    result = await cast(Task, task).aenqueue()

                self.assertEqual(uuid.UUID(result.id).version, 4)
                self.assertEqual(result.status, TaskResultStatus.READY)
                self.assertFalse(result.is_finished)
                self.assertIsNone(result.started_at)
                self.assertIsNone(result.last_attempted_at)
                self.assertIsNone(result.finished_at)
                with self.assertRaisesMessage(ValueError, "Task has not finished yet"):
                    result.return_value  # noqa:B018
                self.assertEqual(result.task, task)
                self.assertEqual(result.args, [])
                self.assertEqual(result.kwargs, {})
                self.assertEqual(result.attempts, 0)

                self.assertEqual(received, [(DatabaseBackend, result)])

    def test_get_result(self) -> None:
        with self.assertNumQueries(1):
            result = default_task_backend.enqueue(test_tasks.noop_task, [], {})

        with self.assertNumQueries(1):
            new_result = default_task_backend.get_result(result.id)

        self.assertEqual(result, new_result)

    async def test_get_result_async(self) -> None:
        result = await default_task_backend.aenqueue(test_tasks.noop_task, [], {})

        new_result = await default_task_backend.aget_result(result.id)

        self.assertEqual(result, new_result)

    def test_refresh_result(self) -> None:
        result = default_task_backend.enqueue(
            test_tasks.calculate_meaning_of_life, (), {}
        )

        DBTaskResult.objects.all().update(
            status=TaskResultStatus.SUCCESSFUL,
            started_at=timezone.now(),
            finished_at=timezone.now(),
            return_value=42,
            worker_ids=[get_random_id()],
        )

        self.assertEqual(result.status, TaskResultStatus.READY)
        self.assertFalse(result.is_finished)
        self.assertIsNone(result.started_at)
        self.assertIsNone(result.last_attempted_at)
        self.assertIsNone(result.finished_at)
        self.assertEqual(result.attempts, 0)

        with self.assertNumQueries(1):
            result.refresh()

        self.assertIsNotNone(result.started_at)
        self.assertIsNotNone(result.last_attempted_at)
        self.assertIsNotNone(result.finished_at)
        self.assertEqual(result.status, TaskResultStatus.SUCCESSFUL)
        self.assertTrue(result.is_finished)
        self.assertEqual(result.return_value, 42)
        self.assertEqual(result.attempts, 1)

    async def test_refresh_result_async(self) -> None:
        result = await default_task_backend.aenqueue(
            test_tasks.calculate_meaning_of_life, (), {}
        )

        await DBTaskResult.objects.all().aupdate(
            status=TaskResultStatus.SUCCESSFUL,
            started_at=timezone.now(),
            finished_at=timezone.now(),
            return_value=42,
            worker_ids=[get_random_id()],
        )

        self.assertEqual(result.status, TaskResultStatus.READY)
        self.assertFalse(result.is_finished)
        self.assertIsNone(result.started_at)
        self.assertIsNone(result.last_attempted_at)
        self.assertIsNone(result.finished_at)
        self.assertEqual(result.attempts, 0)

        await result.arefresh()

        self.assertIsNotNone(result.started_at)
        self.assertIsNotNone(result.last_attempted_at)
        self.assertIsNotNone(result.finished_at)
        self.assertEqual(result.status, TaskResultStatus.SUCCESSFUL)
        self.assertTrue(result.is_finished)
        self.assertEqual(result.return_value, 42)
        self.assertEqual(result.attempts, 1)

    def test_get_missing_result(self) -> None:
        with self.assertRaises(TaskResultDoesNotExist):
            default_task_backend.get_result(str(uuid.uuid4()))

    async def test_async_get_missing_result(self) -> None:
        with self.assertRaises(TaskResultDoesNotExist):
            await default_task_backend.aget_result(str(uuid.uuid4()))

    def test_invalid_uuid(self) -> None:
        with self.assertRaises(TaskResultDoesNotExist):
            default_task_backend.get_result("123")

    async def test_async_invalid_uuid(self) -> None:
        with self.assertRaises(TaskResultDoesNotExist):
            await default_task_backend.aget_result("123")

    def test_invalid_task_path(self) -> None:
        db_task_result = DBTaskResult.objects.create(
            args_kwargs={"args": [["exit", "1"]], "kwargs": {}},
            task_path="subprocess.check_output",
            backend_name="default",
        )

        with self.assertRaisesMessage(
            SuspiciousOperation,
            f"Task {db_task_result.id} does not point to a Task ({db_task_result.task_path})",
        ):
            _ = db_task_result.task

    def test_missing_task_path(self) -> None:
        db_task_result = DBTaskResult.objects.create(
            args_kwargs={"args": [], "kwargs": {}},
            task_path="missing.func",
            backend_name="default",
        )

        with self.assertRaises(ImportError):
            _ = db_task_result.task

    def test_task_name(self) -> None:
        for task_path, expected_task_name in [
            ("tests.tasks.noop_task", "noop_task"),
            ("tests.tasks.task_not_found", "task_not_found"),
            ("tests.tasks.module_not_found.module_not_found", "module_not_found"),
            ("unexpected_function", "unexpected_function"),
        ]:
            with self.subTest(task_path):
                db_task_result = DBTaskResult.objects.create(
                    args_kwargs={"args": [], "kwargs": {}},
                    task_path=task_path,
                    backend_name="default",
                )

                self.assertEqual(db_task_result.task_name, expected_task_name)

    def test_check(self) -> None:
        errors = list(default_task_backend.check())

        self.assertEqual(len(errors), 0, errors)

    @override_settings(INSTALLED_APPS=[])
    def test_database_backend_app_missing(self) -> None:
        errors = list(default_task_backend.check())

        self.assertEqual(len(errors), 1)
        self.assertIn("django_tasks_db", errors[0].hint)  # type:ignore[arg-type]

    def test_priority_range_check(self) -> None:
        with self.assertRaises(IntegrityError):
            DBTaskResult.objects.create(
                task_path="", backend_name="default", priority=-101, args_kwargs={}
            )

        with self.assertRaises(IntegrityError):
            DBTaskResult.objects.create(
                task_path="", backend_name="default", priority=101, args_kwargs={}
            )

        # Django accepts the float, but only stores an int
        result = DBTaskResult.objects.create(
            task_path="", backend_name="default", priority=3.1, args_kwargs={}
        )
        result.refresh_from_db()
        self.assertEqual(result.priority, 3)

        DBTaskResult.objects.create(
            task_path="", backend_name="default", priority=100, args_kwargs={}
        )
        DBTaskResult.objects.create(
            task_path="", backend_name="default", priority=-100, args_kwargs={}
        )
        DBTaskResult.objects.create(
            task_path="", backend_name="default", priority=0, args_kwargs={}
        )

    @override_settings(
        TASKS={
            "default": {
                "BACKEND": "django_tasks_db.DatabaseBackend",
            }
        }
    )
    def test_doesnt_wait_until_transaction_commit(self) -> None:
        with transaction.atomic():
            result = test_tasks.noop_task.enqueue()

            self.assertIsNotNone(result.enqueued_at)

            self.assertEqual(DBTaskResult.objects.count(), 1)

            # SQLite locks the table during this transaction
            if connection.vendor != "sqlite":
                self.assertEqual(self.get_task_count_in_new_connection(), 0)

        if connection.vendor != "sqlite":
            self.assertEqual(self.get_task_count_in_new_connection(), 1)

    def test_enqueue_logs(self) -> None:
        with self.assertLogs("django_tasks", level="DEBUG") as captured_logs:
            result = test_tasks.noop_task.enqueue()

        self.assertEqual(len(captured_logs.output), 1)
        self.assertIn("enqueued", captured_logs.output[0])
        self.assertIn(result.id, captured_logs.output[0])

    def test_index_scan_for_ready(self) -> None:
        test_tasks.noop_task.enqueue()

        # Quickly duplicate tasks
        db_task = DBTaskResult.objects.get()
        db_task.id = None
        DBTaskResult.objects.bulk_create([copy.copy(db_task) for _ in range(5000)])

        # Update query plan for certain databases
        if connection.vendor == "postgresql":
            with connection.cursor() as c:
                c.execute(f"ANALYZE {DBTaskResult._meta.db_table};")
        elif connection.vendor == "mysql":
            with connection.cursor() as c:
                c.execute(f"ANALYZE TABLE {DBTaskResult._meta.db_table};")

        plan = DBTaskResult.objects.ready().explain()

        if connection.vendor == "postgresql":
            self.assertIn("tasks_db_new_ordering_idx", plan)
        elif connection.vendor == "sqlite":
            self.assertIn("USING INDEX tasks_db_new_ordering_idx", plan)
        elif connection.vendor == "mysql":
            self.assertIn("Index lookup", plan)
            self.assertIn("using tasks_db_new_ordering_idx", plan)
        else:
            self.fail("Unknown database engine")

    def test_run_after_tz(self) -> None:
        for use_tz in [True, False]:
            with self.subTest(use_tz=use_tz):
                with override_settings(USE_TZ=use_tz):
                    result = test_tasks.noop_task.enqueue()
                    self.assertIsNone(
                        DBTaskResult.objects.get(id=result.id).task.run_after
                    )

    def test_run_after_null_0016_migration(self) -> None:
        from datetime import timezone

        for use_tz in [True, False]:
            with self.subTest(use_tz=use_tz):
                with override_settings(USE_TZ=use_tz):
                    result = test_tasks.noop_task.enqueue()

                    db_result = DBTaskResult.objects.get(id=result.id)

                    # Literal taken from migration
                    db_result.run_after = datetime(
                        9999,
                        1,
                        1,
                        tzinfo=timezone.utc if use_tz else None,
                    )

                    with warnings.catch_warnings():
                        warnings.filterwarnings(
                            "ignore", module="django.db", category=RuntimeWarning
                        )
                        db_result.save()

                    self.assertIsNone(
                        DBTaskResult.objects.get(id=result.id).task.run_after
                    )

    def test_validate_on_enqueue(self) -> None:
        with override_settings(
            TASKS={
                "default": {
                    "BACKEND": "django_tasks_db.DatabaseBackend",
                    "QUEUES": ["unknown_queue"],
                }
            }
        ):
            task_with_custom_queue_name = test_tasks.noop_task.using(
                queue_name="unknown_queue"
            )

        with self.assertRaisesMessage(
            InvalidTaskError, "Queue 'unknown_queue' is not valid for backend"
        ):
            task_with_custom_queue_name.enqueue()

    async def test_validate_on_aenqueue(self) -> None:
        with override_settings(
            TASKS={
                "default": {
                    "BACKEND": "django_tasks_db.DatabaseBackend",
                    "QUEUES": ["unknown_queue"],
                }
            }
        ):
            task_with_custom_queue_name = test_tasks.noop_task.using(
                queue_name="unknown_queue"
            )

        with self.assertRaisesMessage(
            InvalidTaskError, "Queue 'unknown_queue' is not valid for backend"
        ):
            await task_with_custom_queue_name.aenqueue()

    def test_custom_id_function(self) -> None:
        for id_function in ["uuid.uuid1", uuid.uuid1]:
            with self.subTest(id_function):
                with override_settings(
                    TASKS={
                        "default": {
                            "BACKEND": "django_tasks_db.DatabaseBackend",
                            "OPTIONS": {"id_function": id_function},
                        }
                    }
                ):
                    result = test_tasks.noop_task.enqueue()
                    self.assertEqual(uuid.UUID(result.id).version, 1)

    @override_settings(
        TASKS={
            "default": {
                "BACKEND": "django_tasks_db.DatabaseBackend",
                "OPTIONS": {"id_function": "missing.function"},
            }
        }
    )
    def test_unknown_id_function(self) -> None:
        with self.assertRaises(ImportError):
            test_tasks.noop_task.enqueue()

    @override_settings(
        TASKS={
            "default": {
                "BACKEND": "django_tasks_db.DatabaseBackend",
                "OPTIONS": {
                    "id_function": "django.contrib.postgres.functions.RandomUUID"
                },
            }
        }
    )
    @skipIf(connection.vendor != "postgresql", "RandomUUID only works on postgres")
    @skipIf(VERSION < (6, 0), "DB expressions are only supported on 6.0+")
    def test_postgres_db_id_function(self) -> None:
        with self.assertNumQueries(1) as c:
            result = test_tasks.noop_task.enqueue()

        self.assertIn("GEN_RANDOM_UUID", c.captured_queries[0]["sql"])
        self.assertEqual(uuid.UUID(result.id).version, 4)

    @override_settings(
        TASKS={
            "default": {
                "BACKEND": "django_tasks_db.DatabaseBackend",
                "OPTIONS": {"id_function": "django.db.models.functions.Now"},
            }
        }
    )
    @skipIf(VERSION >= (6, 0), "DB expressions are supported on 6.0+")
    def test_postgres_id_function_expression(self) -> None:
        with self.assertRaisesMessage(
            ImproperlyConfigured,
            "id_function cannot be a database expression until Django 6.0",
        ):
            test_tasks.noop_task.enqueue()


@override_settings(
    TASKS={
        "default": {
            "BACKEND": "django_tasks_db.DatabaseBackend",
            "QUEUES": ["default", "queue-1"],
        },
        "dummy": {"BACKEND": "django_tasks.backends.dummy.DummyBackend"},
    }
)
class DatabaseBackendWorkerTestCase(TransactionTestCase):
    worker_id = get_random_id()

    run_worker = staticmethod(
        partial(
            call_command,
            "db_worker",
            verbosity=0,
            batch=True,
            interval=0,
            startup_delay=False,
            worker_id=worker_id,
        )
    )

    def tearDown(self) -> None:
        logger = logging.getLogger("django_tasks_db")
        tasks_logger = logging.getLogger("django_tasks")

        # Reset the logger after every run, to ensure the correct `stdout` is used
        for handler in logger.handlers:
            logger.removeHandler(handler)

        for handler in tasks_logger.handlers:
            tasks_logger.removeHandler(handler)

    def test_run_enqueued_task(self) -> None:
        for task in [
            test_tasks.noop_task,
            # test_tasks.noop_task_async,
        ]:
            with self.subTest(task):
                result = cast(Task, task).enqueue()
                self.assertEqual(DBTaskResult.objects.ready().count(), 1)

                self.assertEqual(result.status, TaskResultStatus.READY)

                with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
                    self.run_worker()

                self.assertEqual(result.status, TaskResultStatus.READY)
                self.assertEqual(result.attempts, 0)
                result.refresh()
                self.assertIsNotNone(result.started_at)
                self.assertIsNotNone(result.last_attempted_at)
                self.assertIsNotNone(result.finished_at)
                self.assertGreaterEqual(result.started_at, result.enqueued_at)  # type:ignore[arg-type,misc]
                self.assertGreaterEqual(result.finished_at, result.started_at)  # type:ignore[arg-type,misc]
                self.assertEqual(result.status, TaskResultStatus.SUCCESSFUL)
                self.assertEqual(result.attempts, 1)

                self.assertEqual(DBTaskResult.objects.ready().count(), 0)

    def test_batch_processes_all_tasks(self) -> None:
        for _ in range(3):
            test_tasks.noop_task.enqueue()
        test_tasks.failing_task_value_error.enqueue()

        self.assertEqual(DBTaskResult.objects.ready().count(), 4)

        with self.assertNumQueries(27 if connection.vendor == "mysql" else 23):
            self.run_worker()

        self.assertEqual(DBTaskResult.objects.ready().count(), 0)
        self.assertEqual(DBTaskResult.objects.successful().count(), 3)
        self.assertEqual(DBTaskResult.objects.failed().count(), 1)

    def test_no_tasks(self) -> None:
        with self.assertNumQueries(3):
            self.run_worker()

    def test_doesnt_process_different_queue(self) -> None:
        result = test_tasks.noop_task.using(queue_name="queue-1").enqueue()

        self.assertEqual(DBTaskResult.objects.ready().count(), 1)

        with self.assertNumQueries(3):
            self.run_worker()

        self.assertEqual(DBTaskResult.objects.ready().count(), 1)

        with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
            self.run_worker(queue_name=result.task.queue_name)

        self.assertEqual(DBTaskResult.objects.ready().count(), 0)

    def test_process_all_queues(self) -> None:
        test_tasks.noop_task.using(queue_name="queue-1").enqueue()

        self.assertEqual(DBTaskResult.objects.ready().count(), 1)

        with self.assertNumQueries(3):
            self.run_worker()

        self.assertEqual(DBTaskResult.objects.ready().count(), 1)

        with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
            self.run_worker(queue_name="*")

        self.assertEqual(DBTaskResult.objects.ready().count(), 0)

    def test_failing_task(self) -> None:
        result = test_tasks.failing_task_value_error.enqueue()
        self.assertEqual(DBTaskResult.objects.ready().count(), 1)

        with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
            self.run_worker()

        self.assertEqual(result.status, TaskResultStatus.READY)
        result.refresh()
        self.assertIsNotNone(result.started_at)
        self.assertIsNotNone(result.last_attempted_at)
        self.assertIsNotNone(result.finished_at)

        self.assertGreaterEqual(result.started_at, result.enqueued_at)  # type: ignore
        self.assertGreaterEqual(result.finished_at, result.started_at)  # type: ignore
        self.assertEqual(result.status, TaskResultStatus.FAILED)
        with self.assertRaisesMessage(ValueError, "Task failed"):
            result.return_value  # noqa: B018

        self.assertEqual(result.errors[0].exception_class, ValueError)
        traceback = result.errors[0].traceback
        self.assertTrue(
            traceback
            and traceback.endswith("ValueError: This task failed due to ValueError\n"),
            traceback,
        )

        self.assertEqual(DBTaskResult.objects.ready().count(), 0)

    def test_complex_exception(self) -> None:
        result = test_tasks.complex_exception.enqueue()
        self.assertEqual(DBTaskResult.objects.ready().count(), 1)

        with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
            self.run_worker()

        self.assertEqual(result.status, TaskResultStatus.READY)
        result.refresh()
        self.assertIsNotNone(result.started_at)
        self.assertIsNotNone(result.last_attempted_at)
        self.assertIsNotNone(result.finished_at)

        self.assertGreaterEqual(result.started_at, result.enqueued_at)  # type: ignore
        self.assertGreaterEqual(result.finished_at, result.started_at)  # type: ignore
        self.assertEqual(result.status, TaskResultStatus.FAILED)
        with self.assertRaisesMessage(ValueError, "Task failed"):
            result.return_value  # noqa: B018

        self.assertEqual(result.errors[0].exception_class, ValueError)
        self.assertIn(
            'ValueError(ValueError("This task failed"))', result.errors[0].traceback
        )

        self.assertEqual(DBTaskResult.objects.ready().count(), 0)

    def test_complex_return_value(self) -> None:
        result = test_tasks.complex_return_value.enqueue()

        self.run_worker()

        result.refresh()

        self.assertEqual(result.status, TaskResultStatus.FAILED)
        self.assertIsNotNone(result.started_at)
        self.assertIsNotNone(result.last_attempted_at)
        self.assertIsNotNone(result.finished_at)
        self.assertGreaterEqual(result.started_at, result.enqueued_at)  # type:ignore[arg-type,misc]
        self.assertGreaterEqual(result.finished_at, result.started_at)  # type:ignore[arg-type,misc]

        self.assertIsNone(result._return_value)
        self.assertEqual(result.errors[0].exception_class, TypeError)
        self.assertIn("is not JSON serializable", result.errors[0].traceback)

    def test_doesnt_process_different_backend(self) -> None:
        result = test_tasks.failing_task_value_error.enqueue()

        self.assertEqual(DBTaskResult.objects.ready().count(), 1)

        with self.assertNumQueries(3):
            self.run_worker(backend_name="dummy")

        self.assertEqual(DBTaskResult.objects.ready().count(), 1)

        with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
            self.run_worker(backend_name=result.backend)

        self.assertEqual(DBTaskResult.objects.ready().count(), 0)

    def test_unknown_backend(self) -> None:
        output = StringIO()
        with redirect_stderr(output):
            with self.assertRaises(SystemExit):
                execute_from_command_line(
                    ["django-admin", "db_worker", "--backend", "unknown"]
                )
        self.assertIn("The connection 'unknown' doesn't exist.", output.getvalue())

    def test_incorrect_backend(self) -> None:
        output = StringIO()
        with redirect_stderr(output):
            with self.assertRaises(SystemExit):
                execute_from_command_line(
                    ["django-admin", "db_worker", "--backend", "dummy"]
                )
        self.assertIn("Backend 'dummy' is not a database backend", output.getvalue())

    def test_negative_interval(self) -> None:
        output = StringIO()
        with redirect_stderr(output):
            with self.assertRaises(SystemExit):
                execute_from_command_line(
                    ["django-admin", "db_worker", "--interval", "-1"]
                )
        self.assertIn("Must be greater than zero", output.getvalue())

    def test_infinite_interval(self) -> None:
        output = StringIO()
        with redirect_stderr(output):
            with self.assertRaises(SystemExit):
                execute_from_command_line(
                    ["django-admin", "db_worker", "--interval", "inf"]
                )
        self.assertIn("Must be a finite floating point value", output.getvalue())

    def test_fractional_interval(self) -> None:
        with mock.patch(
            "django_tasks_db.management.commands.db_worker.Worker"
        ) as worker_class:
            execute_from_command_line(
                ["django-admin", "db_worker", "--interval", "0.1"]
            )

        self.assertEqual(worker_class.mock_calls[0].kwargs["interval"], 0.1)

    def test_negative_max_tasks(self) -> None:
        output = StringIO()
        with redirect_stderr(output):
            with self.assertRaises(SystemExit):
                execute_from_command_line(
                    ["django-admin", "db_worker", "--max-tasks", "-1"]
                )
        self.assertIn("Must be greater than zero", output.getvalue())

    def test_too_long_worker_id(self) -> None:
        output = StringIO()
        with redirect_stderr(output):
            with self.assertRaises(SystemExit):
                execute_from_command_line(
                    ["django-admin", "db_worker", "--worker-id", "A" * 65]
                )
        self.assertIn(
            "Worker ids must be shorter than 64 characters", output.getvalue()
        )

    def test_empty_worker_id(self) -> None:
        output = StringIO()
        with redirect_stderr(output):
            with self.assertRaises(SystemExit):
                execute_from_command_line(
                    ["django-admin", "db_worker", "--worker-id", ""]
                )
        self.assertIn("Worker id must not be empty", output.getvalue())

    def test_run_after(self) -> None:
        result = test_tasks.noop_task.using(
            run_after=timezone.now() + timedelta(hours=10)
        ).enqueue()

        self.assertEqual(DBTaskResult.objects.count(), 1)
        self.assertEqual(DBTaskResult.objects.ready().count(), 0)

        with self.assertNumQueries(3):
            self.run_worker()

        self.assertEqual(DBTaskResult.objects.count(), 1)
        self.assertEqual(DBTaskResult.objects.ready().count(), 0)
        self.assertEqual(DBTaskResult.objects.successful().count(), 0)

        DBTaskResult.objects.filter(id=result.id).update(run_after=timezone.now())

        self.assertEqual(DBTaskResult.objects.ready().count(), 1)

        with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
            self.run_worker()

        self.assertEqual(DBTaskResult.objects.ready().count(), 0)
        self.assertEqual(DBTaskResult.objects.successful().count(), 1)

    def test_run_after_priority(self) -> None:
        old_task = test_tasks.noop_task.using(
            priority=20, run_after=timezone.now() - timedelta(hours=2)
        ).enqueue()

        very_old_task = test_tasks.noop_task.using(
            priority=20, run_after=timezone.now() - timedelta(hours=10)
        ).enqueue()

        far_future_result = test_tasks.noop_task.using(
            run_after=timezone.now() + timedelta(hours=10)
        ).enqueue()

        high_priority_far_future_result = test_tasks.noop_task.using(
            priority=10, run_after=timezone.now() + timedelta(hours=10)
        ).enqueue()

        future_result = test_tasks.noop_task.using(
            run_after=timezone.now() + timedelta(hours=2)
        ).enqueue()

        high_priority_result = test_tasks.noop_task.using(priority=10).enqueue()

        low_priority_result = test_tasks.noop_task.using(priority=2).enqueue()
        lower_priority_result = test_tasks.noop_task.using(priority=-2).enqueue()

        self.assertEqual(
            [dbt.task_result for dbt in DBTaskResult.objects.all()],
            [
                very_old_task,
                old_task,
                high_priority_far_future_result,
                high_priority_result,
                low_priority_result,
                future_result,
                far_future_result,
                lower_priority_result,
            ],
        )

        self.assertEqual(
            [dbt.task_result for dbt in DBTaskResult.objects.ready()],
            [
                very_old_task,
                old_task,
                high_priority_result,
                low_priority_result,
                lower_priority_result,
            ],
        )

    def test_verbose_logging(self) -> None:
        result = test_tasks.noop_task.enqueue()

        stdout = StringIO()
        self.run_worker(verbosity=3, stdout=stdout, stderr=stdout)

        self.assertEqual(
            stdout.getvalue().splitlines(),
            [
                f"Starting worker worker_id={self.worker_id} queues=default",
                f"Task id={result.id} path=tests.tasks.noop_task state=RUNNING",
                f"Task id={result.id} path=tests.tasks.noop_task state=SUCCESSFUL",
                f"No more tasks to run for worker_id={self.worker_id} - exiting gracefully.",
            ],
        )

    def test_invalid_task_path(self) -> None:
        db_task_result = DBTaskResult.objects.create(
            args_kwargs={"args": [["exit", "1"]], "kwargs": {}},
            task_path="subprocess.check_output",
            backend_name="default",
        )

        self.run_worker()

        db_task_result.refresh_from_db()

        self.assertEqual(db_task_result.status, TaskResultStatus.FAILED)

    def test_missing_task_path(self) -> None:
        db_task_result = DBTaskResult.objects.create(
            args_kwargs={"args": [], "kwargs": {}},
            task_path="missing.func",
            backend_name="default",
        )

        self.run_worker()

        db_task_result.refresh_from_db()

        self.assertEqual(db_task_result.status, TaskResultStatus.FAILED)

    def test_worker_doesnt_exit(self) -> None:
        result = test_tasks.exit_task.enqueue()

        self.run_worker()

        result.refresh()
        self.assertEqual(result.status, TaskResultStatus.FAILED)

    @skipIf(connection.vendor == "sqlite", "SQLite locks the entire database")
    def test_worker_with_locked_rows(self) -> None:
        result_1 = test_tasks.noop_task.enqueue()
        new_connection = connections.create_connection("default")

        with transaction.atomic():
            locked_tasks_query = str(DBTaskResult.objects.select_for_update().query)

        try:
            # Start a transaction in the other connection
            with new_connection.cursor() as c:
                c.execute("BEGIN")

            # Lock the current rows in the table
            with new_connection.cursor() as c:
                c.execute(locked_tasks_query)
                results = list(c.fetchall())
            self.assertEqual(len(results), 1)

            # Add another task which isn't locked
            result_2 = test_tasks.noop_task.enqueue()

            self.run_worker()
        finally:
            new_connection.close()

        result_1.refresh()
        result_2.refresh()

        self.assertEqual(result_1.status, TaskResultStatus.READY)
        self.assertEqual(result_2.status, TaskResultStatus.SUCCESSFUL)

    def test_max_tasks(self) -> None:
        results = [test_tasks.noop_task.enqueue() for _ in range(5)]

        stdout = StringIO()
        self.run_worker(max_tasks=2, stdout=stdout, verbosity=3)

        self.assertIn("Run maximum tasks (2)", stdout.getvalue())

        for result in results:
            result.refresh()

        statuses = Counter(result.status for result in results)

        self.assertEqual(statuses[TaskResultStatus.SUCCESSFUL], 2)
        self.assertEqual(statuses[TaskResultStatus.READY], 3)

    def test_takes_context(self) -> None:
        result = test_tasks.get_task_id.enqueue()

        self.run_worker()

        result.refresh()

        self.assertEqual(result.return_value, result.id)

    def test_context(self) -> None:
        result = test_tasks.test_context.enqueue(1)

        self.run_worker()
        result.refresh()

        self.assertEqual(result.status, TaskResultStatus.SUCCESSFUL)

    def test_task_import_string(self) -> None:
        db_task_result = DBTaskResult.objects.create(
            args_kwargs={"args": [], "kwargs": {}},
            task_path="tests.tasks.some_test",
            backend_name="default",
        )
        self.run_worker()
        db_task_result.refresh_from_db()

        self.assertEqual(db_task_result.status, TaskResultStatus.FAILED)
        self.assertIn("ImportError", db_task_result.traceback)


@override_settings(
    TASKS={
        "default": {
            "BACKEND": "django_tasks_db.DatabaseBackend",
        },
    }
)
class DatabaseTaskResultTestCase(TransactionTestCase):
    def execute_in_new_connection(self, sql: str | QuerySet) -> Sequence:
        if isinstance(sql, QuerySet):
            sql = str(sql.query)
        new_connection = connections.create_connection("default")
        try:
            with new_connection.cursor() as c:
                c.execute(sql)
                return cast(list, c.fetchall())
        finally:
            new_connection.close()

    def test_cross_connection(self) -> None:
        test_tasks.noop_task.enqueue()
        test_tasks.noop_task.enqueue()

        self.assertEqual(DBTaskResult.objects.count(), 2)

        self.assertEqual(DBTaskResult.objects.using("default").count(), 2)

        self.assertEqual(
            len(self.execute_in_new_connection(DBTaskResult.objects.all())),
            2,
        )

    @skipIf(connection.vendor == "sqlite", "SQLite handles locks differently")
    def test_locks_tasks(self) -> None:
        test_tasks.noop_task.enqueue()
        test_tasks.noop_task.enqueue()

        with transaction.atomic():
            self.assertEqual(
                len(
                    self.execute_in_new_connection(
                        DBTaskResult.objects.select_for_update(skip_locked=True)
                    )
                ),
                2,
            )

            self.assertIsNotNone(DBTaskResult.objects.get_locked())

            self.assertEqual(
                len(
                    self.execute_in_new_connection(
                        DBTaskResult.objects.select_for_update(skip_locked=True)
                    )
                ),
                # MySQL likes to lock all the rows
                0 if connection.vendor == "mysql" else 1,
            )

            DBTaskResult.objects.get_locked()

        with transaction.atomic():
            # The original transaction has closed, so the result is unlocked
            self.assertEqual(
                len(
                    self.execute_in_new_connection(
                        DBTaskResult.objects.select_for_update(skip_locked=True)
                    )
                ),
                2,
            )

    @skipIf(connection.vendor != "sqlite", "SQLite handles locks differently")
    def test_locks_tasks_sqlite(self) -> None:
        result = test_tasks.noop_task.enqueue()

        with exclusive_transaction():
            locked_result = DBTaskResult.objects.get_locked()

            self.assertEqual(result.id, str(locked_result.id))  # type:ignore[union-attr]

            with self.assertRaisesMessage(OperationalError, "is locked"):
                self.execute_in_new_connection(
                    DBTaskResult.objects.select_for_update(skip_locked=True)
                )

        # The original transaction has closed, so the database is unlocked
        self.execute_in_new_connection(
            DBTaskResult.objects.select_for_update(skip_locked=True)
        )

    @skipIf(connection.vendor == "sqlite", "SQLite handles locks differently")
    def test_locks_tasks_filtered(self) -> None:
        result = test_tasks.noop_task.using(priority=10).enqueue()
        test_tasks.noop_task.enqueue()

        with transaction.atomic():
            self.assertEqual(
                len(
                    self.execute_in_new_connection(
                        DBTaskResult.objects.select_for_update(skip_locked=True)
                    )
                ),
                2,
            )

            locked_result = DBTaskResult.objects.filter(
                priority=result.task.priority
            ).get_locked()
            self.assertEqual(str(locked_result.id), result.id)

            self.assertEqual(
                len(
                    self.execute_in_new_connection(
                        DBTaskResult.objects.select_for_update(skip_locked=True)
                    )
                ),
                1,
            )

        with transaction.atomic():
            # The original transaction has closed, so the result is unlocked
            self.assertEqual(
                len(
                    self.execute_in_new_connection(
                        DBTaskResult.objects.select_for_update(skip_locked=True)
                    )
                ),
                2,
            )

    @skipIf(connection.vendor != "sqlite", "SQLite handles locks differently")
    def test_locks_tasks_filtered_sqlite(self) -> None:
        result = test_tasks.noop_task.using(priority=10).enqueue()
        test_tasks.noop_task.enqueue()

        with exclusive_transaction():
            locked_result = DBTaskResult.objects.filter(
                priority=result.task.priority
            ).get_locked()

            self.assertEqual(result.id, str(locked_result.id))

            with self.assertRaisesMessage(OperationalError, "is locked"):
                self.execute_in_new_connection(
                    DBTaskResult.objects.select_for_update(skip_locked=True)
                )

        # The original transaction has closed, so the database is unlocked
        self.execute_in_new_connection(
            DBTaskResult.objects.select_for_update(skip_locked=True)
        )

    @exclusive_transaction()
    def test_lock_no_rows(self) -> None:
        self.assertEqual(DBTaskResult.objects.count(), 0)
        self.assertIsNone(DBTaskResult.objects.all().get_locked())

    @skipIf(connection.vendor == "sqlite", "SQLite handles locks differently")
    def test_get_locked_with_locked_rows(self) -> None:
        result_1 = test_tasks.noop_task.enqueue()
        new_connection = connections.create_connection("default")

        with transaction.atomic():
            locked_tasks_query = str(DBTaskResult.objects.select_for_update().query)

        try:
            # Start a transaction in the other connection
            with new_connection.cursor() as c:
                c.execute("BEGIN")

            # Lock the current rows in the table from the other connection
            with new_connection.cursor() as c:
                c.execute(locked_tasks_query)
                results = list(c.fetchall())
            self.assertEqual(len(results), 1)
            self.assertEqual(normalize_uuid(results[0][0]), normalize_uuid(result_1.id))

            with transaction.atomic():
                # .count with skip_locked isn't supported
                self.assertEqual(
                    len(DBTaskResult.objects.select_for_update(skip_locked=True)), 0
                )
                self.assertIsNone(DBTaskResult.objects.get_locked())

            # Add another task which isn't locked
            result_2 = test_tasks.noop_task.enqueue()

            with transaction.atomic():
                self.assertEqual(
                    normalize_uuid(
                        DBTaskResult.objects.select_for_update(
                            skip_locked=True
                        ).values_list("id", flat=True)[0]
                    ),
                    normalize_uuid(result_2.id),
                )
                self.assertEqual(
                    normalize_uuid(DBTaskResult.objects.get_locked().id),  # type:ignore
                    normalize_uuid(result_2.id),
                )
        finally:
            new_connection.close()


class ConnectionExclusiveTranscationTestCase(TransactionTestCase):
    def setUp(self) -> None:
        self.connection = connections.create_connection("default")

    def tearDown(self) -> None:
        self.connection.close()
        # connection.close()

    @skipIf(connection.vendor == "sqlite", "SQLite handled separately")
    def test_non_sqlite(self) -> None:
        self.assertFalse(
            connection_requires_manual_exclusive_transaction(self.connection)
        )

    @skipIf(
        django.VERSION >= (5, 1),
        "Newer Django versions support custom transaction modes",
    )
    @skipIf(connection.vendor != "sqlite", "SQLite only")
    def test_old_django_requires_manual_transaction(self) -> None:
        self.assertTrue(
            connection_requires_manual_exclusive_transaction(self.connection)
        )

    @skipIf(django.VERSION < (5, 1), "Old Django versions require manual transactions")
    @skipIf(connection.vendor != "sqlite", "SQLite only")
    def test_explicit_transaction(self) -> None:
        # HACK: Set the attribute manually
        self.connection.transaction_mode = None  # type:ignore[attr-defined]
        self.assertTrue(
            connection_requires_manual_exclusive_transaction(self.connection)
        )

        self.connection.transaction_mode = "EXCLUSIVE"  # type:ignore[attr-defined]
        self.assertFalse(
            connection_requires_manual_exclusive_transaction(self.connection)
        )

    @skipIf(connection.vendor != "sqlite", "SQLite only")
    def test_exclusive_transaction(self) -> None:
        with self.assertNumQueries(2) as c:
            with exclusive_transaction():
                pass

        self.assertEqual(c.captured_queries[0]["sql"], "BEGIN EXCLUSIVE")


@override_settings(
    TASKS={
        "default": {
            "BACKEND": "django_tasks_db.DatabaseBackend",
            "QUEUES": ["default", "queue-1"],
        },
        "dummy": {"BACKEND": "django_tasks.backends.dummy.DummyBackend"},
    }
)
class DatabaseBackendPruneTaskResultsTestCase(TransactionTestCase):
    prune_task_results = staticmethod(
        partial(call_command, "prune_db_task_results", verbosity=0)
    )

    def tearDown(self) -> None:
        # Reset the logger after every run, to ensure the correct `stdout` is used
        for handler in prune_db_tasks_logger.handlers:
            prune_db_tasks_logger.removeHandler(handler)

    def test_prunes_tasks(self) -> None:
        result = test_tasks.noop_task.enqueue()

        DBTaskResult.objects.all().update(
            status=TaskResultStatus.SUCCESSFUL, finished_at=timezone.now()
        )

        self.assertEqual(DBTaskResult.objects.finished().count(), 1)

        stdout = StringIO()

        with self.assertNumQueries(3):
            self.prune_task_results(min_age_days=0, stdout=stdout, verbosity=3)

        self.assertEqual(DBTaskResult.objects.finished().count(), 0)

        with self.assertRaises(TaskResultDoesNotExist):
            result.refresh()

        self.assertEqual(stdout.getvalue().strip(), "Deleted 1 task result(s)")

    def test_doesnt_prune_new_tasks(self) -> None:
        result = test_tasks.noop_task.enqueue()

        self.assertEqual(DBTaskResult.objects.ready().count(), 1)

        stdout = StringIO()
        with self.assertNumQueries(3):
            self.prune_task_results(min_age_days=0, stdout=stdout, verbosity=3)

        self.assertEqual(DBTaskResult.objects.ready().count(), 1)

        result.refresh()

        self.assertEqual(stdout.getvalue().strip(), "Deleted 0 task result(s)")

    def test_doesnt_prune_running_tasks(self) -> None:
        result = test_tasks.noop_task.enqueue()

        DBTaskResult.objects.all().update(status=TaskResultStatus.RUNNING)

        self.assertEqual(DBTaskResult.objects.running().count(), 1)

        with self.assertNumQueries(3):
            self.prune_task_results(min_age_days=0)

        self.assertEqual(DBTaskResult.objects.running().count(), 1)

        result.refresh()

    def test_only_prunes_specified_queue(self) -> None:
        result = test_tasks.noop_task.enqueue()
        queue_1_result = test_tasks.noop_task.using(queue_name="queue-1").enqueue()

        DBTaskResult.objects.all().update(
            status=TaskResultStatus.SUCCESSFUL, finished_at=timezone.now()
        )

        self.assertEqual(DBTaskResult.objects.successful().count(), 2)

        with self.assertNumQueries(3):
            self.prune_task_results(queue_name="queue-1", min_age_days=0)

        self.assertEqual(DBTaskResult.objects.successful().count(), 1)

        result.refresh()

        with self.assertRaises(TaskResultDoesNotExist):
            queue_1_result.refresh()

    def test_prune_all_queues(self) -> None:
        test_tasks.noop_task.enqueue()
        test_tasks.noop_task.using(queue_name="queue-1").enqueue()

        DBTaskResult.objects.all().update(
            status=TaskResultStatus.SUCCESSFUL, finished_at=timezone.now()
        )

        self.assertEqual(DBTaskResult.objects.successful().count(), 2)

        with self.assertNumQueries(3):
            self.prune_task_results(queue_name="*", min_age_days=0)

        self.assertEqual(DBTaskResult.objects.successful().count(), 0)

    def test_min_age(self) -> None:
        one_day_result = test_tasks.noop_task.enqueue()

        DBTaskResult.objects.ready().update(
            status=TaskResultStatus.SUCCESSFUL,
            finished_at=timezone.now() - timedelta(days=1),
        )

        three_day_result = test_tasks.noop_task.enqueue()
        DBTaskResult.objects.ready().update(
            status=TaskResultStatus.SUCCESSFUL,
            finished_at=timezone.now() - timedelta(days=3),
        )

        self.assertEqual(DBTaskResult.objects.successful().count(), 2)

        with self.assertNumQueries(3):
            self.prune_task_results()

        self.assertEqual(DBTaskResult.objects.successful().count(), 2)

        with self.assertNumQueries(3):
            self.prune_task_results(min_age_days=3)

        self.assertEqual(DBTaskResult.objects.successful().count(), 1)

        one_day_result.refresh()

        with self.assertRaises(TaskResultDoesNotExist):
            three_day_result.refresh()

        with self.assertNumQueries(3):
            self.prune_task_results(min_age_days=1)

        self.assertEqual(DBTaskResult.objects.successful().count(), 0)

    def test_failed_min_age(self) -> None:
        successful_result = test_tasks.noop_task.enqueue()

        DBTaskResult.objects.ready().update(
            status=TaskResultStatus.SUCCESSFUL,
            finished_at=timezone.now() - timedelta(days=3),
        )

        failed_result = test_tasks.noop_task.enqueue()
        DBTaskResult.objects.ready().update(
            status=TaskResultStatus.FAILED,
            finished_at=timezone.now() - timedelta(days=3),
        )

        self.assertEqual(DBTaskResult.objects.finished().count(), 2)

        with self.assertNumQueries(3):
            self.prune_task_results()

        self.assertEqual(DBTaskResult.objects.finished().count(), 2)

        with self.assertNumQueries(3):
            self.prune_task_results(min_age_days=3, failed_min_age_days=5)

        self.assertEqual(DBTaskResult.objects.finished().count(), 1)

        failed_result.refresh()

        with self.assertRaises(TaskResultDoesNotExist):
            successful_result.refresh()

        with self.assertNumQueries(3):
            self.prune_task_results(min_age_days=3, failed_min_age_days=1)

        with self.assertRaises(TaskResultDoesNotExist):
            failed_result.refresh()

    def test_dry_run(self) -> None:
        test_tasks.noop_task.enqueue()

        DBTaskResult.objects.all().update(
            status=TaskResultStatus.SUCCESSFUL, finished_at=timezone.now()
        )

        self.assertEqual(DBTaskResult.objects.count(), 1)

        stdout = StringIO()
        with self.assertNumQueries(1):
            self.prune_task_results(
                min_age_days=0, dry_run=True, stdout=stdout, verbosity=3
            )

        self.assertEqual(DBTaskResult.objects.count(), 1)

        self.assertEqual(stdout.getvalue().strip(), "Would delete 1 task result(s)")

    def test_unknown_backend(self) -> None:
        output = StringIO()
        with redirect_stderr(output):
            with self.assertRaises(SystemExit):
                execute_from_command_line(
                    ["django-admin", "prune_db_task_results", "--backend", "unknown"]
                )
        self.assertIn("The connection 'unknown' doesn't exist.", output.getvalue())

    def test_incorrect_backend(self) -> None:
        output = StringIO()
        with redirect_stderr(output):
            with self.assertRaises(SystemExit):
                execute_from_command_line(
                    ["django-admin", "prune_db_task_results", "--backend", "dummy"]
                )
        self.assertIn("Backend 'dummy' is not a database backend", output.getvalue())

    def test_negative_age(self) -> None:
        output = StringIO()
        with redirect_stderr(output):
            with self.assertRaises(SystemExit):
                execute_from_command_line(
                    ["django-admin", "prune_db_task_results", "--min-age-days", "-1"]
                )
        self.assertIn("Must be greater than zero", output.getvalue())


@override_settings(
    TASKS={
        "default": {
            "BACKEND": "django_tasks_db.DatabaseBackend",
        },
    }
)
@skipIfInMemoryDB()
class DatabaseWorkerProcessTestCase(TransactionTestCase):
    WORKER_STARTUP_TIME = 1

    def setUp(self) -> None:
        self.processes: list[subprocess.Popen] = []

    def tearDown(self) -> None:
        # Try n times to kill any remaining child processes
        for n in range(20):
            for process in self.processes:
                if process.poll() is None:
                    if n >= 5:
                        print("Still waiting for process", process.pid, process.args)  # noqa: T201
                    process.kill()
                    process.wait(1)

    def start_worker(
        self,
        args: list[str] | None = None,
        *,
        debug: bool = False,
        worker_id: str | None = None,
    ) -> subprocess.Popen:
        if args is None:
            args = []

        if worker_id is None:
            worker_id = get_random_id()

        p = subprocess.Popen(
            [
                sys.executable,
                "-m",
                "manage",
                "db_worker",
                "--verbosity",
                "3",
                "--no-startup-delay",
                "--worker-id",
                worker_id,
                *args,
            ],
            stdout=None if debug else subprocess.PIPE,
            stderr=None if debug else subprocess.STDOUT,
            env={
                **os.environ,
                "DJANGO_SETTINGS_MODULE": "tests.db_worker_test_settings",
                "IN_TEST": "",
            },
            text=True,
        )
        self.processes.append(p)
        return p

    def test_run_subprocess(self) -> None:
        result = test_tasks.noop_task.enqueue()
        process = self.start_worker(["--batch"])
        process.wait()
        self.assertEqual(process.returncode, 0)

        self.assertEqual(result.status, TaskResultStatus.READY)

        result.refresh()

        self.assertEqual(result.status, TaskResultStatus.SUCCESSFUL)

    @skipIf(sys.platform == "win32", "Terminate is always forceful on Windows")
    def test_interrupt_no_tasks(self) -> None:
        process = self.start_worker()

        time.sleep(self.WORKER_STARTUP_TIME)

        process.terminate()

        process.wait(timeout=0.5)
        self.assertEqual(process.returncode, 0)

    @skipIf(sys.platform == "win32", "Cannot emulate CTRL-C on Windows")
    def test_interrupt_signals(self) -> None:
        for sig in [
            signal.SIGINT,  # ctrl-c
            signal.SIGTERM,
        ]:
            with self.subTest(sig):
                result = test_tasks.sleep_for.enqueue(2)
                self.assertEqual(DBTaskResult.objects.get(id=result.id).worker_ids, [])

                self.assertGreater(result.args[0], self.WORKER_STARTUP_TIME)

                process = self.start_worker()

                # Make sure the task is running by now
                time.sleep(self.WORKER_STARTUP_TIME)

                result.refresh()
                self.assertEqual(result.status, TaskResultStatus.RUNNING)
                self.assertNotEqual(
                    DBTaskResult.objects.get(id=result.id).worker_ids, []
                )

                process.send_signal(sig)

                process.wait(timeout=2)

                self.assertEqual(process.returncode, 0)

                result.refresh()

                self.assertEqual(result.status, TaskResultStatus.SUCCESSFUL)

    @skipIf(sys.platform == "win32", "Cannot emulate CTRL-C on Windows")
    def test_repeat_ctrl_c(self) -> None:
        result = test_tasks.hang.enqueue()
        self.assertEqual(DBTaskResult.objects.get(id=result.id).worker_ids, [])

        worker_id = get_random_id()

        process = self.start_worker(worker_id=worker_id)

        # Make sure the task is running by now
        time.sleep(self.WORKER_STARTUP_TIME)

        result.refresh()
        self.assertEqual(result.status, TaskResultStatus.RUNNING)
        self.assertEqual(DBTaskResult.objects.get(id=result.id).worker_ids, [worker_id])

        process.send_signal(signal.SIGINT)

        time.sleep(0.5)

        self.assertIsNone(process.poll())
        result.refresh()
        self.assertEqual(result.status, TaskResultStatus.RUNNING)
        self.assertEqual(DBTaskResult.objects.get(id=result.id).worker_ids, [worker_id])

        process.send_signal(signal.SIGINT)

        process.wait(timeout=2)

        self.assertEqual(process.returncode, 0)

        result.refresh()
        self.assertEqual(result.status, TaskResultStatus.FAILED)
        self.assertEqual(result.errors[0].exception_class, SystemExit)

    @skipIf(sys.platform == "win32", "Windows doesn't support SIGKILL")
    def test_kill(self) -> None:
        # Required to keep mypy happy
        assert hasattr(signal, "SIGKILL")

        result = test_tasks.hang.enqueue()

        process = self.start_worker()

        # Make sure the task is running by now
        time.sleep(self.WORKER_STARTUP_TIME)

        result.refresh()
        self.assertEqual(result.status, TaskResultStatus.RUNNING)

        process.kill()

        process.wait(timeout=2)

        self.assertEqual(process.returncode, -signal.SIGKILL)

        result.refresh()

        # TODO: https://github.com/RealOrangeOne/django-tasks-db/issues/46
        self.assertEqual(result.status, TaskResultStatus.RUNNING)

    def test_system_exit_task(self) -> None:
        result = test_tasks.failing_task_system_exit.enqueue()

        process = self.start_worker(["--batch"])
        process.wait(timeout=2)

        self.assertEqual(process.returncode, 0)

        result.refresh()
        self.assertEqual(result.status, TaskResultStatus.FAILED)
        self.assertEqual(result.errors[0].exception_class, SystemExit)

    def test_keyboard_interrupt_task(self) -> None:
        result = test_tasks.failing_task_keyboard_interrupt.enqueue()

        process = self.start_worker(["--batch"])
        process.wait(timeout=2)

        self.assertEqual(process.returncode, 0)

        result.refresh()
        self.assertEqual(result.status, TaskResultStatus.FAILED)
        self.assertEqual(result.errors[0].exception_class, KeyboardInterrupt)

    def test_multiple_workers(self) -> None:
        results = [test_tasks.sleep_for.enqueue(0.1) for _ in range(10)]

        for _ in range(3):
            self.start_worker(["--batch"])

        time.sleep(self.WORKER_STARTUP_TIME)

        for process in self.processes:
            process.wait(timeout=5)
            self.assertIsNotNone(process.returncode)

        for result in results:
            result.refresh()
            self.assertEqual(result.status, TaskResultStatus.SUCCESSFUL)

        all_output = ""

        for process in self.processes:
            stdout_text = process.stdout.read()  # type:ignore[union-attr]
            all_output += stdout_text
            self.assertIn("gracefully", stdout_text)

        for result in results:
            # Running and successful
            self.assertEqual(all_output.count(result.id), 2)


class CompatTestCase(SimpleTestCase):
    def test_compat_has_django_task(self) -> None:
        self.assertIn(Task, compat.TASK_CLASSES)

        if VERSION >= (6, 0):
            from django.tasks.base import Task as DjangoTask

            self.assertIn(DjangoTask, compat.TASK_CLASSES)
