]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
mat vec double buffer (llama/12188)
authorEve <redacted>
Mon, 10 Mar 2025 19:28:11 +0000 (19:28 +0000)
committerGeorgi Gerganov <redacted>
Thu, 27 Mar 2025 09:06:03 +0000 (11:06 +0200)
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp

index 8cdc640e80e31bb6badda522578814a88f8a898b..423ceb8a3df463b917e3f828e3d163cfcfab0c16 100644 (file)
@@ -5,23 +5,24 @@
 
 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
-shared FLOAT_TYPE sccache1[BLOCK_SIZE/16][16];
-shared FLOAT_TYPE sccache2[BLOCK_SIZE/16][16];
+shared FLOAT_TYPE sccache1[2][BLOCK_SIZE/16][16];
+shared FLOAT_TYPE sccache2[2][BLOCK_SIZE/16][16];
 
 FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
+uint csel = 0;
 
 void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint v_im, const uint ix, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
     const uint y_idx = i * QUANT_K + y_offset;
 
     [[unroll]] for (uint n = 0; n < num_rows; ++n) {
         const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
+        csel ^= 1;
 
-        barrier();
         if (!all_threads) { // when we don't have enough blocks to use all threads
             if (i < num_blocks_per_row) {
                 const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]);
-                sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF);
-                sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
+                sccache1[csel][ix][itid] = FLOAT_TYPE(scale & 0xF);
+                sccache2[csel][ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
             }
             barrier();
 
@@ -29,8 +30,8 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
                 continue;
         } else {
             const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]);
-            sccache1[ix][itid] = FLOAT_TYPE(scale & 0xF);
-            sccache2[ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
+            sccache1[csel][ix][itid] = FLOAT_TYPE(scale & 0xF);
+            sccache2[csel][ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
             barrier();
         }
 
@@ -57,22 +58,22 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
             FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
             FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
             [[unroll]] for (int l = 0; l < 2; ++l) {
-                sum1 = fma(FLOAT_TYPE(b0[l]),   sccache1[ix][    8*v_im] * qs_u32_0[l  ],
-                       fma(FLOAT_TYPE(b16[l]),  sccache1[ix][1 + 8*v_im] * qs_u32_0[l+2],
-                       fma(FLOAT_TYPE(b32[l]),  sccache1[ix][2 + 8*v_im] * qs_u32_2[l  ],
-                       fma(FLOAT_TYPE(b48[l]),  sccache1[ix][3 + 8*v_im] * qs_u32_2[l+2],
-                       fma(FLOAT_TYPE(b64[l]),  sccache1[ix][4 + 8*v_im] * qs_u32_4[l  ],
-                       fma(FLOAT_TYPE(b80[l]),  sccache1[ix][5 + 8*v_im] * qs_u32_4[l+2],
-                       fma(FLOAT_TYPE(b96[l]),  sccache1[ix][6 + 8*v_im] * qs_u32_6[l  ],
-                       fma(FLOAT_TYPE(b112[l]), sccache1[ix][7 + 8*v_im] * qs_u32_6[l+2], sum1))))))));
-                sum2 = fma(FLOAT_TYPE(b0[l]),   sccache2[ix][    8*v_im],
-                       fma(FLOAT_TYPE(b16[l]),  sccache2[ix][1 + 8*v_im],
-                       fma(FLOAT_TYPE(b32[l]),  sccache2[ix][2 + 8*v_im],
-                       fma(FLOAT_TYPE(b48[l]),  sccache2[ix][3 + 8*v_im],
-                       fma(FLOAT_TYPE(b64[l]),  sccache2[ix][4 + 8*v_im],
-                       fma(FLOAT_TYPE(b80[l]),  sccache2[ix][5 + 8*v_im],
-                       fma(FLOAT_TYPE(b96[l]),  sccache2[ix][6 + 8*v_im],
-                       fma(FLOAT_TYPE(b112[l]), sccache2[ix][7 + 8*v_im], sum2))))))));
+                sum1 = fma(FLOAT_TYPE(b0[l]),   sccache1[csel][ix][    8*v_im] * qs_u32_0[l  ],
+                       fma(FLOAT_TYPE(b16[l]),  sccache1[csel][ix][1 + 8*v_im] * qs_u32_0[l+2],
+                       fma(FLOAT_TYPE(b32[l]),  sccache1[csel][ix][2 + 8*v_im] * qs_u32_2[l  ],
+                       fma(FLOAT_TYPE(b48[l]),  sccache1[csel][ix][3 + 8*v_im] * qs_u32_2[l+2],
+                       fma(FLOAT_TYPE(b64[l]),  sccache1[csel][ix][4 + 8*v_im] * qs_u32_4[l  ],
+                       fma(FLOAT_TYPE(b80[l]),  sccache1[csel][ix][5 + 8*v_im] * qs_u32_4[l+2],
+                       fma(FLOAT_TYPE(b96[l]),  sccache1[csel][ix][6 + 8*v_im] * qs_u32_6[l  ],
+                       fma(FLOAT_TYPE(b112[l]), sccache1[csel][ix][7 + 8*v_im] * qs_u32_6[l+2], sum1))))))));
+                sum2 = fma(FLOAT_TYPE(b0[l]),   sccache2[csel][ix][    8*v_im],
+                       fma(FLOAT_TYPE(b16[l]),  sccache2[csel][ix][1 + 8*v_im],
+                       fma(FLOAT_TYPE(b32[l]),  sccache2[csel][ix][2 + 8*v_im],
+                       fma(FLOAT_TYPE(b48[l]),  sccache2[csel][ix][3 + 8*v_im],
+                       fma(FLOAT_TYPE(b64[l]),  sccache2[csel][ix][4 + 8*v_im],
+                       fma(FLOAT_TYPE(b80[l]),  sccache2[csel][ix][5 + 8*v_im],
+                       fma(FLOAT_TYPE(b96[l]),  sccache2[csel][ix][6 + 8*v_im],
+                       fma(FLOAT_TYPE(b112[l]), sccache2[csel][ix][7 + 8*v_im], sum2))))))));
             }
             temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
         }
index 3116fad165be0b994bb6ecec67ce555b34123562..e91724a28db2289e8e4705b7201da367b5c590c3 100644 (file)
@@ -5,20 +5,21 @@
 
 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
-shared FLOAT_TYPE sccache[BLOCK_SIZE/16][2][8];
+shared FLOAT_TYPE sccache[2][BLOCK_SIZE/16][2][8];
 
 FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
+uint csel = 0;
 
 void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, const uint itid8, const uint v_im, const uint v_im4, const uint v_in, const uint32_t hm_m[4], const uint q_offset, const uint y_offset, const uint s_shift, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
     const uint y_idx = i * QUANT_K + y_offset;
 
     [[unroll]] for (uint n = 0; n < num_rows; ++n) {
         const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
+        csel ^= 1;
 
         if (!all_threads) { // when we don't have enough blocks to use all threads
-            barrier();
             if (i < num_blocks_per_row)
-                sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);
+                sccache[csel][ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);
             barrier();
 
             if (i >= num_blocks_per_row)
@@ -40,8 +41,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, co
         const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
 
         if (all_threads) {
-            barrier();
-            sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);
+            sccache[csel][ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);
             barrier();
         }
 
@@ -59,14 +59,14 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, co
 
             FLOAT_TYPE sum = FLOAT_TYPE(0.0);
             [[unroll]] for (int l = 0; l < 2; ++l) {
-                sum = fma(FLOAT_TYPE(  b0[l]) * sccache[ix][v_im][0], qs_u32_0[l  ] - hmk_0[l  ],
-                      fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2],
-                      fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], qs_u32_2[l  ] - hmk_1[l  ],
-                      fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2],
-                      fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], qs_u32_4[l  ] - hmk_2[l  ],
-                      fma(FLOAT_TYPE( b80[l]) * sccache[ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2],
-                      fma(FLOAT_TYPE( b96[l]) * sccache[ix][v_im][6], qs_u32_6[l  ] - hmk_3[l  ],
-                      fma(FLOAT_TYPE(b112[l]) * sccache[ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2], sum))))))));
+                sum = fma(FLOAT_TYPE(  b0[l]) * sccache[csel][ix][v_im][0], qs_u32_0[l  ] - hmk_0[l  ],
+                      fma(FLOAT_TYPE( b16[l]) * sccache[csel][ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2],
+                      fma(FLOAT_TYPE( b32[l]) * sccache[csel][ix][v_im][2], qs_u32_2[l  ] - hmk_1[l  ],
+                      fma(FLOAT_TYPE( b48[l]) * sccache[csel][ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2],
+                      fma(FLOAT_TYPE( b64[l]) * sccache[csel][ix][v_im][4], qs_u32_4[l  ] - hmk_2[l  ],
+                      fma(FLOAT_TYPE( b80[l]) * sccache[csel][ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2],
+                      fma(FLOAT_TYPE( b96[l]) * sccache[csel][ix][v_im][6], qs_u32_6[l  ] - hmk_3[l  ],
+                      fma(FLOAT_TYPE(b112[l]) * sccache[csel][ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2], sum))))))));
             }
             temp[j][n] = fma(d, sum, temp[j][n]);
         }
index f05f96b5efb9d3ef6b2256514f0099be45bbfefc..d53d9ee0a2723cb20cede02b4f2d3252513d2178 100644 (file)
@@ -6,20 +6,21 @@
 
 layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
 
-shared FLOAT_TYPE sccache[BLOCK_SIZE/16][16];
+shared FLOAT_TYPE sccache[2][BLOCK_SIZE/16][16];
 
 FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
+uint csel = 0;
 
 void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint ix, const uint ql_offset, const uint qh_offset, const uint s_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
     const uint y_idx = i * QUANT_K + y_offset;
 
     [[unroll]] for (uint n = 0; n < num_rows; ++n) {
         const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
+        csel ^= 1;
 
         if (!all_threads) { // when we don't have enough blocks to use all threads
-            barrier();
             if (i < num_blocks_per_row)
-                sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
+                sccache[csel][ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
             barrier();
 
             if (i >= num_blocks_per_row)
@@ -51,8 +52,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
         const vec4 q3 = vec4(unpack8(q3_u32)) - 32;
 
         if (all_threads) {
-            barrier();
-            sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
+            sccache[csel][ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
             barrier();
         }
 
@@ -71,7 +71,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
                 sum[2] = fma(FLOAT_TYPE(by64[l]), q2[l], sum[2]);
                 sum[3] = fma(FLOAT_TYPE(by96[l]), q3[l], sum[3]);
             }
-            temp[j][n] = fma(fma(sum[0], sccache[ix][s_offset], fma(sum[1], sccache[ix][s_offset + 2], fma(sum[2], sccache[ix][s_offset + 4], sum[3] * sccache[ix][s_offset + 6]))), d, temp[j][n]);
+            temp[j][n] = fma(fma(sum[0], sccache[csel][ix][s_offset], fma(sum[1], sccache[csel][ix][s_offset + 2], fma(sum[2], sccache[csel][ix][s_offset + 4], sum[3] * sccache[csel][ix][s_offset + 6]))), d, temp[j][n]);
         }
     }
 }