diff --git a/torchrec/distributed/benchmark/benchmark_resharding_handler.py b/torchrec/distributed/benchmark/benchmark_resharding_handler.py new file mode 100644 index 000000000..2a5f48293 --- /dev/null +++ b/torchrec/distributed/benchmark/benchmark_resharding_handler.py @@ -0,0 +1,170 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import logging +import random +from typing import List, Optional + +import torch +import torch.distributed as dist +import torch.nn as nn +from torchrec.distributed.embeddingbag import EmbeddingBagCollection + +from torchrec.distributed.sharding.dynamic_sharding import output_sharding_plan_delta + +from torchrec.distributed.sharding_plan import ( + column_wise, + construct_module_sharding_plan, + table_wise, +) + +from torchrec.distributed.test_utils.test_sharding import generate_rank_placements +from torchrec.distributed.types import EmbeddingModuleShardingPlan + +logger: logging.Logger = logging.getLogger(__name__) + + +class ReshardingHandler: + """ + Handles the resharding of a training module by generating and applying different sharding plans. + """ + + def __init__(self, train_module: nn.Module, num_plans: int) -> None: + """ + Initializes the ReshardingHandler with a training module and the number of sharding plans to generate. + + Args: + train_module (nn.Module): The training module to be resharded. + num_plans (int): The number of sharding plans to generate. + """ + self._train_module = train_module + if not hasattr(train_module, "_module"): + raise RuntimeError("Incorrect train module") + + if not hasattr(train_module._module, "plan"): + raise RuntimeError("sharding plan cannot be found") + + # Pyre-ignore + plan = train_module._module.plan.plan + key = "main_module.sparse_arch.embedding_bag_collection" + module = ( + # Pyre-ignore + train_module._module.module.main_module.sparse_arch.embedding_bag_collection + ) + self._resharding_plans: List[EmbeddingModuleShardingPlan] = [] + world_size = dist.get_world_size() + + if key in plan: + ebc = plan[key] + num_tables = len(ebc) + ranks_per_tables = [1 for _ in range(num_tables)] + ranks_per_tables_for_CW = [] + for index, table_config in enumerate(module._embedding_bag_configs): + # CW sharding + valid_candidates = [ + i + for i in range(1, world_size + 1) + if table_config.embedding_dim % i == 0 + ] + rng = random.Random(index) + ranks_per_tables_for_CW.append(rng.choice(valid_candidates)) + + for i in range(num_plans): + new_ranks = generate_rank_placements( + world_size, num_tables, ranks_per_tables, i + ) + new_ranks_cw = generate_rank_placements( + world_size, num_tables, ranks_per_tables_for_CW, i + ) + new_per_param_sharding = {} + for i, (talbe_id, param) in enumerate(ebc.items()): + if param.sharding_type == "column_wise": + cw_gen = column_wise( + ranks=new_ranks_cw[i], + compute_kernel=param.compute_kernel, + ) + new_per_param_sharding[talbe_id] = cw_gen + else: + tw_gen = table_wise( + rank=new_ranks[i][0], + compute_kernel=param.compute_kernel, + ) + new_per_param_sharding[talbe_id] = tw_gen + + lightweight_ebc = EmbeddingBagCollection( + tables=module._embedding_bag_configs, + device=torch.device( + "meta" + ), # Use meta device to avoid actual memory allocation + ) + + meta_device = torch.device("meta") + new_plan = construct_module_sharding_plan( + lightweight_ebc, + per_param_sharding=new_per_param_sharding, + local_size=world_size, + world_size=world_size, + # Pyre-ignore + device_type=meta_device, + ) + self._resharding_plans.append(new_plan) + else: + raise RuntimeError(f"Plan does not have key: {key}") + + def step(self, batch_no: int) -> float: + """ + Executes a step in the training process by selecting and applying a sharding plan. + + Args: + batch_no (int): The current batch number. + + Returns: + float: The data volume of the sharding plan delta. + """ + # Pyre-ignore + plan = self._train_module._module.plan.plan + key = "main_module.sparse_arch.embedding_bag_collection" + + # Use the current step as a seed to ensure all ranks get the same random number + # but it changes on each call + + rng = random.Random(batch_no) + index = rng.randint(0, len(self._resharding_plans) - 1) + logger.info(f"Selected resharding plan index {index} for step {batch_no}") + # Get the selected plan + selected_plan = self._resharding_plans[index] + + # Fix the device mismatch by updating the placement device in the sharding spec + # This is necessary because the plan was created with meta device but needs to be applied on CUDA + for _, param_sharding in selected_plan.items(): + sharding_spec = param_sharding.sharding_spec + if sharding_spec is not None: + # pyre-ignore + for shard in sharding_spec.shards: + placement = shard.placement + rank: Optional[int] = placement.rank() + assert rank is not None + current_device = ( + torch.cuda.current_device() + if rank == torch.distributed.get_rank() + else rank % torch.cuda.device_count() + ) + shard.placement = torch.distributed._remote_device( + f"rank:{rank}/cuda:{current_device}" + ) + + data_volume, delta_plan = output_sharding_plan_delta( + plan[key], selected_plan, True + ) + # Pyre-ignore + self._train_module.module.reshard( + sharded_module_fqn="main_module.sparse_arch.embedding_bag_collection", + changed_shard_to_params=delta_plan, + ) + return data_volume diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 4fdaac60c..1551b585f 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -57,10 +57,11 @@ from torchrec.distributed.sharding.dp_sharding import DpPooledEmbeddingSharding from torchrec.distributed.sharding.dynamic_sharding import ( get_largest_dims_from_sharding_plan_updates, + move_sharded_tensors_to_cpu, shards_all_to_all, update_module_sharding_plan, update_optimizer_state_post_resharding, - update_state_dict_post_resharding, + update_state_post_resharding, ) from torchrec.distributed.sharding.grid_sharding import GridPooledEmbeddingSharding from torchrec.distributed.sharding.rw_sharding import RwPooledEmbeddingSharding @@ -1377,6 +1378,25 @@ def _init_mean_pooling_callback( device=self._device, ) + def _purge_lookups(self) -> None: + # Purge old lookups + for lookup in self._lookups: + # Call purge method if available (for TBE modules) + if hasattr(lookup, "purge") and callable(lookup.purge): + # Pyre-ignore + lookup.purge() + + # For DDP modules, get the underlying module + while isinstance(lookup, DistributedDataParallel): + lookup = lookup.module + if hasattr(lookup, "purge") and callable(lookup.purge): + lookup.purge() + + # Clear the lookups list + self._lookups.clear() + # Force garbage collection to free memory + torch.cuda.empty_cache() + def _create_lookups( self, ) -> None: @@ -1723,12 +1743,13 @@ def update_shards( env (ShardingEnv): The sharding environment for the module. device (Optional[torch.device]): The device to place the updated module on. """ - if env.output_dtensor: raise RuntimeError("We do not yet support DTensor for resharding yet") return current_state = self.state_dict() + current_state = move_sharded_tensors_to_cpu(current_state) + # TODO: improve, checking one would be enough has_local_optimizer = len(self._optim._optims) > 0 and all( len(i) > 0 for i in self._optim.state_dict()["state"].values() ) @@ -1740,22 +1761,18 @@ def update_shards( has_optimizer = self._is_optimizer_enabled(has_local_optimizer, env, device) - # TODO: Saving lookups tensors to CPU to eventually avoid recreating them completely again - # TODO: Ensure lookup tensors are actually being deleted - for _, lookup in enumerate(self._lookups): - # pyre-ignore - lookup.purge() - - # Deleting all lookups - self._lookups.clear() + # TODO: make sure this is clearing all lookups + self._purge_lookups() # Get max dim size to enable padding for all_to_all max_dim_0, max_dim_1 = get_largest_dims_from_sharding_plan_updates( changed_sharding_params ) old_optimizer_state = self._optim.state_dict() if has_local_optimizer else None + if old_optimizer_state is not None: + move_sharded_tensors_to_cpu(old_optimizer_state) - local_shard_names_by_src_rank, local_output_tensor = shards_all_to_all( + local_shard_names_by_src_rank, local_output_tensor_cpu = shards_all_to_all( module=self, state_dict=current_state, device=device, # pyre-ignore @@ -1832,22 +1849,21 @@ def update_shards( if has_optimizer: optimizer_state = update_optimizer_state_post_resharding( old_opt_state=old_optimizer_state, # pyre-ignore - new_opt_state=copy.deepcopy(self._optim.state_dict()), + new_opt_state=self._optim.state_dict(), ordered_shard_names_and_lengths=local_shard_names_by_src_rank, - output_tensor=local_output_tensor, + output_tensor=local_output_tensor_cpu, max_dim_0=max_dim_0, extend_shard_name=self.extend_shard_name, ) self._optim.load_state_dict(optimizer_state) - current_state = update_state_dict_post_resharding( - state_dict=current_state, + new_state = self.state_dict() + current_state = update_state_post_resharding( + old_state=current_state, + new_state=new_state, ordered_shard_names_and_lengths=local_shard_names_by_src_rank, - output_tensor=local_output_tensor, - new_sharding_params=changed_sharding_params, - curr_rank=dist.get_rank(), + output_tensor=local_output_tensor_cpu, extend_shard_name=self.extend_shard_name, - max_dim_0=max_dim_0, has_optimizer=has_optimizer, ) diff --git a/torchrec/distributed/sharding/dynamic_sharding.py b/torchrec/distributed/sharding/dynamic_sharding.py index 6c5e12e83..a4de9835c 100644 --- a/torchrec/distributed/sharding/dynamic_sharding.py +++ b/torchrec/distributed/sharding/dynamic_sharding.py @@ -13,7 +13,6 @@ import torch import torch.distributed as dist import torch.nn.functional as F -from torch.distributed._shard.sharded_tensor import Shard from torchrec.distributed.types import ( EmbeddingModuleShardingPlan, ParameterSharding, @@ -102,7 +101,7 @@ def _generate_shard_allocation_metadata( destination_params: ParameterSharding, rank: int, world_size: int, -) -> Dict[int, List[Tuple[int, int, int, List[int]]]]: +) -> ShardToRankMapping: """ Generates a mapping of shards to ranks for redistribution of data. @@ -120,7 +119,7 @@ def _generate_shard_allocation_metadata( Dict[int, List[Tuple[int, List[int]]]]: A dictionary mapping source ranks to a list of tuples, where each tuple contains a destination rank and the corresponding shard offsets. """ - shard_to_rank_mapping: Dict[int, List[Tuple[int, int, int, List[int]]]] = {} + shard_to_rank_mapping: ShardToRankMapping = {} src_rank_index = 0 dst_rank_index = 0 curr_source_offset = 0 @@ -129,7 +128,6 @@ def _generate_shard_allocation_metadata( assert source_params.ranks is not None assert destination_params.ranks is not None - assert source_params.sharding_spec is not None assert destination_params.sharding_spec is not None @@ -244,20 +242,20 @@ def _process_shard_redistribution_metadata( # Update the shard size with new size shard_size = [shard_size[0], split_offsets[1] - split_offsets[0], shard_id] # Update split sizes for communication - input_splits_per_rank[src_rank][dst_rank] += max_dim_0 - output_splits_per_rank[dst_rank][src_rank] += max_dim_0 + input_splits_per_rank[src_rank][dst_rank] += shard_size[0] + output_splits_per_rank[dst_rank][src_rank] += shard_size[0] # Process data being sent from current rank if src_rank == rank: # Handle optimizer state if present + extended_shard_name: str = extend_shard_name(shard_name) if has_local_optimizer: + momentun_name = tmp_momentum_extender(shard_name) + # Pyre-ignore local_optimizer_shards = optimizer_state["state"][ - extend_shard_name(shard_name) - ][tmp_momentum_extender(shard_name)].local_shards() - # assert ( - # len(local_optimizer_shards) == 1 - # ), "Expected exactly one local optimizer shard" + extended_shard_name + ][momentun_name].local_shards() local_optimizer_tensor = local_optimizer_shards[ local_shard_id @@ -272,21 +270,23 @@ def _process_shard_redistribution_metadata( local_optimizer_tensor = local_optimizer_tensor[ :, split_offsets[0] : split_offsets[1] ] + padded_optimizer_tensor = pad_tensor_to_max_dims( - local_optimizer_tensor, max_dim_0, max_dim_1 + local_optimizer_tensor, shard_size[0], max_dim_1 ) local_table_to_opt_by_dst_rank[dst_rank].append( padded_optimizer_tensor ) - input_splits_per_rank[src_rank][dst_rank] += max_dim_0 - # Handle main tensor data - local_shard = sharded_tensor.local_shards()[local_shard_id] + input_splits_per_rank[src_rank][dst_rank] += shard_size[0] + + local_shards = sharded_tensor.local_shards() + + local_tensor = local_shards[local_shard_id].tensor - # cut the tensor based on split points - dst_t = local_shard.tensor[:, split_offsets[0] : split_offsets[1]] + dst_t = local_tensor[:, split_offsets[0] : split_offsets[1]] - padded_tensor = pad_tensor_to_max_dims(dst_t, max_dim_0, max_dim_1) + padded_tensor = pad_tensor_to_max_dims(dst_t, shard_size[0], max_dim_1) local_table_to_input_tensor_by_dst_rank[dst_rank].append(padded_tensor) # Process data being received at current rank @@ -294,10 +294,12 @@ def _process_shard_redistribution_metadata( shard_names_to_lengths_by_src_rank[src_rank].append( (shard_name, shard_size) ) - output_tensor_count += max_dim_0 + + output_tensor_count += shard_size[0] if has_optimizer: - output_optimizer_count += max_dim_0 - output_splits_per_rank[dst_rank][src_rank] += max_dim_0 + + output_optimizer_count += shard_size[0] + output_splits_per_rank[dst_rank][src_rank] += shard_size[0] return output_tensor_count, output_optimizer_count @@ -305,10 +307,9 @@ def _process_shard_redistribution_metadata( def _create_local_shard_tensors( ordered_shard_names_and_lengths: OrderedShardNamesWithSizes, output_tensor: torch.Tensor, - max_dim_0: int, has_optimizer: bool = False, optimizer_mode: bool = False, - new_opt_state_state: Optional[Dict[str, Dict[str, ShardedTensor]]] = None, + new_state: Optional[Dict[str, Dict[str, ShardedTensor]]] = None, extend_shard_name: Optional[Callable[[str], str]] = None, ) -> Dict[str, List[torch.Tensor]]: """ @@ -322,7 +323,6 @@ def _create_local_shard_tensors( ordered_shard_names_and_lengths (OrderedShardNamesWithSizes): A list of tuples containing shard names and their corresponding sizes. output_tensor (torch.Tensor): The tensor containing all shards received by the current rank. - max_dim_0 (int): The maximum dimension size of dim 0 for slicing the output tensor. Returns: Dict[str, torch.Tensor]: A dictionary mapping shard names to their corresponding local output tensors. @@ -335,42 +335,52 @@ def _create_local_shard_tensors( shard_name_to_local_output_tensor: Dict[str, List[torch.Tensor]] = {} - slice_index = 0 if not optimizer_mode else max_dim_0 - step_size = max_dim_0 + slice_index = 0 splitted_shards_with_names: Dict[str, List[Tuple[int, torch.Tensor]]] = {} - - for shard_name, shard_size in ordered_shard_names_and_lengths: - + for i, (shard_name, shard_size) in enumerate(ordered_shard_names_and_lengths): + if i == 0: + slice_index = 0 if not optimizer_mode else shard_size[0] shard_id = shard_size[2] - end_slice_index = slice_index + step_size + end_slice_index = slice_index + shard_size[0] cur_t = output_tensor[slice_index:end_slice_index] - cur_t = pad_tensor_to_max_dims(cur_t, shard_size[0], shard_size[1]) + cur_t = cur_t[: shard_size[0], : shard_size[1]] extended_shard_name = ( extend_shard_name(shard_name) if extend_shard_name else shard_name ) - new_opt_state_state = new_opt_state_state if new_opt_state_state else {} + new_state = new_state if new_state else {} momentum_name = tmp_momentum_extender(shard_name) if ( optimizer_mode - and new_opt_state_state is not None - and extended_shard_name in new_opt_state_state.keys() + and new_state is not None + and extended_shard_name in new_state.keys() ): - sharded_t = new_opt_state_state[extended_shard_name][momentum_name] + sharded_t = new_state[extended_shard_name][momentum_name] assert len(sharded_t._local_shards) == 1 if len(sharded_t._local_shards[0].tensor.size()) == 1: - cur_t = cur_t * shard_size[1] # Supporting RowWise Adagrad operation + cur_t.mul_(shard_size[1]) # Supporting RowWise Adagrad operation if shard_name not in splitted_shards_with_names: splitted_shards_with_names[shard_name] = [(shard_id, cur_t)] else: splitted_shards_with_names[shard_name].append((shard_id, cur_t)) + slice_index = ( - end_slice_index if not has_optimizer else end_slice_index + max_dim_0 + end_slice_index + if not has_optimizer + else ( + end_slice_index + shard_size[0] + if not optimizer_mode + else ( + end_slice_index + ordered_shard_names_and_lengths[i + 1][1][0] + if i < len(ordered_shard_names_and_lengths) - 1 + else end_slice_index + ) + ) ) # Assuming splitted_shards_with_names is already populated @@ -396,6 +406,44 @@ def _create_local_shard_tensors( return shard_name_to_local_output_tensor +def move_sharded_tensors_to_cpu(state_dict: Dict[str, Any]) -> Dict[str, Any]: + """ + Recursively traverse a state dictionary and move all local shard tensors to CPU. + This helps reduce GPU memory usage by keeping tensors in CPU memory until needed. + + Args: + state_dict: The state dictionary to traverse (can be model or optimizer state dict) + + Returns: + The modified state dictionary with tensors moved to CPU + """ + + # Pyre-ignore + def _process_item(item: Any) -> Any: + if isinstance(item, ShardedTensor): + # For ShardedTensor, move all local shards to CPU + for shard in item.local_shards(): + if shard.tensor.device.type == "cuda": + shard.tensor = shard.tensor.cpu() + return item + elif isinstance(item, dict): + # Recursively process dictionaries + return {k: _process_item(v) for k, v in item.items()} + elif isinstance(item, list): + # Recursively process lists + return [_process_item(v) for v in item] + elif isinstance(item, tuple): + # Recursively process tuples + return tuple(_process_item(v) for v in item) + else: + # Return other types unchanged + return item + + processed_dict = _process_item(state_dict) + torch.cuda.empty_cache() + return processed_dict + + def shards_all_to_all( module: ShardedModule[Any, Any, Any, Any], # pyre-ignore state_dict: Dict[str, ShardedTensor], @@ -507,48 +555,79 @@ def shards_all_to_all( local_input_splits = input_splits_per_rank[rank] local_output_splits = output_splits_per_rank[rank] - local_input_tensor = torch.empty([0, max_dim_1], device=device) + total_input_size = sum(local_input_splits) + local_input_tensor = torch.zeros( + [total_input_size, max_dim_1], device=torch.device("cpu") + ) + + current_pos = 0 for i, sub_l in enumerate(local_table_to_input_tensor_by_dst_rank): - for j, shard_info in enumerate(sub_l): - local_input_tensor = torch.cat( - ( - local_input_tensor, - shard_info, - ), - dim=0, - ) + batch_size = local_input_splits[i] + if batch_size == 0: + continue + + # Create a view into the pre-allocated tensor for this destination rank + batch_view = local_input_tensor[current_pos : current_pos + batch_size] + + current_row = 0 + for j, shard_info_cpu in enumerate(sub_l): + if shard_info_cpu is not None: + rows = shard_info_cpu.size(0) + + # Copying data to input tensor uvm->uvm operation + batch_view[current_row : current_row + rows] = shard_info_cpu + current_row += rows + + # Free CPU memory by removing reference + local_table_to_input_tensor_by_dst_rank[i][j] = None if has_local_optimizer: - shard_info = local_table_to_opt_by_dst_rank[i][j] - local_input_tensor = torch.cat( - ( - local_input_tensor, - shard_info, - ), - dim=0, - ) + opt_shard_info_cpu = local_table_to_opt_by_dst_rank[i][j] + if opt_shard_info_cpu is not None: + opt_rows = opt_shard_info_cpu.size(0) + + batch_view[current_row : current_row + opt_rows] = ( + opt_shard_info_cpu + ) + current_row += opt_rows + + # Free CPU memory by removing reference + local_table_to_opt_by_dst_rank[i][j] = None + + # Move position pointer forward + current_pos += batch_size receive_count = output_tensor_tensor_count + output_optimizer_tensor_count max_embedding_size = max_dim_1 - local_output_tensor = torch.empty( + local_output_tensor_cpu = torch.empty( [ receive_count, max_embedding_size, ], - device=device, + device=torch.device("cpu"), ) - assert sum(local_output_splits) == len(local_output_tensor) + assert sum(local_output_splits) == len(local_output_tensor_cpu) assert sum(local_input_splits) == len(local_input_tensor) + # TODO: move this to hireachical process creation if possible for scaling beyond 32 + if not hasattr(env, "cpu_process_group"): + # Create a CPU process group with Gloo backend + # Pyre-ignore + env.cpu_process_group = dist.new_group( + ranks=list(range(env.world_size)), + backend="gloo", # Use Gloo backend for CPU operations + ) + dist.all_to_all_single( - output=local_output_tensor, + output=local_output_tensor_cpu, input=local_input_tensor, output_split_sizes=local_output_splits, input_split_sizes=local_input_splits, - group=env.process_group, # TODO: 2D uses env.sharding_pg + group=env.cpu_process_group, # TODO: 2D uses env.sharding_pg ) + del local_input_tensor flattened_output_names_lengths = [ shard_info @@ -556,16 +635,14 @@ def shards_all_to_all( for shard_info in sub_l ] - return flattened_output_names_lengths, local_output_tensor + return flattened_output_names_lengths, local_output_tensor_cpu -def update_state_dict_post_resharding( - state_dict: Dict[str, ShardedTensor], +def update_state_post_resharding( + old_state: Dict[str, ShardedTensor], + new_state: Dict[str, ShardedTensor], ordered_shard_names_and_lengths: OrderedShardNamesWithSizes, output_tensor: torch.Tensor, - new_sharding_params: Dict[str, ParameterSharding], - curr_rank: int, - max_dim_0: int, extend_shard_name: Callable[[str], str] = lambda x: x, has_optimizer: bool = False, ) -> Dict[str, ShardedTensor]: @@ -601,40 +678,40 @@ def update_state_dict_post_resharding( _create_local_shard_tensors( ordered_shard_names_and_lengths, output_tensor, - max_dim_0, has_optimizer=has_optimizer, optimizer_mode=False, ) ) - for shard_name, param in new_sharding_params.items(): - extended_name = extend_shard_name(shard_name) - sharded_t = state_dict[extended_name] - sharded_t.metadata().shards_metadata.clear() + for extended_shard_name, item in new_state.items(): + shard_name = extract_shard_name(extended_shard_name) + if ( + old_state is not None + and extended_shard_name in old_state + and shard_name not in shard_name_to_local_output_tensor.keys() + ): - # pyre-ignore - for i in range(len(param.ranks)): - # pyre-ignore - r = param.ranks[i] - - # Update placements - # pyre-ignore - sharded_t.metadata().shards_metadata.append(param.sharding_spec.shards[i]) - # Update local shards - if r == curr_rank: - assert len(output_tensor) > 0 - # slice output tensor for correct size. - sharded_t._local_shards = [ - Shard( - tensor=shard_name_to_local_output_tensor[shard_name][0], - metadata=param.sharding_spec.shards[i], - ) - ] - break - else: - sharded_t._local_shards = [] + sharded_t = new_state[extended_shard_name] + sharded_t_old = old_state[extended_shard_name] + + local_shards = sharded_t._local_shards + for i, shard in enumerate(local_shards): + shard.tensor.copy_( + sharded_t_old._local_shards[i].tensor, non_blocking=True + ) + shard.metadata = sharded_t_old._local_shards[i].metadata + else: + + sharded_t = item + assert len(sharded_t._local_shards) == 1 + # local_tensor is updated in-pace for CW sharding + local_tensor = shard_name_to_local_output_tensor[shard_name][0] - return state_dict + for i, shard in enumerate(sharded_t._local_shards): + shard.tensor.copy_(local_tensor, non_blocking=True) + shard.metadata = sharded_t._local_shards[i].metadata + + return new_state def update_optimizer_state_post_resharding( @@ -653,18 +730,19 @@ def update_optimizer_state_post_resharding( _create_local_shard_tensors( ordered_shard_names_and_lengths, output_tensor, - max_dim_0, has_optimizer=True, optimizer_mode=True, - new_opt_state_state=new_opt_state_state, + new_state=new_opt_state_state, extend_shard_name=extend_shard_name, ) ) + if new_opt_state_state is None or len(new_opt_state_state) == 0: return new_opt_state for extended_shard_name, item in new_opt_state_state.items(): shard_name = extract_shard_name(extended_shard_name) + momentum_name = tmp_momentum_extender(shard_name) if ( old_opt_state_state is not None @@ -672,11 +750,16 @@ def update_optimizer_state_post_resharding( and shard_name not in shard_name_to_local_output_tensor.keys() ): - new_opt_state_state[extended_shard_name] = old_opt_state_state[ - extended_shard_name - ] + sharded_t = new_opt_state_state[extended_shard_name][momentum_name] + sharded_t_old = old_opt_state_state[extended_shard_name][momentum_name] + local_shards = sharded_t._local_shards + for i, shard in enumerate(local_shards): + shard.tensor.copy_( + sharded_t_old._local_shards[i].tensor, non_blocking=True + ) + shard.metadata = sharded_t_old._local_shards[i].metadata else: - momentum_name = tmp_momentum_extender(shard_name) + sharded_t = item[momentum_name] assert len(sharded_t._local_shards) == 1 # local_tensor is updated in-pace for CW sharding @@ -688,15 +771,10 @@ def update_optimizer_state_post_resharding( squared_sum_t = torch.sum(local_tensor, dim=1, keepdim=True) mean_squared_sum_t = torch.div(squared_sum_t, local_tensor_dim) local_tensor = mean_squared_sum_t.T[0] - sharded_t._local_shards = [ - Shard( - tensor=local_tensor, - metadata=shard.metadata, - ) - for shard in sharded_t._local_shards - ] - item[momentum_name] = sharded_t - new_opt_state_state[extended_shard_name] = item + + for i, shard in enumerate(sharded_t._local_shards): + shard.tensor.copy_(local_tensor, non_blocking=True) + shard.metadata = sharded_t._local_shards[i].metadata return new_opt_state @@ -776,6 +854,8 @@ def pad_tensor_to_max_dims( Returns: torch.Tensor: The padded tensor. """ + if expected_dim_0 == t.size(0) and expected_dim_1 == t.size(1): + return t pad_right = expected_dim_1 - t.size(1) pad_bottom = expected_dim_0 - t.size(0) pad = (0, pad_right, 0, pad_bottom) @@ -789,8 +869,10 @@ def pad_tensor_to_max_dims( # Utils def output_sharding_plan_delta( - old_plan: EmbeddingModuleShardingPlan, new_plan: EmbeddingModuleShardingPlan -) -> EmbeddingModuleShardingPlan: + old_plan: EmbeddingModuleShardingPlan, + new_plan: EmbeddingModuleShardingPlan, + return_data_volume: bool = False, +) -> Tuple[float, EmbeddingModuleShardingPlan]: """ Compute and return a new sharding plan that is the delta between new and old embedding module plans. Assumes that the old and new plan @@ -800,13 +882,23 @@ def output_sharding_plan_delta( ParameterSharding or shards that needs to be moved. """ assert len(old_plan) == len(new_plan) - return EmbeddingModuleShardingPlan( + diff = EmbeddingModuleShardingPlan( { k: copy.deepcopy(v) for k, v in new_plan.items() if v.ranks != old_plan[k].ranks } ) + data_volume: float = 0 + if return_data_volume: + for _, v in diff.items(): + # Pyre-ignore + for shard in v.sharding_spec.shards: + data_volume += ( + shard.shard_sizes[0] * shard.shard_sizes[1] * 4 / (1024 * 1024) + ) # Asumming float datatype + + return (data_volume, diff) """ diff --git a/torchrec/distributed/test_utils/test_sharding.py b/torchrec/distributed/test_utils/test_sharding.py index ff72f9fa2..c955249d2 100644 --- a/torchrec/distributed/test_utils/test_sharding.py +++ b/torchrec/distributed/test_utils/test_sharding.py @@ -590,7 +590,7 @@ def dynamic_sharding_test( exclude_predfix="sparse.pooled_embedding_arch.embedding_modules._itp_iter", ) - new_module_sharding_plan_delta = output_sharding_plan_delta( + _, new_module_sharding_plan_delta = output_sharding_plan_delta( plan.plan["sparse.ebc"], new_module_sharding_plan # pyre-ignore ) diff --git a/torchrec/distributed/tests/test_dynamic_sharding.py b/torchrec/distributed/tests/test_dynamic_sharding.py index 92a7db0ce..5b1eefd97 100644 --- a/torchrec/distributed/tests/test_dynamic_sharding.py +++ b/torchrec/distributed/tests/test_dynamic_sharding.py @@ -284,7 +284,7 @@ def _test_ebc_resharding( device=ctx.device, ) - new_module_sharding_plan_delta = output_sharding_plan_delta( + _, new_module_sharding_plan_delta = output_sharding_plan_delta( module_sharding_plan, new_module_sharding_plan ) @@ -544,7 +544,7 @@ class MultiRankDMPDynamicShardingTest(ModelParallelTestShared): ), data_type=st.sampled_from([DataType.FP16, DataType.FP32]), random_seed=st.integers(0, 1000), - world_size=st.sampled_from([2, 4, 8]), + world_size=st.sampled_from([8]), ) @settings(verbosity=Verbosity.verbose, max_examples=8, deadline=None) def test_sharding( @@ -672,7 +672,7 @@ def test_output_sharding_plan_delta(self) -> None: device_type="cuda" if torch.cuda.is_available() else "cpu", ) - new_module_sharding_plan_delta = output_sharding_plan_delta( + _, new_module_sharding_plan_delta = output_sharding_plan_delta( module_sharding_plan, new_module_sharding_plan )