44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7- from collections .abc import Sequence
7+ from collections .abc import Callable , Generator , Sequence
88from contextlib import contextmanager
99from dataclasses import dataclass
10+ from typing import Any
1011
1112import torch
1213import torch .nn as nn
3334
3435
3536@contextmanager
36- def disable_active_parametrization ():
37+ def disable_active_parametrization () -> Generator [ None , None , None ] :
3738 global _active_parametrization
3839 try :
3940 _active_parametrization = False
@@ -183,9 +184,11 @@ def _register_parametrization(
183184 module .__class__ = module_cls
184185
185186
186- def fsdp_policy ():
187- def _fsdp_recomp_policy ():
188- def _custom_policy (ctx , func , * args , ** kwargs ):
187+ def fsdp_policy () -> tuple [Callable [..., Any ], Callable [..., Any ]]:
188+ def _fsdp_recomp_policy () -> Callable [..., CheckpointPolicy ]:
189+ def _custom_policy (
190+ ctx : Any , func : Any , * args : Any , ** kwargs : Any
191+ ) -> CheckpointPolicy :
189192 to_recompute = func in {
190193 torch .ops ._c10d_functional .all_gather_into_tensor .default ,
191194 torch .ops ._c10d_functional .wait_tensor .default ,
@@ -205,20 +208,20 @@ def _custom_policy(ctx, func, *args, **kwargs):
205208class ReplicateComputation (torch .nn .Module ):
206209 def __init__ (
207210 self ,
208- device_mesh ,
209- param_sharding ,
210- mode ,
211- regional_ac ,
212- mp_policy ,
213- reshard_after_forward ,
214- reduction_divide_factor ,
215- ):
211+ device_mesh : DeviceMesh ,
212+ param_sharding : tuple [ Placement , ...] ,
213+ mode : str ,
214+ regional_ac : bool ,
215+ mp_policy : MixedPrecisionPolicy | None ,
216+ reshard_after_forward : bool ,
217+ reduction_divide_factor : float | None ,
218+ ) -> None :
216219 super ().__init__ ()
217220 self .device_mesh = device_mesh
218221 self .param_sharding = param_sharding
219222 self .mode = mode
220- self .compute_placements = [Replicate ()] * self .device_mesh .ndim
221- self .grad_placements = [
223+ self .compute_placements : list [ Placement ] = [Replicate ()] * self .device_mesh .ndim
224+ self .grad_placements : list [ Placement ] = [
222225 _ScaledPartial (
223226 reduction_divide_factor = reduction_divide_factor ,
224227 )
@@ -227,8 +230,8 @@ def __init__(
227230 ] * self .device_mesh .ndim
228231 self .regional_ac = regional_ac
229232 mp_policy = mp_policy or MixedPrecisionPolicy ()
230- self .param_dtype = mp_policy .param_dtype
231- self .reduce_dtype = mp_policy .reduce_dtype
233+ self .param_dtype : torch . dtype | None = mp_policy .param_dtype
234+ self .reduce_dtype : torch . dtype | None = mp_policy .reduce_dtype
232235 self .reshard_after_forward = reshard_after_forward
233236
234237 def replicate_compute (self , x : DTensor ) -> torch .Tensor :
@@ -319,7 +322,8 @@ def data_parallel(
319322 reshard_after_forward : bool = True ,
320323 shard_dim : int = 0 ,
321324 reduction_divide_factor : float | None = None ,
322- ):
325+ ) -> nn .Module :
326+ param_sharding : tuple [Placement , ...]
323327 if mode == "replicate" :
324328 param_sharding = (Replicate (),)
325329 elif mode == "fully_shard" :
0 commit comments