Skip to content

Commit 8585909

Browse files
committed
[Full DTensor] Add full_dtensor flag
When full_dtensor is True, the compute_placement will be preserved. This means that `to_local()` won't be called for fsdp only case. nD parallelism case (fsdp + tp) will error out as we have not implemented this case. This argument doesn't affect the current simple_fsdp. We have verified `full_dtensor=True` case with the full dtensor skleton PR, which will be published once it is ready. ghstack-source-id: 9f9efce Pull-Request: #2002
1 parent 120fc67 commit 8585909

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

torchtitan/experiments/simple_fsdp/simple_fsdp.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ def __init__(
215215
mp_policy: MixedPrecisionPolicy | None,
216216
reshard_after_forward: bool,
217217
reduction_divide_factor: float | None,
218+
full_dtensor: bool = False,
218219
) -> None:
219220
super().__init__()
220221
self.device_mesh = device_mesh
@@ -233,6 +234,7 @@ def __init__(
233234
self.param_dtype: torch.dtype | None = mp_policy.param_dtype
234235
self.reduce_dtype: torch.dtype | None = mp_policy.reduce_dtype
235236
self.reshard_after_forward = reshard_after_forward
237+
self.full_dtensor = full_dtensor
236238

237239
def replicate_compute(self, x: DTensor) -> torch.Tensor:
238240
# data parallel runtime replicate parameters and do local compute
@@ -242,6 +244,10 @@ def replicate_compute(self, x: DTensor) -> torch.Tensor:
242244
non_dp_mesh_dims = x._spec.mesh.ndim - self.device_mesh.ndim
243245
assert non_dp_mesh_dims <= 2, "Only DP + EP/TP/EP+TP is supported"
244246
if non_dp_mesh_dims > 0:
247+
if self.full_dtensor:
248+
raise NotImplementedError(
249+
"full_dtensor not implemented for nD parallelisms"
250+
)
245251
dp_mesh = self.device_mesh
246252
# re-wrap 2D DTensor to 1D DTensor on dp_mesh for efficient FSDP all-gather
247253
sharded_local_tensor = x.to_local()
@@ -277,7 +283,10 @@ def replicate_compute(self, x: DTensor) -> torch.Tensor:
277283
placements=self.compute_placements,
278284
forward_dtype=self.param_dtype,
279285
backward_dtype=self.reduce_dtype,
280-
).to_local(grad_placements=self.grad_placements)
286+
)
287+
288+
if not self.full_dtensor:
289+
output = output.to_local(grad_placements=self.grad_placements)
281290
else:
282291
raise AssertionError(
283292
f"Unsupported replicate compute on placement {x._spec.placements} for DTensor {x}"
@@ -322,6 +331,7 @@ def data_parallel(
322331
reshard_after_forward: bool = True,
323332
shard_dim: int = 0,
324333
reduction_divide_factor: float | None = None,
334+
full_dtensor: bool = False,
325335
) -> nn.Module:
326336
param_sharding: tuple[Placement, ...]
327337
if mode == "replicate":
@@ -387,6 +397,7 @@ def data_parallel(
387397
mp_policy=mp_policy,
388398
reshard_after_forward=reshard_after_forward,
389399
reduction_divide_factor=reduction_divide_factor,
400+
full_dtensor=full_dtensor,
390401
),
391402
)
392403
return model

0 commit comments

Comments
 (0)