Skip to content

Commit 96cb28e

Browse files
committed
Add input and output tokens to response
1 parent 861a198 commit 96cb28e

File tree

1 file changed

+23
-3
lines changed

1 file changed

+23
-3
lines changed

src/model.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,11 @@ def auto_complete_config(auto_complete_model_config):
6666
"optional": True,
6767
},
6868
]
69-
outputs = [{"name": "text_output", "data_type": "TYPE_STRING", "dims": [-1]}]
69+
outputs = [
70+
{"name": "text_output", "data_type": "TYPE_STRING", "dims": [-1]},
71+
{"name": "input_tokens", "data_type": "TYPE_INT32", "dims": [-1]},
72+
{"name": "output_tokens", "data_type": "TYPE_INT32", "dims": [-1]},
73+
]
7074

7175
# Store the model configuration as a dictionary.
7276
config = auto_complete_model_config.as_dict()
@@ -151,6 +155,15 @@ def initialize(self, args):
151155
)
152156
self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"])
153157

158+
output_tokens_config = pb_utils.get_output_config_by_name(
159+
self.model_config, "output_tokens"
160+
)
161+
self.output_tokens_dtype = pb_utils.triton_string_to_numpy(output_tokens_config["data_type"])
162+
input_tokens_config = pb_utils.get_output_config_by_name(
163+
self.model_config, "input_tokens"
164+
)
165+
self.input_tokens_dtype = pb_utils.triton_string_to_numpy(input_tokens_config["data_type"])
166+
154167
# Counter to keep track of ongoing request counts
155168
self.ongoing_request_count = 0
156169

@@ -246,10 +259,17 @@ def create_response(self, vllm_output, prepend_input):
246259
text_outputs = [
247260
(prompt + output.text).encode("utf-8") for output in vllm_output.outputs
248261
]
262+
output_tokens = sum([len(output.token_ids) for output in vllm_output.outputs])
249263
triton_output_tensor = pb_utils.Tensor(
250-
"text_output", np.asarray(text_outputs, dtype=self.output_dtype)
264+
"text_output", np.asarray(text_outputs, dtype=self.output_dtype),
251265
)
252-
return pb_utils.InferenceResponse(output_tensors=[triton_output_tensor])
266+
triton_tokens_tensor = pb_utils.Tensor(
267+
"output_tokens", np.asarray(output_tokens, dtype=self.output_tokens_dtype),
268+
)
269+
triton_input_tokens_tensor = pb_utils.Tensor(
270+
"input_tokens", np.asarray(len(vllm_output.prompt_token_ids), dtype=self.input_tokens_dtype),
271+
)
272+
return pb_utils.InferenceResponse(output_tensors=[triton_output_tensor, triton_tokens_tensor, triton_input_tokens_tensor])
253273

254274
def create_stream_response(self, vllm_output, previous_outputs_lengths):
255275
"""

0 commit comments

Comments
 (0)