Skip to content

Commit 2b531dd

Browse files
committed
Merge branch 'main' of github.com:triton-inference-server/vllm_backend into jacky-vllm-additional-outputs
2 parents dae3c13 + 6c066f6 commit 2b531dd

File tree

2 files changed

+46
-13
lines changed

2 files changed

+46
-13
lines changed

src/model.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,25 @@
2525
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2626

2727
import asyncio
28+
import base64
2829
import gc
2930
import json
3031
import os
3132
import queue
3233
import threading
34+
from io import BytesIO
3335
from typing import Dict, List
3436

3537
import numpy as np
3638
import torch
3739
import triton_python_backend_utils as pb_utils
40+
from PIL import Image
3841
from vllm.engine.arg_utils import AsyncEngineArgs
3942
from vllm.engine.async_llm_engine import AsyncLLMEngine
4043
from vllm.lora.request import LoRARequest
4144
from vllm.sampling_params import SamplingParams
4245
from vllm.utils import random_uuid
46+
from vllm.version import __version__ as _VLLM_VERSION
4347

4448
from utils.metrics import VllmStatLogger
4549

@@ -67,7 +71,7 @@ def auto_complete_config(cls, auto_complete_model_config):
6771

6872
@staticmethod
6973
def _auto_complete_inputs_and_outputs(auto_complete_model_config):
70-
# Inputs/Outputs expected by the backend.
74+
# Inputs expected by the backend.
7175
inputs = [
7276
{"name": "text_input", "data_type": "TYPE_STRING", "dims": [1]},
7377
{
@@ -107,6 +111,16 @@ def _auto_complete_inputs_and_outputs(auto_complete_model_config):
107111
"optional": True,
108112
},
109113
]
114+
if _VLLM_VERSION >= "0.6.3.post1":
115+
inputs.append(
116+
{
117+
"name": "image",
118+
"data_type": "TYPE_STRING",
119+
"dims": [-1], # can be multiple images as separate elements
120+
"optional": True,
121+
}
122+
)
123+
# Outputs expected by the backend.
110124
outputs = [
111125
{"name": "text_output", "data_type": "TYPE_STRING", "dims": [-1]},
112126
{"name": "finish_reason", "data_type": "TYPE_STRING", "dims": [-1]},
@@ -313,6 +327,21 @@ def _get_input_tensors(self, request):
313327
if isinstance(prompt, bytes):
314328
prompt = prompt.decode("utf-8")
315329

330+
# image
331+
if _VLLM_VERSION >= "0.6.3.post1":
332+
images = pb_utils.get_input_tensor_by_name(request, "image")
333+
if images:
334+
images_vllm = []
335+
for image_np in images.as_numpy():
336+
image_b = base64.b64decode(image_np.decode("utf-8"))
337+
image_rgb = Image.open(BytesIO(image_b)).convert("RGB")
338+
images_vllm.append(image_rgb)
339+
if len(images_vllm) > 0:
340+
prompt = {
341+
"prompt": prompt,
342+
"multi_modal_data": {"image": images_vllm},
343+
}
344+
316345
# stream
317346
stream = pb_utils.get_input_tensor_by_name(request, "stream")
318347
if stream:

src/utils/metrics.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from vllm.engine.metrics import StatLoggerBase as VllmStatLoggerBase
3333
from vllm.engine.metrics import Stats as VllmStats
3434
from vllm.engine.metrics import SupportsMetricsInfo, build_1_2_5_buckets
35-
35+
from vllm.version import __version__ as _VLLM_VERSION
3636

3737
class TritonMetrics:
3838
def __init__(self, labels: List[str], max_model_len: int):
@@ -76,11 +76,14 @@ def __init__(self, labels: List[str], max_model_len: int):
7676
description="Number of generation tokens processed.",
7777
kind=pb_utils.MetricFamily.HISTOGRAM,
7878
)
79-
self.histogram_best_of_request_family = pb_utils.MetricFamily(
80-
name="vllm:request_params_best_of",
81-
description="Histogram of the best_of request parameter.",
82-
kind=pb_utils.MetricFamily.HISTOGRAM,
83-
)
79+
# 'best_of' metric has been hidden since vllm 0.6.3
80+
# https://github.com/vllm-project/vllm/commit/cbc2ef55292b2af6ff742095c030e8425124c005
81+
if _VLLM_VERSION < "0.6.3":
82+
self.histogram_best_of_request_family = pb_utils.MetricFamily(
83+
name="vllm:request_params_best_of",
84+
description="Histogram of the best_of request parameter.",
85+
kind=pb_utils.MetricFamily.HISTOGRAM,
86+
)
8487
self.histogram_n_request_family = pb_utils.MetricFamily(
8588
name="vllm:request_params_n",
8689
description="Histogram of the n request parameter.",
@@ -159,10 +162,11 @@ def __init__(self, labels: List[str], max_model_len: int):
159162
buckets=build_1_2_5_buckets(max_model_len),
160163
)
161164
)
162-
self.histogram_best_of_request = self.histogram_best_of_request_family.Metric(
163-
labels=labels,
164-
buckets=[1, 2, 5, 10, 20],
165-
)
165+
if _VLLM_VERSION < "0.6.3":
166+
self.histogram_best_of_request = self.histogram_best_of_request_family.Metric(
167+
labels=labels,
168+
buckets=[1, 2, 5, 10, 20],
169+
)
166170
self.histogram_n_request = self.histogram_n_request_family.Metric(
167171
labels=labels,
168172
buckets=[1, 2, 5, 10, 20],
@@ -247,10 +251,10 @@ def log(self, stats: VllmStats) -> None:
247251
self.metrics.histogram_num_generation_tokens_request,
248252
stats.num_generation_tokens_requests,
249253
),
250-
(self.metrics.histogram_best_of_request, stats.best_of_requests),
251254
(self.metrics.histogram_n_request, stats.n_requests),
252255
]
253-
256+
if _VLLM_VERSION < "0.6.3":
257+
histogram_metrics.append((self.metrics.histogram_best_of_request, stats.best_of_requests))
254258
for metric, data in counter_metrics:
255259
self._log_counter(metric, data)
256260
for metric, data in histogram_metrics:

0 commit comments

Comments
 (0)