@@ -59,7 +59,7 @@ def eager_attn_fwd(q, k, v, attn_bias, sinks, scale, dropout):
5959 attn_output = einops .rearrange (attn_output , 'b h s d -> b s h d' )
6060 attn_output = attn_output .contiguous ()
6161
62- return attn_output , None
62+ return attn_output , probs
6363
6464
6565# @torch.compile
@@ -71,20 +71,6 @@ def eager_attn_bwd(q, k, v, attn_bias, sinks, scale, dropout, attn_output, probs
7171 _k_T = einops .rearrange (k , 'b s h d -> b h s d' )
7272 _v_T = einops .rearrange (v , ' b s h d -> b h d s' )
7373
74- # recompute probs and slice attn_w from probs
75- if probs is None :
76- _q = einops .rearrange (q , 'b s h d -> b h s d' )
77- _k = einops .rearrange (k , 'b s h d -> b h d s' )
78- attn_w = torch .matmul (_q , _k ) * scale
79- attn_w = attn_w + attn_bias
80- if sinks is None :
81- logits = attn_w
82- else :
83- _sinks = sinks .reshape (1 , h , 1 , 1 ).expand (b , - 1 , sq , 1 )
84- logits = torch .cat ([attn_w , _sinks ], dim = - 1 )
85- probs = F .softmax (logits , dim = - 1 , dtype = logits .dtype )
86- del _q , _k , logits
87-
8874 if sinks is None :
8975 attn_w = probs
9076 else :
@@ -182,7 +168,7 @@ def forward(
182168
183169 nheads = q .shape [2 ]
184170 nheads_k = k .shape [2 ]
185- heads_k_stride = nheads_k
171+ heads_k_stride = 1
186172 assert nheads % nheads_k == 0 and nheads_k % heads_k_stride == 0
187173 outs = []
188174 probs = []
@@ -227,38 +213,30 @@ def forward(
227213 out = torch .cat (outs , dim = 2 )
228214 out = einops .rearrange (out , 'b s h d -> s b h d' )
229215
230- ctx .save_for_backward (q , k , v )
231- ctx .outs = outs
232- ctx .probs = probs
233- ctx .attention_mask = attention_mask
216+ ctx .save_for_backward (q , k , v , attention_mask , * outs , * probs )
234217 ctx .dropout = attention_dropout
235218 ctx .scale = softmax_scale
236- ctx .op = None
237- ctx .output_dtype = None
238- ctx .heads_k_stride = heads_k_stride
219+ ctx .heads_k_stride = heads_k_stride # TODO make it configurable
239220 ctx .pg = pg
240221
241222 return out
242223
243224 @staticmethod
244225 def backward (ctx , dout ):
245- q , k , v = ctx .saved_tensors
246- outs = ctx .outs
247- probs = ctx .probs
248- attention_mask = ctx .attention_mask
249- op = None
250- output_dtype = ctx .output_dtype
226+ q , k , v , attention_mask , * rest = ctx .saved_tensors
227+ nheads = q .shape [2 ]
228+ nheads_k = k .shape [2 ]
251229 heads_k_stride = ctx .heads_k_stride
252- pg = ctx .pg
230+ assert nheads_k % heads_k_stride == 0
231+ outs = rest [:nheads_k // heads_k_stride ]
232+ probs = rest [nheads_k // heads_k_stride :]
253233
234+ pg = ctx .pg
254235 cp_size = 1
255236 if pg is not None :
256237 cp_size = torch .distributed .get_world_size (pg )
257238 comm = AllGatherComm (group = pg )
258239
259- nheads = q .shape [2 ]
260- nheads_k = k .shape [2 ]
261-
262240 kv_buffer = torch .empty (
263241 (2 , k .shape [0 ] * cp_size , k .shape [1 ], heads_k_stride , k .shape [3 ]),
264242 dtype = k .dtype ,
0 commit comments