diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index 838cfda9c..c80dd3bdb 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -434,9 +434,7 @@ def _clip_grad_norm_with_ep( if math.isinf(norm_type): total_norm = torch.maximum(ep_grads_total_norm, non_ep_grads_total_norm) else: - total_norm = ( - ep_grads_total_norm**norm_type + non_ep_grads_total_norm**norm_type - ) + total_norm = ep_grads_total_norm**norm_type + non_ep_grads_total_norm**norm_type total_norm **= 1.0 / norm_type if pp_mesh is not None: @@ -462,9 +460,23 @@ def cp_shard( order_sensitive_buffers_seq_dims: dict[str, int], ): from torch.distributed.tensor.experimental._attention import _context_parallel_shard + from torch.distributed.tensor.experimental._load_balancer import ( + _HeadTailLoadBalancer, + _PTRRLoadBalancer, + ) from torch.nn.attention.flex_attention import BlockMask load_balancer = None + """ + seq_length = inputs.shape[1] + load_balancer = _HeadTailLoadBalancer( + seq_length, cp_mesh.size(0), cp_mesh.device_type + ) + + assert isinstance(attention_masks, BlockMask) + load_balancer = _PTRRLoadBalancer(attention_masks, cp_mesh.size(0)) + """ + inputs, labels = _context_parallel_shard( mesh=cp_mesh, buffers=(inputs, labels),