@@ -208,8 +208,6 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
208
208
GGML_ASSERT (Q->ne [2 ] % K->ne [2 ] == 0 );
209
209
210
210
const int cc = ggml_cuda_info ().devices [device].cc ;
211
- const int warp_size = ggml_cuda_info ().devices [device].warp_size ;
212
- const enum ggml_prec prec = ggml_flash_attn_ext_get_prec (KQV);
213
211
214
212
switch (K->ne [0 ]) {
215
213
case 64 :
@@ -267,29 +265,31 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
267
265
return BEST_FATTN_KERNEL_NONE;
268
266
}
269
267
270
- const bool can_use_vector_kernel = Q->ne [0 ] <= 256 && Q->ne [0 ] % ( 2 *warp_size) == 0 ;
268
+ const bool can_use_vector_kernel = Q->ne [0 ] <= 256 && Q->ne [0 ] % 64 == 0 ;
271
269
272
270
// If Turing tensor cores available, use them except for some cases with batch size 1:
273
271
if (turing_mma_available (cc)) {
274
272
best_fattn_kernel best = BEST_FATTN_KERNEL_MMA_F16;
275
273
276
- if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) {
277
- if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne [1 ] == 1 && Q->ne [3 ] == 1 && !(gqa_ratio > 4 && K->ne [1 ] >= 8192 )) {
278
- best = BEST_FATTN_KERNEL_VEC;
279
- }
280
- } else {
281
- if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
282
- if (Q->ne [1 ] <= 2 ) {
274
+ if (can_use_vector_kernel) {
275
+ if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) {
276
+ if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne [1 ] == 1 && Q->ne [3 ] == 1 && !(gqa_ratio > 4 && K->ne [1 ] >= 8192 )) {
283
277
best = BEST_FATTN_KERNEL_VEC;
284
278
}
285
279
} else {
286
- if (Q->ne [1 ] == 1 ) {
287
- best = BEST_FATTN_KERNEL_VEC;
280
+ if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
281
+ if (Q->ne [1 ] <= 2 ) {
282
+ best = BEST_FATTN_KERNEL_VEC;
283
+ }
284
+ } else {
285
+ if (Q->ne [1 ] == 1 ) {
286
+ best = BEST_FATTN_KERNEL_VEC;
287
+ }
288
288
}
289
289
}
290
- }
291
- if ((gqa_ratio % 2 != 0 || !mask) && Q-> ne [ 1 ] == 1 ) {
292
- best = BEST_FATTN_KERNEL_VEC; // GQA-specific optimizations in the mma kernel do not apply.
290
+ if ((gqa_ratio % 2 != 0 || !mask) && Q-> ne [ 1 ] == 1 ) {
291
+ best = BEST_FATTN_KERNEL_VEC; // GQA-specific optimizations in the mma kernel do not apply.
292
+ }
293
293
}
294
294
295
295
return best;
0 commit comments