Skip to content

Commit 6c066f6

Browse files
xiejibingjibxie
and
jibxie
authored
Support input for llama3.2 multi-modal model (#69)
Co-authored-by: jibxie <[email protected]>
1 parent b71088a commit 6c066f6

File tree

2 files changed

+47
-13
lines changed

2 files changed

+47
-13
lines changed

src/model.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@
3131
import queue
3232
import threading
3333
from typing import Dict, List
34-
34+
import base64
35+
from PIL import Image
36+
from io import BytesIO
3537
import numpy as np
3638
import torch
3739
import triton_python_backend_utils as pb_utils
@@ -40,6 +42,7 @@
4042
from vllm.lora.request import LoRARequest
4143
from vllm.sampling_params import SamplingParams
4244
from vllm.utils import random_uuid
45+
from vllm.version import __version__ as _VLLM_VERSION
4346

4447
from utils.metrics import VllmStatLogger
4548

@@ -71,6 +74,14 @@ def auto_complete_config(auto_complete_model_config):
7174
"optional": True,
7275
},
7376
]
77+
if _VLLM_VERSION >= "0.6.3.post1":
78+
inputs.append({
79+
"name": "image",
80+
"data_type": "TYPE_STRING",
81+
"dims": [-1], # can be multiple images as separate elements
82+
"optional": True,
83+
})
84+
7485
outputs = [{"name": "text_output", "data_type": "TYPE_STRING", "dims": [-1]}]
7586

7687
# Store the model configuration as a dictionary.
@@ -385,6 +396,25 @@ async def generate(self, request):
385396
).as_numpy()[0]
386397
if isinstance(prompt, bytes):
387398
prompt = prompt.decode("utf-8")
399+
400+
if _VLLM_VERSION >= "0.6.3.post1":
401+
image_input_tensor = pb_utils.get_input_tensor_by_name(
402+
request, "image"
403+
)
404+
if image_input_tensor:
405+
image_list = []
406+
for image_raw in image_input_tensor.as_numpy():
407+
image_data = base64.b64decode(image_raw.decode("utf-8"))
408+
image = Image.open(BytesIO(image_data)).convert("RGB")
409+
image_list.append(image)
410+
if len(image_list) > 0:
411+
prompt = {
412+
"prompt": prompt,
413+
"multi_modal_data": {
414+
"image": image_list
415+
}
416+
}
417+
388418
stream = pb_utils.get_input_tensor_by_name(request, "stream")
389419
if stream:
390420
stream = stream.as_numpy()[0]

src/utils/metrics.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232
from vllm.engine.metrics import StatLoggerBase as VllmStatLoggerBase
3333
from vllm.engine.metrics import Stats as VllmStats
3434
from vllm.engine.metrics import SupportsMetricsInfo, build_1_2_5_buckets
35-
35+
from vllm.version import __version__ as _VLLM_VERSION
3636

3737
class TritonMetrics:
3838
def __init__(self, labels: List[str], max_model_len: int):
@@ -76,11 +76,14 @@ def __init__(self, labels: List[str], max_model_len: int):
7676
description="Number of generation tokens processed.",
7777
kind=pb_utils.MetricFamily.HISTOGRAM,
7878
)
79-
self.histogram_best_of_request_family = pb_utils.MetricFamily(
80-
name="vllm:request_params_best_of",
81-
description="Histogram of the best_of request parameter.",
82-
kind=pb_utils.MetricFamily.HISTOGRAM,
83-
)
79+
# 'best_of' metric has been hidden since vllm 0.6.3
80+
# https://github.com/vllm-project/vllm/commit/cbc2ef55292b2af6ff742095c030e8425124c005
81+
if _VLLM_VERSION < "0.6.3":
82+
self.histogram_best_of_request_family = pb_utils.MetricFamily(
83+
name="vllm:request_params_best_of",
84+
description="Histogram of the best_of request parameter.",
85+
kind=pb_utils.MetricFamily.HISTOGRAM,
86+
)
8487
self.histogram_n_request_family = pb_utils.MetricFamily(
8588
name="vllm:request_params_n",
8689
description="Histogram of the n request parameter.",
@@ -159,10 +162,11 @@ def __init__(self, labels: List[str], max_model_len: int):
159162
buckets=build_1_2_5_buckets(max_model_len),
160163
)
161164
)
162-
self.histogram_best_of_request = self.histogram_best_of_request_family.Metric(
163-
labels=labels,
164-
buckets=[1, 2, 5, 10, 20],
165-
)
165+
if _VLLM_VERSION < "0.6.3":
166+
self.histogram_best_of_request = self.histogram_best_of_request_family.Metric(
167+
labels=labels,
168+
buckets=[1, 2, 5, 10, 20],
169+
)
166170
self.histogram_n_request = self.histogram_n_request_family.Metric(
167171
labels=labels,
168172
buckets=[1, 2, 5, 10, 20],
@@ -247,10 +251,10 @@ def log(self, stats: VllmStats) -> None:
247251
self.metrics.histogram_num_generation_tokens_request,
248252
stats.num_generation_tokens_requests,
249253
),
250-
(self.metrics.histogram_best_of_request, stats.best_of_requests),
251254
(self.metrics.histogram_n_request, stats.n_requests),
252255
]
253-
256+
if _VLLM_VERSION < "0.6.3":
257+
histogram_metrics.append((self.metrics.histogram_best_of_request, stats.best_of_requests))
254258
for metric, data in counter_metrics:
255259
self._log_counter(metric, data)
256260
for metric, data in histogram_metrics:

0 commit comments

Comments
 (0)