]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
vulkan: add RTE variants for glu/add/sub/mul/div (llama/14653)
authorJeff Bolz <redacted>
Tue, 15 Jul 2025 19:32:11 +0000 (14:32 -0500)
committerGeorgi Gerganov <redacted>
Sat, 19 Jul 2025 14:47:23 +0000 (17:47 +0300)
src/ggml-vulkan/ggml-vulkan.cpp
src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp
src/ggml-vulkan/vulkan-shaders/generic_binary_head.comp
src/ggml-vulkan/vulkan-shaders/glu_head.comp
src/ggml-vulkan/vulkan-shaders/im2col.comp
src/ggml-vulkan/vulkan-shaders/rope_head.comp
src/ggml-vulkan/vulkan-shaders/rte.comp [new file with mode: 0644]
src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

index 416ee3bd3f70acecca32cfd9a47ee7deebe3b726..9f5646bf29daecb8bbdccfa574bcfecf97f60bcf 100644 (file)
@@ -2835,10 +2835,11 @@ static void ggml_vk_load_shaders(vk_device& device) {
         return s;
     };
 
+    bool rte = device->float_controls_rte_fp16;
 #define CREATE_BINARY(name, namemod, spec) \
     for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
         ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
-                                #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d], name ## _data[s0][s1][d], \
+                                #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \
                                 "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
 
     CREATE_BINARY(add, , {0})
@@ -2890,8 +2891,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
 #undef CREATE_UNARY
 
 #define CREATE_GLU(name)  \
-    ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);  \
-    ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);
+    if (device->float_controls_rte_fp16) {  \
+        ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);   \
+        ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);   \
+    } else {    \
+        ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);   \
+        ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true);   \
+    }
 
     CREATE_GLU(geglu)
     CREATE_GLU(reglu)
index e06547e48f7fe1ecd171da854359e60251843ffd..27d6b7464f62c0fc5fd12ee536792124f9c0da67 100644 (file)
@@ -1,10 +1,6 @@
 #version 450
 
-#if RTE16
-#extension GL_EXT_spirv_intrinsics : enable
-spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
-#endif // RTE16
-
+#include "rte.comp"
 #include "types.comp"
 
 #if defined(SET_ROWS) && QUANT_K == 1
index 062e2a4cdf2d89eaa4e0bef52f011df304e9af73..4b4316cf3d9f263e140b4dfb655158c446e6d596 100644 (file)
@@ -1,6 +1,8 @@
 #extension GL_EXT_shader_16bit_storage : require
 #extension GL_EXT_control_flow_attributes : require
 
+#include "rte.comp"
+
 layout (push_constant) uniform parameter
 {
     uint ne;
index 41a29889075f69f107645882b4d19b178e11c819..004a61fc1625480e92ddbe68b76caf33f1604e53 100644 (file)
@@ -1,5 +1,7 @@
 #extension GL_EXT_shader_16bit_storage : require
 
+#include "rte.comp"
+
 layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
 
 layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
index 09aa849e8815c13fc2adcf7d437a7aeb0d28f3ae..17c7ccb90d001a44229ac59cecdf360cd2f51487 100644 (file)
@@ -1,12 +1,9 @@
 #version 450
 
 #extension GL_EXT_shader_16bit_storage : require
-#extension GL_EXT_spirv_intrinsics: enable
 #extension GL_EXT_control_flow_attributes : require
 
-#if RTE16
-spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
-#endif
+#include "rte.comp"
 
 layout (push_constant) uniform parameter
 {
index 96c9c4cbd307ced519430adfeee6218e5869db63..00e203e73bd1b4448724c2c3fd255ad709177c24 100644 (file)
@@ -1,11 +1,8 @@
 #include "types.comp"
 
 #extension GL_EXT_shader_16bit_storage : require
-#extension GL_EXT_spirv_intrinsics: enable
 
-#if RTE16
-spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
-#endif
+#include "rte.comp"
 
 layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in;
 
diff --git a/src/ggml-vulkan/vulkan-shaders/rte.comp b/src/ggml-vulkan/vulkan-shaders/rte.comp
new file mode 100644 (file)
index 0000000..ad51c1e
--- /dev/null
@@ -0,0 +1,5 @@
+
+#if RTE16
+#extension GL_EXT_spirv_intrinsics : enable
+spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits
+#endif // RTE16
index d4a4e4c5290d8795a78dc297b43b6b3e49c07f3b..809c0bd9bd305f735c0b39156106ee425b1cff57 100644 (file)
@@ -537,8 +537,10 @@ void process_shaders() {
     for (auto src0_f16 : {false, true}) {
     for (auto src1_f16 : {false, true}) {
     for (auto dst_f16  : {false, true}) {
-        auto name = op + get_suffix(src0_f16, src1_f16, dst_f16);
-        string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}});
+    for (auto rte      : {false, true}) {
+        auto name = op + get_suffix(src0_f16, src1_f16, dst_f16) + (rte ? "_rte" : "");
+        string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}});
+    }
     }
     }
     }
@@ -592,16 +594,19 @@ void process_shaders() {
     string_to_spv("sigmoid_f16",    "sigmoid.comp",     {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
     string_to_spv("sigmoid_f32",    "sigmoid.comp",     {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
 
-    string_to_spv("geglu_f16",      "geglu.comp",       {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
-    string_to_spv("geglu_f32",      "geglu.comp",       {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
-    string_to_spv("reglu_f16",      "reglu.comp",       {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
-    string_to_spv("reglu_f32",      "reglu.comp",       {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
-    string_to_spv("swiglu_f16",     "swiglu.comp",      {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
-    string_to_spv("swiglu_f32",     "swiglu.comp",      {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
-    string_to_spv("geglu_erf_f16",  "geglu_erf.comp",   {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
-    string_to_spv("geglu_erf_f32",  "geglu_erf.comp",   {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
-    string_to_spv("geglu_quick_f16","geglu_quick.comp", {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
-    string_to_spv("geglu_quick_f32","geglu_quick.comp", {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
+    for (auto rte : {false, true}) {
+        std::string suffix = rte ? "_rte" : "";
+        string_to_spv("geglu_f16" + suffix,      "geglu.comp",       {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"},   {"RTE16", rte ? "1" : "0"}});
+        string_to_spv("geglu_f32" + suffix,      "geglu.comp",       {{"A_TYPE", "float"},       {"D_TYPE", "float"},       {"RTE16", rte ? "1" : "0"}});
+        string_to_spv("reglu_f16" + suffix,      "reglu.comp",       {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"},   {"RTE16", rte ? "1" : "0"}});
+        string_to_spv("reglu_f32" + suffix,      "reglu.comp",       {{"A_TYPE", "float"},       {"D_TYPE", "float"},       {"RTE16", rte ? "1" : "0"}});
+        string_to_spv("swiglu_f16" + suffix,     "swiglu.comp",      {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"},   {"RTE16", rte ? "1" : "0"}});
+        string_to_spv("swiglu_f32" + suffix,     "swiglu.comp",      {{"A_TYPE", "float"},       {"D_TYPE", "float"},       {"RTE16", rte ? "1" : "0"}});
+        string_to_spv("geglu_erf_f16" + suffix,  "geglu_erf.comp",   {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"},   {"RTE16", rte ? "1" : "0"}});
+        string_to_spv("geglu_erf_f32" + suffix,  "geglu_erf.comp",   {{"A_TYPE", "float"},       {"D_TYPE", "float"},       {"RTE16", rte ? "1" : "0"}});
+        string_to_spv("geglu_quick_f16" + suffix,"geglu_quick.comp", {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"},   {"RTE16", rte ? "1" : "0"}});
+        string_to_spv("geglu_quick_f32" + suffix,"geglu_quick.comp", {{"A_TYPE", "float"},       {"D_TYPE", "float"},       {"RTE16", rte ? "1" : "0"}});
+    }
 
     string_to_spv("leaky_relu_f32", "leaky_relu.comp",  {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
     string_to_spv("silu_back_f32",  "silu_back.comp",   {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
@@ -709,11 +714,59 @@ void write_output_files() {
             std::remove(path.c_str());
         }
     }
+
+    std::string suffixes[2] = {"_f32", "_f16"};
     for (const char *op : {"add", "sub", "mul", "div"}) {
-        fprintf(hdr, "extern unsigned char *%s_data[2][2][2];\n", op);
-        fprintf(hdr, "extern uint64_t %s_len[2][2][2];\n", op);
-        fprintf(src, "unsigned char *%s_data[2][2][2] = {{{%s_f32_f32_f32_data, %s_f32_f32_f16_data}, {%s_f32_f16_f32_data, %s_f32_f16_f16_data}}, {{%s_f16_f32_f32_data, %s_f16_f32_f16_data}, {%s_f16_f16_f32_data, %s_f16_f16_f16_data}}};\n", op, op, op, op, op, op, op, op, op);
-        fprintf(src, "uint64_t %s_len[2][2][2] = {{{%s_f32_f32_f32_len, %s_f32_f32_f16_len}, {%s_f32_f16_f32_len, %s_f32_f16_f16_len}}, {{%s_f16_f32_f32_len, %s_f16_f32_f16_len}, {%s_f16_f16_f32_len, %s_f16_f16_f16_len}}};\n", op, op, op, op, op, op, op, op, op);
+        fprintf(hdr, "extern unsigned char *%s_data[2][2][2][2];\n", op);
+        fprintf(hdr, "extern uint64_t %s_len[2][2][2][2];\n", op);
+        std::string data = "unsigned char *" + std::string(op) + "_data[2][2][2][2] = ";
+        std::string len = "uint64_t " + std::string(op) + "_len[2][2][2][2] = ";
+        for (uint32_t t0 = 0; t0 < 2; ++t0) {
+            if (t0 == 0) {
+                data += "{";
+                len += "{";
+            }
+            for (uint32_t t1 = 0; t1 < 2; ++t1) {
+                if (t1 == 0) {
+                    data += "{";
+                    len += "{";
+                }
+                for (uint32_t t2 = 0; t2 < 2; ++t2) {
+                    if (t2 == 0) {
+                        data += "{";
+                        len += "{";
+                    }
+                    for (uint32_t rte = 0; rte < 2; ++rte) {
+                        if (rte == 0) {
+                            data += "{";
+                            len += "{";
+                        }
+                        data += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0) ? "_rte" : "");
+                        len  += op + suffixes[t0] + suffixes[t1] + suffixes[t2] + ((rte != 0) ? "_rte" : "");
+                        data += "_data,";
+                        len  += "_len,";
+                        if (rte == 1) {
+                            data += "}, ";
+                            len += "}, ";
+                        }
+                    }
+                    if (t2 == 1) {
+                        data += "}, ";
+                        len += "}, ";
+                    }
+                }
+                if (t1 == 1) {
+                    data += "}, ";
+                    len += "}, ";
+                }
+            }
+            if (t0 == 1) {
+                data += "};\n";
+                len += "};\n";
+            }
+        }
+        fprintf(src, data.c_str());
+        fprintf(src, len.c_str());
     }
     fclose(hdr);
     fclose(src);