]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
vulkan: copy iq4_nl LUT into shared memory (llama/10409)
authorJeff Bolz <redacted>
Wed, 20 Nov 2024 07:40:18 +0000 (01:40 -0600)
committerGeorgi Gerganov <redacted>
Tue, 3 Dec 2024 19:05:37 +0000 (21:05 +0200)
src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp
src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp
src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp
src/ggml-vulkan/vulkan-shaders/mul_mm.comp
src/ggml-vulkan/vulkan-shaders/types.comp
src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

index 34ef3da30b82cec56d5c6b7664c0035148cb094c..8de14fc03f102957c1538d448791148b37bf298c 100644 (file)
@@ -10,6 +10,8 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
 void main() {
     const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
 
+    init_iq4nl_shmem();
+
     const uint tid = gl_LocalInvocationID.x % 64;
     const uint il  = tid/32;
     const uint ir  = tid%32;
index 8d30b63c165b2effbf85bc6deb3c40957b5a713a..7f608315b68443ac38bfaf93fb5b469749ef437d 100644 (file)
@@ -12,6 +12,10 @@ void main() {
     const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
     const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
 
+#if defined(DATA_A_IQ4_NL)
+    init_iq4nl_shmem();
+#endif
+
     if (i00 >= p.ne00) {
         return;
     }
index 00807a060c46bd4a1d107b9701fe00997c42228b..2d5b8e4661312e1d14e305a35e6a945ce4376d02 100644 (file)
@@ -161,6 +161,10 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
 void main() {
     const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
 
+#if defined(DATA_A_IQ4_NL)
+    init_iq4nl_shmem();
+#endif
+
     // do NUM_ROWS at a time, unless there aren't enough remaining rows
     if (first_row + NUM_ROWS <= p.stride_d) {
         compute_outputs(first_row, NUM_ROWS);
index fffdd18189d5586805ce7982fa11b503849a6135..2ff5c430519bce79be5ee672649252fbb29c04d2 100644 (file)
@@ -75,6 +75,10 @@ shared u16vec2 row_ids[3072];
 #endif
 
 void main() {
+#if defined(DATA_A_IQ4_NL)
+    init_iq4nl_shmem();
+#endif
+
 #ifdef MUL_MAT_ID
     const uint expert_idx = gl_GlobalInvocationID.z;
 #else
index 7a34820bc2789c2e9acb5acd33dad729dcba0cc3..bc28e0ab857aa086f7297369743efbb7b3357a3f 100644 (file)
@@ -298,10 +298,21 @@ struct block_iq4_nl_packed16
 #define A_TYPE block_iq4_nl
 #define A_TYPE_PACKED16 block_iq4_nl_packed16
 
-const int8_t kvalues_iq4nl[16] = {
+const int8_t kvalues_iq4nl_const[16] = {
     int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10),
     int8_t(1), int8_t(13), int8_t(25), int8_t(38), int8_t(53), int8_t(69), int8_t(89), int8_t(113)
 };
+
+shared FLOAT_TYPE kvalues_iq4nl[16];
+
+void init_iq4nl_shmem()
+{
+    // copy the table into shared memory and sync
+    if (gl_LocalInvocationIndex.x < 16) {
+        kvalues_iq4nl[gl_LocalInvocationIndex.x] = FLOAT_TYPE(kvalues_iq4nl_const[gl_LocalInvocationIndex.x]);
+    }
+    barrier();
+}
 #endif
 
 #endif // !defined(GGML_TYPES_COMP)
index f753109556f933b1fe5269aa0688be8375923b51..6bbe8e96edd9f741032c74a53f2e9604a3f51281 100644 (file)
@@ -331,11 +331,11 @@ void process_shaders() {
             shader = (tname == "f32" || tname == "f16") ? "get_rows.comp" : "get_rows_quant.comp";
 
             if (tname == "f16") {
-                string_to_spv("get_rows_" + tname, shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
+                string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}));
             } else {
-                string_to_spv("get_rows_" + tname, shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}});
+                string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}}));
             }
-            string_to_spv("get_rows_" + tname + "_f32", shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}});
+            string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}}));
         }
     }