Skip to content

Commit c7ec17f

Browse files
fix kernel selection logic
1 parent e267903 commit c7ec17f

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

ggml/src/ggml-cuda/fattn.cu

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -208,8 +208,6 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
208208
GGML_ASSERT(Q->ne[2] % K->ne[2] == 0);
209209

210210
const int cc = ggml_cuda_info().devices[device].cc;
211-
const int warp_size = ggml_cuda_info().devices[device].warp_size;
212-
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
213211

214212
switch (K->ne[0]) {
215213
case 64:
@@ -267,29 +265,31 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
267265
return BEST_FATTN_KERNEL_NONE;
268266
}
269267

270-
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % (2*warp_size) == 0;
268+
const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0;
271269

272270
// If Turing tensor cores available, use them except for some cases with batch size 1:
273271
if (turing_mma_available(cc)) {
274272
best_fattn_kernel best = BEST_FATTN_KERNEL_MMA_F16;
275273

276-
if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) {
277-
if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
278-
best = BEST_FATTN_KERNEL_VEC;
279-
}
280-
} else {
281-
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
282-
if (Q->ne[1] <= 2) {
274+
if (can_use_vector_kernel) {
275+
if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) {
276+
if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) {
283277
best = BEST_FATTN_KERNEL_VEC;
284278
}
285279
} else {
286-
if (Q->ne[1] == 1) {
287-
best = BEST_FATTN_KERNEL_VEC;
280+
if (cc >= GGML_CUDA_CC_ADA_LOVELACE) {
281+
if (Q->ne[1] <= 2) {
282+
best = BEST_FATTN_KERNEL_VEC;
283+
}
284+
} else {
285+
if (Q->ne[1] == 1) {
286+
best = BEST_FATTN_KERNEL_VEC;
287+
}
288288
}
289289
}
290-
}
291-
if ((gqa_ratio % 2 != 0 || !mask) && Q->ne[1] == 1) {
292-
best = BEST_FATTN_KERNEL_VEC; // GQA-specific optimizations in the mma kernel do not apply.
290+
if ((gqa_ratio % 2 != 0 || !mask) && Q->ne[1] == 1) {
291+
best = BEST_FATTN_KERNEL_VEC; // GQA-specific optimizations in the mma kernel do not apply.
292+
}
293293
}
294294

295295
return best;

0 commit comments

Comments
 (0)