]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
vulkan: increase LOAD_VEC_A to 8 (IQ1/IQ2) or 4 (IQ3) (llama/14485)
authorEve <redacted>
Sun, 6 Jul 2025 10:29:36 +0000 (10:29 +0000)
committerGeorgi Gerganov <redacted>
Sat, 12 Jul 2025 13:05:00 +0000 (16:05 +0300)
Commit taken from remyoudompheng's PR https://github.com/ggml-org/llama.cpp/pull/12260

Co-authored-by: Rémy Oudompheng <redacted>
src/ggml-vulkan/vulkan-shaders/mul_mm.comp
src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

index 26163b167c7ed712f98439ee260e4056e999ff2e..888ce79f6ec113ac5bf47fe6fce154c086ab21ad 100644 (file)
@@ -500,10 +500,9 @@ void main() {
             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
             const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
 
-            const uint ib = idx / 128;                  // 2 values per idx
-            const uint ib32 = (idx % 128) / 16;         // 0..7
-            const uint ib8 = (idx % 128) / 4;
-            const int i8 = 2 * int(idx % 4);
+            const uint ib = idx / 32;                  // 8 values per idx
+            const uint ib32 = (idx % 32) / 4;         // 0..7
+            const uint ib8 = idx % 32;
 
             const float d = float(data_a[ib].d);
             const uint qh = data_a[ib].qh[ib32];
@@ -512,22 +511,16 @@ void main() {
             const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
             const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]);
 
-            const ivec2 gvec = ivec2(
-              bitfieldExtract(grid, 2 * (i8), 2),
-              bitfieldExtract(grid, 2 * (i8 + 1), 2)
-            );
-            const vec2 v = dl * (vec2(gvec) + delta);
-
-            buf_a[buf_idx    ] = FLOAT_TYPE(v.x);
-            buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
+            [[unroll]] for (int k = 0; k < 8; ++k) {
+                buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
+            }
 #elif defined(DATA_A_IQ1_M)
             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
             const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
 
-            const uint ib = idx / 128;                  // 2 values per idx
-            const uint ib8 = (idx % 128) / 4;
+            const uint ib = idx / 32;  // 8 values per idx
+            const uint ib8 = idx % 32;
             const uint ib16 = ib8 / 2;
-            const int i8 = 2 * int(idx % 4);
 
             const uint16_t[4] scales = data_a[ib].scales;
             const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
@@ -538,21 +531,17 @@ void main() {
             const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
             const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
             const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
-            const ivec2 gvec = ivec2(
-              bitfieldExtract(grid, 2 * (i8), 2),
-              bitfieldExtract(grid, 2 * (i8 + 1), 2)
-            );
-            const vec2 v = dl * (vec2(gvec) + delta);
 
-            buf_a[buf_idx    ] = FLOAT_TYPE(v.x);
-            buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
+            [[unroll]] for (int k = 0; k < 8; ++k) {
+                buf_a[buf_idx + k] = FLOAT_TYPE(dl * (bitfieldExtract(grid, 2 * k, 2) + delta));
+            }
 #elif defined(DATA_A_IQ2_XXS)
             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
             const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
 
-            const uint ib = idx / 128;                  // 2 values per idx
-            const uint ib32 = (idx % 128) / 16;         // 0..7
-            const uint ib8 = (idx / 4) % 4;
+            const uint ib = idx / 32;                 // 8 values per idx
+            const uint ib32 = (idx % 32) / 4;         // 0..7
+            const uint ib8 = idx % 4;
 
             const float d = float(data_a[ib].d);
             const uint qs = data_a[ib].qs[8 * ib32 + ib8];
@@ -562,63 +551,81 @@ void main() {
                 data_a[ib].qs[8*ib32 + 6],
                 data_a[ib].qs[8*ib32 + 7]
             ));
-            const float db = d * 0.25 * (0.5 + (signs >> 28));
+            const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + (signs >> 28)));
             const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7);
-            const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
-            const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
-            const uint grid = iq2xxs_grid[qs][(idx % 4) / 2] >> (16 * (idx & 1));
-            const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
-
-            buf_a[buf_idx    ] = FLOAT_TYPE(v.x);
-            buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
+            const uint sign = sign7 | (bitCount(sign7) << 7);
+            const uvec2 grid = iq2xxs_grid[qs];
+            const vec4 grid0 = vec4(unpack8(grid.x));
+            const vec4 grid1 = vec4(unpack8(grid.y));
+
+            buf_a[buf_idx    ] = db * FLOAT_TYPE((sign &   1) != 0 ? -grid0.x : grid0.x);
+            buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign &   2) != 0 ? -grid0.y : grid0.y);
+            buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign &   4) != 0 ? -grid0.z : grid0.z);
+            buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign &   8) != 0 ? -grid0.w : grid0.w);
+            buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign &  16) != 0 ? -grid1.x : grid1.x);
+            buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign &  32) != 0 ? -grid1.y : grid1.y);
+            buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign &  64) != 0 ? -grid1.z : grid1.z);
+            buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
 #elif defined(DATA_A_IQ2_XS)
             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
             const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
 
-            const uint ib = idx / 128;                  // 2 values per idx
-            const uint ib32 = (idx % 128) / 16;         // 0..7
-            const uint ib8 = (idx / 4) % 4;             // 0..3
+            const uint ib = idx / 32;            // 8 values per idx
+            const uint ib32 = (idx % 32) / 4;    // 0..7
+            const uint ib8 = idx % 4;            // 0..3
 
             const float d = float(data_a[ib].d);
             const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
-            const float db = d * 0.25 * (0.5 + scale);
+            const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
             const uint qs = data_a[ib].qs[4 * ib32 + ib8];
             const uint sign7 = qs >> 9;
-            const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
-            const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
-            const uint grid = iq2xs_grid[qs & 511][(idx % 4) / 2] >> (16 * (idx & 1));
-            const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
-
-            buf_a[buf_idx    ] = FLOAT_TYPE(v.x);
-            buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
+            const uint sign = sign7 | (bitCount(sign7) << 7);
+            const uvec2 grid = iq2xs_grid[qs & 511];
+            const vec4 grid0 = vec4(unpack8(grid.x));
+            const vec4 grid1 = vec4(unpack8(grid.y));
+
+            buf_a[buf_idx    ] = db * FLOAT_TYPE((sign &   1) != 0 ? -grid0.x : grid0.x);
+            buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign &   2) != 0 ? -grid0.y : grid0.y);
+            buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign &   4) != 0 ? -grid0.z : grid0.z);
+            buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign &   8) != 0 ? -grid0.w : grid0.w);
+            buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign &  16) != 0 ? -grid1.x : grid1.x);
+            buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign &  32) != 0 ? -grid1.y : grid1.y);
+            buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign &  64) != 0 ? -grid1.z : grid1.z);
+            buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
 #elif defined(DATA_A_IQ2_S)
             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
             const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
 
-            const uint ib = idx / 128;        // 2 values per idx
-            const uint ib8 = (idx % 128) / 4; // 0..31
-            const uint ib32 = ib8 / 4;        // 0..7
+            const uint ib = idx / 32;  // 8 values per idx
+            const uint ib8 = idx % 32; // 0..31
+            const uint ib32 = ib8 / 4; // 0..7
 
             const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
             const uint qs = data_a[ib].qs[ib8];
             const uint qh = data_a[ib].qh[ib32];
             const uint qhshift = 2 * (ib8 % 4);
-            const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4));
+            const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8];
 
             const float d = float(data_a[ib].d);
-            const float db = d * 0.25 * (0.5 + scale);
-            const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
-            const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1];
-            const vec2 v = db * vec2(sign01) * vec2(unpack8(uint32_t(grid)).xy); // vec4 used due to #12147
-
-            buf_a[buf_idx    ] = FLOAT_TYPE(v.x);
-            buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
+            const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale));
+            const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)];
+            const vec4 grid0 = vec4(unpack8(grid.x));
+            const vec4 grid1 = vec4(unpack8(grid.y));
+
+            buf_a[buf_idx    ] = db * FLOAT_TYPE((sign &   1) != 0 ? -grid0.x : grid0.x);
+            buf_a[buf_idx + 1] = db * FLOAT_TYPE((sign &   2) != 0 ? -grid0.y : grid0.y);
+            buf_a[buf_idx + 2] = db * FLOAT_TYPE((sign &   4) != 0 ? -grid0.z : grid0.z);
+            buf_a[buf_idx + 3] = db * FLOAT_TYPE((sign &   8) != 0 ? -grid0.w : grid0.w);
+            buf_a[buf_idx + 4] = db * FLOAT_TYPE((sign &  16) != 0 ? -grid1.x : grid1.x);
+            buf_a[buf_idx + 5] = db * FLOAT_TYPE((sign &  32) != 0 ? -grid1.y : grid1.y);
+            buf_a[buf_idx + 6] = db * FLOAT_TYPE((sign &  64) != 0 ? -grid1.z : grid1.z);
+            buf_a[buf_idx + 7] = db * FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w);
 #elif defined(DATA_A_IQ3_XXS)
             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
             const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
 
-            const uint ib = idx / 128;                  // 2 values per idx
-            const uint iqs = (idx % 128) / 2;           // 0..63
+            const uint ib = idx / 64;            // 4 values per idx
+            const uint iqs = idx % 64;           // 0..63
             const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values
 
             const float d = float(data_a[ib].d);
@@ -631,33 +638,36 @@ void main() {
             ));
             const float db = d * 0.5 * (0.5 + (signs >> 28));
             const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
-            const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
-            const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
-            const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1));
-            const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
-
-            buf_a[buf_idx    ] = FLOAT_TYPE(v.x);
-            buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
+            const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2));
+            const uint grid = iq3xxs_grid[qs];
+            const vec4 v = db * vec4(unpack8(grid));
+
+            buf_a[buf_idx    ] = FLOAT_TYPE((sign &   1) != 0 ? -v.x : v.x);
+            buf_a[buf_idx + 1] = FLOAT_TYPE((sign &   2) != 0 ? -v.y : v.y);
+            buf_a[buf_idx + 2] = FLOAT_TYPE((sign &   4) != 0 ? -v.z : v.z);
+            buf_a[buf_idx + 3] = FLOAT_TYPE((sign &   8) != 0 ? -v.w : v.w);
 #elif defined(DATA_A_IQ3_S)
             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
             const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
 
-            const uint ib = idx / 128;                  // 2 values per idx
-            const uint iqs = (idx % 128) / 2;           // 0..63
+            const uint ib = idx / 64;            // 4 values per idx
+            const uint iqs = idx % 64;           // 0..63
             const uint iqh = iqs / 8;
 
             const float d = float(data_a[ib].d);
             const uint qs = data_a[ib].qs[iqs];
             const uint qh = data_a[ib].qh[iqh];
-            const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (2 * (idx % 4)));
+            const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2)));
             const uint scale = data_a[ib].scales[iqs / 16];
             const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
             const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
-            const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2));
-            const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); // vec4 used due to #12147
+            const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)];
+            const vec4 v = db * vec4(unpack8(grid));
 
-            buf_a[buf_idx    ] = FLOAT_TYPE(v.x);
-            buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
+            buf_a[buf_idx    ] = FLOAT_TYPE((sign &   1) != 0 ? -v.x : v.x);
+            buf_a[buf_idx + 1] = FLOAT_TYPE((sign &   2) != 0 ? -v.y : v.y);
+            buf_a[buf_idx + 2] = FLOAT_TYPE((sign &   4) != 0 ? -v.z : v.z);
+            buf_a[buf_idx + 3] = FLOAT_TYPE((sign &   8) != 0 ? -v.w : v.w);
 #elif defined(DATA_A_IQ4_XS)
             const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
             const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
index 2698522ed7101fd610459fa8690b0c8ab2b918ed..30f78fabb3c8519373affb1111c58c82c422daee 100644 (file)
@@ -360,9 +360,9 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
 
     for (const auto& tname : type_names) {
         std::string load_vec_quant = "2";
-        if ((tname == "q4_0") || (tname == "q4_1"))
+        if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
             load_vec_quant = "8";
-        else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq4_nl"))
+        else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl"))
             load_vec_quant = "4";
 
         if (tname == "bf16") {