Skip to content

Integrating resharding API to training pipeline #3289

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions torchrec/distributed/benchmark/benchmark_resharding_handler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

import logging
import random
from typing import List, Optional

import torch
import torch.distributed as dist
import torch.nn as nn
from torchrec.distributed.embeddingbag import EmbeddingBagCollection

from torchrec.distributed.sharding.dynamic_sharding import output_sharding_plan_delta

from torchrec.distributed.sharding_plan import (
column_wise,
construct_module_sharding_plan,
table_wise,
)

from torchrec.distributed.test_utils.test_sharding import generate_rank_placements
from torchrec.distributed.types import EmbeddingModuleShardingPlan

logger: logging.Logger = logging.getLogger(__name__)


class ReshardingHandler:
"""
Handles the resharding of a training module by generating and applying different sharding plans.
"""

def __init__(self, train_module: nn.Module, num_plans: int) -> None:
"""
Initializes the ReshardingHandler with a training module and the number of sharding plans to generate.

Args:
train_module (nn.Module): The training module to be resharded.
num_plans (int): The number of sharding plans to generate.
"""
self._train_module = train_module
if not hasattr(train_module, "_module"):
raise RuntimeError("Incorrect train module")

if not hasattr(train_module._module, "plan"):
raise RuntimeError("sharding plan cannot be found")

# Pyre-ignore
plan = train_module._module.plan.plan
key = "main_module.sparse_arch.embedding_bag_collection"
module = (
# Pyre-ignore
train_module._module.module.main_module.sparse_arch.embedding_bag_collection
)
self._resharding_plans: List[EmbeddingModuleShardingPlan] = []
world_size = dist.get_world_size()

if key in plan:
ebc = plan[key]
num_tables = len(ebc)
ranks_per_tables = [1 for _ in range(num_tables)]
ranks_per_tables_for_CW = []
for index, table_config in enumerate(module._embedding_bag_configs):
# CW sharding
valid_candidates = [
i
for i in range(1, world_size + 1)
if table_config.embedding_dim % i == 0
]
rng = random.Random(index)
ranks_per_tables_for_CW.append(rng.choice(valid_candidates))

for i in range(num_plans):
new_ranks = generate_rank_placements(
world_size, num_tables, ranks_per_tables, i
)
new_ranks_cw = generate_rank_placements(
world_size, num_tables, ranks_per_tables_for_CW, i
)
new_per_param_sharding = {}
for i, (talbe_id, param) in enumerate(ebc.items()):
if param.sharding_type == "column_wise":
cw_gen = column_wise(
ranks=new_ranks_cw[i],
compute_kernel=param.compute_kernel,
)
new_per_param_sharding[talbe_id] = cw_gen
else:
tw_gen = table_wise(
rank=new_ranks[i][0],
compute_kernel=param.compute_kernel,
)
new_per_param_sharding[talbe_id] = tw_gen

lightweight_ebc = EmbeddingBagCollection(
tables=module._embedding_bag_configs,
device=torch.device(
"meta"
), # Use meta device to avoid actual memory allocation
)

meta_device = torch.device("meta")
new_plan = construct_module_sharding_plan(
lightweight_ebc,
per_param_sharding=new_per_param_sharding,
local_size=world_size,
world_size=world_size,
# Pyre-ignore
device_type=meta_device,
)
self._resharding_plans.append(new_plan)
else:
raise RuntimeError(f"Plan does not have key: {key}")

def step(self, batch_no: int) -> float:
"""
Executes a step in the training process by selecting and applying a sharding plan.

Args:
batch_no (int): The current batch number.

Returns:
float: The data volume of the sharding plan delta.
"""
# Pyre-ignore
plan = self._train_module._module.plan.plan
key = "main_module.sparse_arch.embedding_bag_collection"

# Use the current step as a seed to ensure all ranks get the same random number
# but it changes on each call

rng = random.Random(batch_no)
index = rng.randint(0, len(self._resharding_plans) - 1)
logger.info(f"Selected resharding plan index {index} for step {batch_no}")
# Get the selected plan
selected_plan = self._resharding_plans[index]

# Fix the device mismatch by updating the placement device in the sharding spec
# This is necessary because the plan was created with meta device but needs to be applied on CUDA
for _, param_sharding in selected_plan.items():
sharding_spec = param_sharding.sharding_spec
if sharding_spec is not None:
# pyre-ignore
for shard in sharding_spec.shards:
placement = shard.placement
rank: Optional[int] = placement.rank()
assert rank is not None
current_device = (
torch.cuda.current_device()
if rank == torch.distributed.get_rank()
else rank % torch.cuda.device_count()
)
shard.placement = torch.distributed._remote_device(
f"rank:{rank}/cuda:{current_device}"
)

data_volume, delta_plan = output_sharding_plan_delta(
plan[key], selected_plan, True
)
# Pyre-ignore
self._train_module.module.reshard(
sharded_module_fqn="main_module.sparse_arch.embedding_bag_collection",
changed_shard_to_params=delta_plan,
)
return data_volume
54 changes: 35 additions & 19 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,11 @@
from torchrec.distributed.sharding.dp_sharding import DpPooledEmbeddingSharding
from torchrec.distributed.sharding.dynamic_sharding import (
get_largest_dims_from_sharding_plan_updates,
move_sharded_tensors_to_cpu,
shards_all_to_all,
update_module_sharding_plan,
update_optimizer_state_post_resharding,
update_state_dict_post_resharding,
update_state_post_resharding,
)
from torchrec.distributed.sharding.grid_sharding import GridPooledEmbeddingSharding
from torchrec.distributed.sharding.rw_sharding import RwPooledEmbeddingSharding
Expand Down Expand Up @@ -1377,6 +1378,25 @@ def _init_mean_pooling_callback(
device=self._device,
)

def _purge_lookups(self) -> None:
# Purge old lookups
for lookup in self._lookups:
# Call purge method if available (for TBE modules)
if hasattr(lookup, "purge") and callable(lookup.purge):
# Pyre-ignore
lookup.purge()

# For DDP modules, get the underlying module
while isinstance(lookup, DistributedDataParallel):
lookup = lookup.module
if hasattr(lookup, "purge") and callable(lookup.purge):
lookup.purge()

# Clear the lookups list
self._lookups.clear()
# Force garbage collection to free memory
torch.cuda.empty_cache()

def _create_lookups(
self,
) -> None:
Expand Down Expand Up @@ -1723,12 +1743,13 @@ def update_shards(
env (ShardingEnv): The sharding environment for the module.
device (Optional[torch.device]): The device to place the updated module on.
"""

if env.output_dtensor:
raise RuntimeError("We do not yet support DTensor for resharding yet")
return

current_state = self.state_dict()
current_state = move_sharded_tensors_to_cpu(current_state)
# TODO: improve, checking one would be enough
has_local_optimizer = len(self._optim._optims) > 0 and all(
len(i) > 0 for i in self._optim.state_dict()["state"].values()
)
Expand All @@ -1740,22 +1761,18 @@ def update_shards(

has_optimizer = self._is_optimizer_enabled(has_local_optimizer, env, device)

# TODO: Saving lookups tensors to CPU to eventually avoid recreating them completely again
# TODO: Ensure lookup tensors are actually being deleted
for _, lookup in enumerate(self._lookups):
# pyre-ignore
lookup.purge()

# Deleting all lookups
self._lookups.clear()
# TODO: make sure this is clearing all lookups
self._purge_lookups()

# Get max dim size to enable padding for all_to_all
max_dim_0, max_dim_1 = get_largest_dims_from_sharding_plan_updates(
changed_sharding_params
)
old_optimizer_state = self._optim.state_dict() if has_local_optimizer else None
if old_optimizer_state is not None:
move_sharded_tensors_to_cpu(old_optimizer_state)

local_shard_names_by_src_rank, local_output_tensor = shards_all_to_all(
local_shard_names_by_src_rank, local_output_tensor_cpu = shards_all_to_all(
module=self,
state_dict=current_state,
device=device, # pyre-ignore
Expand Down Expand Up @@ -1832,22 +1849,21 @@ def update_shards(
if has_optimizer:
optimizer_state = update_optimizer_state_post_resharding(
old_opt_state=old_optimizer_state, # pyre-ignore
new_opt_state=copy.deepcopy(self._optim.state_dict()),
new_opt_state=self._optim.state_dict(),
ordered_shard_names_and_lengths=local_shard_names_by_src_rank,
output_tensor=local_output_tensor,
output_tensor=local_output_tensor_cpu,
max_dim_0=max_dim_0,
extend_shard_name=self.extend_shard_name,
)
self._optim.load_state_dict(optimizer_state)

current_state = update_state_dict_post_resharding(
state_dict=current_state,
new_state = self.state_dict()
current_state = update_state_post_resharding(
old_state=current_state,
new_state=new_state,
ordered_shard_names_and_lengths=local_shard_names_by_src_rank,
output_tensor=local_output_tensor,
new_sharding_params=changed_sharding_params,
curr_rank=dist.get_rank(),
output_tensor=local_output_tensor_cpu,
extend_shard_name=self.extend_shard_name,
max_dim_0=max_dim_0,
has_optimizer=has_optimizer,
)

Expand Down
Loading
Loading