]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
metal : fix build and some more comments (llama/10229)
authorGeorgi Gerganov <redacted>
Sat, 9 Nov 2024 09:53:02 +0000 (11:53 +0200)
committerGeorgi Gerganov <redacted>
Fri, 15 Nov 2024 13:21:04 +0000 (15:21 +0200)
ggml/src/ggml-metal.m
ggml/src/ggml-metal.metal

index c112fd866f7faa136fcad13bbcfcfef295d9bd30..04ec5117f64b96990d9a7ea1f845c75d7e7ece40 100644 (file)
@@ -3041,6 +3041,8 @@ static void ggml_metal_encode_node(
 
                 bool use_vec_kernel = false;
 
+                // TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0)
+                //       for now avoiding mainly to keep the number of templates/kernels a bit lower
                 if (ne01 >= 4 || (ne00%128 != 0)) {
                     switch (src1->type) {
                         case GGML_TYPE_F16:
index 1f233ba7f8eaab1bac165ff791fe2309daba0468..779f459681fa1069dedb63df56c72de30b968052 100644 (file)
@@ -3356,8 +3356,8 @@ kernel void kernel_flash_attn_ext_vec(
     const short D4  = D/4;
     const short D16 = D/16;
     const short NW  = N_SIMDWIDTH;
-    const short NL  = NW/4;
-    const short SH  = 2*C; // shared memory per simdgroup
+    const short NL  = NW/4; // note: this can be adjusted to support D%64 == 0 and D%32 == 0
+    const short SH  = 2*C;  // shared memory per simdgroup
 
     const short T = D + nsg*SH; // shared memory size per query in (half)
 
@@ -3448,7 +3448,7 @@ kernel void kernel_flash_attn_ext_vec(
 
             // Q*K^T
             {
-                // each simdgroup processes 1 query and 4 keys
+                // each simdgroup processes 1 query and 4 (NW/NL) keys
                 for (short cc = 0; cc < C/4; ++cc) {
                     qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 };
 
@@ -3646,7 +3646,7 @@ kernel void kernel_flash_attn_ext_vec(
     half,  half4,  half4x4, \
                    half4x4
 
-typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 64>) flash_attn_ext_vec_t;
+typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4x4, 1, dequantize_f16, half4x4, 1, dequantize_f16, 128>) flash_attn_ext_vec_t;
 
 template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4x4,    1, dequantize_f16,  half4x4,     1, dequantize_f16,  128>;
 #if defined(GGML_METAL_USE_BF16)