Skip to content

Commit 502caa7

Browse files
committed
feat: revert the original code before refactoring the multi-stream[2/2].
Signed-off-by: Tao Peng <[email protected]>
1 parent 26af848 commit 502caa7

37 files changed

+591
-828
lines changed

xllm/core/distributed_runtime/comm_channel.cpp

Lines changed: 6 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -481,22 +481,14 @@ bool CommChannel::get_active_activation_memory_async(
481481
bool CommChannel::execute_model_with_brpc(
482482
const std::vector<RawForwardInput>& inputs,
483483
folly::Promise<std::optional<RawForwardOutput>>& promise) {
484-
// convert to proto::BatchedForwardInputs
485-
proto::BatchedForwardInputs pb_batched_fwd_inputs;
486-
std::vector<proto::ForwardInput> batched_fwd_inputs_vec;
487-
batched_fwd_inputs_vec.reserve(inputs.size());
488-
for (auto i = 0; i < inputs.size(); ++i) {
489-
proto::ForwardInput pb_fwd_input;
490-
forward_input_to_proto(inputs[i], &pb_fwd_input);
491-
batched_fwd_inputs_vec.push_back(std::move(pb_fwd_input));
492-
}
493-
ADD_VECTOR_TO_PROTO(pb_batched_fwd_inputs.mutable_micro_inputs(),
494-
batched_fwd_inputs_vec);
484+
// convert to proto::ForwardInput
485+
proto::ForwardInput pb_forward_input;
486+
forward_input_to_proto(inputs[0], &pb_forward_input);
487+
495488
// call ExecuteModel with callback
496489
auto done = new ExecuteModelClosure();
497490
done->promise = std::move(promise);
498-
stub_->ExecuteModel(
499-
&done->cntl, &pb_batched_fwd_inputs, &done->pb_output, done);
491+
stub_->ExecuteModel(&done->cntl, &pb_forward_input, &done->pb_output, done);
500492
return true;
501493
}
502494

@@ -541,4 +533,4 @@ void TransferBlocksClosure::Run() {
541533
return;
542534
}
543535

544-
} // namespace xllm
536+
} // namespace xllm

xllm/core/distributed_runtime/comm_channel.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,4 +145,4 @@ class TransferBlocksClosure : public google::protobuf::Closure {
145145
brpc::Controller cntl;
146146
folly::Promise<uint32_t> promise;
147147
};
148-
} // namespace xllm
148+
} // namespace xllm

xllm/core/distributed_runtime/remote_worker.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -167,13 +167,14 @@ folly::SemiFuture<std::optional<ForwardOutput>> RemoteWorker::step_async(
167167
}
168168

169169
folly::SemiFuture<std::optional<RawForwardOutput>> RemoteWorker::step_async(
170-
const std::vector<RawForwardInput>& inputs) {
170+
const RawForwardInput& inputs) {
171171
folly::Promise<std::optional<RawForwardOutput>> promise;
172172
auto future = promise.getSemiFuture();
173-
threadpool_.schedule(
174-
[this, inputs = inputs, promise = std::move(promise)]() mutable {
175-
channel_->execute_model_async(inputs, promise);
176-
});
173+
threadpool_.schedule([this,
174+
inputs = std::move(inputs),
175+
promise = std::move(promise)]() mutable {
176+
channel_->execute_model_async({inputs}, promise);
177+
});
177178

178179
return future;
179180
}

xllm/core/distributed_runtime/remote_worker.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ class RemoteWorker : public WorkerClient {
127127
const ForwardInput& inputs) override;
128128

129129
virtual folly::SemiFuture<std::optional<RawForwardOutput>> step_async(
130-
const std::vector<RawForwardInput>& inputs) override;
130+
const RawForwardInput& inputs) override;
131131

132132
virtual folly::SemiFuture<folly::Unit> process_group_test_async() override;
133133

xllm/core/distributed_runtime/worker_service.cpp

100755100644
Lines changed: 60 additions & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ void WorkerService::set_worker(std::unique_ptr<Worker> worker) {
6666
initialized_ = true;
6767
}
6868

69-
void WorkerService::step(BatchedForwardInputs& batched_fwd_inputs,
69+
void WorkerService::step(ForwardInput& fwd_input,
7070
torch::Tensor& next_tokens,
7171
torch::Tensor& logprobs,
7272
torch::Tensor& top_tokens,
@@ -78,7 +78,7 @@ void WorkerService::step(BatchedForwardInputs& batched_fwd_inputs,
7878
torch::Tensor& out_tokens,
7979
torch::Tensor& out_logprobs) {
8080
// execute model
81-
auto future = worker_->step_async(batched_fwd_inputs);
81+
auto future = worker_->step_async(fwd_input);
8282

8383
if (!options_.enable_schedule_overlap()) {
8484
auto forward_outputs = std::move(future).get();
@@ -142,10 +142,10 @@ void WorkerService::step(BatchedForwardInputs& batched_fwd_inputs,
142142
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU);
143143
auto total_prefill_seq_len = 0;
144144
auto total_num_sequences = 0;
145-
for (auto& input : batched_fwd_inputs.micro_inputs) {
146-
total_num_sequences += input.input_params.num_sequences;
147-
total_prefill_seq_len += input.input_params.prefill_seq_len;
148-
}
145+
146+
total_num_sequences += fwd_input.input_params.num_sequences;
147+
total_prefill_seq_len += fwd_input.input_params.prefill_seq_len;
148+
149149
next_tokens =
150150
torch::arange(-1,
151151
-1 * (total_num_sequences - total_prefill_seq_len + 1),
@@ -166,7 +166,7 @@ void WorkerService::create_polling_shm_thread(
166166
output_shm_manager = std::move(output_shm_manager)]() mutable {
167167
Timer timer;
168168
while (true) {
169-
BatchedForwardInputs batched_fwd_inputs;
169+
ForwardInput fwd_input;
170170
std::vector<ForwardInput> inputs;
171171
input_shm_manager->raw_input_read(inputs);
172172
timer.reset();
@@ -184,31 +184,9 @@ void WorkerService::create_polling_shm_thread(
184184
torch::Tensor out_tokens;
185185
torch::Tensor out_logprobs;
186186

187-
auto micro_batches_num = inputs.size();
188-
batched_fwd_inputs.micro_inputs = std::move(inputs);
189-
batched_fwd_inputs.concated_sampling_params =
190-
batched_fwd_inputs.micro_inputs[0].sampling_params;
191-
for (auto i = 1; i < micro_batches_num; ++i) {
192-
batched_fwd_inputs.concated_sampling_params.concat(
193-
batched_fwd_inputs.micro_inputs[i].sampling_params);
194-
}
195-
196-
// concat acc_logprob here for beam search together
197-
if (micro_batches_num > 1) {
198-
std::vector<torch::Tensor> acc_logprob_vec;
199-
acc_logprob_vec.reserve(micro_batches_num);
200-
for (auto i = 0; i < micro_batches_num; ++i) {
201-
acc_logprob_vec.push_back(
202-
batched_fwd_inputs.micro_inputs[i].acc_logprob);
203-
}
204-
batched_fwd_inputs.acc_logprob =
205-
torch::cat(acc_logprob_vec, /*dim=*/-1);
206-
} else {
207-
batched_fwd_inputs.acc_logprob =
208-
batched_fwd_inputs.micro_inputs[0].acc_logprob;
209-
}
187+
fwd_input = std::move(inputs[0]);
210188

211-
step(batched_fwd_inputs,
189+
step(fwd_input,
212190
next_tokens,
213191
logprobs,
214192
top_tokens,
@@ -592,90 +570,58 @@ void WorkerService::UnlinkCluster(::google::protobuf::RpcController* controller,
592570
return;
593571
}
594572

595-
void WorkerService::ExecuteModel(
596-
::google::protobuf::RpcController* controller,
597-
const proto::BatchedForwardInputs* pb_batched_fwd_inputs,
598-
proto::ForwardOutput* pb_forward_output,
599-
::google::protobuf::Closure* done) {
600-
threadpool_->schedule([this,
601-
controller,
602-
pb_batched_fwd_inputs,
603-
pb_forward_output,
604-
done]() mutable {
605-
brpc::ClosureGuard done_guard(done);
606-
Timer timer;
607-
// convert proto::BatchedForwardInputs to BatchedForwardInputs
608-
auto micro_batches_num = pb_batched_fwd_inputs->micro_inputs().size();
609-
BatchedForwardInputs batched_fwd_inputs;
610-
batched_fwd_inputs.micro_inputs.reserve(micro_batches_num);
611-
for (auto i = 0; i < micro_batches_num; ++i) {
612-
ForwardInput forward_input;
613-
proto_to_forward_input(&(pb_batched_fwd_inputs->micro_inputs()[i]),
614-
forward_input,
615-
options_.num_decoding_tokens());
616-
batched_fwd_inputs.micro_inputs.push_back(std::move(forward_input));
617-
}
618-
619-
// concat sampling parameters
620-
batched_fwd_inputs.concated_sampling_params =
621-
batched_fwd_inputs.micro_inputs[0].sampling_params;
622-
for (auto i = 1; i < micro_batches_num; ++i) {
623-
batched_fwd_inputs.concated_sampling_params.concat(
624-
batched_fwd_inputs.micro_inputs[i].sampling_params);
625-
}
626-
627-
// concat acc_logprob here for beam search together
628-
if (micro_batches_num > 1) {
629-
std::vector<torch::Tensor> acc_logprob_vec;
630-
acc_logprob_vec.reserve(micro_batches_num);
631-
for (auto i = 0; i < micro_batches_num; ++i) {
632-
acc_logprob_vec.push_back(
633-
batched_fwd_inputs.micro_inputs[i].acc_logprob);
634-
}
635-
batched_fwd_inputs.acc_logprob = torch::cat(acc_logprob_vec, /*dim=*/-1);
636-
} else {
637-
batched_fwd_inputs.acc_logprob =
638-
batched_fwd_inputs.micro_inputs[0].acc_logprob;
639-
}
573+
void WorkerService::ExecuteModel(::google::protobuf::RpcController* controller,
574+
const proto::ForwardInput* pb_forward_input,
575+
proto::ForwardOutput* pb_forward_output,
576+
::google::protobuf::Closure* done) {
577+
threadpool_->schedule(
578+
[this, controller, pb_forward_input, pb_forward_output, done]() mutable {
579+
brpc::ClosureGuard done_guard(done);
580+
// convert proto::ForwardInput to ForwardInput
640581

641-
// model output
642-
torch::Tensor next_tokens;
643-
torch::Tensor logprobs;
644-
torch::Tensor top_tokens;
645-
torch::Tensor top_logprobs;
646-
torch::Tensor embeddings;
647-
torch::Tensor expert_load_data;
648-
int32_t prepared_layer_id = -1;
649-
// beam search kernel output
650-
torch::Tensor src_seq_idxes;
651-
torch::Tensor out_tokens;
652-
torch::Tensor out_logprobs;
653-
654-
step(batched_fwd_inputs,
655-
next_tokens,
656-
logprobs,
657-
top_tokens,
658-
top_logprobs,
659-
embeddings,
660-
expert_load_data,
661-
prepared_layer_id,
662-
src_seq_idxes,
663-
out_tokens,
664-
out_logprobs);
665-
// convert to proto output
666-
forward_output_to_proto(next_tokens,
667-
logprobs,
668-
top_tokens,
669-
top_logprobs,
670-
embeddings,
671-
expert_load_data,
672-
prepared_layer_id,
673-
src_seq_idxes,
674-
out_tokens,
675-
out_logprobs,
676-
pb_forward_output);
677-
COUNTER_ADD(worker_service_latency_seconds, timer.elapsed_seconds());
678-
});
582+
Timer timer;
583+
ForwardInput forward_input;
584+
proto_to_forward_input(
585+
pb_forward_input, forward_input, options_.num_decoding_tokens());
586+
587+
// model output
588+
torch::Tensor next_tokens;
589+
torch::Tensor logprobs;
590+
torch::Tensor top_tokens;
591+
torch::Tensor top_logprobs;
592+
torch::Tensor embeddings;
593+
torch::Tensor expert_load_data;
594+
int32_t prepared_layer_id = -1;
595+
// beam search kernel output
596+
torch::Tensor src_seq_idxes;
597+
torch::Tensor out_tokens;
598+
torch::Tensor out_logprobs;
599+
600+
step(forward_input,
601+
next_tokens,
602+
logprobs,
603+
top_tokens,
604+
top_logprobs,
605+
embeddings,
606+
expert_load_data,
607+
prepared_layer_id,
608+
src_seq_idxes,
609+
out_tokens,
610+
out_logprobs);
611+
// convert to proto output
612+
forward_output_to_proto(next_tokens,
613+
logprobs,
614+
top_tokens,
615+
top_logprobs,
616+
embeddings,
617+
expert_load_data,
618+
prepared_layer_id,
619+
src_seq_idxes,
620+
out_tokens,
621+
out_logprobs,
622+
pb_forward_output);
623+
COUNTER_ADD(worker_service_latency_seconds, timer.elapsed_seconds());
624+
});
679625
}
680626

681627
void WorkerService::GetLastStepResult(

xllm/core/distributed_runtime/worker_service.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ class WorkerService : public proto::DistributeWorker {
111111
::google::protobuf::Closure* done) override;
112112

113113
void ExecuteModel(::google::protobuf::RpcController* controller,
114-
const proto::BatchedForwardInputs* pb_batched_fwd_inputs,
114+
const proto::ForwardInput* pb_fwd_input,
115115
proto::ForwardOutput* pb_forward_output,
116116
::google::protobuf::Closure* done) override;
117117

@@ -126,7 +126,7 @@ class WorkerService : public proto::DistributeWorker {
126126
::google::protobuf::Closure* done) override;
127127

128128
private:
129-
void step(BatchedForwardInputs& batched_fwd_inputs,
129+
void step(ForwardInput& fwd_input,
130130
torch::Tensor& next_tokens,
131131
torch::Tensor& logprobs,
132132
torch::Tensor& top_tokens,

xllm/core/runtime/acl_graph_executor_impl.cpp

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -187,15 +187,14 @@ ForwardInput AclGraphExecutorImpl::prepare_inputs(Batch& batch) {
187187
// tokens: [num_decode_tokens]
188188
// positions: [num_decode_tokens] token pos in the sequence
189189
// returns: [num_decode_tokens, hidden_size]
190-
torch::Tensor AclGraphExecutorImpl::run(
191-
const std::vector<torch::Tensor>& tokens,
192-
const std::vector<torch::Tensor>& positions,
193-
std::vector<KVCache>& kv_caches,
194-
const std::vector<ModelInputParams>& params) {
190+
torch::Tensor AclGraphExecutorImpl::run(const torch::Tensor& tokens,
191+
const torch::Tensor& positions,
192+
std::vector<KVCache>& kv_caches,
193+
const ModelInputParams& params) {
195194
// no mirco batch in decode phase
196-
const torch::Tensor& tokens_tensor = tokens[0];
197-
const torch::Tensor& positions_tensor = positions[0];
198-
const ModelInputParams& params_single = params[0];
195+
const torch::Tensor& tokens_tensor = tokens;
196+
const torch::Tensor& positions_tensor = positions;
197+
const ModelInputParams& params_single = params;
199198
// Identify decode phase using q_max_seq_len for precise detection
200199
// Decode phase: all sequences have q_seq_len == 1 (generating one token at a
201200
// time) Prefill phase: sequences have q_seq_len > 1 (processing multiple
@@ -207,7 +206,7 @@ torch::Tensor AclGraphExecutorImpl::run(
207206
// If not in decode phase, use eager mode directly without acl graph
208207
if (!in_decoding_phase) {
209208
COUNTER_INC(num_model_execution_total_eager);
210-
return model_->forward(tokens[0], positions[0], kv_caches, params[0]);
209+
return model_->forward(tokens, positions, kv_caches, params);
211210
}
212211

213212
// Only use acl graph in decode phase for performance optimization
@@ -229,15 +228,12 @@ torch::Tensor AclGraphExecutorImpl::run(
229228

230229
// Combined condition for graph capture support
231230
// ACL graph executor only supports single tensor inputs (no micro-batching)
232-
const bool single_input =
233-
(tokens.size() == 1) && (positions.size() == 1) && (params.size() == 1);
234-
const bool capture_supported =
235-
single_input && seq_len_supported && same_num_decoding_tokens;
231+
const bool capture_supported = seq_len_supported && same_num_decoding_tokens;
236232

237233
// Early return if conditions are not suitable for graph operations
238234
if (!capture_supported) {
239235
COUNTER_INC(num_model_execution_total_eager);
240-
return model_->forward(tokens[0], positions[0], kv_caches, params[0]);
236+
return model_->forward(tokens, positions, kv_caches, params);
241237
}
242238

243239
// Check if captured graph exists for this bucket size
@@ -273,7 +269,7 @@ torch::Tensor AclGraphExecutorImpl::run(
273269
// Fallback to eager mode if capture fails
274270
LOG(ERROR) << "Failed to capture ACL graph for bucket size: " << bucket_size;
275271
COUNTER_INC(num_model_execution_total_eager);
276-
return model_->forward(tokens[0], positions[0], kv_caches, params[0]);
272+
return model_->forward(tokens, positions, kv_caches, params);
277273
}
278274

279275
void AclGraph::copy_data_to_graph_buffer(const torch::Tensor& tokens,

xllm/core/runtime/acl_graph_executor_impl.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,10 @@ class AclGraphExecutorImpl : public ExecutorImpl {
101101
ForwardInput prepare_inputs(Batch& batch) override;
102102

103103
// Execute model with graph optimization for decode phase
104-
torch::Tensor run(const std::vector<torch::Tensor>& tokens,
105-
const std::vector<torch::Tensor>& positions,
104+
torch::Tensor run(const torch::Tensor& tokens,
105+
const torch::Tensor& positions,
106106
std::vector<KVCache>& kv_caches,
107-
const std::vector<ModelInputParams>& params) override;
107+
const ModelInputParams& params) override;
108108

109109
private:
110110
// not own
@@ -123,4 +123,4 @@ class AclGraphExecutorImpl : public ExecutorImpl {
123123
uint32_t get_bucket_size(uint32_t batch_size) const;
124124
};
125125

126-
} // namespace xllm
126+
} // namespace xllm

0 commit comments

Comments
 (0)