Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 6 additions & 14 deletions xllm/core/distributed_runtime/comm_channel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -507,22 +507,14 @@ bool CommChannel::get_active_activation_memory_async(
bool CommChannel::execute_model_with_brpc(
const std::vector<RawForwardInput>& inputs,
folly::Promise<std::optional<RawForwardOutput>>& promise) {
// convert to proto::BatchedForwardInputs
proto::BatchedForwardInputs pb_batched_fwd_inputs;
std::vector<proto::ForwardInput> batched_fwd_inputs_vec;
batched_fwd_inputs_vec.reserve(inputs.size());
for (auto i = 0; i < inputs.size(); ++i) {
proto::ForwardInput pb_fwd_input;
forward_input_to_proto(inputs[i], &pb_fwd_input);
batched_fwd_inputs_vec.push_back(std::move(pb_fwd_input));
}
ADD_VECTOR_TO_PROTO(pb_batched_fwd_inputs.mutable_micro_inputs(),
batched_fwd_inputs_vec);
// convert to proto::ForwardInput
proto::ForwardInput pb_forward_input;
forward_input_to_proto(inputs[0], &pb_forward_input);

// call ExecuteModel with callback
auto done = new ExecuteModelClosure();
done->promise = std::move(promise);
stub_->ExecuteModel(
&done->cntl, &pb_batched_fwd_inputs, &done->pb_output, done);
stub_->ExecuteModel(&done->cntl, &pb_forward_input, &done->pb_output, done);
return true;
}

Expand Down Expand Up @@ -567,4 +559,4 @@ void TransferBlocksClosure::Run() {
return;
}

} // namespace xllm
} // namespace xllm
2 changes: 1 addition & 1 deletion xllm/core/distributed_runtime/comm_channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,4 +145,4 @@ class TransferBlocksClosure : public google::protobuf::Closure {
brpc::Controller cntl;
folly::Promise<uint32_t> promise;
};
} // namespace xllm
} // namespace xllm
11 changes: 6 additions & 5 deletions xllm/core/distributed_runtime/remote_worker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,13 +167,14 @@ folly::SemiFuture<std::optional<ForwardOutput>> RemoteWorker::step_async(
}

folly::SemiFuture<std::optional<RawForwardOutput>> RemoteWorker::step_async(
const std::vector<RawForwardInput>& inputs) {
const RawForwardInput& inputs) {
folly::Promise<std::optional<RawForwardOutput>> promise;
auto future = promise.getSemiFuture();
threadpool_.schedule(
[this, inputs = inputs, promise = std::move(promise)]() mutable {
channel_->execute_model_async(inputs, promise);
});
threadpool_.schedule([this,
inputs = std::move(inputs),
promise = std::move(promise)]() mutable {
channel_->execute_model_async({inputs}, promise);
});

return future;
}
Expand Down
2 changes: 1 addition & 1 deletion xllm/core/distributed_runtime/remote_worker.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class RemoteWorker : public WorkerClient {
const ForwardInput& inputs) override;

virtual folly::SemiFuture<std::optional<RawForwardOutput>> step_async(
const std::vector<RawForwardInput>& inputs) override;
const RawForwardInput& inputs) override;

virtual folly::SemiFuture<folly::Unit> process_group_test_async() override;

Expand Down
174 changes: 60 additions & 114 deletions xllm/core/distributed_runtime/worker_service.cpp
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ void WorkerService::set_worker(std::unique_ptr<Worker> worker) {
initialized_ = true;
}

void WorkerService::step(BatchedForwardInputs& batched_fwd_inputs,
void WorkerService::step(ForwardInput& fwd_input,
torch::Tensor& next_tokens,
torch::Tensor& logprobs,
torch::Tensor& top_tokens,
Expand All @@ -78,7 +78,7 @@ void WorkerService::step(BatchedForwardInputs& batched_fwd_inputs,
torch::Tensor& out_tokens,
torch::Tensor& out_logprobs) {
// execute model
auto future = worker_->step_async(batched_fwd_inputs);
auto future = worker_->step_async(fwd_input);

if (!options_.enable_schedule_overlap()) {
auto forward_outputs = std::move(future).get();
Expand Down Expand Up @@ -142,10 +142,10 @@ void WorkerService::step(BatchedForwardInputs& batched_fwd_inputs,
torch::TensorOptions().dtype(torch::kInt32).device(torch::kCPU);
auto total_prefill_seq_len = 0;
auto total_num_sequences = 0;
for (auto& input : batched_fwd_inputs.micro_inputs) {
total_num_sequences += input.input_params.num_sequences;
total_prefill_seq_len += input.input_params.prefill_seq_len;
}

total_num_sequences += fwd_input.input_params.num_sequences;
total_prefill_seq_len += fwd_input.input_params.prefill_seq_len;

next_tokens =
torch::arange(-1,
-1 * (total_num_sequences - total_prefill_seq_len + 1),
Expand All @@ -166,7 +166,7 @@ void WorkerService::create_polling_shm_thread(
output_shm_manager = std::move(output_shm_manager)]() mutable {
Timer timer;
while (true) {
BatchedForwardInputs batched_fwd_inputs;
ForwardInput fwd_input;
std::vector<ForwardInput> inputs;
input_shm_manager->raw_input_read(inputs);
timer.reset();
Expand All @@ -184,31 +184,9 @@ void WorkerService::create_polling_shm_thread(
torch::Tensor out_tokens;
torch::Tensor out_logprobs;

auto micro_batches_num = inputs.size();
batched_fwd_inputs.micro_inputs = std::move(inputs);
batched_fwd_inputs.concated_sampling_params =
batched_fwd_inputs.micro_inputs[0].sampling_params;
for (auto i = 1; i < micro_batches_num; ++i) {
batched_fwd_inputs.concated_sampling_params.concat(
batched_fwd_inputs.micro_inputs[i].sampling_params);
}

// concat acc_logprob here for beam search together
if (micro_batches_num > 1) {
std::vector<torch::Tensor> acc_logprob_vec;
acc_logprob_vec.reserve(micro_batches_num);
for (auto i = 0; i < micro_batches_num; ++i) {
acc_logprob_vec.push_back(
batched_fwd_inputs.micro_inputs[i].acc_logprob);
}
batched_fwd_inputs.acc_logprob =
torch::cat(acc_logprob_vec, /*dim=*/-1);
} else {
batched_fwd_inputs.acc_logprob =
batched_fwd_inputs.micro_inputs[0].acc_logprob;
}
fwd_input = std::move(inputs[0]);

step(batched_fwd_inputs,
step(fwd_input,
next_tokens,
logprobs,
top_tokens,
Expand Down Expand Up @@ -598,90 +576,58 @@ void WorkerService::UnlinkCluster(::google::protobuf::RpcController* controller,
return;
}

void WorkerService::ExecuteModel(
::google::protobuf::RpcController* controller,
const proto::BatchedForwardInputs* pb_batched_fwd_inputs,
proto::ForwardOutput* pb_forward_output,
::google::protobuf::Closure* done) {
threadpool_->schedule([this,
controller,
pb_batched_fwd_inputs,
pb_forward_output,
done]() mutable {
brpc::ClosureGuard done_guard(done);
Timer timer;
// convert proto::BatchedForwardInputs to BatchedForwardInputs
auto micro_batches_num = pb_batched_fwd_inputs->micro_inputs().size();
BatchedForwardInputs batched_fwd_inputs;
batched_fwd_inputs.micro_inputs.reserve(micro_batches_num);
for (auto i = 0; i < micro_batches_num; ++i) {
ForwardInput forward_input;
proto_to_forward_input(&(pb_batched_fwd_inputs->micro_inputs()[i]),
forward_input,
options_.num_decoding_tokens());
batched_fwd_inputs.micro_inputs.push_back(std::move(forward_input));
}

// concat sampling parameters
batched_fwd_inputs.concated_sampling_params =
batched_fwd_inputs.micro_inputs[0].sampling_params;
for (auto i = 1; i < micro_batches_num; ++i) {
batched_fwd_inputs.concated_sampling_params.concat(
batched_fwd_inputs.micro_inputs[i].sampling_params);
}

// concat acc_logprob here for beam search together
if (micro_batches_num > 1) {
std::vector<torch::Tensor> acc_logprob_vec;
acc_logprob_vec.reserve(micro_batches_num);
for (auto i = 0; i < micro_batches_num; ++i) {
acc_logprob_vec.push_back(
batched_fwd_inputs.micro_inputs[i].acc_logprob);
}
batched_fwd_inputs.acc_logprob = torch::cat(acc_logprob_vec, /*dim=*/-1);
} else {
batched_fwd_inputs.acc_logprob =
batched_fwd_inputs.micro_inputs[0].acc_logprob;
}
void WorkerService::ExecuteModel(::google::protobuf::RpcController* controller,
const proto::ForwardInput* pb_forward_input,
proto::ForwardOutput* pb_forward_output,
::google::protobuf::Closure* done) {
threadpool_->schedule(
[this, controller, pb_forward_input, pb_forward_output, done]() mutable {
brpc::ClosureGuard done_guard(done);
// convert proto::ForwardInput to ForwardInput

// model output
torch::Tensor next_tokens;
torch::Tensor logprobs;
torch::Tensor top_tokens;
torch::Tensor top_logprobs;
torch::Tensor embeddings;
torch::Tensor expert_load_data;
int32_t prepared_layer_id = -1;
// beam search kernel output
torch::Tensor src_seq_idxes;
torch::Tensor out_tokens;
torch::Tensor out_logprobs;

step(batched_fwd_inputs,
next_tokens,
logprobs,
top_tokens,
top_logprobs,
embeddings,
expert_load_data,
prepared_layer_id,
src_seq_idxes,
out_tokens,
out_logprobs);
// convert to proto output
forward_output_to_proto(next_tokens,
logprobs,
top_tokens,
top_logprobs,
embeddings,
expert_load_data,
prepared_layer_id,
src_seq_idxes,
out_tokens,
out_logprobs,
pb_forward_output);
COUNTER_ADD(worker_service_latency_seconds, timer.elapsed_seconds());
});
Timer timer;
ForwardInput forward_input;
proto_to_forward_input(
pb_forward_input, forward_input, options_.num_decoding_tokens());

// model output
torch::Tensor next_tokens;
torch::Tensor logprobs;
torch::Tensor top_tokens;
torch::Tensor top_logprobs;
torch::Tensor embeddings;
torch::Tensor expert_load_data;
int32_t prepared_layer_id = -1;
// beam search kernel output
torch::Tensor src_seq_idxes;
torch::Tensor out_tokens;
torch::Tensor out_logprobs;

step(forward_input,
next_tokens,
logprobs,
top_tokens,
top_logprobs,
embeddings,
expert_load_data,
prepared_layer_id,
src_seq_idxes,
out_tokens,
out_logprobs);
// convert to proto output
forward_output_to_proto(next_tokens,
logprobs,
top_tokens,
top_logprobs,
embeddings,
expert_load_data,
prepared_layer_id,
src_seq_idxes,
out_tokens,
out_logprobs,
pb_forward_output);
COUNTER_ADD(worker_service_latency_seconds, timer.elapsed_seconds());
});
}

void WorkerService::GetLastStepResult(
Expand Down
4 changes: 2 additions & 2 deletions xllm/core/distributed_runtime/worker_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ class WorkerService : public proto::DistributeWorker {
::google::protobuf::Closure* done) override;

void ExecuteModel(::google::protobuf::RpcController* controller,
const proto::BatchedForwardInputs* pb_batched_fwd_inputs,
const proto::ForwardInput* pb_fwd_input,
proto::ForwardOutput* pb_forward_output,
::google::protobuf::Closure* done) override;

Expand All @@ -126,7 +126,7 @@ class WorkerService : public proto::DistributeWorker {
::google::protobuf::Closure* done) override;

private:
void step(BatchedForwardInputs& batched_fwd_inputs,
void step(ForwardInput& fwd_input,
torch::Tensor& next_tokens,
torch::Tensor& logprobs,
torch::Tensor& top_tokens,
Expand Down
28 changes: 12 additions & 16 deletions xllm/core/framework/model/causal_lm.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,10 @@ class CausalLM : public torch::nn::Module {
// tokens: [num_tokens]
// positions: [num_tokens]
// returns: [num_tokens, hidden_size]
virtual torch::Tensor forward(
const std::vector<torch::Tensor>& tokens,
const std::vector<torch::Tensor>& positions,
std::vector<KVCache>& kv_caches,
const std::vector<ModelInputParams>& parameters) = 0;
virtual torch::Tensor forward(const torch::Tensor& tokens,
const torch::Tensor& positions,
std::vector<KVCache>& kv_caches,
const ModelInputParams& parameters) = 0;

// hidden_states: [num_tokens, hidden_size]
// seleted_idxes: [num_tokens]
Expand All @@ -68,9 +67,8 @@ class CausalLM : public torch::nn::Module {

virtual layer::LmHead get_lm_head() = 0;
virtual void set_lm_head(layer::LmHead& head) = 0;
virtual std::vector<layer::WordEmbedding> get_word_embedding() = 0;
virtual void set_word_embedding(
std::vector<layer::WordEmbedding>& embedding) = 0;
virtual layer::WordEmbedding get_word_embedding() = 0;
virtual void set_word_embedding(layer::WordEmbedding& embedding) = 0;
};

template <typename Model>
Expand All @@ -79,11 +77,10 @@ class CausalLMImpl : public CausalLM {
CausalLMImpl(Model model, const torch::TensorOptions& options)
: model_(std::move(model)), options_(options) {}

torch::Tensor forward(
const std::vector<torch::Tensor>& tokens,
const std::vector<torch::Tensor>& positions,
std::vector<KVCache>& kv_caches,
const std::vector<ModelInputParams>& parameters) override {
torch::Tensor forward(const torch::Tensor& tokens,
const torch::Tensor& positions,
std::vector<KVCache>& kv_caches,
const ModelInputParams& parameters) override {
return model_->forward(tokens, positions, kv_caches, parameters);
}

Expand All @@ -109,12 +106,11 @@ class CausalLMImpl : public CausalLM {

void set_lm_head(layer::LmHead& head) override { model_->set_lm_head(head); };

std::vector<layer::WordEmbedding> get_word_embedding() override {
layer::WordEmbedding get_word_embedding() override {
return model_->get_word_embedding();
};

void set_word_embedding(
std::vector<layer::WordEmbedding>& embedding) override {
void set_word_embedding(layer::WordEmbedding& embedding) override {
model_->set_word_embedding(embedding);
};

Expand Down
14 changes: 6 additions & 8 deletions xllm/core/framework/model/causal_vlm.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,10 @@ class CausalVLMImpl : public CausalVLM {
CausalVLMImpl(Model model, const torch::TensorOptions& options)
: model_(std::move(model)), options_(options) {}

torch::Tensor forward(
const std::vector<torch::Tensor>& tokens,
const std::vector<torch::Tensor>& positions,
std::vector<KVCache>& kv_caches,
const std::vector<ModelInputParams>& parameters) override {
torch::Tensor forward(const torch::Tensor& tokens,
const torch::Tensor& positions,
std::vector<KVCache>& kv_caches,
const ModelInputParams& parameters) override {
return model_->forward(tokens, positions, kv_caches, parameters);
}

Expand All @@ -68,12 +67,11 @@ class CausalVLMImpl : public CausalVLM {

void set_lm_head(layer::LmHead& head) override { model_->set_lm_head(head); };

std::vector<layer::WordEmbedding> get_word_embedding() override {
layer::WordEmbedding get_word_embedding() override {
return model_->get_word_embedding();
};

void set_word_embedding(
std::vector<layer::WordEmbedding>& embedding) override {
void set_word_embedding(layer::WordEmbedding& embedding) override {
model_->set_word_embedding(embedding);
};

Expand Down
Loading
Loading