Skip to content

Commit 5e605ca

Browse files
committed
Add test for additional outputs
* Add additional outputs test * Update copyright * Some test enhancement and notes
1 parent 58ee481 commit 5e605ca

File tree

3 files changed

+267
-2
lines changed

3 files changed

+267
-2
lines changed
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
#
3+
# Redistribution and use in source and binary forms, with or without
4+
# modification, are permitted provided that the following conditions
5+
# are met:
6+
# * Redistributions of source code must retain the above copyright
7+
# notice, this list of conditions and the following disclaimer.
8+
# * Redistributions in binary form must reproduce the above copyright
9+
# notice, this list of conditions and the following disclaimer in the
10+
# documentation and/or other materials provided with the distribution.
11+
# * Neither the name of NVIDIA CORPORATION nor the names of its
12+
# contributors may be used to endorse or promote products derived
13+
# from this software without specific prior written permission.
14+
#
15+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
16+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
18+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
19+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
20+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
21+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
22+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
23+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
24+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
25+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26+
27+
import json
28+
import unittest
29+
30+
import numpy as np
31+
import tritonclient.grpc as grpcclient
32+
33+
34+
class InferTest(unittest.TestCase):
35+
_grpc_url = "localhost:8001"
36+
_model_name = "vllm_opt"
37+
_sampling_parameters = {"temperature": "0", "top_p": "1"}
38+
_prompt = "In this example,"
39+
40+
def _get_inputs(
41+
self,
42+
prompt,
43+
stream=True,
44+
sampling_parameters=None,
45+
output_finish_reason=None,
46+
output_cumulative_logprob=None,
47+
output_num_token_ids=None,
48+
):
49+
inputs = []
50+
51+
inputs.append(grpcclient.InferInput("text_input", [1], "BYTES"))
52+
inputs[-1].set_data_from_numpy(
53+
np.array([prompt.encode("utf-8")], dtype=np.object_)
54+
)
55+
56+
inputs.append(grpcclient.InferInput("stream", [1], "BOOL"))
57+
inputs[-1].set_data_from_numpy(np.array([stream], dtype=bool))
58+
59+
if sampling_parameters is not None:
60+
inputs.append(grpcclient.InferInput("sampling_parameters", [1], "BYTES"))
61+
inputs[-1].set_data_from_numpy(
62+
np.array(
63+
[json.dumps(sampling_parameters).encode("utf-8")], dtype=np.object_
64+
)
65+
)
66+
67+
if output_finish_reason is not None:
68+
inputs.append(grpcclient.InferInput("output_finish_reason", [1], "BOOL"))
69+
inputs[-1].set_data_from_numpy(np.array([output_finish_reason], dtype=bool))
70+
71+
if output_cumulative_logprob is not None:
72+
inputs.append(
73+
grpcclient.InferInput("output_cumulative_logprob", [1], "BOOL")
74+
)
75+
inputs[-1].set_data_from_numpy(
76+
np.array([output_cumulative_logprob], dtype=bool)
77+
)
78+
79+
if output_num_token_ids is not None:
80+
inputs.append(grpcclient.InferInput("output_num_token_ids", [1], "BOOL"))
81+
inputs[-1].set_data_from_numpy(np.array([output_num_token_ids], dtype=bool))
82+
83+
return inputs
84+
85+
def _callback(self, result, error):
86+
self._responses.append({"result": result, "error": error})
87+
88+
def _llm_infer(self, inputs):
89+
self._responses = []
90+
with grpcclient.InferenceServerClient(self._grpc_url) as client:
91+
client.start_stream(self._callback)
92+
client.async_stream_infer(
93+
self._model_name, inputs=inputs, parameters=self._sampling_parameters
94+
)
95+
client.stop_stream()
96+
self.assertGreater(len(self._responses), 0)
97+
98+
def _assert_text_output_valid(self):
99+
text_output = ""
100+
for response in self._responses:
101+
result, error = response["result"], response["error"]
102+
self.assertIsNone(error)
103+
text_output += result.as_numpy(name="text_output")[0].decode("utf-8")
104+
self.assertGreater(len(text_output), 0, "output is empty")
105+
self.assertGreater(text_output.count(" "), 4, "output is not a sentence")
106+
107+
def _assert_finish_reason(self, output_finish_reason):
108+
for i in range(len(self._responses)):
109+
result, error = self._responses[i]["result"], self._responses[i]["error"]
110+
self.assertIsNone(error)
111+
finish_reason_np = result.as_numpy(name="finish_reason")
112+
if output_finish_reason is None or output_finish_reason == False:
113+
self.assertIsNone(finish_reason_np)
114+
continue
115+
finish_reason = finish_reason_np[0].decode("utf-8")
116+
if i < len(self._responses) - 1:
117+
self.assertEqual(finish_reason, "None")
118+
else:
119+
self.assertEqual(finish_reason, "length")
120+
121+
def _assert_cumulative_logprob(self, output_cumulative_logprob):
122+
prev_cumulative_logprob = 0.0
123+
for response in self._responses:
124+
result, error = response["result"], response["error"]
125+
self.assertIsNone(error)
126+
cumulative_logprob_np = result.as_numpy(name="cumulative_logprob")
127+
if output_cumulative_logprob is None or output_cumulative_logprob == False:
128+
self.assertIsNone(cumulative_logprob_np)
129+
continue
130+
cumulative_logprob = cumulative_logprob_np[0].astype(float)
131+
self.assertNotEqual(cumulative_logprob, prev_cumulative_logprob)
132+
prev_cumulative_logprob = cumulative_logprob
133+
134+
def _assert_num_token_ids(self, output_num_token_ids):
135+
for response in self._responses:
136+
result, error = response["result"], response["error"]
137+
self.assertIsNone(error)
138+
num_token_ids_np = result.as_numpy(name="num_token_ids")
139+
if output_num_token_ids is None or output_num_token_ids == False:
140+
self.assertIsNone(num_token_ids_np)
141+
continue
142+
num_token_ids = num_token_ids_np[0].astype(int)
143+
# TODO: vLLM may return token ids identical to the previous one when
144+
# streaming, for example:
145+
#
146+
# prev: None
147+
# curr: text=' the', token_ids=array('l', [5])
148+
#
149+
# prev: text=' the', token_ids=array('l', [5, 1385])
150+
# curr: text=' the term', token_ids=array('l', [5, 1385])
151+
#
152+
# prev: text=' the term', token_ids=array('l', [5, 1385, 44])
153+
# curr: text=' the term', token_ids=array('l', [5, 1385, 44])
154+
#
155+
# prev: text=' the term', token_ids=array('l', [5, 1385, 44, 48])
156+
# curr: text=' the term “', token_ids=array('l', [5, 1385, 44, 48])
157+
#
158+
# If this is no longer the case in a future release, change the assert
159+
# to assertGreater().
160+
self.assertGreaterEqual(num_token_ids, 0)
161+
162+
def _assert_additional_outputs_valid(
163+
self,
164+
stream,
165+
output_finish_reason,
166+
output_cumulative_logprob,
167+
output_num_token_ids,
168+
):
169+
inputs = self._get_inputs(
170+
self._prompt,
171+
stream=stream,
172+
sampling_parameters=self._sampling_parameters,
173+
output_finish_reason=output_finish_reason,
174+
output_cumulative_logprob=output_cumulative_logprob,
175+
output_num_token_ids=output_num_token_ids,
176+
)
177+
self._llm_infer(inputs)
178+
self._assert_text_output_valid()
179+
self._assert_finish_reason(output_finish_reason)
180+
self._assert_cumulative_logprob(output_cumulative_logprob)
181+
self._assert_num_token_ids(output_num_token_ids)
182+
183+
def test_additional_outputs(self):
184+
for stream in [True, False]:
185+
choices = [None, False, True]
186+
for output_finish_reason in choices:
187+
for output_cumulative_logprob in choices:
188+
for output_num_token_ids in choices:
189+
self._assert_additional_outputs_valid(
190+
stream,
191+
output_finish_reason,
192+
output_cumulative_logprob,
193+
output_num_token_ids,
194+
)
195+
196+
197+
if __name__ == "__main__":
198+
unittest.main()

ci/L0_vllm_additional_outputs/test.sh

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
#!/bin/bash
2+
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions
6+
# are met:
7+
# * Redistributions of source code must retain the above copyright
8+
# notice, this list of conditions and the following disclaimer.
9+
# * Redistributions in binary form must reproduce the above copyright
10+
# notice, this list of conditions and the following disclaimer in the
11+
# documentation and/or other materials provided with the distribution.
12+
# * Neither the name of NVIDIA CORPORATION nor the names of its
13+
# contributors may be used to endorse or promote products derived
14+
# from this software without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
17+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
19+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
20+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
24+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
26+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27+
28+
export CUDA_VISIBLE_DEVICES=0
29+
source ../common/util.sh
30+
31+
pip3 install tritonclient[grpc]
32+
33+
# Prepare Model
34+
rm -rf models vllm_baseline_output.pkl && mkdir -p models
35+
SAMPLE_MODELS_REPO="../../samples/model_repository"
36+
cp -r $SAMPLE_MODELS_REPO/vllm_model models/vllm_opt
37+
sed -i 's/"gpu_memory_utilization": 0.5/"gpu_memory_utilization": 0.3/' models/vllm_opt/1/model.json
38+
39+
RET=0
40+
41+
# Infer Test
42+
CLIENT_LOG="vllm_opt.log"
43+
SERVER_LOG="vllm_opt.server.log"
44+
SERVER_ARGS="--model-repository=models"
45+
run_server
46+
if [ "$SERVER_PID" == "0" ]; then
47+
echo -e "\n***\n*** Failed to start $SERVER\n***"
48+
cat $SERVER_LOG
49+
exit 1
50+
fi
51+
set +e
52+
python3 additional_outputs_test.py > $CLIENT_LOG 2>&1
53+
if [ $? -ne 0 ]; then
54+
cat $CLIENT_LOG
55+
echo -e "\n***\n*** additional_outputs_test FAILED. \n***"
56+
RET=1
57+
fi
58+
set -e
59+
kill $SERVER_PID
60+
wait $SERVER_PID
61+
62+
if [ $RET -eq 0 ]; then
63+
echo -e "\n***\n*** Test Passed\n***"
64+
else
65+
echo -e "\n***\n*** Test FAILED\n***"
66+
fi
67+
exit $RET

ci/common/util.sh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/bin/bash
2-
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# Copyright 2023-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
#
44
# Redistribution and use in source and binary forms, with or without
55
# modification, are permitted provided that the following conditions
@@ -25,7 +25,7 @@
2525
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
2626
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
2727

28-
28+
SERVER=${SERVER:=/opt/tritonserver/bin/tritonserver}
2929
SERVER_IPADDR=${TRITONSERVER_IPADDR:=localhost}
3030
SERVER_LOG=${SERVER_LOG:=./server.log}
3131
SERVER_TIMEOUT=${SERVER_TIMEOUT:=120}

0 commit comments

Comments
 (0)