]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
cuda: optimize iq2xxs/iq2xs/iq3xxs dequantization (#19624)
authorDavid Friehs <redacted>
Sun, 15 Feb 2026 17:08:42 +0000 (18:08 +0100)
committerGitHub <redacted>
Sun, 15 Feb 2026 17:08:42 +0000 (22:38 +0530)
* cuda: optimize iq2xxs/iq2xs/iq3xxs dequantization

- load all 8 int8 for a grid position in one load
- calculate signs via popcnt instead of fetching from ksigns table
- broadcast signs to drop individual shift/mask

* cuda: iq2xxs: simplify sum scaling

express `(sum * scale + sum / 2) / 4` as `(sum * (scale * 2 + 1)) / 8`
express `((aux32 >> 28) * 2 + 1)` as `(aux32 >> 27 | 1)`

saves 3 registers for mul_mat_vec_q (152 -> 149) according to nsight
AFAICT no overflow can occur here as iq2xxs values are far too small

* uint -> uint32_t

error: identifier "uint" is undefined

ggml/src/ggml-cuda/mmq.cuh
ggml/src/ggml-cuda/vecdotq.cuh

index f80f98cda2c4cff6f30337a616b00c907ae87ae5..255e59f6fc68729e1ef0fd51fd16a504145fca45 100644 (file)
@@ -2715,14 +2715,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
 
 #pragma unroll
         for (int l = 0; l < QR2_XXS; ++l) {
-            const int * grid_pos = (const int *) (iq2xxs_grid + aux8[l]);
-            const int signs_packed = ksigns_iq2xs[(aux32 >> (7*l)) & 0x7F];
+            const uint2 grid_pos = ((const uint2*)iq2xxs_grid)[aux8[l]];
+            const uint32_t signs = unpack_ksigns(aux32 >> (7 * l));
 
-            const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000);
-            const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0);
+            const int signs0 = __vcmpne4(signs & 0x08040201, 0);
+            const int grid0 = __vsub4(grid_pos.x ^ signs0, signs0);
 
-            const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
-            const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
+            const int signs1 = __vcmpne4(signs & 0x80402010, 0);
+            const int grid1 = __vsub4(grid_pos.y ^ signs1, signs1);
 
 #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0;
@@ -2733,12 +2733,12 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
 #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
         }
 
-        const int ls = aux32 >> 28;
+        const int ls = aux32 >> 27 | 1; // (scale * 2 + 1)
         const float d = bxi->d;
 #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
-        x_df[i*MMQ_MMA_TILE_X_K_Q8_0   + kqsx] = (ls*d + d/2)/4;
+        x_df[i*MMQ_MMA_TILE_X_K_Q8_0   + kqsx] = d * ls / 8; // (d * scale + d / 2) / 4
 #else
-        x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4;
+        x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = d * ls / 8; // (d * scale + d / 2) / 4
 #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE)  || defined(AMD_WMMA_AVAILABLE)
     }
 }
@@ -2776,11 +2776,14 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
 
     #pragma unroll
         for (int l = 0; l < QR2_XS; ++l) {
-            const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l] & 0x000001FF));
-            const uint32_t * signs    = (const uint32_t *)(ksigns64   + (q2[l] >> 9));
+            const uint2 grid_pos = ((const uint2*)iq2xs_grid)[q2[l] & 0x1FF];
+            const uint32_t signs = unpack_ksigns(q2[l] >> 9);
+
+            const int signs0 = __vcmpne4(signs & 0x08040201, 0);
+            const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
 
-            const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
-            const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
+            const int signs1 = __vcmpne4(signs & 0x80402010, 0);
+            const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
 
 #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
             x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l;
@@ -2904,11 +2907,13 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
 #pragma unroll
         for (int l = 0; l < QR3_XXS; ++l) {
             const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]);
+            const uint32_t signs = unpack_ksigns(aux32 >> (7*l));
 
-            const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l)) & 0x7F));
+            const int signs0 = __vcmpne4(signs & 0x08040201, 0);
+            const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
 
-            const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
-            const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
+            const int signs1 = __vcmpne4(signs & 0x80402010, 0);
+            const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
 
 #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
             x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l;
index 6baab1176ffe1cf0c85dcc4babf0ceb34bed540e..ab803aca21b1d13ac31d3c156b14d766c924fe13 100644 (file)
@@ -94,6 +94,15 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con
 #endif
 }
 
+static __device__ __forceinline__ uint32_t unpack_ksigns(const uint8_t v) {
+    // v is a 7 bit int, with the 8th sign being encodable as popcnt
+    // with xor we can "correct" the bit instead of having to mask
+    const uint32_t p = __popc(v) & 1;
+    const uint32_t s = v ^ p << 7;
+    // broadcast over uint to allow for 0x08040201 / 0x80402010 as selectors
+    return s * 0x01010101;
+}
+
 // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called
 // MMVQ = mul_mat_vec_q, MMQ = mul_mat_q
 
@@ -905,22 +914,22 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
     int sumi = 0;
 #pragma unroll
     for (int k0 = 0; k0 < 8; k0 += 2) {
-        const int * grid_pos = (const int *) (iq2xxs_grid + aux8[k0/2]);
-        const int signs_packed = ksigns_iq2xs[(aux32 >> (7*k0/2)) & 0x7F];
+        const uint2 grid_pos = ((const uint2*)iq2xxs_grid)[aux8[k0/2]];
+        const uint32_t signs = unpack_ksigns(aux32 >> (7 * k0 / 2));
 
-        const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000);
-        const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0);
+        const int signs0 = __vcmpne4(signs & 0x08040201, 0);
+        const int grid0 = __vsub4(grid_pos.x ^ signs0, signs0);
         const int u0 = get_int_b4(bq8_1[iqs/2].qs, k0 + 0);
         sumi = ggml_cuda_dp4a(grid0, u0, sumi);
 
-        const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000);
-        const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1);
+        const int signs1 = __vcmpne4(signs & 0x80402010, 0);
+        const int grid1 = __vsub4(grid_pos.y ^ signs1, signs1);
         const int u1 = get_int_b4(bq8_1[iqs/2].qs, k0 + 1);
         sumi = ggml_cuda_dp4a(grid1, u1, sumi);
     }
 
-    const int ls = aux32 >> 28;
-    sumi = (ls*sumi + sumi/2)/4;
+    const int ls = aux32 >> 27 | 1; // (scale * 2 + 1)
+    sumi = sumi * ls / 8;           // (sumi * scale + sumi / 2) / 4
     const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds);
     return d * sumi;
 }
@@ -942,13 +951,15 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
     int sumi1 = 0;
 #pragma unroll
     for (int l0 = 0; l0 < 8; l0 += 2) {
-        const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l0/2] & 0x000001FF));
-        const uint32_t * signs    = (const uint32_t *)(ksigns64   + (q2[l0/2] >> 9));
-
-        const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]);
-        const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]);
+        const uint2 grid_pos = ((const uint2*)iq2xs_grid)[q2[l0/2] & 0x1FF];
+        const uint32_t signs = unpack_ksigns(q2[l0/2] >> 9);
 
+        const int signs0 = __vcmpne4(signs & 0x08040201, 0);
+        const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
         const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
+
+        const int signs1 = __vcmpne4(signs & 0x80402010, 0);
+        const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
         const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
 
         if (l0 < 4) {
@@ -1028,13 +1039,16 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
 #pragma unroll
     for (int l0 = 0; l0 < 8; l0 += 2) {
         const int2 grid_pos = make_int2(iq3xxs_grid[q3[l0 + 0]], iq3xxs_grid[q3[l0 + 1]]);
+        const uint32_t signs = unpack_ksigns(aux32 >> (7*l0/2));
 
-        const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l0/2)) & 0x7F));
-
-        const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]);
-        const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]);
+        const int signs0 = __vcmpne4(signs & 0x08040201, 0);
+        const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0);
 
         const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0);
+
+        const int signs1 = __vcmpne4(signs & 0x80402010, 0);
+        const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1);
+
         const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1);
 
         sumi = ggml_cuda_dp4a(grid_l, u0, sumi);