@@ -7817,7 +7817,7 @@ kernel void kernel_set_rows_f(
7817
7817
#define SG_MAT_ROW 8
7818
7818
7819
7819
// each block_q contains 16*nl weights
7820
- template <typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short , thread T4x4 &)>
7820
+ template <typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short , thread T4x4 &), typename U, typename U2x4 >
7821
7821
kernel void kernel_mul_mm (
7822
7822
constant ggml_metal_kargs_mul_mm & args,
7823
7823
device const char * src0,
@@ -7862,7 +7862,7 @@ kernel void kernel_mul_mm(
7862
7862
device const block_q * x = (device const block_q *)(src0
7863
7863
+ args.nb01 *(r0*BLOCK_SIZE_M + thread_row) + offset0) + offset1;
7864
7864
7865
- device const float * y = (device const float *)(src1
7865
+ device const U * y = (device const U *)(src1
7866
7866
+ args.nb13 *i13
7867
7867
+ args.nb12 *i12
7868
7868
+ args.nb11 *(r1*BLOCK_SIZE_N + thread_col)
@@ -7882,7 +7882,7 @@ kernel void kernel_mul_mm(
7882
7882
+ (tiitg/THREAD_PER_ROW)%8 + (i&7 )*8 ) = temp_a[i/4 ][i%4 ];
7883
7883
}
7884
7884
7885
- *(threadgroup float2x4 *)(sb + 32 *8 *(tiitg%THREAD_PER_COL) + 8 *(tiitg/THREAD_PER_COL)) = *((device float2x4 *) y);
7885
+ *(threadgroup float2x4 *)(sb + 32 *8 *(tiitg%THREAD_PER_COL) + 8 *(tiitg/THREAD_PER_COL)) = (float2x4)( *((device U2x4 *) y) );
7886
7886
7887
7887
il = (il + 2 < nl) ? il + 2 : il % 2 ;
7888
7888
x = (il < 2 ) ? x + (2 + nl - 1 )/nl : x;
@@ -8248,33 +8248,59 @@ template [[host_name("kernel_set_rows_iq4_nl_i32")]] kernel set_rows_q32_t kerne
8248
8248
// matrix-matrix multiplication
8249
8249
//
8250
8250
8251
- typedef decltype (kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1 , dequantize_f32>) mul_mm_t;
8251
+ typedef decltype (kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1 , dequantize_f32, float , float2x4 >) mul_mm_t;
8252
8252
8253
- template [[host_name(" kernel_mul_mm_f32_f32" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1 , dequantize_f32>;
8254
- template [[host_name(" kernel_mul_mm_f16_f32" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1 , dequantize_f16>;
8253
+ template [[host_name(" kernel_mul_mm_f32_f32" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1 , dequantize_f32, float , float2x4>;
8254
+ template [[host_name(" kernel_mul_mm_f16_f32" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1 , dequantize_f16, float , float2x4>;
8255
+ #if defined(GGML_METAL_HAS_BF16)
8256
+ template [[host_name(" kernel_mul_mm_bf16_f32" )]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1 , dequantize_bf16, float , float2x4>;
8257
+ #endif
8258
+ template [[host_name(" kernel_mul_mm_q4_0_f32" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2 , dequantize_q4_0, float , float2x4>;
8259
+ template [[host_name(" kernel_mul_mm_q4_1_f32" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2 , dequantize_q4_1, float , float2x4>;
8260
+ template [[host_name(" kernel_mul_mm_q5_0_f32" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2 , dequantize_q5_0, float , float2x4>;
8261
+ template [[host_name(" kernel_mul_mm_q5_1_f32" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2 , dequantize_q5_1, float , float2x4>;
8262
+ template [[host_name(" kernel_mul_mm_q8_0_f32" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2 , dequantize_q8_0, float , float2x4>;
8263
+ template [[host_name(" kernel_mul_mm_mxfp4_f32" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_mxfp4, 2 , dequantize_mxfp4, float , float2x4>;
8264
+ template [[host_name(" kernel_mul_mm_q2_K_f32" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, float , float2x4>;
8265
+ template [[host_name(" kernel_mul_mm_q3_K_f32" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, float , float2x4>;
8266
+ template [[host_name(" kernel_mul_mm_q4_K_f32" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, float , float2x4>;
8267
+ template [[host_name(" kernel_mul_mm_q5_K_f32" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, float , float2x4>;
8268
+ template [[host_name(" kernel_mul_mm_q6_K_f32" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, float , float2x4>;
8269
+ template [[host_name(" kernel_mul_mm_iq2_xxs_f32" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, float , float2x4>;
8270
+ template [[host_name(" kernel_mul_mm_iq2_xs_f32" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, float , float2x4>;
8271
+ template [[host_name(" kernel_mul_mm_iq3_xxs_f32" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, float , float2x4>;
8272
+ template [[host_name(" kernel_mul_mm_iq3_s_f32" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, float , float2x4>;
8273
+ template [[host_name(" kernel_mul_mm_iq2_s_f32" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, float , float2x4>;
8274
+ template [[host_name(" kernel_mul_mm_iq1_s_f32" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, float , float2x4>;
8275
+ template [[host_name(" kernel_mul_mm_iq1_m_f32" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, float , float2x4>;
8276
+ template [[host_name(" kernel_mul_mm_iq4_nl_f32" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2 , dequantize_iq4_nl, float , float2x4>;
8277
+ template [[host_name(" kernel_mul_mm_iq4_xs_f32" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, float , float2x4>;
8278
+
8279
+ template [[host_name(" kernel_mul_mm_f32_f16" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, float4x4, 1 , dequantize_f32, half, half2x4>;
8280
+ template [[host_name(" kernel_mul_mm_f16_f16" )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half4x4, 1 , dequantize_f16, half, half2x4>;
8255
8281
#if defined(GGML_METAL_HAS_BF16)
8256
- template [[host_name(" kernel_mul_mm_bf16_f32 " )]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1 , dequantize_bf16>;
8282
+ template [[host_name(" kernel_mul_mm_bf16_f16 " )]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, bfloat4x4, 1 , dequantize_bf16, half, half2x4 >;
8257
8283
#endif
8258
- template [[host_name(" kernel_mul_mm_q4_0_f32 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2 , dequantize_q4_0>;
8259
- template [[host_name(" kernel_mul_mm_q4_1_f32 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2 , dequantize_q4_1>;
8260
- template [[host_name(" kernel_mul_mm_q5_0_f32 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2 , dequantize_q5_0>;
8261
- template [[host_name(" kernel_mul_mm_q5_1_f32 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2 , dequantize_q5_1>;
8262
- template [[host_name(" kernel_mul_mm_q8_0_f32 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2 , dequantize_q8_0>;
8263
- template [[host_name(" kernel_mul_mm_mxfp4_f32 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_mxfp4, 2 , dequantize_mxfp4>;
8264
- template [[host_name(" kernel_mul_mm_q2_K_f32 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K>;
8265
- template [[host_name(" kernel_mul_mm_q3_K_f32 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K>;
8266
- template [[host_name(" kernel_mul_mm_q4_K_f32 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K>;
8267
- template [[host_name(" kernel_mul_mm_q5_K_f32 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K>;
8268
- template [[host_name(" kernel_mul_mm_q6_K_f32 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K>;
8269
- template [[host_name(" kernel_mul_mm_iq2_xxs_f32 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs>;
8270
- template [[host_name(" kernel_mul_mm_iq2_xs_f32 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs>;
8271
- template [[host_name(" kernel_mul_mm_iq3_xxs_f32 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs>;
8272
- template [[host_name(" kernel_mul_mm_iq3_s_f32 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s>;
8273
- template [[host_name(" kernel_mul_mm_iq2_s_f32 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s>;
8274
- template [[host_name(" kernel_mul_mm_iq1_s_f32 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s>;
8275
- template [[host_name(" kernel_mul_mm_iq1_m_f32 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m>;
8276
- template [[host_name(" kernel_mul_mm_iq4_nl_f32 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2 , dequantize_iq4_nl>;
8277
- template [[host_name(" kernel_mul_mm_iq4_xs_f32 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs>;
8284
+ template [[host_name(" kernel_mul_mm_q4_0_f16 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_0, 2 , dequantize_q4_0, half, half2x4 >;
8285
+ template [[host_name(" kernel_mul_mm_q4_1_f16 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_1, 2 , dequantize_q4_1, half, half2x4 >;
8286
+ template [[host_name(" kernel_mul_mm_q5_0_f16 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_0, 2 , dequantize_q5_0, half, half2x4 >;
8287
+ template [[host_name(" kernel_mul_mm_q5_1_f16 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_1, 2 , dequantize_q5_1, half, half2x4 >;
8288
+ template [[host_name(" kernel_mul_mm_q8_0_f16 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q8_0, 2 , dequantize_q8_0, half, half2x4 >;
8289
+ template [[host_name(" kernel_mul_mm_mxfp4_f16 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_mxfp4, 2 , dequantize_mxfp4, half, half2x4 >;
8290
+ template [[host_name(" kernel_mul_mm_q2_K_f16 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q2_K, QK_NL, dequantize_q2_K, half, half2x4 >;
8291
+ template [[host_name(" kernel_mul_mm_q3_K_f16 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q3_K, QK_NL, dequantize_q3_K, half, half2x4 >;
8292
+ template [[host_name(" kernel_mul_mm_q4_K_f16 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q4_K, QK_NL, dequantize_q4_K, half, half2x4 >;
8293
+ template [[host_name(" kernel_mul_mm_q5_K_f16 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q5_K, QK_NL, dequantize_q5_K, half, half2x4 >;
8294
+ template [[host_name(" kernel_mul_mm_q6_K_f16 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_q6_K, QK_NL, dequantize_q6_K, half, half2x4 >;
8295
+ template [[host_name(" kernel_mul_mm_iq2_xxs_f16 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xxs, QK_NL, dequantize_iq2_xxs, half, half2x4 >;
8296
+ template [[host_name(" kernel_mul_mm_iq2_xs_f16 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_xs, QK_NL, dequantize_iq2_xs, half, half2x4 >;
8297
+ template [[host_name(" kernel_mul_mm_iq3_xxs_f16 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_xxs, QK_NL, dequantize_iq3_xxs, half, half2x4 >;
8298
+ template [[host_name(" kernel_mul_mm_iq3_s_f16 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq3_s, QK_NL, dequantize_iq3_s, half, half2x4 >;
8299
+ template [[host_name(" kernel_mul_mm_iq2_s_f16 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq2_s, QK_NL, dequantize_iq2_s, half, half2x4 >;
8300
+ template [[host_name(" kernel_mul_mm_iq1_s_f16 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_s, QK_NL, dequantize_iq1_s, half, half2x4 >;
8301
+ template [[host_name(" kernel_mul_mm_iq1_m_f16 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq1_m, QK_NL, dequantize_iq1_m, half, half2x4 >;
8302
+ template [[host_name(" kernel_mul_mm_iq4_nl_f16 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_nl, 2 , dequantize_iq4_nl, half, half2x4 >;
8303
+ template [[host_name(" kernel_mul_mm_iq4_xs_f16 " )]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, block_iq4_xs, QK_NL, dequantize_iq4_xs, half, half2x4 >;
8278
8304
8279
8305
//
8280
8306
// indirect matrix-matrix multiplication
0 commit comments