Skip to content

Commit 120fc67

Browse files
committed
[SimpleFSDP] Add typing to simple_fsdp.py
Add typing, credit to Claude. ghstack-source-id: a4a76c7 Pull-Request: #2001
1 parent 865a04e commit 120fc67

File tree

1 file changed

+22
-18
lines changed

1 file changed

+22
-18
lines changed

torchtitan/experiments/simple_fsdp/simple_fsdp.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,10 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from collections.abc import Sequence
7+
from collections.abc import Callable, Generator, Sequence
88
from contextlib import contextmanager
99
from dataclasses import dataclass
10+
from typing import Any
1011

1112
import torch
1213
import torch.nn as nn
@@ -33,7 +34,7 @@
3334

3435

3536
@contextmanager
36-
def disable_active_parametrization():
37+
def disable_active_parametrization() -> Generator[None, None, None]:
3738
global _active_parametrization
3839
try:
3940
_active_parametrization = False
@@ -183,9 +184,11 @@ def _register_parametrization(
183184
module.__class__ = module_cls
184185

185186

186-
def fsdp_policy():
187-
def _fsdp_recomp_policy():
188-
def _custom_policy(ctx, func, *args, **kwargs):
187+
def fsdp_policy() -> tuple[Callable[..., Any], Callable[..., Any]]:
188+
def _fsdp_recomp_policy() -> Callable[..., CheckpointPolicy]:
189+
def _custom_policy(
190+
ctx: Any, func: Any, *args: Any, **kwargs: Any
191+
) -> CheckpointPolicy:
189192
to_recompute = func in {
190193
torch.ops._c10d_functional.all_gather_into_tensor.default,
191194
torch.ops._c10d_functional.wait_tensor.default,
@@ -205,20 +208,20 @@ def _custom_policy(ctx, func, *args, **kwargs):
205208
class ReplicateComputation(torch.nn.Module):
206209
def __init__(
207210
self,
208-
device_mesh,
209-
param_sharding,
210-
mode,
211-
regional_ac,
212-
mp_policy,
213-
reshard_after_forward,
214-
reduction_divide_factor,
215-
):
211+
device_mesh: DeviceMesh,
212+
param_sharding: tuple[Placement, ...],
213+
mode: str,
214+
regional_ac: bool,
215+
mp_policy: MixedPrecisionPolicy | None,
216+
reshard_after_forward: bool,
217+
reduction_divide_factor: float | None,
218+
) -> None:
216219
super().__init__()
217220
self.device_mesh = device_mesh
218221
self.param_sharding = param_sharding
219222
self.mode = mode
220-
self.compute_placements = [Replicate()] * self.device_mesh.ndim
221-
self.grad_placements = [
223+
self.compute_placements: list[Placement] = [Replicate()] * self.device_mesh.ndim
224+
self.grad_placements: list[Placement] = [
222225
_ScaledPartial(
223226
reduction_divide_factor=reduction_divide_factor,
224227
)
@@ -227,8 +230,8 @@ def __init__(
227230
] * self.device_mesh.ndim
228231
self.regional_ac = regional_ac
229232
mp_policy = mp_policy or MixedPrecisionPolicy()
230-
self.param_dtype = mp_policy.param_dtype
231-
self.reduce_dtype = mp_policy.reduce_dtype
233+
self.param_dtype: torch.dtype | None = mp_policy.param_dtype
234+
self.reduce_dtype: torch.dtype | None = mp_policy.reduce_dtype
232235
self.reshard_after_forward = reshard_after_forward
233236

234237
def replicate_compute(self, x: DTensor) -> torch.Tensor:
@@ -319,7 +322,8 @@ def data_parallel(
319322
reshard_after_forward: bool = True,
320323
shard_dim: int = 0,
321324
reduction_divide_factor: float | None = None,
322-
):
325+
) -> nn.Module:
326+
param_sharding: tuple[Placement, ...]
323327
if mode == "replicate":
324328
param_sharding = (Replicate(),)
325329
elif mode == "fully_shard":

0 commit comments

Comments
 (0)