Skip to content

Commit 10ddbd8

Browse files
committed
use save for bwd
1 parent 7fa3708 commit 10ddbd8

File tree

1 file changed

+11
-33
lines changed

1 file changed

+11
-33
lines changed

megatron/core/transformer/dot_product_attention.py

Lines changed: 11 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def eager_attn_fwd(q, k, v, attn_bias, sinks, scale, dropout):
5959
attn_output = einops.rearrange(attn_output, 'b h s d -> b s h d')
6060
attn_output = attn_output.contiguous()
6161

62-
return attn_output, None
62+
return attn_output, probs
6363

6464

6565
# @torch.compile
@@ -71,20 +71,6 @@ def eager_attn_bwd(q, k, v, attn_bias, sinks, scale, dropout, attn_output, probs
7171
_k_T = einops.rearrange(k, 'b s h d -> b h s d')
7272
_v_T = einops.rearrange(v, ' b s h d -> b h d s')
7373

74-
# recompute probs and slice attn_w from probs
75-
if probs is None:
76-
_q = einops.rearrange(q, 'b s h d -> b h s d')
77-
_k = einops.rearrange(k, 'b s h d -> b h d s')
78-
attn_w = torch.matmul(_q, _k) * scale
79-
attn_w = attn_w + attn_bias
80-
if sinks is None:
81-
logits = attn_w
82-
else:
83-
_sinks = sinks.reshape(1, h, 1, 1).expand(b, -1, sq, 1)
84-
logits = torch.cat([attn_w, _sinks], dim=-1)
85-
probs = F.softmax(logits, dim=-1, dtype=logits.dtype)
86-
del _q, _k, logits
87-
8874
if sinks is None:
8975
attn_w = probs
9076
else:
@@ -182,7 +168,7 @@ def forward(
182168

183169
nheads = q.shape[2]
184170
nheads_k = k.shape[2]
185-
heads_k_stride = nheads_k
171+
heads_k_stride = 1
186172
assert nheads % nheads_k == 0 and nheads_k % heads_k_stride == 0
187173
outs = []
188174
probs = []
@@ -227,38 +213,30 @@ def forward(
227213
out = torch.cat(outs, dim=2)
228214
out = einops.rearrange(out, 'b s h d -> s b h d')
229215

230-
ctx.save_for_backward(q, k, v)
231-
ctx.outs = outs
232-
ctx.probs = probs
233-
ctx.attention_mask = attention_mask
216+
ctx.save_for_backward(q, k, v, attention_mask, *outs, *probs)
234217
ctx.dropout = attention_dropout
235218
ctx.scale = softmax_scale
236-
ctx.op = None
237-
ctx.output_dtype = None
238-
ctx.heads_k_stride = heads_k_stride
219+
ctx.heads_k_stride = heads_k_stride # TODO make it configurable
239220
ctx.pg = pg
240221

241222
return out
242223

243224
@staticmethod
244225
def backward(ctx, dout):
245-
q, k, v = ctx.saved_tensors
246-
outs = ctx.outs
247-
probs = ctx.probs
248-
attention_mask = ctx.attention_mask
249-
op = None
250-
output_dtype = ctx.output_dtype
226+
q, k, v, attention_mask, *rest = ctx.saved_tensors
227+
nheads = q.shape[2]
228+
nheads_k = k.shape[2]
251229
heads_k_stride = ctx.heads_k_stride
252-
pg = ctx.pg
230+
assert nheads_k % heads_k_stride == 0
231+
outs = rest[:nheads_k // heads_k_stride]
232+
probs = rest[nheads_k // heads_k_stride:]
253233

234+
pg = ctx.pg
254235
cp_size = 1
255236
if pg is not None:
256237
cp_size = torch.distributed.get_world_size(pg)
257238
comm = AllGatherComm(group=pg)
258239

259-
nheads = q.shape[2]
260-
nheads_k = k.shape[2]
261-
262240
kv_buffer = torch.empty(
263241
(2, k.shape[0] * cp_size, k.shape[1], heads_k_stride, k.shape[3]),
264242
dtype=k.dtype,

0 commit comments

Comments
 (0)