@@ -66,7 +66,7 @@ void WorkerService::set_worker(std::unique_ptr<Worker> worker) {
6666 initialized_ = true ;
6767}
6868
69- void WorkerService::step (BatchedForwardInputs& batched_fwd_inputs ,
69+ void WorkerService::step (ForwardInput& fwd_input ,
7070 torch::Tensor& next_tokens,
7171 torch::Tensor& logprobs,
7272 torch::Tensor& top_tokens,
@@ -78,7 +78,7 @@ void WorkerService::step(BatchedForwardInputs& batched_fwd_inputs,
7878 torch::Tensor& out_tokens,
7979 torch::Tensor& out_logprobs) {
8080 // execute model
81- auto future = worker_->step_async (batched_fwd_inputs );
81+ auto future = worker_->step_async (fwd_input );
8282
8383 if (!options_.enable_schedule_overlap ()) {
8484 auto forward_outputs = std::move (future).get ();
@@ -142,10 +142,10 @@ void WorkerService::step(BatchedForwardInputs& batched_fwd_inputs,
142142 torch::TensorOptions ().dtype (torch::kInt32 ).device (torch::kCPU );
143143 auto total_prefill_seq_len = 0 ;
144144 auto total_num_sequences = 0 ;
145- for ( auto & input : batched_fwd_inputs. micro_inputs ) {
146- total_num_sequences += input .input_params .num_sequences ;
147- total_prefill_seq_len += input .input_params .prefill_seq_len ;
148- }
145+
146+ total_num_sequences += fwd_input .input_params .num_sequences ;
147+ total_prefill_seq_len += fwd_input .input_params .prefill_seq_len ;
148+
149149 next_tokens =
150150 torch::arange (-1 ,
151151 -1 * (total_num_sequences - total_prefill_seq_len + 1 ),
@@ -166,7 +166,7 @@ void WorkerService::create_polling_shm_thread(
166166 output_shm_manager = std::move (output_shm_manager)]() mutable {
167167 Timer timer;
168168 while (true ) {
169- BatchedForwardInputs batched_fwd_inputs ;
169+ ForwardInput fwd_input ;
170170 std::vector<ForwardInput> inputs;
171171 input_shm_manager->raw_input_read (inputs);
172172 timer.reset ();
@@ -184,31 +184,9 @@ void WorkerService::create_polling_shm_thread(
184184 torch::Tensor out_tokens;
185185 torch::Tensor out_logprobs;
186186
187- auto micro_batches_num = inputs.size ();
188- batched_fwd_inputs.micro_inputs = std::move (inputs);
189- batched_fwd_inputs.concated_sampling_params =
190- batched_fwd_inputs.micro_inputs [0 ].sampling_params ;
191- for (auto i = 1 ; i < micro_batches_num; ++i) {
192- batched_fwd_inputs.concated_sampling_params .concat (
193- batched_fwd_inputs.micro_inputs [i].sampling_params );
194- }
195-
196- // concat acc_logprob here for beam search together
197- if (micro_batches_num > 1 ) {
198- std::vector<torch::Tensor> acc_logprob_vec;
199- acc_logprob_vec.reserve (micro_batches_num);
200- for (auto i = 0 ; i < micro_batches_num; ++i) {
201- acc_logprob_vec.push_back (
202- batched_fwd_inputs.micro_inputs [i].acc_logprob );
203- }
204- batched_fwd_inputs.acc_logprob =
205- torch::cat (acc_logprob_vec, /* dim=*/ -1 );
206- } else {
207- batched_fwd_inputs.acc_logprob =
208- batched_fwd_inputs.micro_inputs [0 ].acc_logprob ;
209- }
187+ fwd_input = std::move (inputs[0 ]);
210188
211- step (batched_fwd_inputs ,
189+ step (fwd_input ,
212190 next_tokens,
213191 logprobs,
214192 top_tokens,
@@ -592,90 +570,58 @@ void WorkerService::UnlinkCluster(::google::protobuf::RpcController* controller,
592570 return ;
593571}
594572
595- void WorkerService::ExecuteModel (
596- ::google::protobuf::RpcController* controller,
597- const proto::BatchedForwardInputs* pb_batched_fwd_inputs,
598- proto::ForwardOutput* pb_forward_output,
599- ::google::protobuf::Closure* done) {
600- threadpool_->schedule ([this ,
601- controller,
602- pb_batched_fwd_inputs,
603- pb_forward_output,
604- done]() mutable {
605- brpc::ClosureGuard done_guard (done);
606- Timer timer;
607- // convert proto::BatchedForwardInputs to BatchedForwardInputs
608- auto micro_batches_num = pb_batched_fwd_inputs->micro_inputs ().size ();
609- BatchedForwardInputs batched_fwd_inputs;
610- batched_fwd_inputs.micro_inputs .reserve (micro_batches_num);
611- for (auto i = 0 ; i < micro_batches_num; ++i) {
612- ForwardInput forward_input;
613- proto_to_forward_input (&(pb_batched_fwd_inputs->micro_inputs ()[i]),
614- forward_input,
615- options_.num_decoding_tokens ());
616- batched_fwd_inputs.micro_inputs .push_back (std::move (forward_input));
617- }
618-
619- // concat sampling parameters
620- batched_fwd_inputs.concated_sampling_params =
621- batched_fwd_inputs.micro_inputs [0 ].sampling_params ;
622- for (auto i = 1 ; i < micro_batches_num; ++i) {
623- batched_fwd_inputs.concated_sampling_params .concat (
624- batched_fwd_inputs.micro_inputs [i].sampling_params );
625- }
626-
627- // concat acc_logprob here for beam search together
628- if (micro_batches_num > 1 ) {
629- std::vector<torch::Tensor> acc_logprob_vec;
630- acc_logprob_vec.reserve (micro_batches_num);
631- for (auto i = 0 ; i < micro_batches_num; ++i) {
632- acc_logprob_vec.push_back (
633- batched_fwd_inputs.micro_inputs [i].acc_logprob );
634- }
635- batched_fwd_inputs.acc_logprob = torch::cat (acc_logprob_vec, /* dim=*/ -1 );
636- } else {
637- batched_fwd_inputs.acc_logprob =
638- batched_fwd_inputs.micro_inputs [0 ].acc_logprob ;
639- }
573+ void WorkerService::ExecuteModel (::google::protobuf::RpcController* controller,
574+ const proto::ForwardInput* pb_forward_input,
575+ proto::ForwardOutput* pb_forward_output,
576+ ::google::protobuf::Closure* done) {
577+ threadpool_->schedule (
578+ [this , controller, pb_forward_input, pb_forward_output, done]() mutable {
579+ brpc::ClosureGuard done_guard (done);
580+ // convert proto::ForwardInput to ForwardInput
640581
641- // model output
642- torch::Tensor next_tokens;
643- torch::Tensor logprobs;
644- torch::Tensor top_tokens;
645- torch::Tensor top_logprobs;
646- torch::Tensor embeddings;
647- torch::Tensor expert_load_data;
648- int32_t prepared_layer_id = -1 ;
649- // beam search kernel output
650- torch::Tensor src_seq_idxes;
651- torch::Tensor out_tokens;
652- torch::Tensor out_logprobs;
653-
654- step (batched_fwd_inputs,
655- next_tokens,
656- logprobs,
657- top_tokens,
658- top_logprobs,
659- embeddings,
660- expert_load_data,
661- prepared_layer_id,
662- src_seq_idxes,
663- out_tokens,
664- out_logprobs);
665- // convert to proto output
666- forward_output_to_proto (next_tokens,
667- logprobs,
668- top_tokens,
669- top_logprobs,
670- embeddings,
671- expert_load_data,
672- prepared_layer_id,
673- src_seq_idxes,
674- out_tokens,
675- out_logprobs,
676- pb_forward_output);
677- COUNTER_ADD (worker_service_latency_seconds, timer.elapsed_seconds ());
678- });
582+ Timer timer;
583+ ForwardInput forward_input;
584+ proto_to_forward_input (
585+ pb_forward_input, forward_input, options_.num_decoding_tokens ());
586+
587+ // model output
588+ torch::Tensor next_tokens;
589+ torch::Tensor logprobs;
590+ torch::Tensor top_tokens;
591+ torch::Tensor top_logprobs;
592+ torch::Tensor embeddings;
593+ torch::Tensor expert_load_data;
594+ int32_t prepared_layer_id = -1 ;
595+ // beam search kernel output
596+ torch::Tensor src_seq_idxes;
597+ torch::Tensor out_tokens;
598+ torch::Tensor out_logprobs;
599+
600+ step (forward_input,
601+ next_tokens,
602+ logprobs,
603+ top_tokens,
604+ top_logprobs,
605+ embeddings,
606+ expert_load_data,
607+ prepared_layer_id,
608+ src_seq_idxes,
609+ out_tokens,
610+ out_logprobs);
611+ // convert to proto output
612+ forward_output_to_proto (next_tokens,
613+ logprobs,
614+ top_tokens,
615+ top_logprobs,
616+ embeddings,
617+ expert_load_data,
618+ prepared_layer_id,
619+ src_seq_idxes,
620+ out_tokens,
621+ out_logprobs,
622+ pb_forward_output);
623+ COUNTER_ADD (worker_service_latency_seconds, timer.elapsed_seconds ());
624+ });
679625}
680626
681627void WorkerService::GetLastStepResult (
0 commit comments