Skip to content

Commit 1f39153

Browse files
author
Abhishek
committed
DeepCompile ZeRO-3: robust allgather for uneven shards; fix profiling meta key (max_mem)
Signed-off-by: Abhishek <[email protected]>
1 parent 64ac13f commit 1f39153

File tree

2 files changed

+58
-4
lines changed

2 files changed

+58
-4
lines changed

csrc/compile/z3.cpp

Lines changed: 57 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,68 @@ class Z3CustomOpExecutor : public CustomOpExecutor {
6969
const at::Tensor& ds_tensor = param.getDSTensor();
7070

7171
if (symm_mem == nullptr) {
72-
ncclResult_t result = ncclAllGather(ds_tensor.contiguous().data_ptr(),
73-
output_buf.data_ptr(),
74-
ds_tensor.numel(),
72+
// Support uneven shard sizes across ranks by padding to max shard size
73+
int world_size = process_group_->getSize();
74+
int rank = process_group_->getRank();
75+
76+
int64_t local_count = ds_tensor.numel();
77+
78+
// Gather local shard sizes from all ranks
79+
auto count_options = at::TensorOptions().dtype(at::kLong).device(at::kCUDA);
80+
at::Tensor local_count_tensor = torch::tensor({local_count}, count_options);
81+
std::vector<at::Tensor> all_counts(world_size);
82+
for (int i = 0; i < world_size; ++i) {
83+
all_counts[i] = torch::empty_like(local_count_tensor);
84+
}
85+
process_group_->allgather(all_counts, local_count_tensor)->wait();
86+
87+
int64_t max_count = 0;
88+
std::vector<int64_t> host_counts(world_size);
89+
for (int i = 0; i < world_size; ++i) {
90+
host_counts[i] = all_counts[i].to(torch::kCPU).item<int64_t>();
91+
if (host_counts[i] > max_count) { max_count = host_counts[i]; }
92+
}
93+
94+
// Prepare padded send buffer and gather buffer on AG stream
95+
at::Tensor send_buf;
96+
at::Tensor gather_tmp;
97+
{
98+
at::cuda::CUDAStreamGuard guard(ag_stream_);
99+
send_buf = torch::empty({max_count}, ds_tensor.options());
100+
// Copy real shard
101+
send_buf.index_put_({torch::indexing::Slice(0, local_count)}, ds_tensor.flatten(), true);
102+
// Zero-pad the tail if needed
103+
if (local_count < max_count) {
104+
auto pad_len = max_count - local_count;
105+
send_buf.index_put_({torch::indexing::Slice(local_count, max_count)},
106+
torch::zeros({pad_len}, ds_tensor.options()),
107+
true);
108+
}
109+
gather_tmp = torch::empty({static_cast<long>(world_size) * max_count}, ds_tensor.options());
110+
}
111+
112+
ncclResult_t result = ncclAllGather(send_buf.data_ptr(),
113+
gather_tmp.data_ptr(),
114+
max_count,
75115
get_nccl_data_type(ds_tensor.scalar_type()),
76116
nccl_comm_,
77117
ag_stream_);
78118

79119
if (result != ncclSuccess) { throw std::runtime_error("NCCL AllGather failed"); }
120+
121+
// Reconstruct full parameter into output_buf (flattened), then shape
122+
{
123+
at::cuda::CUDAStreamGuard guard(ag_stream_);
124+
auto out_flat = output_buf.flatten();
125+
int64_t out_offset = 0;
126+
for (int i = 0; i < world_size; ++i) {
127+
int64_t len = host_counts[i];
128+
if (len == 0) { continue; }
129+
auto src = gather_tmp.index({torch::indexing::Slice(i * max_count, i * max_count + len)});
130+
out_flat.index_put_({torch::indexing::Slice(out_offset, out_offset + len)}, src, true);
131+
out_offset += len;
132+
}
133+
}
80134
} else {
81135
at::cuda::CUDAStreamGuard guard(ag_stream_);
82136
int world_size = process_group_->getSize();

deepspeed/compile/profilers/graph_profile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def run_node(self, n: torch.fx.Node) -> Any:
122122
n.meta["device_time"] = 0.0
123123
n.meta["wall_time"] = 0.0
124124
n.meta["alloc_mem"] = 0
125-
n.meta["max_memory"] = 0
125+
n.meta["max_mem"] = 0
126126
n.meta["tensor_size"] = _node_size(n)
127127
return super().run_node(n)
128128

0 commit comments

Comments
 (0)