]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: clamp matmul and FA results to the max finite value (#15652)
authorJeff Bolz <redacted>
Sun, 31 Aug 2025 06:27:57 +0000 (01:27 -0500)
committerGitHub <redacted>
Sun, 31 Aug 2025 06:27:57 +0000 (08:27 +0200)
* vulkan: clamp matmul and FA results to the max finite value

* only clamp for fp16

ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

index d40848e15fe973ad644748d5b6902a065b086be1..482445c6fea2c4c86a2edc9cbbdcafc998018763 100644 (file)
@@ -334,6 +334,9 @@ void main() {
     [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
         [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
             Of[r][d] *= Lfrcp[r];
+#if defined(ACC_TYPE_MAX)
+            Of[r][d] = clamp(Of[r][d], -vec4(ACC_TYPE_MAX), vec4(ACC_TYPE_MAX));
+#endif
         }
     }
 
index 97c2a54129709dcbc0a825535600b9173f5d575c..63b32171b0c07076967faf4a272e048d52506b84 100644 (file)
@@ -373,6 +373,9 @@ void main() {
     [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
         [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
             Of[r][d] *= ACC_TYPE(Lfrcp[r]);
+#if defined(ACC_TYPE_MAX)
+            Of[r][d] = clamp(Of[r][d], -ACC_TYPE_MAX, ACC_TYPE_MAX);
+#endif
         }
     }
 
index 77ae5ff01d03ebff85b19aacf969702dd80ef992..ab647e9bc8b688407db6d8d4a1e5027658c381f7 100644 (file)
@@ -283,6 +283,10 @@ void main() {
 
     O = Ldiag*O;
 
+#if defined(ACC_TYPE_MAX)
+    [[unroll]] for (uint i = 0; i < O.length(); ++i) { O[i] = clamp(O[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
+#endif
+
     uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
 
     coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);
index 76ef4b6dfb571c55cb7d17db3730b488ccb95911..06e83822fe326929bc4c5af164679307bb0d954d 100644 (file)
@@ -111,6 +111,10 @@ void main() {
             }
         }
         O *= L;
+
+        const float FLT_MAX = uintBitsToFloat(0x7F7FFFFF);
+        O = clamp(O, -FLT_MAX, FLT_MAX);
+
         data_d[iq3 * D * N + D * n + d] = O;
     }
 }
index 5ecf68a64383b84a56daf936d9048382f2f287ad..7e10e99e9e8771b3ec6f56f309baef1dbb6a2e67 100644 (file)
@@ -891,6 +891,20 @@ void main() {
         barrier();
     }
 
+#if defined(ACC_TYPE_MAX)
+#ifdef COOPMAT
+    [[unroll]] for (uint j = 0; j < cms_per_row * cms_per_col; j++) {
+        [[unroll]] for (uint i = 0; i < sums[j].length(); ++i) {
+            sums[j][i] = clamp(sums[j][i], -ACC_TYPE_MAX, ACC_TYPE_MAX);
+        }
+    }
+#else
+    [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
+        sums[i] = clamp(sums[i], -ACC_TYPE_MAX, ACC_TYPE_MAX);
+    }
+#endif
+#endif
+
     const uint dr = ir * BM + warp_r * WM;
     const uint dc = ic * BN + warp_c * WN;
 
index f5aebf6e93f948386000b0b2fdde42dae2821110..dd1b176049be3b4d60f132bd374ab48778c625c8 100644 (file)
@@ -349,6 +349,10 @@ void main() {
                 sum = coopMatMulAdd(mat_a, mat_b, sum);
                 block_k += BK;
             }
+#if defined(ACC_TYPE_MAX)
+            [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
+#endif
+
             coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum);
 
             coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover4, ir * BM, BM), tensorViewTranspose);
@@ -388,6 +392,10 @@ void main() {
                 sum = coopMatMulAdd(mat_a, mat_b, sum);
                 block_k += BK;
             }
+#if defined(ACC_TYPE_MAX)
+            [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
+#endif
+
             coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);
 
             coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover2, ir * BM, BM), tensorViewTranspose);
@@ -428,6 +436,10 @@ void main() {
                 sum = coopMatMulAdd(mat_a, mat_b, sum);
                 block_k += BK;
             }
+#if defined(ACC_TYPE_MAX)
+            [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
+#endif
+
             coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
 
             coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
@@ -485,6 +497,9 @@ void main() {
                 sum = coopMatMulAdd(mat_a, mat_b, sum);
             }
         }
+#if defined(ACC_TYPE_MAX)
+        [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
+#endif
 
         // Convert from ACC_TYPE to D_TYPE
         coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d;
index a973625857af7200ab11bc0bba2c008edc85a122..d81bb47e7b7053653a897f98f8f649059d90489e 100644 (file)
@@ -323,6 +323,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
     }
 
     base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
+    if (f16acc) {
+        base_dict["ACC_TYPE_MAX"] = "\"float16_t(65504.0)\"";
+    }
 
     if (coopmat) {
         base_dict["COOPMAT"] = "1";
@@ -437,8 +440,12 @@ void process_shaders() {
 
     // flash attention
     for (const auto& f16acc : {false, true}) {
-        std::string acctype = f16acc ? "float16_t" : "float";
-        std::string acctypev4 = f16acc ? "f16vec4" : "vec4";
+        std::map<std::string, std::string> fa_base_dict = base_dict;
+        fa_base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
+        fa_base_dict["ACC_TYPEV4"] = f16acc ? "f16vec4" : "vec4";
+        if (f16acc) {
+            fa_base_dict["ACC_TYPE_MAX"] = "\"float16_t(65504.0)\"";
+        }
 
         for (const auto& tname : type_names) {
             if (tname == "f32") {
@@ -449,30 +456,30 @@ void process_shaders() {
 #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
             if (tname == "f16") {
                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
-                    merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc);
+                    merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, true, f16acc);
             } else {
                 std::string data_a_key = "DATA_A_" + to_uppercase(tname);
                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
-                    merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
+                    merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
             }
 #endif
 #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
             if (tname == "f16") {
                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
-                    merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"COOPMAT", "1"}}), true, true, false, f16acc);
+                    merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc);
             } else if (tname == "q4_0" || tname == "q8_0") {
                 std::string data_a_key = "DATA_A_" + to_uppercase(tname);
                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
-                    merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
+                    merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
             }
 #endif
             if (tname == "f16") {
                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
-                    merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, false, f16acc);
+                    merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc);
             } else if (tname == "q4_0" || tname == "q8_0") {
                 std::string data_a_key = "DATA_A_" + to_uppercase(tname);
                 string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
-                    merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
+                    merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
             }
         }
     }