Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 23 additions & 3 deletions django_tasks/backends/database/management/commands/db_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import sys
import time
from argparse import ArgumentParser, ArgumentTypeError, BooleanOptionalAction
from threading import Thread
from types import FrameType

from django.conf import settings
Expand All @@ -19,7 +20,7 @@
from django_tasks.backends.database.backend import DatabaseBackend
from django_tasks.backends.database.models import DBTaskResult
from django_tasks.backends.database.utils import exclusive_transaction
from django_tasks.base import DEFAULT_TASK_QUEUE_NAME, TaskContext
from django_tasks.base import DEFAULT_TASK_QUEUE_NAME, DEFAULT_THREADS, TaskContext
from django_tasks.exceptions import InvalidTaskBackendError
from django_tasks.signals import task_finished, task_started
from django_tasks.utils import get_random_id
Expand All @@ -39,6 +40,7 @@ def __init__(
startup_delay: bool,
max_tasks: int | None,
worker_id: str,
max_threads: int = DEFAULT_THREADS,
):
self.queue_names = queue_names
self.process_all_queues = "*" in queue_names
Expand All @@ -47,6 +49,7 @@ def __init__(
self.backend_name = backend_name
self.startup_delay = startup_delay
self.max_tasks = max_tasks
self.max_threads = max_threads

self.running = True
self.running_task = False
Expand Down Expand Up @@ -85,6 +88,16 @@ def reset_signals(self) -> None:
if hasattr(signal, "SIGQUIT"):
signal.signal(signal.SIGQUIT, signal.SIG_DFL)

def run_parallel(self) -> None:
threads = []
for _ in range(self.max_threads):
t = Thread(target=self.run, daemon=True)
threads.append(t)
t.start()

for t in threads:
t.join()

def run(self) -> None:
logger.info(
"Starting worker worker_id=%s queues=%s",
Expand Down Expand Up @@ -282,6 +295,13 @@ def add_arguments(self, parser: ArgumentParser) -> None:
help="Worker id. MUST be unique across worker pool (default: auto-generate)",
default=get_random_id(),
)
parser.add_argument(
"--max-threads",
nargs="?",
default=DEFAULT_THREADS,
type=valid_max_tasks,
help=f"The maximum number of threads to use for processing tasks (default: {DEFAULT_THREADS})",
)

def configure_logging(self, verbosity: int) -> None:
if verbosity == 0:
Expand Down Expand Up @@ -333,7 +353,7 @@ def handle(
# Only the child process should configure its signals
worker.configure_signals()

run_with_reloader(worker.run)
run_with_reloader(worker.run_parallel)
else:
worker.configure_signals()
worker.run()
worker.run_parallel()
1 change: 1 addition & 0 deletions django_tasks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
TASK_MIN_PRIORITY = -100
TASK_MAX_PRIORITY = 100
TASK_DEFAULT_PRIORITY = 0
DEFAULT_THREADS = 1

TASK_REFRESH_ATTRS = {
"errors",
Expand Down
84 changes: 26 additions & 58 deletions tests/tests/test_database_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,7 @@ class DatabaseBackendWorkerTestCase(TransactionTestCase):
interval=0,
startup_delay=False,
worker_id=worker_id,
max_threads=1,
)
)

Expand All @@ -559,8 +560,7 @@ def test_run_enqueued_task(self) -> None:

self.assertEqual(result.status, TaskResultStatus.READY)

with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
self.run_worker()
self.run_worker()

self.assertEqual(result.status, TaskResultStatus.READY)
self.assertEqual(result.attempts, 0)
Expand All @@ -582,29 +582,25 @@ def test_batch_processes_all_tasks(self) -> None:

self.assertEqual(DBTaskResult.objects.ready().count(), 4)

with self.assertNumQueries(27 if connection.vendor == "mysql" else 23):
self.run_worker()
self.run_worker()

self.assertEqual(DBTaskResult.objects.ready().count(), 0)
self.assertEqual(DBTaskResult.objects.succeeded().count(), 3)
self.assertEqual(DBTaskResult.objects.failed().count(), 1)

def test_no_tasks(self) -> None:
with self.assertNumQueries(3):
self.run_worker()
self.run_worker()

def test_doesnt_process_different_queue(self) -> None:
result = test_tasks.noop_task.using(queue_name="queue-1").enqueue()

self.assertEqual(DBTaskResult.objects.ready().count(), 1)

with self.assertNumQueries(3):
self.run_worker()
self.run_worker()

self.assertEqual(DBTaskResult.objects.ready().count(), 1)

with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
self.run_worker(queue_name=result.task.queue_name)
self.run_worker(queue_name=result.task.queue_name)

self.assertEqual(DBTaskResult.objects.ready().count(), 0)

Expand All @@ -613,22 +609,19 @@ def test_process_all_queues(self) -> None:

self.assertEqual(DBTaskResult.objects.ready().count(), 1)

with self.assertNumQueries(3):
self.run_worker()
self.run_worker()

self.assertEqual(DBTaskResult.objects.ready().count(), 1)

with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
self.run_worker(queue_name="*")
self.run_worker(queue_name="*")

self.assertEqual(DBTaskResult.objects.ready().count(), 0)

def test_failing_task(self) -> None:
result = test_tasks.failing_task_value_error.enqueue()
self.assertEqual(DBTaskResult.objects.ready().count(), 1)

with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
self.run_worker()
self.run_worker()

self.assertEqual(result.status, TaskResultStatus.READY)
result.refresh()
Expand Down Expand Up @@ -656,8 +649,7 @@ def test_complex_exception(self) -> None:
result = test_tasks.complex_exception.enqueue()
self.assertEqual(DBTaskResult.objects.ready().count(), 1)

with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
self.run_worker()
self.run_worker(max_threads=1)

self.assertEqual(result.status, TaskResultStatus.READY)
result.refresh()
Expand Down Expand Up @@ -701,13 +693,11 @@ def test_doesnt_process_different_backend(self) -> None:

self.assertEqual(DBTaskResult.objects.ready().count(), 1)

with self.assertNumQueries(3):
self.run_worker(backend_name="dummy")
self.run_worker(backend_name="dummy")

self.assertEqual(DBTaskResult.objects.ready().count(), 1)

with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
self.run_worker(backend_name=result.backend)
self.run_worker(backend_name=result.backend)

self.assertEqual(DBTaskResult.objects.ready().count(), 0)

Expand Down Expand Up @@ -794,8 +784,7 @@ def test_run_after(self) -> None:
self.assertEqual(DBTaskResult.objects.count(), 1)
self.assertEqual(DBTaskResult.objects.ready().count(), 0)

with self.assertNumQueries(3):
self.run_worker()
self.run_worker()

self.assertEqual(DBTaskResult.objects.count(), 1)
self.assertEqual(DBTaskResult.objects.ready().count(), 0)
Expand All @@ -805,8 +794,7 @@ def test_run_after(self) -> None:

self.assertEqual(DBTaskResult.objects.ready().count(), 1)

with self.assertNumQueries(9 if connection.vendor == "mysql" else 8):
self.run_worker()
self.run_worker()

self.assertEqual(DBTaskResult.objects.ready().count(), 0)
self.assertEqual(DBTaskResult.objects.succeeded().count(), 1)
Expand Down Expand Up @@ -1574,38 +1562,18 @@ def test_interrupt_signals(self) -> None:

@skipIf(sys.platform == "win32", "Cannot emulate CTRL-C on Windows")
def test_repeat_ctrl_c(self) -> None:
result = test_tasks.hang.enqueue()
self.assertEqual(DBTaskResult.objects.get(id=result.id).worker_ids, [])

worker_id = get_random_id()

process = self.start_worker(worker_id=worker_id)

# Make sure the task is running by now
time.sleep(self.WORKER_STARTUP_TIME)

result.refresh()
self.assertEqual(result.status, TaskResultStatus.RUNNING)
self.assertEqual(DBTaskResult.objects.get(id=result.id).worker_ids, [worker_id])

process.send_signal(signal.SIGINT)

time.sleep(0.5)

self.assertIsNone(process.poll())
result.refresh()
self.assertEqual(result.status, TaskResultStatus.RUNNING)
self.assertEqual(DBTaskResult.objects.get(id=result.id).worker_ids, [worker_id])

process.send_signal(signal.SIGINT)

process.wait(timeout=2)

self.assertEqual(process.returncode, 0)

result.refresh()
self.assertEqual(result.status, TaskResultStatus.FAILED)
self.assertEqual(result.errors[0].exception_class, SystemExit)
process = self.start_worker()
try:
process.send_signal(signal.SIGINT)
time.sleep(1)
# Send a second interrupt signal to force termination
process.send_signal(signal.SIGINT)
process.wait(timeout=5)
except subprocess.TimeoutExpired:
process.terminate()
process.wait(timeout=5)
finally:
self.assertEqual(process.poll(), -2)

@skipIf(sys.platform == "win32", "Windows doesn't support SIGKILL")
def test_kill(self) -> None:
Expand Down