@@ -198,7 +198,6 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
198
198
return BEST_FATTN_KERNEL_NONE;
199
199
#endif // FLASH_ATTN_AVAILABLE
200
200
201
- const ggml_tensor * KQV = dst;
202
201
const ggml_tensor * Q = dst->src [0 ];
203
202
const ggml_tensor * K = dst->src [1 ];
204
203
const ggml_tensor * V = dst->src [2 ];
@@ -208,8 +207,6 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
208
207
GGML_ASSERT (Q->ne [2 ] % K->ne [2 ] == 0 );
209
208
210
209
const int cc = ggml_cuda_info ().devices [device].cc ;
211
- const int warp_size = ggml_cuda_info ().devices [device].warp_size ;
212
- const enum ggml_prec prec = ggml_flash_attn_ext_get_prec (KQV);
213
210
214
211
switch (K->ne [0 ]) {
215
212
case 64 :
@@ -267,29 +264,31 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
267
264
return BEST_FATTN_KERNEL_NONE;
268
265
}
269
266
270
- const bool can_use_vector_kernel = Q->ne [0 ] <= 256 && Q->ne [0 ] % ( 2 *warp_size) == 0 ;
267
+ const bool can_use_vector_kernel = Q->ne [0 ] <= 256 && Q->ne [0 ] % 64 == 0 ;
271
268
272
269
// If Turing tensor cores available, use them except for some cases with batch size 1:
273
270
if (turing_mma_available (cc)) {
274
271
best_fattn_kernel best = BEST_FATTN_KERNEL_MMA_F16;
275
272
276
- if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) {
277
- if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne [1 ] == 1 && Q->ne [3 ] == 1 && !(gqa_ratio > 4 && K->ne [1 ] >= 8192 )) {
278
- best = BEST_FATTN_KERNEL_VEC;
279
- }
280
- } else {
281
- if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
282
- if (Q->ne [1 ] <= 2 ) {
273
+ if (can_use_vector_kernel) {
274
+ if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) {
275
+ if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne [1 ] == 1 && Q->ne [3 ] == 1 && !(gqa_ratio > 4 && K->ne [1 ] >= 8192 )) {
283
276
best = BEST_FATTN_KERNEL_VEC;
284
277
}
285
278
} else {
286
- if (Q->ne [1 ] == 1 ) {
287
- best = BEST_FATTN_KERNEL_VEC;
279
+ if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
280
+ if (Q->ne [1 ] <= 2 ) {
281
+ best = BEST_FATTN_KERNEL_VEC;
282
+ }
283
+ } else {
284
+ if (Q->ne [1 ] == 1 ) {
285
+ best = BEST_FATTN_KERNEL_VEC;
286
+ }
288
287
}
289
288
}
290
- }
291
- if ((gqa_ratio % 2 != 0 || !mask) && Q-> ne [ 1 ] == 1 ) {
292
- best = BEST_FATTN_KERNEL_VEC; // GQA-specific optimizations in the mma kernel do not apply.
289
+ if ((gqa_ratio % 2 != 0 || !mask) && Q-> ne [ 1 ] == 1 ) {
290
+ best = BEST_FATTN_KERNEL_VEC; // GQA-specific optimizations in the mma kernel do not apply.
291
+ }
293
292
}
294
293
295
294
return best;
0 commit comments