]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : use constexpr in FA kernels + fix typedef (#12659)
authorGeorgi Gerganov <redacted>
Sun, 30 Mar 2025 19:04:04 +0000 (22:04 +0300)
committerGitHub <redacted>
Sun, 30 Mar 2025 19:04:04 +0000 (22:04 +0300)
* metal : use constexpr in FA kernels

ggml-ci

* cont

ggml-ci

* cont : fix typedef

ggml-ci

ggml/src/ggml-metal/ggml-metal.metal

index 1c0ca5adf66919c8d8b2b23bfff6899b4425f39f..80d0765b4fc0e7d506a7e9824177dab1529f41df 100644 (file)
@@ -3128,14 +3128,15 @@ kernel void kernel_flash_attn_ext(
     const int iq2 = tgpig[1];
     const int iq1 = tgpig[0]*Q;
 
-    const short DK4  = DK/4;
-    const short DK8  = DK/8;
-    const short DK16 = DK/16;
-    const short DV4  = DV/4;
-    const short DV8  = DV/8;
-    const short DV16 = DV/16;
-    const short NW  = N_SIMDWIDTH;
-    const short SH  = (2*C + Q); // shared memory per simdgroup (s_t == float)
+    constexpr short DK4  = DK/4;
+    constexpr short DK8  = DK/8;
+    constexpr short DK16 = DK/16;
+    constexpr short DV4  = DV/4;
+    constexpr short DV8  = DV/8;
+    constexpr short DV16 = DV/16;
+
+    constexpr short NW  = N_SIMDWIDTH;
+    constexpr short SH  = (2*C + Q); // shared memory per simdgroup (s_t == float)
 
     const short TS = nsg*SH;   // shared memory size per query in (s_t == float)
     const short T  = DK + 2*TS; // shared memory size per query in (half)
@@ -3641,11 +3642,11 @@ kernel void kernel_flash_attn_ext_vec(
     const int iq2 = tgpig[1];
     const int iq1 = tgpig[0];
 
-    const short DK4 = DK/4;
-    const short DV4 = DV/4;
-    const short NW  = N_SIMDWIDTH;
-    const short NL  = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
-    const short SH  = 2*C;   // shared memory per simdgroup
+    constexpr short DK4 = DK/4;
+    constexpr short DV4 = DV/4;
+    constexpr short NW  = N_SIMDWIDTH;
+    constexpr short NL  = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
+    constexpr short SH  = 2*C;   // shared memory per simdgroup
 
     const short T = DK + nsg*SH; // shared memory size per query in (half)
 
@@ -3956,7 +3957,7 @@ kernel void kernel_flash_attn_ext_vec(
     half,  half4, \
            half4
 
-typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 128>) flash_attn_ext_vec_t;
+typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
 
 template [[host_name("kernel_flash_attn_ext_vec_f16_h128")]]  kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4,             1, dequantize_f16_t4,  half4,       1, dequantize_f16_t4,  128, 128, 4>;
 #if defined(GGML_METAL_USE_BF16)