From: Jeff Bolz Date: Sun, 31 Aug 2025 06:27:57 +0000 (-0500) Subject: vulkan: clamp matmul and FA results to the max finite value (#15652) X-Git-Tag: upstream/0.0.6527~199 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=94e82c7eadeb8fff0db4bfd1ab6d8cf65fa6f2e0;p=pkg%2Fggml%2Fsources%2Fllama.cpp vulkan: clamp matmul and FA results to the max finite value (#15652) * vulkan: clamp matmul and FA results to the max finite value * only clamp for fp16 --- diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index d40848e1..482445c6 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -334,6 +334,9 @@ void main() { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) { Of[r][d] *= Lfrcp[r]; +#if defined(ACC_TYPE_MAX) + Of[r][d] = clamp(Of[r][d], -vec4(ACC_TYPE_MAX), vec4(ACC_TYPE_MAX)); +#endif } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 97c2a541..63b32171 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -373,6 +373,9 @@ void main() { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Of[r][d] *= ACC_TYPE(Lfrcp[r]); +#if defined(ACC_TYPE_MAX) + Of[r][d] = clamp(Of[r][d], -ACC_TYPE_MAX, ACC_TYPE_MAX); +#endif } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 77ae5ff0..ab647e9b 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -283,6 +283,10 @@ void main() { O = Ldiag*O; +#if defined(ACC_TYPE_MAX) + [[unroll]] for (uint i = 0; i < O.length(); ++i) { O[i] = clamp(O[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } +#endif + uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; coopmat O_D = coopmat(O); diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp index 76ef4b6d..06e83822 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp @@ -111,6 +111,10 @@ void main() { } } O *= L; + + const float FLT_MAX = uintBitsToFloat(0x7F7FFFFF); + O = clamp(O, -FLT_MAX, FLT_MAX); + data_d[iq3 * D * N + D * n + d] = O; } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 5ecf68a6..7e10e99e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -891,6 +891,20 @@ void main() { barrier(); } +#if defined(ACC_TYPE_MAX) +#ifdef COOPMAT + [[unroll]] for (uint j = 0; j < cms_per_row * cms_per_col; j++) { + [[unroll]] for (uint i = 0; i < sums[j].length(); ++i) { + sums[j][i] = clamp(sums[j][i], -ACC_TYPE_MAX, ACC_TYPE_MAX); + } + } +#else + [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { + sums[i] = clamp(sums[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); + } +#endif +#endif + const uint dr = ir * BM + warp_r * WM; const uint dc = ic * BN + warp_c * WN; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index f5aebf6e..dd1b1760 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -349,6 +349,10 @@ void main() { sum = coopMatMulAdd(mat_a, mat_b, sum); block_k += BK; } +#if defined(ACC_TYPE_MAX) + [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } +#endif + coopmat mat_d = coopmat(sum); coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover4, ir * BM, BM), tensorViewTranspose); @@ -388,6 +392,10 @@ void main() { sum = coopMatMulAdd(mat_a, mat_b, sum); block_k += BK; } +#if defined(ACC_TYPE_MAX) + [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } +#endif + coopmat mat_d = coopmat(sum); coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover2, ir * BM, BM), tensorViewTranspose); @@ -428,6 +436,10 @@ void main() { sum = coopMatMulAdd(mat_a, mat_b, sum); block_k += BK; } +#if defined(ACC_TYPE_MAX) + [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } +#endif + coopmat mat_d = coopmat(sum); coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose); @@ -485,6 +497,9 @@ void main() { sum = coopMatMulAdd(mat_a, mat_b, sum); } } +#if defined(ACC_TYPE_MAX) + [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } +#endif // Convert from ACC_TYPE to D_TYPE coopmat mat_d; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index a9736258..d81bb47e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -323,6 +323,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c } base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; + if (f16acc) { + base_dict["ACC_TYPE_MAX"] = "\"float16_t(65504.0)\""; + } if (coopmat) { base_dict["COOPMAT"] = "1"; @@ -437,8 +440,12 @@ void process_shaders() { // flash attention for (const auto& f16acc : {false, true}) { - std::string acctype = f16acc ? "float16_t" : "float"; - std::string acctypev4 = f16acc ? "f16vec4" : "vec4"; + std::map fa_base_dict = base_dict; + fa_base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; + fa_base_dict["ACC_TYPEV4"] = f16acc ? "f16vec4" : "vec4"; + if (f16acc) { + fa_base_dict["ACC_TYPE_MAX"] = "\"float16_t(65504.0)\""; + } for (const auto& tname : type_names) { if (tname == "f32") { @@ -449,30 +456,30 @@ void process_shaders() { #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", - merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc); + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, true, f16acc); } else { std::string data_a_key = "DATA_A_" + to_uppercase(tname); string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", - merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc); + merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc); } #endif #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", - merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"COOPMAT", "1"}}), true, true, false, f16acc); + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc); } else if (tname == "q4_0" || tname == "q8_0") { std::string data_a_key = "DATA_A_" + to_uppercase(tname); string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", - merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc); + merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc); } #endif if (tname == "f16") { string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", - merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, false, f16acc); + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc); } else if (tname == "q4_0" || tname == "q8_0") { std::string data_a_key = "DATA_A_" + to_uppercase(tname); string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", - merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc); + merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc); } } }