Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions torchtitan/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,9 +434,7 @@ def _clip_grad_norm_with_ep(
if math.isinf(norm_type):
total_norm = torch.maximum(ep_grads_total_norm, non_ep_grads_total_norm)
else:
total_norm = (
ep_grads_total_norm**norm_type + non_ep_grads_total_norm**norm_type
)
total_norm = ep_grads_total_norm**norm_type + non_ep_grads_total_norm**norm_type
total_norm **= 1.0 / norm_type

if pp_mesh is not None:
Expand All @@ -462,9 +460,23 @@ def cp_shard(
order_sensitive_buffers_seq_dims: dict[str, int],
):
from torch.distributed.tensor.experimental._attention import _context_parallel_shard
from torch.distributed.tensor.experimental._load_balancer import (
_HeadTailLoadBalancer,
_PTRRLoadBalancer,
)
from torch.nn.attention.flex_attention import BlockMask

load_balancer = None
"""
seq_length = inputs.shape[1]
load_balancer = _HeadTailLoadBalancer(
seq_length, cp_mesh.size(0), cp_mesh.device_type
)

assert isinstance(attention_masks, BlockMask)
load_balancer = _PTRRLoadBalancer(attention_masks, cp_mesh.size(0))
"""

inputs, labels = _context_parallel_shard(
mesh=cp_mesh,
buffers=(inputs, labels),
Expand Down
Loading