diff --git a/ci/L0_backend_vllm/metrics_test/test.sh b/ci/L0_backend_vllm/metrics_test/test.sh index 5564fb12..a9a4db90 100755 --- a/ci/L0_backend_vllm/metrics_test/test.sh +++ b/ci/L0_backend_vllm/metrics_test/test.sh @@ -74,8 +74,10 @@ run_test() { RET=1 fi fi + set -e + # TODO: Non-graceful shutdown when metrics are enabled. kill $SERVER_PID wait $SERVER_PID } diff --git a/ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py b/ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py index 6bef1746..1f8514e3 100644 --- a/ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py +++ b/ci/L0_backend_vllm/metrics_test/vllm_metrics_test.py @@ -170,6 +170,7 @@ def test_vllm_metrics(self): total_prompts, ) + # TODO: Revisit this test due to the removal of best_of def test_custom_sampling_params(self): # Adding sampling parameters for testing metrics. # Definitions can be found here https://docs.vllm.ai/en/latest/dev/sampling_params.html @@ -191,6 +192,7 @@ def test_custom_sampling_params(self): total_prompts = len(self.prompts) # vllm:request_params_best_of + """ self.assertEqual( metrics_dict["vllm:request_params_best_of_count"], total_prompts ) @@ -200,9 +202,10 @@ def test_custom_sampling_params(self): self.assertEqual( metrics_dict["vllm:request_params_best_of_bucket"], total_prompts ) + """ # vllm:request_params_n self.assertEqual(metrics_dict["vllm:request_params_n_count"], total_prompts) - self.assertEqual(metrics_dict["vllm:request_params_n_sum"], n * total_prompts) + # self.assertEqual(metrics_dict["vllm:request_params_n_sum"], n * total_prompts) self.assertEqual(metrics_dict["vllm:request_params_n_bucket"], total_prompts) def test_vllm_metrics_disabled(self): diff --git a/ci/L0_check_health_vllm/mock_async_llm_engine.py b/ci/L0_check_health_vllm/mock_async_llm_engine.py deleted file mode 100644 index d8d9f038..00000000 --- a/ci/L0_check_health_vllm/mock_async_llm_engine.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# Redistribution and use in source and binary forms, with or without -# modification, are permitted provided that the following conditions -# are met: -# * Redistributions of source code must retain the above copyright -# notice, this list of conditions and the following disclaimer. -# * Redistributions in binary form must reproduce the above copyright -# notice, this list of conditions and the following disclaimer in the -# documentation and/or other materials provided with the distribution. -# * Neither the name of NVIDIA CORPORATION nor the names of its -# contributors may be used to endorse or promote products derived -# from this software without specific prior written permission. -# -# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY -# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE -# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR -# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR -# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, -# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, -# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR -# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY -# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE -# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -from vllm.engine.async_llm_engine import AsyncLLMEngine as real_AsyncLLMEngine - - -class mock_AsyncLLMEngine(real_AsyncLLMEngine): - _mock_check_health_count = 0 - - async def check_health(self) -> None: - self._mock_check_health_count += 1 - if self._mock_check_health_count > 1: - raise RuntimeError("Simulated vLLM check_health() failure") diff --git a/ci/L0_check_health_vllm/test.sh b/ci/L0_check_health_vllm/test.sh index 9c3b4eec..50c1a097 100755 --- a/ci/L0_check_health_vllm/test.sh +++ b/ci/L0_check_health_vllm/test.sh @@ -47,16 +47,24 @@ function enable_health_check { echo -e "}" >> models/vllm_opt/config.pbtxt } +VLLM_INSTALL_PATH="/usr/local/lib/python3.12/dist-packages/vllm" + function mock_vllm_async_llm_engine { - mv /opt/tritonserver/backends/vllm/model.py /opt/tritonserver/backends/vllm/.model.py.backup - cp /opt/tritonserver/backends/vllm/.model.py.backup /opt/tritonserver/backends/vllm/model.py - sed -i 's/from vllm.engine.async_llm_engine import AsyncLLMEngine/from mock_async_llm_engine import mock_AsyncLLMEngine as AsyncLLMEngine/' /opt/tritonserver/backends/vllm/model.py - cp mock_async_llm_engine.py /opt/tritonserver/backends/vllm + # backup original file + mv $VLLM_INSTALL_PATH/engine/multiprocessing/client.py $VLLM_INSTALL_PATH/engine/multiprocessing/client.py.backup + cp $VLLM_INSTALL_PATH/engine/multiprocessing/client.py.backup $VLLM_INSTALL_PATH/engine/multiprocessing/client.py + # overwrite the original check_health method + echo -e "" >> $VLLM_INSTALL_PATH/engine/multiprocessing/client.py + echo -e " async def check_health(self, check_count=[0]):" >> $VLLM_INSTALL_PATH/engine/multiprocessing/client.py + echo -e " check_count[0] += 1" >> $VLLM_INSTALL_PATH/engine/multiprocessing/client.py + echo -e " if check_count[0] > 1:" >> $VLLM_INSTALL_PATH/engine/multiprocessing/client.py + echo -e " raise RuntimeError(\"Simulated vLLM check_health() failure\")" >> $VLLM_INSTALL_PATH/engine/multiprocessing/client.py } function unmock_vllm_async_llm_engine { - rm -f /opt/tritonserver/backends/vllm/mock_async_llm_engine.py /opt/tritonserver/backends/vllm/model.py - mv /opt/tritonserver/backends/vllm/.model.py.backup /opt/tritonserver/backends/vllm/model.py + # restore from backup + rm -f $VLLM_INSTALL_PATH/engine/multiprocessing/client.py + mv $VLLM_INSTALL_PATH/engine/multiprocessing/client.py.backup $VLLM_INSTALL_PATH/engine/multiprocessing/client.py } function test_check_health { diff --git a/src/model.py b/src/model.py index c3d54479..ad2a5c88 100644 --- a/src/model.py +++ b/src/model.py @@ -39,11 +39,12 @@ import triton_python_backend_utils as pb_utils from PIL import Image from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.async_llm_engine import AsyncLLMEngine +from vllm.entrypoints.openai.api_server import ( + build_async_engine_client_from_engine_args, +) from vllm.lora.request import LoRARequest from vllm.sampling_params import SamplingParams from vllm.utils import random_uuid -from vllm.version import __version__ as _VLLM_VERSION from utils.metrics import VllmStatLogger @@ -74,6 +75,12 @@ def _auto_complete_inputs_and_outputs(auto_complete_model_config): # Inputs expected by the backend. inputs = [ {"name": "text_input", "data_type": "TYPE_STRING", "dims": [1]}, + { + "name": "image", + "data_type": "TYPE_STRING", + "dims": [-1], # can be multiple images as separate elements + "optional": True, + }, { "name": "stream", "data_type": "TYPE_BOOL", @@ -123,15 +130,6 @@ def _auto_complete_inputs_and_outputs(auto_complete_model_config): "optional": True, }, ] - if _VLLM_VERSION >= "0.6.3.post1": - inputs.append( - { - "name": "image", - "data_type": "TYPE_STRING", - "dims": [-1], # can be multiple images as separate elements - "optional": True, - } - ) # Outputs expected by the backend. outputs = [ {"name": "text_output", "data_type": "TYPE_STRING", "dims": [-1]}, @@ -174,27 +172,31 @@ def initialize(self, args): ) self._is_healthy = True - # Prepare vLLM engine - self.init_engine() + # Initialize engine arguments + # TODO: Move this into _init_engine(), after moving check metrics enabled. + self._init_engine_args() - # Counter to keep track of ongoing request counts - self.ongoing_request_count = 0 + # Check if metrics are enabled. The ZMQ process cannot be used when metrics are + # enabled. + # TODO: Move the check into _setup_metrics(). + self._enable_metrics = ( + self._get_bool_config_param("REPORT_CUSTOM_METRICS") + and not self._aync_engine_args.disable_log_stats + ) + + # Starting the vLLM engine and its event thread running the AsyncIO event loop. + self._init_engine() + + # Setup vLLM metrics + self._setup_metrics() # Starting the response thread. It allows vLLM to keep making progress while # response sender(s) are sending responses to server frontend. self._response_queue = queue.Queue() - self._response_thread = threading.Thread(target=self.response_loop) + self._response_thread = threading.Thread(target=self._response_loop) self._response_thread.start() - # Starting asyncio event loop to process the received requests asynchronously. - self._loop = asyncio.get_event_loop() - self._event_thread = threading.Thread( - target=self.engine_loop, args=(self._loop,) - ) - self._shutdown_event = asyncio.Event() - self._event_thread.start() - - def init_engine(self): + def _init_engine_args(self): # Currently, Triton needs to use decoupled policy for asynchronously # forwarding requests to vLLM engine, so assert it. self.using_decoupled = pb_utils.using_decoupled_model_transaction_policy( @@ -214,71 +216,88 @@ def init_engine(self): self.vllm_engine_config = json.load(file) # Validate device and multi-processing settings are currently set based on model/configs. - self.validate_device_config() + self._validate_device_config() # Check for LoRA config and set it up if enabled - self.setup_lora() - - # Create an AsyncLLMEngine from the config from JSON - aync_engine_args = AsyncEngineArgs(**self.vllm_engine_config) - self.llm_engine = AsyncLLMEngine.from_engine_args(aync_engine_args) + self._setup_lora() - # Create vLLM custom metrics - self.vllm_metrics = None - if ( - self._get_bool_config_param("REPORT_CUSTOM_METRICS") - and not aync_engine_args.disable_log_stats - ): - try: - labels = { - "model": self.args["model_name"], - "version": self.args["model_version"], - } - # Add vLLM custom metrics - engine_config = self.llm_engine.engine.model_config - self.vllm_metrics = VllmStatLogger( - labels, engine_config.max_model_len, self.logger - ) - self.llm_engine.add_logger("triton", self.vllm_metrics) - except pb_utils.TritonModelException as e: - if "metrics not supported" in str(e): - # Metrics are disabled at the server - self.logger.log_info("[vllm] Metrics not supported") - else: - raise e + # Create an AsyncEngineArgs from the config from JSON + self._aync_engine_args = AsyncEngineArgs(**self.vllm_engine_config) - def _get_bool_config_param(self, param_name: str) -> bool: - return (param_name in self.model_config["parameters"]) and ( - self.model_config["parameters"][param_name]["string_value"].lower() - == "true" + def _init_engine(self): + # Run the engine in a separate thread running the AsyncIO event loop. + self._llm_engine = None + self._llm_engine_start_cv = threading.Condition() + self._llm_engine_shutdown_event = threading.Event() + self._event_thread = threading.Thread( + target=asyncio.run, args=(self._run_llm_engine(),) ) + self._event_thread.start() + with self._llm_engine_start_cv: + while self._llm_engine is None: + self._llm_engine_start_cv.wait() + + # The 'threading.Thread()' will not raise the exception here should the engine + # failed to start, so the exception is passed back via the engine variable. + if isinstance(self._llm_engine, Exception): + e = self._llm_engine + self.logger.log_error(f"[vllm] Failed to start engine: {e}") + if self._event_thread is not None: + self._event_thread.join() + self._event_thread = None + raise e - def setup_lora(self): - self.enable_lora = False + async def _run_llm_engine(self): + # Counter to keep track of ongoing request counts. + self._ongoing_request_count = 0 - # Check if `enable_lora` field is in the `model.json`, - # and if it is, read its contents, which can be string or bool. - if ( - "enable_lora" in self.vllm_engine_config.keys() - and str(self.vllm_engine_config["enable_lora"]).lower() == "true" - ): - # create Triton LoRA weights repository - multi_lora_args_filepath = os.path.join( - pb_utils.get_model_dir(), _MULTI_LORA_ARGS_FILENAME - ) - try: - with open(multi_lora_args_filepath) as lora_file: - lora_repository: Dict[str, str] = json.load(lora_file) - self.lora_repository = lora_repository - self.supported_loras: List[str] = list(self.lora_repository.keys()) - self.supported_loras_len = len(self.supported_loras) - self.enable_lora = True - except FileNotFoundError: - raise FileNotFoundError( - f"Triton backend cannot find {multi_lora_args_filepath}." - ) + try: + # Start the vLLM engine. The engine lives for the scope of this with + # statement. + # TODO: Metrics should work with ZMQ enabled. + async with build_async_engine_client_from_engine_args( + engine_args=self._aync_engine_args, + disable_frontend_multiprocessing=self._enable_metrics, + ) as engine: + # Capture the engine event loop and make it visible to other threads. + self._event_loop = asyncio.get_running_loop() + + # Signal the engine is started and make it visible to other threads. + with self._llm_engine_start_cv: + self._llm_engine = engine + self._llm_engine_start_cv.notify_all() + + # Wait for the engine shutdown signal. + while not self._llm_engine_shutdown_event.is_set(): + await asyncio.sleep(0.1) # Prevent busy-waiting + + # Wait for the ongoing requests to complete. + while self._ongoing_request_count > 0: + self.logger.log_info( + "[vllm] Awaiting remaining {} requests".format( + self._ongoing_request_count + ) + ) + await asyncio.sleep(1) + + # Cancel all tasks in the event loop. + for task in asyncio.all_tasks(loop=self._event_loop): + if task is not asyncio.current_task(): + task.cancel() + except Exception as e: + # Signal and pass the exception back via the engine variable if the engine + # failed to start. If the engine has started, re-raise the exception. + with self._llm_engine_start_cv: + if self._llm_engine is None: + self._llm_engine = e + self._llm_engine_start_cv.notify_all() + return + raise e + + self._llm_engine = None + self.logger.log_info("[vllm] Shutdown complete") - def validate_device_config(self): + def _validate_device_config(self): triton_kind = self.args["model_instance_kind"] triton_device_id = int(self.args["model_instance_device_id"]) triton_instance = f"{self.args['model_name']}_{triton_device_id}" @@ -304,46 +323,202 @@ def validate_device_config(self): # vLLM doesn't currently (v0.4.2) expose device selection in the APIs torch.cuda.set_device(triton_device_id) - def create_task(self, coro): - """ - Creates a task on the engine's event loop which is running on a separate thread. - """ - assert ( - self._shutdown_event.is_set() is False - ), "Cannot create tasks after shutdown has been requested" - - return asyncio.run_coroutine_threadsafe(coro, self._loop) - - def engine_loop(self, loop): - """ - Runs the engine's event loop on a separate thread. - """ - asyncio.set_event_loop(loop) - self._loop.run_until_complete(self.await_shutdown()) - - async def await_shutdown(self): - """ - Primary coroutine running on the engine event loop. This coroutine is responsible for - keeping the engine alive until a shutdown is requested. - """ - # first await the shutdown signal - while self._shutdown_event.is_set() is False: - await asyncio.sleep(5) - - # Wait for the ongoing_requests - while self.ongoing_request_count > 0: - self.logger.log_info( - "[vllm] Awaiting remaining {} requests".format( - self.ongoing_request_count + def _setup_lora(self): + self.enable_lora = False + + # Check if `enable_lora` field is in the `model.json`, + # and if it is, read its contents, which can be string or bool. + if ( + "enable_lora" in self.vllm_engine_config.keys() + and str(self.vllm_engine_config["enable_lora"]).lower() == "true" + ): + # create Triton LoRA weights repository + multi_lora_args_filepath = os.path.join( + pb_utils.get_model_dir(), _MULTI_LORA_ARGS_FILENAME + ) + try: + with open(multi_lora_args_filepath) as lora_file: + lora_repository: Dict[str, str] = json.load(lora_file) + self.lora_repository = lora_repository + self.supported_loras: List[str] = list(self.lora_repository.keys()) + self.supported_loras_len = len(self.supported_loras) + self.enable_lora = True + except FileNotFoundError: + raise FileNotFoundError( + f"Triton backend cannot find {multi_lora_args_filepath}." + ) + + def _setup_metrics(self): + self._vllm_metrics = None + # TODO: Do not read metrics directly from the vLLM engine, read from prometheus + # client to allow the use of ZMQ process when metrics are enabled. See + # https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/entrypoints/openai/api_server.py#L222-L245 + if self._enable_metrics: + try: + labels = { + "model": self.args["model_name"], + "version": self.args["model_version"], + } + # Add vLLM custom metrics + engine_config = self._llm_engine.engine.model_config + self._vllm_metrics = VllmStatLogger( + labels, engine_config.max_model_len, self.logger ) + self._llm_engine.add_logger("triton", self._vllm_metrics) + except pb_utils.TritonModelException as e: + if "metrics not supported" in str(e): + # Metrics are disabled at the server + self.logger.log_info("[vllm] Metrics not supported") + else: + raise e + + def _get_bool_config_param(self, param_name: str) -> bool: + return (param_name in self.model_config["parameters"]) and ( + self.model_config["parameters"][param_name]["string_value"].lower() + == "true" + ) + + def _response_loop(self): + while True: + item = self._response_queue.get() + # To signal shutdown a None item will be added to the queue. + if item is None: + break + response_state, response, response_flag = item + response_sender = response_state["response_sender"] + try: + response_sender.send(response, response_flag) + # Stop checking for cancellation if the last response is generated. + if not response_state["last_response_generated"]: + response_state["is_cancelled"] = response_sender.is_cancelled() + except Exception as e: + self.logger.log_error( + f"An error occurred while sending a response: {e}" + ) + finally: + if response_flag == pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL: + self._ongoing_request_count -= 1 + + def execute(self, requests): + if self._enable_health_check and not self._check_health(requests): + return None + for request in requests: + request = self._verify_loras(request) + if request is not None: + assert ( + self._llm_engine_shutdown_event.is_set() is False + ), "Cannot create tasks after shutdown has been requested" + coro = self._generate(request) + asyncio.run_coroutine_threadsafe(coro, self._event_loop) + return None + + async def _generate(self, request): + response_sender = request.get_response_sender() + response_state = { + "response_sender": response_sender, + "is_cancelled": False, + "last_response_generated": False, # last response ready but not yet sent + } + self._ongoing_request_count += 1 + decrement_ongoing_request_count = True + try: + request_id = random_uuid() + ( + prompt, + stream, + prepend_input, + parameters, + additional_outputs, + ) = self._get_input_tensors(request) + + sampling_params_dict = self._get_sampling_params_dict(parameters) + lora_name = sampling_params_dict.pop("lora_name", None) + sampling_params = SamplingParams(**sampling_params_dict) + lora_request = None + if lora_name is not None: + lora_id = str(self.supported_loras.index(lora_name) + 1) + lora_int_id = int(lora_id) + lora_local_path = self.lora_repository[lora_name] + lora_request = LoRARequest(lora_id, lora_int_id, lora_local_path) + + response_iterator = self._llm_engine.generate( + prompt, sampling_params, request_id, lora_request=lora_request ) - await asyncio.sleep(5) - for task in asyncio.all_tasks(loop=self._loop): - if task is not asyncio.current_task(): - task.cancel() + request_output_state = {} + async for request_output in response_iterator: + # Cancellation state will be checked by the response loop and written to + # the response state if streaming. If not streaming, cancellation state + # needs to be checked here. + is_cancelled = response_state["is_cancelled"] + if not stream: + is_cancelled = response_sender.is_cancelled() + if is_cancelled: + self.logger.log_info("[vllm] Cancelling the request") + await self._llm_engine.abort(request_id) + self.logger.log_info("[vllm] Successfully cancelled the request") - self.logger.log_info("[vllm] Shutdown complete") + if stream: + # Add cancelled final response to response loop. + response_state["last_response_generated"] = True + response = pb_utils.InferenceResponse( + error=pb_utils.TritonError( + message="Request was cancelled", + code=pb_utils.TritonError.CANCELLED, + ) + ) + flags = pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL + decrement_ongoing_request_count = False + self._response_queue.put_nowait( + (response_state, response, flags) + ) + + break + + # Send each response if streaming. + if stream: + response = self._create_response( + request_output_state, + request_output, + prepend_input=False, + additional_outputs=additional_outputs, + ) + flags = 0 + if request_output.finished: + response_state["last_response_generated"] = True + flags = pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL + decrement_ongoing_request_count = False + self._response_queue.put_nowait((response_state, response, flags)) + + # Send the last response which contains all the outputs if not streaming. + if not stream: + response_sender.send( + self._create_response( + request_output_state={}, + request_output=request_output, + prepend_input=prepend_input, + additional_outputs=additional_outputs, + ), + flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL, + ) + + except Exception as e: + self.logger.log_error(f"[vllm] Error generating stream: {e}") + error = pb_utils.TritonError(f"Error generating stream: {e}") + text_output_tensor = pb_utils.Tensor( + "text_output", np.asarray(["N/A"], dtype=self.output_dtype) + ) + response = pb_utils.InferenceResponse( + output_tensors=[text_output_tensor], error=error + ) + response_sender.send( + response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL + ) + raise e + + finally: + if decrement_ongoing_request_count: + self._ongoing_request_count -= 1 def _get_input_tensors(self, request): # prompt @@ -352,19 +527,18 @@ def _get_input_tensors(self, request): prompt = prompt.decode("utf-8") # image - if _VLLM_VERSION >= "0.6.3.post1": - images = pb_utils.get_input_tensor_by_name(request, "image") - if images: - images_vllm = [] - for image_np in images.as_numpy(): - image_b = base64.b64decode(image_np.decode("utf-8")) - image_rgb = Image.open(BytesIO(image_b)).convert("RGB") - images_vllm.append(image_rgb) - if len(images_vllm) > 0: - prompt = { - "prompt": prompt, - "multi_modal_data": {"image": images_vllm}, - } + images = pb_utils.get_input_tensor_by_name(request, "image") + if images: + images_vllm = [] + for image_np in images.as_numpy(): + image_b = base64.b64decode(image_np.decode("utf-8")) + image_rgb = Image.open(BytesIO(image_b)).convert("RGB") + images_vllm.append(image_rgb) + if len(images_vllm) > 0: + prompt = { + "prompt": prompt, + "multi_modal_data": {"image": images_vllm}, + } # stream stream = pb_utils.get_input_tensor_by_name(request, "stream") @@ -419,59 +593,6 @@ def _get_input_tensors(self, request): return prompt, stream, prepend_input, parameters, additional_outputs - def get_sampling_params_dict(self, params_json): - """ - This functions parses the dictionary values into their - expected format. - """ - - params_dict = json.loads(params_json) - - # Special parsing for the supported sampling parameters - bool_keys = ["ignore_eos", "skip_special_tokens", "use_beam_search"] - for k in bool_keys: - if k in params_dict: - params_dict[k] = bool(params_dict[k]) - - float_keys = [ - "frequency_penalty", - "length_penalty", - "presence_penalty", - "temperature", - "top_p", - ] - for k in float_keys: - if k in params_dict: - params_dict[k] = float(params_dict[k]) - - int_keys = ["best_of", "max_tokens", "min_tokens", "n", "top_k"] - for k in int_keys: - if k in params_dict: - params_dict[k] = int(params_dict[k]) - - return params_dict - - def response_loop(self): - while True: - item = self._response_queue.get() - # To signal shutdown a None item will be added to the queue. - if item is None: - break - response_state, response, response_flag = item - response_sender = response_state["response_sender"] - try: - response_sender.send(response, response_flag) - # Stop checking for cancellation if the last response is generated. - if not response_state["last_response_generated"]: - response_state["is_cancelled"] = response_sender.is_cancelled() - except Exception as e: - self.logger.log_error( - f"An error occurred while sending a response: {e}" - ) - finally: - if response_flag == pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL: - self.ongoing_request_count -= 1 - def _create_response( self, request_output_state, request_output, prepend_input, additional_outputs ): @@ -584,118 +705,34 @@ def _create_response( return pb_utils.InferenceResponse(output_tensors=output_tensors) - async def generate(self, request): - """ - Forwards single request to LLM engine and returns responses. - """ - response_sender = request.get_response_sender() - response_state = { - "response_sender": response_sender, - "is_cancelled": False, - "last_response_generated": False, # last response ready but not yet sent - } - self.ongoing_request_count += 1 - decrement_ongoing_request_count = True - try: - request_id = random_uuid() - ( - prompt, - stream, - prepend_input, - parameters, - additional_outputs, - ) = self._get_input_tensors(request) - - sampling_params_dict = self.get_sampling_params_dict(parameters) - lora_name = sampling_params_dict.pop("lora_name", None) - sampling_params = SamplingParams(**sampling_params_dict) - lora_request = None - if lora_name is not None: - lora_id = str(self.supported_loras.index(lora_name) + 1) - lora_int_id = int(lora_id) - lora_local_path = self.lora_repository[lora_name] - lora_request = LoRARequest(lora_id, lora_int_id, lora_local_path) - - response_iterator = await self.llm_engine.add_request( - request_id, prompt, sampling_params, lora_request=lora_request - ) - - request_output_state = {} - async for request_output in response_iterator: - # Cancellation state will be checked by the response loop and written to - # the response state if streaming. If not streaming, cancellation state - # needs to be checked here. - is_cancelled = response_state["is_cancelled"] - if not stream: - is_cancelled = response_sender.is_cancelled() - if is_cancelled: - self.logger.log_info("[vllm] Cancelling the request") - await self.llm_engine.abort(request_id) - self.logger.log_info("[vllm] Successfully cancelled the request") - - if stream: - # Add cancelled final response to response loop. - response_state["last_response_generated"] = True - response = pb_utils.InferenceResponse( - error=pb_utils.TritonError( - message="Request was cancelled", - code=pb_utils.TritonError.CANCELLED, - ) - ) - flags = pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL - decrement_ongoing_request_count = False - self._response_queue.put_nowait( - (response_state, response, flags) - ) - - break + def _get_sampling_params_dict(self, params_json): + params_dict = json.loads(params_json) - # Send each response if streaming. - if stream: - response = self._create_response( - request_output_state, - request_output, - prepend_input=False, - additional_outputs=additional_outputs, - ) - flags = 0 - if request_output.finished: - response_state["last_response_generated"] = True - flags = pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL - decrement_ongoing_request_count = False - self._response_queue.put_nowait((response_state, response, flags)) + # Special parsing for the supported sampling parameters + bool_keys = ["ignore_eos", "skip_special_tokens", "use_beam_search"] + for k in bool_keys: + if k in params_dict: + params_dict[k] = bool(params_dict[k]) - # Send the last response which contains all the outputs if not streaming. - if not stream: - response_sender.send( - self._create_response( - request_output_state={}, - request_output=request_output, - prepend_input=prepend_input, - additional_outputs=additional_outputs, - ), - flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL, - ) + float_keys = [ + "frequency_penalty", + "length_penalty", + "presence_penalty", + "temperature", + "top_p", + ] + for k in float_keys: + if k in params_dict: + params_dict[k] = float(params_dict[k]) - except Exception as e: - self.logger.log_error(f"[vllm] Error generating stream: {e}") - error = pb_utils.TritonError(f"Error generating stream: {e}") - text_output_tensor = pb_utils.Tensor( - "text_output", np.asarray(["N/A"], dtype=self.output_dtype) - ) - response = pb_utils.InferenceResponse( - output_tensors=[text_output_tensor], error=error - ) - response_sender.send( - response, flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL - ) - raise e + int_keys = ["best_of", "max_tokens", "min_tokens", "n", "top_k"] + for k in int_keys: + if k in params_dict: + params_dict[k] = int(params_dict[k]) - finally: - if decrement_ongoing_request_count: - self.ongoing_request_count -= 1 + return params_dict - def verify_loras(self, request): + def _verify_loras(self, request): # We will check if the requested lora exists here, if not we will send a # response with `LoRA not found` information. In this way we may avoid # further processing. @@ -707,7 +744,7 @@ def verify_loras(self, request): ) if parameters_input_tensor: parameters = parameters_input_tensor.as_numpy()[0].decode("utf-8") - sampling_params_dict = self.get_sampling_params_dict(parameters) + sampling_params_dict = self._get_sampling_params_dict(parameters) lora_name = sampling_params_dict.pop("lora_name", None) if lora_name is not None: @@ -739,8 +776,8 @@ def verify_loras(self, request): return verified_request def _check_health(self, requests): - coro = self.llm_engine.check_health() - future = asyncio.run_coroutine_threadsafe(coro, self._loop) + coro = self._llm_engine.check_health() + future = asyncio.run_coroutine_threadsafe(coro, self._event_loop) try: future.result() except Exception as e: @@ -762,30 +799,9 @@ def _check_health(self, requests): ) return self._is_healthy - def execute(self, requests): - """ - Triton core issues requests to the backend via this method. - - When this method returns, new requests can be issued to the backend. Blocking - this function would prevent the backend from pulling additional requests from - Triton into the vLLM engine. This can be done if the kv cache within vLLM engine - is too loaded. - We are pushing all the requests on vllm and let it handle the full traffic. - """ - if self._enable_health_check and not self._check_health(requests): - return None - for request in requests: - request = self.verify_loras(request) - if request is not None: - self.create_task(self.generate(request)) - return None - def finalize(self): - """ - Triton virtual method; called when the model is unloaded. - """ self.logger.log_info("[vllm] Issuing finalize to vllm backend") - self._shutdown_event.set() + self._llm_engine_shutdown_event.set() # Shutdown the event thread. if self._event_thread is not None: @@ -798,9 +814,9 @@ def finalize(self): self._response_thread.join() self._response_thread = None - # Shutdown the logger thread. - if self.vllm_metrics is not None: - self.vllm_metrics.finalize() + # Shutdown the metrics thread. + if self._vllm_metrics is not None: + self._vllm_metrics.finalize() # When using parallel tensors, the stub process may not shutdown due to # unreleased references, so manually run the garbage collector once. diff --git a/src/utils/metrics.py b/src/utils/metrics.py index 48b77a2c..c251e941 100644 --- a/src/utils/metrics.py +++ b/src/utils/metrics.py @@ -32,7 +32,6 @@ from vllm.engine.metrics import StatLoggerBase as VllmStatLoggerBase from vllm.engine.metrics import Stats as VllmStats from vllm.engine.metrics import SupportsMetricsInfo, build_1_2_5_buckets -from vllm.version import __version__ as _VLLM_VERSION class TritonMetrics: @@ -77,14 +76,6 @@ def __init__(self, labels: List[str], max_model_len: int): description="Number of generation tokens processed.", kind=pb_utils.MetricFamily.HISTOGRAM, ) - # 'best_of' metric has been hidden since vllm 0.6.3 - # https://github.com/vllm-project/vllm/commit/cbc2ef55292b2af6ff742095c030e8425124c005 - if _VLLM_VERSION < "0.6.3": - self.histogram_best_of_request_family = pb_utils.MetricFamily( - name="vllm:request_params_best_of", - description="Histogram of the best_of request parameter.", - kind=pb_utils.MetricFamily.HISTOGRAM, - ) self.histogram_n_request_family = pb_utils.MetricFamily( name="vllm:request_params_n", description="Histogram of the n request parameter.", @@ -163,13 +154,6 @@ def __init__(self, labels: List[str], max_model_len: int): buckets=build_1_2_5_buckets(max_model_len), ) ) - if _VLLM_VERSION < "0.6.3": - self.histogram_best_of_request = ( - self.histogram_best_of_request_family.Metric( - labels=labels, - buckets=[1, 2, 5, 10, 20], - ) - ) self.histogram_n_request = self.histogram_n_request_family.Metric( labels=labels, buckets=[1, 2, 5, 10, 20], @@ -256,10 +240,6 @@ def log(self, stats: VllmStats) -> None: ), (self.metrics.histogram_n_request, stats.n_requests), ] - if _VLLM_VERSION < "0.6.3": - histogram_metrics.append( - (self.metrics.histogram_best_of_request, stats.best_of_requests) - ) for metric, data in counter_metrics: self._log_counter(metric, data) for metric, data in histogram_metrics: