Skip to content

Commit 8ba0ff7

Browse files
fix kernel selection logic
1 parent e267903 commit 8ba0ff7

File tree

1 file changed

+15
-16
lines changed

1 file changed

+15
-16
lines changed

ggml/src/ggml-cuda/fattn.cu

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,6 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
198198
return BEST_FATTN_KERNEL_NONE;
199199
#endif// FLASH_ATTN_AVAILABLE
200200

201-
const ggml_tensor * KQV = dst;
202201
const ggml_tensor * Q = dst->src[0];
203202
const ggml_tensor * K = dst->src[1];
204203
const ggml_tensor * V = dst->src[2];
@@ -208,8 +207,6 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
208207
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
209208

210209
const int cc = ggml_cuda_info().devices[device].cc;
211-
const int warp_size = ggml_cuda_info().devices[device].warp_size;
212-
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
213210

214211
switch (K->ne[0]) {
215212
case 64:
@@ -267,29 +264,31 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
267264
return BEST_FATTN_KERNEL_NONE;
268265
}
269266

270-
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0;
267+
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0;
271268

272269
// If Turing tensor cores available, use them except for some cases with batch size 1:
273270
if (turing_mma_available(cc)) {
274271
best_fattn_kernel best = BEST_FATTN_KERNEL_MMA_F16;
275272

276-
if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) {
277-
if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
278-
best = BEST_FATTN_KERNEL_VEC;
279-
}
280-
} else {
281-
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
282-
if (Q->ne[1] <= 2) {
273+
if (can_use_vector_kernel) {
274+
if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) {
275+
if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
283276
best = BEST_FATTN_KERNEL_VEC;
284277
}
285278
} else {
286-
if (Q->ne[1] == 1) {
287-
best = BEST_FATTN_KERNEL_VEC;
279+
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
280+
if (Q->ne[1] <= 2) {
281+
best = BEST_FATTN_KERNEL_VEC;
282+
}
283+
} else {
284+
if (Q->ne[1] == 1) {
285+
best = BEST_FATTN_KERNEL_VEC;
286+
}
288287
}
289288
}
290-
}
291-
if ((gqa_ratio % 2 != 0 || !mask) && Q->ne[1] == 1) {
292-
best = BEST_FATTN_KERNEL_VEC; // GQA-specific optimizations in the mma kernel do not apply.
289+
if ((gqa_ratio % 2 != 0 || !mask) && Q->ne[1] == 1) {
290+
best = BEST_FATTN_KERNEL_VEC; // GQA-specific optimizations in the mma kernel do not apply.
291+
}
293292
}
294293

295294
return best;

0 commit comments

Comments
 (0)