Skip to content

Commit ceb5961

Browse files
authored
feat: Support sending additional outputs from vLLM inference (#70)
1 parent 6c066f6 commit ceb5961

File tree

13 files changed

+610
-135
lines changed

13 files changed

+610
-135
lines changed

README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,11 @@ you need to specify a different `shm-region-prefix-name` for each server. See
203203
[here](https://github.com/triton-inference-server/python_backend#running-multiple-instances-of-triton-server)
204204
for more information.
205205

206+
## Additional vLLM outputs
207+
208+
Additional vLLM outputs may be requested optionally on a per-request basis. See
209+
[this docs](docs/additional_outputs.md) for more information.
210+
206211
## Triton Metrics
207212
Starting with the 24.08 release of Triton, users can now obtain specific
208213
vLLM metrics by querying the Triton metrics endpoint (see complete vLLM metrics
Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Redistribution and use in source and binary forms, with or without
4+
# modification, are permitted provided that the following conditions
5+
# are met:
6+
# * Redistributions of source code must retain the above copyright
7+
# notice, this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of NVIDIA CORPORATION nor the names of its
12+
# contributors may be used to endorse or promote products derived
13+
# from this software without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
import json
28+
29+
import numpy as np
30+
import pytest
31+
import tritonclient.grpc as grpcclient
32+
33+
34+
class TestAdditionalOutputs:
35+
_grpc_url = "localhost:8001"
36+
_model_name = "vllm_opt"
37+
_sampling_parameters = {"temperature": "0", "top_p": "1"}
38+
_prompt = "In this example,"
39+
40+
def _get_inputs(
41+
self,
42+
prompt,
43+
stream=True,
44+
sampling_parameters=None,
45+
return_finish_reason=None,
46+
return_cumulative_logprob=None,
47+
return_num_output_tokens=None,
48+
):
49+
inputs = []
50+
51+
inputs.append(grpcclient.InferInput("text_input", [1], "BYTES"))
52+
inputs[-1].set_data_from_numpy(
53+
np.array([prompt.encode("utf-8")], dtype=np.object_)
54+
)
55+
56+
inputs.append(grpcclient.InferInput("stream", [1], "BOOL"))
57+
inputs[-1].set_data_from_numpy(np.array([stream], dtype=bool))
58+
59+
if sampling_parameters is not None:
60+
inputs.append(grpcclient.InferInput("sampling_parameters", [1], "BYTES"))
61+
inputs[-1].set_data_from_numpy(
62+
np.array(
63+
[json.dumps(sampling_parameters).encode("utf-8")], dtype=np.object_
64+
)
65+
)
66+
67+
if return_finish_reason is not None:
68+
inputs.append(grpcclient.InferInput("return_finish_reason", [1], "BOOL"))
69+
inputs[-1].set_data_from_numpy(np.array([return_finish_reason], dtype=bool))
70+
71+
if return_cumulative_logprob is not None:
72+
inputs.append(
73+
grpcclient.InferInput("return_cumulative_logprob", [1], "BOOL")
74+
)
75+
inputs[-1].set_data_from_numpy(
76+
np.array([return_cumulative_logprob], dtype=bool)
77+
)
78+
79+
if return_num_output_tokens is not None:
80+
inputs.append(
81+
grpcclient.InferInput("return_num_output_tokens", [1], "BOOL")
82+
)
83+
inputs[-1].set_data_from_numpy(
84+
np.array([return_num_output_tokens], dtype=bool)
85+
)
86+
87+
return inputs
88+
89+
def _callback(self, result, error):
90+
self._responses.append({"result": result, "error": error})
91+
92+
def _llm_infer(self, inputs):
93+
self._responses = []
94+
with grpcclient.InferenceServerClient(self._grpc_url) as client:
95+
client.start_stream(self._callback)
96+
client.async_stream_infer(
97+
self._model_name, inputs=inputs, parameters=self._sampling_parameters
98+
)
99+
client.stop_stream()
100+
assert len(self._responses) > 0
101+
102+
def _assert_text_output_valid(self):
103+
text_output = ""
104+
for response in self._responses:
105+
result, error = response["result"], response["error"]
106+
assert error is None
107+
text_output += result.as_numpy(name="text_output")[0].decode("utf-8")
108+
assert len(text_output) > 0, "output is empty"
109+
assert text_output.count(" ") > 4, "output is not a sentence"
110+
111+
def _assert_finish_reason(self, return_finish_reason):
112+
for i in range(len(self._responses)):
113+
result, error = self._responses[i]["result"], self._responses[i]["error"]
114+
assert error is None
115+
finish_reason_np = result.as_numpy(name="finish_reason")
116+
if return_finish_reason is None or return_finish_reason == False:
117+
assert finish_reason_np is None
118+
continue
119+
finish_reason = finish_reason_np[0].decode("utf-8")
120+
if i < len(self._responses) - 1:
121+
assert finish_reason == "None"
122+
else:
123+
assert finish_reason == "length"
124+
125+
def _assert_cumulative_logprob(self, return_cumulative_logprob):
126+
prev_cumulative_logprob = 0.0
127+
for response in self._responses:
128+
result, error = response["result"], response["error"]
129+
assert error is None
130+
cumulative_logprob_np = result.as_numpy(name="cumulative_logprob")
131+
if return_cumulative_logprob is None or return_cumulative_logprob == False:
132+
assert cumulative_logprob_np is None
133+
continue
134+
cumulative_logprob = cumulative_logprob_np[0].astype(float)
135+
assert cumulative_logprob != prev_cumulative_logprob
136+
prev_cumulative_logprob = cumulative_logprob
137+
138+
def _assert_num_output_tokens(self, return_num_output_tokens):
139+
for response in self._responses:
140+
result, error = response["result"], response["error"]
141+
assert error is None
142+
num_output_tokens_np = result.as_numpy(name="num_output_tokens")
143+
if return_num_output_tokens is None or return_num_output_tokens == False:
144+
assert num_output_tokens_np is None
145+
continue
146+
num_output_tokens = num_output_tokens_np[0].astype(int)
147+
# TODO: vLLM may return token ids identical to the previous one when
148+
# streaming, for example:
149+
#
150+
# prev: None
151+
# curr: text=' the', token_ids=array('l', [5])
152+
#
153+
# prev: text=' the', token_ids=array('l', [5, 1385])
154+
# curr: text=' the term', token_ids=array('l', [5, 1385])
155+
#
156+
# prev: text=' the term', token_ids=array('l', [5, 1385, 44])
157+
# curr: text=' the term', token_ids=array('l', [5, 1385, 44])
158+
#
159+
# prev: text=' the term', token_ids=array('l', [5, 1385, 44, 48])
160+
# curr: text=' the term “', token_ids=array('l', [5, 1385, 44, 48])
161+
#
162+
# If this is no longer the case in a future release, change the assert
163+
# to assert num_output_tokens > 0.
164+
assert num_output_tokens >= 0
165+
166+
@pytest.mark.parametrize("stream", [True, False])
167+
@pytest.mark.parametrize("return_finish_reason", [None, True, False])
168+
@pytest.mark.parametrize("return_cumulative_logprob", [None, True, False])
169+
@pytest.mark.parametrize("return_num_output_tokens", [None, True, False])
170+
def test_additional_outputs(
171+
self,
172+
stream,
173+
return_finish_reason,
174+
return_cumulative_logprob,
175+
return_num_output_tokens,
176+
):
177+
inputs = self._get_inputs(
178+
self._prompt,
179+
stream=stream,
180+
sampling_parameters=self._sampling_parameters,
181+
return_finish_reason=return_finish_reason,
182+
return_cumulative_logprob=return_cumulative_logprob,
183+
return_num_output_tokens=return_num_output_tokens,
184+
)
185+
self._llm_infer(inputs)
186+
self._assert_text_output_valid()
187+
self._assert_finish_reason(return_finish_reason)
188+
self._assert_cumulative_logprob(return_cumulative_logprob)
189+
self._assert_num_output_tokens(return_num_output_tokens)

ci/L0_additional_outputs_vllm/test.sh

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#!/bin/bash
2+
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions
6+
# are met:
7+
# * Redistributions of source code must retain the above copyright
8+
# notice, this list of conditions and the following disclaimer.
9+
# * Redistributions in binary form must reproduce the above copyright
10+
# notice, this list of conditions and the following disclaimer in the
11+
# documentation and/or other materials provided with the distribution.
12+
# * Neither the name of NVIDIA CORPORATION nor the names of its
13+
# contributors may be used to endorse or promote products derived
14+
# from this software without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
17+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
19+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
20+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
24+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
26+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27+
28+
export CUDA_VISIBLE_DEVICES=0
29+
source ../common/util.sh
30+
31+
pip3 install pytest==8.1.1
32+
pip3 install tritonclient[grpc]
33+
34+
# Prepare Model
35+
rm -rf models vllm_baseline_output.pkl && mkdir -p models
36+
SAMPLE_MODELS_REPO="../../samples/model_repository"
37+
cp -r $SAMPLE_MODELS_REPO/vllm_model models/vllm_opt
38+
sed -i 's/"gpu_memory_utilization": 0.5/"gpu_memory_utilization": 0.3/' models/vllm_opt/1/model.json
39+
40+
RET=0
41+
42+
# Test
43+
SERVER_LOG="vllm_opt.server.log"
44+
SERVER_ARGS="--model-repository=models"
45+
run_server
46+
if [ "$SERVER_PID" == "0" ]; then
47+
echo -e "\n***\n*** Failed to start $SERVER\n***"
48+
cat $SERVER_LOG
49+
exit 1
50+
fi
51+
set +e
52+
python3 -m pytest --junitxml=test_additional_outputs.xml -s -v additional_outputs_test.py
53+
if [ $? -ne 0 ]; then
54+
echo -e "\n***\n*** additional_outputs_test FAILED. \n***"
55+
RET=1
56+
fi
57+
set -e
58+
kill $SERVER_PID
59+
wait $SERVER_PID
60+
61+
if [ $RET -eq 0 ]; then
62+
echo -e "\n***\n*** Test Passed\n***"
63+
else
64+
echo -e "\n***\n*** Test FAILED\n***"
65+
fi
66+
exit $RET
File renamed without changes.

ci/common/util.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/bin/bash
2-
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
#
44
# Redistribution and use in source and binary forms, with or without
55
# modification, are permitted provided that the following conditions
@@ -25,7 +25,7 @@
2525
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2626
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2727

28-
28+
SERVER=${SERVER:=/opt/tritonserver/bin/tritonserver}
2929
SERVER_IPADDR=${TRITONSERVER_IPADDR:=localhost}
3030
SERVER_LOG=${SERVER_LOG:=./server.log}
3131
SERVER_TIMEOUT=${SERVER_TIMEOUT:=120}

docs/additional_outputs.md

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
<!--
2+
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions
6+
# are met:
7+
# * Redistributions of source code must retain the above copyright
8+
# notice, this list of conditions and the following disclaimer.
9+
# * Redistributions in binary form must reproduce the above copyright
10+
# notice, this list of conditions and the following disclaimer in the
11+
# documentation and/or other materials provided with the distribution.
12+
# * Neither the name of NVIDIA CORPORATION nor the names of its
13+
# contributors may be used to endorse or promote products derived
14+
# from this software without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
17+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
19+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
20+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
24+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
26+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27+
-->
28+
29+
# Additional Outputs from vLLM
30+
31+
The vLLM backend supports sending additional outputs from vLLM on top of the
32+
usual `text_output` when requested.
33+
34+
All additional outputs are disabled by default and they need to be enabled on a
35+
per-request basis. If enabled, the corresponding output tensor will be set for
36+
all responses from the request.
37+
38+
## Supported Additional Outputs
39+
40+
### Finish Reason
41+
42+
The reason why the sequence is finished. See
43+
[here](https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/outputs.py#L26)
44+
for more details.
45+
46+
To enable, set `return_finish_reason` input tensor to `True`. The reason will be
47+
sent as a string on the `finish_reason` output tensor.
48+
49+
Supported since r24.12.
50+
51+
### Cumulative Log Probabilities
52+
53+
The cumulative log probability of the generated output text. See
54+
[here](https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/outputs.py#L22)
55+
for more details.
56+
57+
To enable, set `return_cumulative_logprob` input tensor to `True`. The floating
58+
point value will be sent on the `cumulative_logprob` output tensor.
59+
60+
Supported since r24.12.
61+
62+
### Number of Output Tokens
63+
64+
The number of token IDs of the generated output text sent on this response. It
65+
is the difference in length of the token IDs generated from the last response to
66+
this response. If this is the first response, the last response length is
67+
presumed to be zero. See
68+
[here](https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/outputs.py#L21)
69+
for more details on the token IDs of the generated output text.
70+
71+
To enable, set `return_num_output_tokens` input tensor to `True`. The unsigned
72+
integer value will be sent on the `num_output_tokens` output tensor.
73+
74+
Supported since r24.12.
75+
76+
## Examples
77+
78+
### Add Finish Reason to Outputs
79+
80+
```python
81+
import numpy as np
82+
import tritonclient.grpc as grpcclient
83+
84+
inputs = []
85+
86+
inputs.append(grpcclient.InferInput("text_input", [1], "BYTES"))
87+
inputs[-1].set_data_from_numpy(
88+
np.array(["example prompt".encode("utf-8")], dtype=np.object_)
89+
)
90+
91+
inputs.append(grpcclient.InferInput("return_finish_reason", [1], "BOOL"))
92+
inputs[-1].set_data_from_numpy(np.array([True], dtype=bool))
93+
94+
def callback(result, error):
95+
...
96+
print(result.as_numpy(name="finish_reason"))
97+
98+
with grpcclient.InferenceServerClient("localhost:8001") as client:
99+
client.start_stream(callback)
100+
client.async_stream_infer("vLLM_model_name", inputs=inputs, ...)
101+
client.stop_stream()
102+
```
103+
104+
## Notes
105+
106+
* Enabling additional outputs may impact performance, only add additional
107+
outputs when necessary.

0 commit comments

Comments
 (0)