Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
2e6f1b8
DeepCompile ZeRO-3: robust allgather for uneven shards; fix profiling…
adalakoti90 Aug 15, 2025
24ceb0b
DeepCompile ZeRO-3: remove size-allgather; enforce uniform shard size…
adalakoti90 Aug 18, 2025
94e3dd6
Z3: allgatherParam pre-allocates padded buffer and returns a sliced v…
adalakoti90 Aug 21, 2025
ffa2aba
Z3: use .contiguous() for NCCL allgather send buffer; add comment
adalakoti90 Aug 21, 2025
04f9bd4
Merge branch 'master' into fix/dc-zero3-allgather-uneven-shards
juyterman1000 Aug 26, 2025
80ff735
Merge branch 'master' into fix/dc-zero3-allgather-uneven-shards
juyterman1000 Aug 27, 2025
6ccf869
Fix compilation error and add test for uneven shard assertion
adalakoti90 Aug 28, 2025
04e7617
Merge branch 'master' into fix/dc-zero3-allgather-uneven-shards
juyterman1000 Aug 29, 2025
daa3e73
Merge branch 'master' into fix/dc-zero3-allgather-uneven-shards
juyterman1000 Aug 31, 2025
4320c5a
Fix allgather lvalue reference error
adalakoti90 Aug 31, 2025
0c13f9e
Merge branch 'master' into fix/dc-zero3-allgather-uneven-shards
juyterman1000 Sep 1, 2025
d038200
Merge branch 'master' into fix/dc-zero3-allgather-uneven-shards
juyterman1000 Sep 7, 2025
9ce693c
Fix ZeRO-3 DeepCompile to handle padded parameters correctly - Update…
adalakoti90 Sep 12, 2025
a01d899
Merge branch 'master' into fix/dc-zero3-allgather-uneven-shards
juyterman1000 Sep 12, 2025
9e6ea01
Merge branch 'fix/dc-zero3-allgather-uneven-shards' of github.com:juy…
adalakoti90 Sep 13, 2025
4f322d4
Z3 registration: compute padded-per-rank from ds_shape (total_numel) …
adalakoti90 Sep 13, 2025
05a13ed
Z3 allgatherParam: derive padded_numel via padded_per_rank from ds_shape
adalakoti90 Sep 13, 2025
c8f11bb
Merge remote-tracking branch 'upstream/master' into fix/dc-zero3-allg…
adalakoti90 Sep 13, 2025
c844010
Merge branch 'master' into fix/dc-zero3-allgather-uneven-shards
juyterman1000 Sep 17, 2025
ea3ad68
Merge branch 'master' into fix/dc-zero3-allgather-uneven-shards
juyterman1000 Sep 20, 2025
f3f7cf5
z3: fix registration allgather output tensor shape (ProcessGroup::all…
adalakoti90 Sep 20, 2025
09c1be3
fix formatting
tohtana Sep 22, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 81 additions & 10 deletions csrc/compile/z3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,15 @@ class Z3CustomOpExecutor : public CustomOpExecutor {
const at::Tensor& ds_tensor = param.getDSTensor();

if (symm_mem == nullptr) {
// Fast path: assume uniform shard sizes (ZeRO-3 partitions are padded to uniform size)
const int world_size = process_group_->getSize();
const int64_t shard_elems = ds_tensor.numel();

// Perform all-gather directly into the pre-allocated padded output buffer
// NCCL requires contiguous storage; use .contiguous() explicitly
ncclResult_t result = ncclAllGather(ds_tensor.contiguous().data_ptr(),
output_buf.data_ptr(),
ds_tensor.numel(),
shard_elems,
get_nccl_data_type(ds_tensor.scalar_type()),
nccl_comm_,
ag_stream_);
Expand Down Expand Up @@ -104,13 +110,30 @@ class Z3CustomOpExecutor : public CustomOpExecutor {
at::Tensor allgatherParam(long ds_id,
c10::intrusive_ptr<c10d::symmetric_memory::SymmetricMemory> symm_mem)
{
if (param_registry_->isValid(ds_id)) { return param_registry_->getGatheredParam(ds_id); }

const DSParam& param = param_registry_->getParam(ds_id);
const at::Tensor& ds_tensor = param.getDSTensor();
at::Tensor output_buf = param_registry_->hasGatheredParam(ds_id)
? param_registry_->getGatheredParam(ds_id)
: torch::empty(param.getShape(), ds_tensor.options());
const int world_size = process_group_->getSize();
const int64_t true_numel = static_cast<int64_t>(productDim(param.getShape()));
const int64_t padded_per_rank = (true_numel + world_size - 1) / world_size;
const int64_t padded_numel = static_cast<int64_t>(world_size) * padded_per_rank;

if (param_registry_->isValid(ds_id)) {
// Return a view sliced to the true size with the original shape
auto base = param_registry_->getGatheredParam(ds_id);
return base.flatten()
.index({torch::indexing::Slice(0, true_numel)})
.view(param.getShape());
}

at::Tensor output_buf;
if (param_registry_->hasGatheredParam(ds_id)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure when isValid(ds_id) is false while hasGatheredParam(ds_id) is true. They are both set at the end of launchAllGather(), and releasing a gathered param will unset the valid flag in unregisterGatheredParam().

auto existing = param_registry_->getGatheredParam(ds_id);
if (existing.defined() && existing.numel() == padded_numel) { output_buf = existing; }
}
if (!output_buf.defined()) {
at::cuda::CUDAStreamGuard guard(ag_stream_);
output_buf = torch::empty({padded_numel}, ds_tensor.options());
}

assert(hasKey(ag_comp_done_events_, ds_id));
ag_comp_done_events_[ds_id]->record();
Expand All @@ -119,7 +142,10 @@ class Z3CustomOpExecutor : public CustomOpExecutor {
launchAllGather(output_buf, ds_id, symm_mem);

ag_comm_done_events_[ds_id]->record(ag_stream_);
return output_buf;
// Return a view of the gathered padded buffer matching the true param shape
return output_buf.flatten()
.index({torch::indexing::Slice(0, true_numel)})
.view(param.getShape());
}

void prefetchParamsFused(std::vector<int64_t> ds_ids,
Expand All @@ -133,11 +159,19 @@ class Z3CustomOpExecutor : public CustomOpExecutor {
std::unordered_map<long, at::Tensor> output_bufs;
for (long ds_id : invalid_ds_ids) {
const DSParam& param = param_registry_->getParam(ds_id);
const at::Tensor& ds_tensor = param.getDSTensor();
const int world_size = process_group_->getSize();
const int64_t shard_elems = ds_tensor.numel();
const int64_t padded_numel = static_cast<int64_t>(world_size) * shard_elems;

if (param_registry_->hasGatheredParam(ds_id)) {
output_bufs[ds_id] = param_registry_->getGatheredParam(ds_id);
} else {
output_bufs[ds_id] = torch::empty(param.getShape(), param.getDSTensor().options());
auto existing = param_registry_->getGatheredParam(ds_id);
if (existing.defined() && existing.numel() == padded_numel) {
output_bufs[ds_id] = existing;
continue;
}
}
output_bufs[ds_id] = torch::empty({padded_numel}, ds_tensor.options());
}

for (long ds_id : invalid_ds_ids) {
Expand Down Expand Up @@ -383,6 +417,43 @@ void register_z3_param(long ds_id,
{
param_registry->registerParam(ds_id, ds_shape, ds_tensor, grad_buffer, true, 0, persistent);
if (persistent) { param_registry->registerGatheredParam(ds_id, ds_tensor); }

// Validate that padded shard sizes are uniform across ranks at registration time
// DeepSpeed pads parameters to ensure even division, so we check the padded size
// which should be uniform across all ranks for correct allgather behavior
const int64_t local_count = ds_tensor.numel();
const int world_size = process_group->getSize();

// Calculate padded size (aligned to world_size)
// Use ds_shape to compute the full (unpartitioned) parameter size
int64_t total_numel = 1;
for (const auto dim : ds_shape) { total_numel *= dim; }
const int64_t padded_per_rank = (total_numel + world_size - 1) / world_size;

// For verification: all ranks should have the same padded size
auto count_options = at::TensorOptions().dtype(at::kLong).device(at::kCUDA);
at::Tensor local_padded_tensor = torch::tensor({padded_per_rank}, count_options);
std::vector<at::Tensor> all_padded_counts(world_size);
for (int i = 0; i < world_size; ++i) {
all_padded_counts[i] = torch::empty_like(local_padded_tensor);
}

// Build lvalue buffers for output and input as required by ProcessGroup::allgather
// The first argument must be a single-element vector containing a vector of WORLD_SIZE tensors
std::vector<std::vector<at::Tensor>> output_tensors(1);
output_tensors[0] = all_padded_counts;
std::vector<at::Tensor> input_tensors = {local_padded_tensor};
process_group->allgather(output_tensors, input_tensors)->wait();

// Verify all ranks agree on the padded size
for (int i = 0; i < world_size; ++i) {
int64_t padded_count = all_padded_counts[i].to(torch::kCPU).item<int64_t>();
if (padded_count != padded_per_rank) {
throw std::runtime_error(
"ZeRO-3 registration error: inconsistent padded shard sizes across ranks. "
"This is an internal error - please report this issue.");
}
}
}

at::Tensor allgather_param(at::Tensor param_tensor, long graph_id, long ds_id)
Expand Down
2 changes: 1 addition & 1 deletion deepspeed/compile/profilers/graph_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def run_node(self, n: torch.fx.Node) -> Any:
n.meta["device_time"] = 0.0
n.meta["wall_time"] = 0.0
n.meta["alloc_mem"] = 0
n.meta["max_memory"] = 0
n.meta["max_mem"] = 0
n.meta["tensor_size"] = _node_size(n)
return super().run_node(n)

Expand Down
33 changes: 33 additions & 0 deletions tests/unit/v1/compile/test_compile_zero.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,36 @@ def test(self, zero_stage, dtype, deepcompile):

# Need warmup steps
compare_loss(self, config_dict, dtype, iteration=10)

@pytest.mark.parametrize('dtype', [torch.float32])
@pytest.mark.parametrize('zero_stage', [3])
def test_padded_shard_handling(self, zero_stage, dtype):
"""Test that parameters with padding (uneven division) work correctly with DeepCompile"""
if not required_torch_version(min_version=2.6):
pytest.skip("DeepCompile requires PyTorch >= v2.6")

if get_accelerator().device_name() == "cpu":
pytest.skip("CPU does not support this test yet")

# Use a hidden dimension that requires padding when divided across ranks
# With world_size=2, a hidden_dim of 13 creates parameters that need padding
config_dict = {
"train_micro_batch_size_per_gpu": 1,
"steps_per_print": 1,
"optimizer": {
"type": "Adam",
"params": {
"lr": 0.00015
}
},
"zero_optimization": {
"stage": zero_stage,
},
"compile": {
"deepcompile": True
}
}

# This should work correctly with our padding-aware implementation
# The test verifies that padded parameters are handled properly
compare_loss(self, config_dict, dtype, iteration=1, hidden_dim_override=13)
4 changes: 2 additions & 2 deletions tests/unit/v1/compile/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@


@enable_determinism(123)
def compare_loss(self, config, dtype, iteration=5):
hidden_dim = 10
def compare_loss(self, config, dtype, iteration=5, hidden_dim_override=None):
hidden_dim = hidden_dim_override if hidden_dim_override is not None else 10
RTOL = 5e-1
ATOL = 1e-2

Expand Down