Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 116 additions & 6 deletions django_grpc/management/commands/grpcserver.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import datetime
import asyncio
import signal
import threading
import time

from django.core.management.base import BaseCommand
from django.utils import autoreload
Expand All @@ -13,6 +16,13 @@ class Command(BaseCommand):
help = "Run gRPC server"
config = getattr(settings, "GRPCSERVER", dict())

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Graceful shutdown을 위한 상태 관리
self._shutdown_event = threading.Event()
self._server = None
self._original_sigterm_handler = None

def add_arguments(self, parser):
parser.add_argument("--max_workers", type=int, help="Number of workers")
parser.add_argument("--port", type=int, default=50051, help="Port number to listen")
Expand Down Expand Up @@ -40,49 +50,141 @@ def handle(self, *args, **options):
else:
self._serve(**options)

def _setup_signal_handlers(self):
"""시그널 핸들러를 설정합니다 (Gunicorn arbiter.py 참고)"""
# SIGTERM 핸들러 저장
self._original_sigterm_handler = signal.signal(signal.SIGTERM, self._handle_sigterm)

# SIGINT 핸들러도 설정 (Ctrl+C)
signal.signal(signal.SIGINT, self._handle_sigterm)

self.stdout.write("Signal handlers registered for graceful shutdown")

def _handle_sigterm(self, signum, frame):
"""SIGTERM 시그널을 처리하여 graceful shutdown을 시작합니다"""
self.stdout.write(f"Received signal {signum}. Starting graceful shutdown...")
self._shutdown_event.set()

def _graceful_shutdown(self, server):
"""서버를 gracefully하게 종료합니다"""
try:
# 새로운 연결 수락을 중지
self.stdout.write("Stopping server from accepting new connections...")

# gRPC 서버 종료 (graceful=True로 설정하여 진행 중인 요청 완료 대기)
if hasattr(server, 'stop'):
# 동기 서버의 경우
server.stop(grace=True)
else:
# 비동기 서버의 경우
asyncio.create_task(server.stop(grace=True))

# Django 시그널 전송
grpc_shutdown.send(None)

self.stdout.write("Graceful shutdown completed")

except Exception as e:
self.stderr.write(f"Error during graceful shutdown: {e}")

async def _graceful_shutdown_async(self, server):
"""비동기 서버를 gracefully하게 종료합니다"""
try:
# 새로운 연결 수락을 중지
self.stdout.write("Stopping async server from accepting new connections...")

# gRPC 비동기 서버 종료
await server.stop(grace=True)

# Django 시그널 전송
grpc_shutdown.send(None)

self.stdout.write("Async graceful shutdown completed")

except Exception as e:
self.stderr.write(f"Error during async graceful shutdown: {e}")

def _serve(self, max_workers, port, *args, **kwargs):
"""
Run gRPC server
"""
autoreload.raise_last_exception()
self.stdout.write("gRPC server starting at %s" % datetime.datetime.now())

# autoreload 모드가 아닐 때만 시그널 핸들러 설정
# autoreload는 별도 스레드에서 실행되므로 메인 스레드가 아니어서 시그널 핸들러를 등록할 수 없음
if not kwargs.get("autoreload", False):
self._setup_signal_handlers()

server = create_server(max_workers, port)
self._server = server

server.start()

self.stdout.write("gRPC server is listening port %s" % port)

if kwargs["list_handlers"] is True:
# list_handlers 옵션이 있으면 핸들러 목록 출력 (기본값 False)
if kwargs.get("list_handlers", False):
self.stdout.write("Registered handlers:")
for handler in extract_handlers(server):
self.stdout.write("* %s" % handler)

server.wait_for_termination()
# Send shutdown signal to all connected receivers
grpc_shutdown.send(None)
# autoreload 모드가 아닐 때만 graceful shutdown 로직 실행
if not kwargs.get("autoreload", False):
# Graceful shutdown을 위한 대기 루프
try:
while not self._shutdown_event.is_set():
time.sleep(0.1)
except KeyboardInterrupt:
self.stdout.write("Received keyboard interrupt, starting graceful shutdown...")
self._shutdown_event.set()

# Graceful shutdown 수행
self._graceful_shutdown(server)
else:
# autoreload 모드에서는 기존 방식대로 wait_for_termination 사용
server.wait_for_termination()
# Send shutdown signal to all connected receivers
grpc_shutdown.send(None)

def _serve_async(self, max_workers, port, *args, **kwargs):
"""
Run gRPC server in async mode
"""
self.stdout.write("gRPC async server starting at %s" % datetime.datetime.now())

# autoreload 모드가 아닐 때만 시그널 핸들러 설정
# autoreload는 별도 스레드에서 실행되므로 메인 스레드가 아니어서 시그널 핸들러를 등록할 수 없음
if not kwargs.get("autoreload", False):
self._setup_signal_handlers()

# Coroutines to be invoked when the event loop is shutting down.
_cleanup_coroutines = []

server = create_server(max_workers, port)
self._server = server

async def _main_routine():
await server.start()
self.stdout.write("gRPC async server is listening port %s" % port)

if kwargs["list_handlers"] is True:
# list_handlers 옵션이 있으면 핸들러 목록 출력 (기본값 False)
if kwargs.get("list_handlers", False):
self.stdout.write("Registered handlers:")
for handler in extract_handlers(server):
self.stdout.write("* %s" % handler)

await server.wait_for_termination()
# autoreload 모드가 아닐 때만 graceful shutdown 로직 실행
if not kwargs.get("autoreload", False):
# Graceful shutdown을 위한 대기
while not self._shutdown_event.is_set():
await asyncio.sleep(0.1)

# Graceful shutdown 수행
await self._graceful_shutdown_async(server)
else:
# autoreload 모드에서는 기존 방식대로 wait_for_termination 사용
await server.wait_for_termination()

async def _graceful_shutdown():
# Send the signal to all connected receivers on server shutdown.
Expand All @@ -92,6 +194,14 @@ async def _graceful_shutdown():
loop = asyncio.get_event_loop()
try:
loop.run_until_complete(_main_routine())
except KeyboardInterrupt:
if not kwargs.get("autoreload", False):
self.stdout.write("Received keyboard interrupt, starting graceful shutdown...")
self._shutdown_event.set()
loop.run_until_complete(_main_routine())
else:
# autoreload 모드에서는 KeyboardInterrupt를 무시하고 정상 종료
pass
finally:
loop.run_until_complete(*_cleanup_coroutines)
loop.close()
173 changes: 173 additions & 0 deletions tests/test_graceful_shutdown.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
import os
import signal
import subprocess
import time
import pytest
from django.test import TestCase
from django.core.management import call_command
from django.core.management.base import CommandError
from unittest.mock import patch, MagicMock


class GracefulShutdownTestCase(TestCase):
"""Graceful shutdown 기능을 테스트하는 클래스"""

def setUp(self):
"""테스트 설정"""
super().setUp()
self.port = 50052 # 테스트용 포트

def test_signal_handler_registration(self):
"""시그널 핸들러가 올바르게 등록되는지 테스트"""
from django_grpc.management.commands.grpcserver import Command

command = Command()

# 시그널 핸들러 설정 전 상태 확인
self.assertIsNone(command._original_sigterm_handler)

# 시그널 핸들러 설정
with patch('signal.signal') as mock_signal:
command._setup_signal_handlers()

# signal.signal이 두 번 호출되었는지 확인 (SIGTERM, SIGINT)
self.assertEqual(mock_signal.call_count, 2)

# SIGTERM 핸들러가 저장되었는지 확인
self.assertIsNotNone(command._original_sigterm_handler)

def test_sigterm_handler(self):
"""SIGTERM 핸들러가 올바르게 작동하는지 테스트"""
from django_grpc.management.commands.grpcserver import Command

command = Command()

# 초기 상태 확인
self.assertFalse(command._shutdown_event.is_set())

# SIGTERM 핸들러 호출
command._handle_sigterm(signal.SIGTERM, None)

# shutdown 이벤트가 설정되었는지 확인
self.assertTrue(command._shutdown_event.is_set())

@patch('django_grpc.management.commands.grpcserver.create_server')
def test_graceful_shutdown_sync_server(self, mock_create_server):
"""동기 서버의 graceful shutdown 테스트"""
from django_grpc.management.commands.grpcserver import Command

# Mock 서버 생성
mock_server = MagicMock()
mock_create_server.return_value = mock_server

command = Command()

# graceful shutdown 호출
command._graceful_shutdown(mock_server)

# 서버의 stop 메서드가 grace=True로 호출되었는지 확인
mock_server.stop.assert_called_once_with(grace=True)

@patch('django_grpc.management.commands.grpcserver.create_server')
def test_graceful_shutdown_async_server(self, mock_create_server):
"""비동기 서버의 graceful shutdown 테스트"""
from django_grpc.management.commands.grpcserver import Command

# Mock 서버 생성 (stop 메서드가 없는 경우)
mock_server = MagicMock()
del mock_server.stop
mock_create_server.return_value = mock_server

command = Command()

# graceful shutdown 호출
command._graceful_shutdown(mock_server)

# asyncio.create_task가 호출되었는지 확인
# (실제로는 mock을 통해 확인하기 어려우므로 예외 처리만 확인)

def test_command_initialization(self):
"""Command 초기화가 올바르게 되는지 테스트"""
from django_grpc.management.commands.grpcserver import Command

command = Command()

# 초기 상태 확인
self.assertIsNotNone(command._shutdown_event)
self.assertIsNone(command._server)
self.assertIsNone(command._original_sigterm_handler)

@patch('django_grpc.management.commands.grpcserver.create_server')
@patch('django_grpc.management.commands.grpcserver.signal.signal')
def test_serve_method_signal_setup(self, mock_signal, mock_create_server):
"""_serve 메서드에서 시그널 핸들러가 설정되는지 테스트"""
from django_grpc.management.commands.grpcserver import Command

# Mock 서버 생성
mock_server = MagicMock()
mock_create_server.return_value = mock_server

command = Command()

# _serve 메서드 호출 (실제로는 무한 루프에 빠지므로 일부만 테스트)
with patch.object(command, '_setup_signal_handlers') as mock_setup:
with patch.object(command, '_graceful_shutdown'):
# shutdown 이벤트를 미리 설정하여 루프를 빠져나오도록 함
command._shutdown_event.set()
command._serve(max_workers=1, port=self.port)

# 시그널 핸들러 설정이 호출되었는지 확인
mock_setup.assert_called_once()


class GracefulShutdownIntegrationTestCase(TestCase):
"""통합 테스트: 실제 프로세스에서 graceful shutdown 테스트"""

def setUp(self):
"""테스트 설정"""
super().setUp()
self.port = 50053 # 통합 테스트용 포트

@pytest.mark.skipif(
os.name == 'nt', # Windows에서는 signal 처리가 다르므로 스킵
reason="Windows에서는 signal 처리가 다르므로 스킵"
)
def test_sigterm_integration(self):
"""실제 SIGTERM 시그널을 보내서 graceful shutdown 테스트"""
# 이 테스트는 실제 프로세스를 시작하고 SIGTERM을 보내는 통합 테스트입니다.
# 실제 환경에서만 실행해야 합니다.

# 테스트용 Django 설정
os.environ.setdefault('DJANGO_SETTINGS_MODULE', 'tests.settings')

# 프로세스 시작
process = subprocess.Popen([
'python', 'manage.py', 'grpcserver',
'--port', str(self.port),
'--max_workers', '1'
], stdout=subprocess.PIPE, stderr=subprocess.PIPE)

try:
# 서버가 시작될 때까지 잠시 대기
time.sleep(2)

# SIGTERM 시그널 전송
process.send_signal(signal.SIGTERM)

# graceful shutdown을 위한 대기
process.wait(timeout=10)

# 프로세스가 정상적으로 종료되었는지 확인
self.assertEqual(process.returncode, 0)

except subprocess.TimeoutExpired:
# 타임아웃 발생 시 프로세스 강제 종료
process.kill()
process.wait()
self.fail("Graceful shutdown이 타임아웃되었습니다")

finally:
# 프로세스가 아직 실행 중이면 강제 종료
if process.poll() is None:
process.kill()
process.wait()