@@ -69,14 +69,68 @@ class Z3CustomOpExecutor : public CustomOpExecutor {
69
69
const at::Tensor& ds_tensor = param.getDSTensor ();
70
70
71
71
if (symm_mem == nullptr ) {
72
- ncclResult_t result = ncclAllGather (ds_tensor.contiguous ().data_ptr (),
73
- output_buf.data_ptr (),
74
- ds_tensor.numel (),
72
+ // Support uneven shard sizes across ranks by padding to max shard size
73
+ int world_size = process_group_->getSize ();
74
+ int rank = process_group_->getRank ();
75
+
76
+ int64_t local_count = ds_tensor.numel ();
77
+
78
+ // Gather local shard sizes from all ranks
79
+ auto count_options = at::TensorOptions ().dtype (at::kLong ).device (at::kCUDA );
80
+ at::Tensor local_count_tensor = torch::tensor ({local_count}, count_options);
81
+ std::vector<at::Tensor> all_counts (world_size);
82
+ for (int i = 0 ; i < world_size; ++i) {
83
+ all_counts[i] = torch::empty_like (local_count_tensor);
84
+ }
85
+ process_group_->allgather (all_counts, local_count_tensor)->wait ();
86
+
87
+ int64_t max_count = 0 ;
88
+ std::vector<int64_t > host_counts (world_size);
89
+ for (int i = 0 ; i < world_size; ++i) {
90
+ host_counts[i] = all_counts[i].to (torch::kCPU ).item <int64_t >();
91
+ if (host_counts[i] > max_count) { max_count = host_counts[i]; }
92
+ }
93
+
94
+ // Prepare padded send buffer and gather buffer on AG stream
95
+ at::Tensor send_buf;
96
+ at::Tensor gather_tmp;
97
+ {
98
+ at::cuda::CUDAStreamGuard guard (ag_stream_);
99
+ send_buf = torch::empty ({max_count}, ds_tensor.options ());
100
+ // Copy real shard
101
+ send_buf.index_put_ ({torch::indexing::Slice (0 , local_count)}, ds_tensor.flatten (), true );
102
+ // Zero-pad the tail if needed
103
+ if (local_count < max_count) {
104
+ auto pad_len = max_count - local_count;
105
+ send_buf.index_put_ ({torch::indexing::Slice (local_count, max_count)},
106
+ torch::zeros ({pad_len}, ds_tensor.options ()),
107
+ true );
108
+ }
109
+ gather_tmp = torch::empty ({static_cast <long >(world_size) * max_count}, ds_tensor.options ());
110
+ }
111
+
112
+ ncclResult_t result = ncclAllGather (send_buf.data_ptr (),
113
+ gather_tmp.data_ptr (),
114
+ max_count,
75
115
get_nccl_data_type (ds_tensor.scalar_type ()),
76
116
nccl_comm_,
77
117
ag_stream_);
78
118
79
119
if (result != ncclSuccess) { throw std::runtime_error (" NCCL AllGather failed" ); }
120
+
121
+ // Reconstruct full parameter into output_buf (flattened), then shape
122
+ {
123
+ at::cuda::CUDAStreamGuard guard (ag_stream_);
124
+ auto out_flat = output_buf.flatten ();
125
+ int64_t out_offset = 0 ;
126
+ for (int i = 0 ; i < world_size; ++i) {
127
+ int64_t len = host_counts[i];
128
+ if (len == 0 ) { continue ; }
129
+ auto src = gather_tmp.index ({torch::indexing::Slice (i * max_count, i * max_count + len)});
130
+ out_flat.index_put_ ({torch::indexing::Slice (out_offset, out_offset + len)}, src, true );
131
+ out_offset += len;
132
+ }
133
+ }
80
134
} else {
81
135
at::cuda::CUDAStreamGuard guard (ag_stream_);
82
136
int world_size = process_group_->getSize ();
0 commit comments