@@ -66,7 +66,11 @@ def auto_complete_config(auto_complete_model_config):
66
66
"optional" : True ,
67
67
},
68
68
]
69
- outputs = [{"name" : "text_output" , "data_type" : "TYPE_STRING" , "dims" : [- 1 ]}]
69
+ outputs = [
70
+ {"name" : "text_output" , "data_type" : "TYPE_STRING" , "dims" : [- 1 ]},
71
+ {"name" : "input_tokens" , "data_type" : "TYPE_INT32" , "dims" : [- 1 ]},
72
+ {"name" : "output_tokens" , "data_type" : "TYPE_INT32" , "dims" : [- 1 ]},
73
+ ]
70
74
71
75
# Store the model configuration as a dictionary.
72
76
config = auto_complete_model_config .as_dict ()
@@ -151,6 +155,15 @@ def initialize(self, args):
151
155
)
152
156
self .output_dtype = pb_utils .triton_string_to_numpy (output_config ["data_type" ])
153
157
158
+ output_tokens_config = pb_utils .get_output_config_by_name (
159
+ self .model_config , "output_tokens"
160
+ )
161
+ self .output_tokens_dtype = pb_utils .triton_string_to_numpy (output_tokens_config ["data_type" ])
162
+ input_tokens_config = pb_utils .get_output_config_by_name (
163
+ self .model_config , "input_tokens"
164
+ )
165
+ self .input_tokens_dtype = pb_utils .triton_string_to_numpy (input_tokens_config ["data_type" ])
166
+
154
167
# Counter to keep track of ongoing request counts
155
168
self .ongoing_request_count = 0
156
169
@@ -246,10 +259,17 @@ def create_response(self, vllm_output, prepend_input):
246
259
text_outputs = [
247
260
(prompt + output .text ).encode ("utf-8" ) for output in vllm_output .outputs
248
261
]
262
+ output_tokens = sum ([len (output .token_ids ) for output in vllm_output .outputs ])
249
263
triton_output_tensor = pb_utils .Tensor (
250
- "text_output" , np .asarray (text_outputs , dtype = self .output_dtype )
264
+ "text_output" , np .asarray (text_outputs , dtype = self .output_dtype ),
251
265
)
252
- return pb_utils .InferenceResponse (output_tensors = [triton_output_tensor ])
266
+ triton_tokens_tensor = pb_utils .Tensor (
267
+ "output_tokens" , np .asarray (output_tokens , dtype = self .output_tokens_dtype ),
268
+ )
269
+ triton_input_tokens_tensor = pb_utils .Tensor (
270
+ "input_tokens" , np .asarray (len (vllm_output .prompt_token_ids ), dtype = self .input_tokens_dtype ),
271
+ )
272
+ return pb_utils .InferenceResponse (output_tensors = [triton_output_tensor , triton_tokens_tensor , triton_input_tokens_tensor ])
253
273
254
274
def create_stream_response (self , vllm_output , previous_outputs_lengths ):
255
275
"""
0 commit comments