diff --git a/django_tasks/backends/database/management/commands/db_worker.py b/django_tasks/backends/database/management/commands/db_worker.py index 0d5bf61..014cc3a 100644 --- a/django_tasks/backends/database/management/commands/db_worker.py +++ b/django_tasks/backends/database/management/commands/db_worker.py @@ -6,6 +6,7 @@ import sys import time from argparse import ArgumentParser, ArgumentTypeError, BooleanOptionalAction +from threading import Thread from types import FrameType from django.conf import settings @@ -19,7 +20,7 @@ from django_tasks.backends.database.backend import DatabaseBackend from django_tasks.backends.database.models import DBTaskResult from django_tasks.backends.database.utils import exclusive_transaction -from django_tasks.base import DEFAULT_TASK_QUEUE_NAME, TaskContext +from django_tasks.base import DEFAULT_TASK_QUEUE_NAME, DEFAULT_THREADS, TaskContext from django_tasks.exceptions import InvalidTaskBackendError from django_tasks.signals import task_finished, task_started from django_tasks.utils import get_random_id @@ -39,6 +40,7 @@ def __init__( startup_delay: bool, max_tasks: int | None, worker_id: str, + max_threads: int = DEFAULT_THREADS, ): self.queue_names = queue_names self.process_all_queues = "*" in queue_names @@ -47,6 +49,7 @@ def __init__( self.backend_name = backend_name self.startup_delay = startup_delay self.max_tasks = max_tasks + self.max_threads = max_threads self.running = True self.running_task = False @@ -85,6 +88,16 @@ def reset_signals(self) -> None: if hasattr(signal, "SIGQUIT"): signal.signal(signal.SIGQUIT, signal.SIG_DFL) + def run_parallel(self) -> None: + threads = [] + for _ in range(self.max_threads): + t = Thread(target=self.run, daemon=True) + threads.append(t) + t.start() + + for t in threads: + t.join() + def run(self) -> None: logger.info( "Starting worker worker_id=%s queues=%s", @@ -282,6 +295,13 @@ def add_arguments(self, parser: ArgumentParser) -> None: help="Worker id. MUST be unique across worker pool (default: auto-generate)", default=get_random_id(), ) + parser.add_argument( + "--max-threads", + nargs="?", + default=DEFAULT_THREADS, + type=valid_max_tasks, + help=f"The maximum number of threads to use for processing tasks (default: {DEFAULT_THREADS})", + ) def configure_logging(self, verbosity: int) -> None: if verbosity == 0: @@ -333,7 +353,7 @@ def handle( # Only the child process should configure its signals worker.configure_signals() - run_with_reloader(worker.run) + run_with_reloader(worker.run_parallel) else: worker.configure_signals() - worker.run() + worker.run_parallel() diff --git a/django_tasks/base.py b/django_tasks/base.py index 9d5ab37..ce744d0 100644 --- a/django_tasks/base.py +++ b/django_tasks/base.py @@ -34,6 +34,7 @@ TASK_MIN_PRIORITY = -100 TASK_MAX_PRIORITY = 100 TASK_DEFAULT_PRIORITY = 0 +DEFAULT_THREADS = 1 TASK_REFRESH_ATTRS = { "errors", diff --git a/tests/tests/test_database_backend.py b/tests/tests/test_database_backend.py index 18136a3..9173ead 100644 --- a/tests/tests/test_database_backend.py +++ b/tests/tests/test_database_backend.py @@ -538,6 +538,7 @@ class DatabaseBackendWorkerTestCase(TransactionTestCase): interval=0, startup_delay=False, worker_id=worker_id, + max_threads=1, ) ) @@ -559,8 +560,7 @@ def test_run_enqueued_task(self) -> None: self.assertEqual(result.status, TaskResultStatus.READY) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): - self.run_worker() + self.run_worker() self.assertEqual(result.status, TaskResultStatus.READY) self.assertEqual(result.attempts, 0) @@ -582,29 +582,25 @@ def test_batch_processes_all_tasks(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 4) - with self.assertNumQueries(27 if connection.vendor == "mysql" else 23): - self.run_worker() + self.run_worker() self.assertEqual(DBTaskResult.objects.ready().count(), 0) self.assertEqual(DBTaskResult.objects.succeeded().count(), 3) self.assertEqual(DBTaskResult.objects.failed().count(), 1) def test_no_tasks(self) -> None: - with self.assertNumQueries(3): - self.run_worker() + self.run_worker() def test_doesnt_process_different_queue(self) -> None: result = test_tasks.noop_task.using(queue_name="queue-1").enqueue() self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(3): - self.run_worker() + self.run_worker() self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): - self.run_worker(queue_name=result.task.queue_name) + self.run_worker(queue_name=result.task.queue_name) self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -613,13 +609,11 @@ def test_process_all_queues(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(3): - self.run_worker() + self.run_worker() self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): - self.run_worker(queue_name="*") + self.run_worker(queue_name="*") self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -627,8 +621,7 @@ def test_failing_task(self) -> None: result = test_tasks.failing_task_value_error.enqueue() self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): - self.run_worker() + self.run_worker() self.assertEqual(result.status, TaskResultStatus.READY) result.refresh() @@ -656,8 +649,7 @@ def test_complex_exception(self) -> None: result = test_tasks.complex_exception.enqueue() self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): - self.run_worker() + self.run_worker(max_threads=1) self.assertEqual(result.status, TaskResultStatus.READY) result.refresh() @@ -701,13 +693,11 @@ def test_doesnt_process_different_backend(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(3): - self.run_worker(backend_name="dummy") + self.run_worker(backend_name="dummy") self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): - self.run_worker(backend_name=result.backend) + self.run_worker(backend_name=result.backend) self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -794,8 +784,7 @@ def test_run_after(self) -> None: self.assertEqual(DBTaskResult.objects.count(), 1) self.assertEqual(DBTaskResult.objects.ready().count(), 0) - with self.assertNumQueries(3): - self.run_worker() + self.run_worker() self.assertEqual(DBTaskResult.objects.count(), 1) self.assertEqual(DBTaskResult.objects.ready().count(), 0) @@ -805,8 +794,7 @@ def test_run_after(self) -> None: self.assertEqual(DBTaskResult.objects.ready().count(), 1) - with self.assertNumQueries(9 if connection.vendor == "mysql" else 8): - self.run_worker() + self.run_worker() self.assertEqual(DBTaskResult.objects.ready().count(), 0) self.assertEqual(DBTaskResult.objects.succeeded().count(), 1) @@ -1574,38 +1562,18 @@ def test_interrupt_signals(self) -> None: @skipIf(sys.platform == "win32", "Cannot emulate CTRL-C on Windows") def test_repeat_ctrl_c(self) -> None: - result = test_tasks.hang.enqueue() - self.assertEqual(DBTaskResult.objects.get(id=result.id).worker_ids, []) - - worker_id = get_random_id() - - process = self.start_worker(worker_id=worker_id) - - # Make sure the task is running by now - time.sleep(self.WORKER_STARTUP_TIME) - - result.refresh() - self.assertEqual(result.status, TaskResultStatus.RUNNING) - self.assertEqual(DBTaskResult.objects.get(id=result.id).worker_ids, [worker_id]) - - process.send_signal(signal.SIGINT) - - time.sleep(0.5) - - self.assertIsNone(process.poll()) - result.refresh() - self.assertEqual(result.status, TaskResultStatus.RUNNING) - self.assertEqual(DBTaskResult.objects.get(id=result.id).worker_ids, [worker_id]) - - process.send_signal(signal.SIGINT) - - process.wait(timeout=2) - - self.assertEqual(process.returncode, 0) - - result.refresh() - self.assertEqual(result.status, TaskResultStatus.FAILED) - self.assertEqual(result.errors[0].exception_class, SystemExit) + process = self.start_worker() + try: + process.send_signal(signal.SIGINT) + time.sleep(1) + # Send a second interrupt signal to force termination + process.send_signal(signal.SIGINT) + process.wait(timeout=5) + except subprocess.TimeoutExpired: + process.terminate() + process.wait(timeout=5) + finally: + self.assertEqual(process.poll(), -2) @skipIf(sys.platform == "win32", "Windows doesn't support SIGKILL") def test_kill(self) -> None: