Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 22 additions & 18 deletions torchtitan/experiments/simple_fsdp/simple_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from collections.abc import Sequence
from collections.abc import Callable, Generator, Sequence
from contextlib import contextmanager
from dataclasses import dataclass
from typing import Any

import torch
import torch.nn as nn
Expand All @@ -33,7 +34,7 @@


@contextmanager
def disable_active_parametrization():
def disable_active_parametrization() -> Generator[None, None, None]:
global _active_parametrization
try:
_active_parametrization = False
Expand Down Expand Up @@ -183,9 +184,11 @@ def _register_parametrization(
module.__class__ = module_cls


def fsdp_policy():
def _fsdp_recomp_policy():
def _custom_policy(ctx, func, *args, **kwargs):
def fsdp_policy() -> tuple[Callable[..., Any], Callable[..., Any]]:
def _fsdp_recomp_policy() -> Callable[..., CheckpointPolicy]:
def _custom_policy(
ctx: Any, func: Any, *args: Any, **kwargs: Any
) -> CheckpointPolicy:
to_recompute = func in {
torch.ops._c10d_functional.all_gather_into_tensor.default,
torch.ops._c10d_functional.wait_tensor.default,
Expand All @@ -205,20 +208,20 @@ def _custom_policy(ctx, func, *args, **kwargs):
class ReplicateComputation(torch.nn.Module):
def __init__(
self,
device_mesh,
param_sharding,
mode,
regional_ac,
mp_policy,
reshard_after_forward,
reduction_divide_factor,
):
device_mesh: DeviceMesh,
param_sharding: tuple[Placement, ...],
mode: str,
regional_ac: bool,
mp_policy: MixedPrecisionPolicy | None,
reshard_after_forward: bool,
reduction_divide_factor: float | None,
) -> None:
super().__init__()
self.device_mesh = device_mesh
self.param_sharding = param_sharding
self.mode = mode
self.compute_placements = [Replicate()] * self.device_mesh.ndim
self.grad_placements = [
self.compute_placements: list[Placement] = [Replicate()] * self.device_mesh.ndim
self.grad_placements: list[Placement] = [
_ScaledPartial(
reduction_divide_factor=reduction_divide_factor,
)
Expand All @@ -227,8 +230,8 @@ def __init__(
] * self.device_mesh.ndim
self.regional_ac = regional_ac
mp_policy = mp_policy or MixedPrecisionPolicy()
self.param_dtype = mp_policy.param_dtype
self.reduce_dtype = mp_policy.reduce_dtype
self.param_dtype: torch.dtype | None = mp_policy.param_dtype
self.reduce_dtype: torch.dtype | None = mp_policy.reduce_dtype
self.reshard_after_forward = reshard_after_forward

def replicate_compute(self, x: DTensor) -> torch.Tensor:
Expand Down Expand Up @@ -319,7 +322,8 @@ def data_parallel(
reshard_after_forward: bool = True,
shard_dim: int = 0,
reduction_divide_factor: float | None = None,
):
) -> nn.Module:
param_sharding: tuple[Placement, ...]
if mode == "replicate":
param_sharding = (Replicate(),)
elif mode == "fully_shard":
Expand Down