]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
CUDA: Accelerate MXFP4 table lookup using `__byte_perm` (llama/15451)
authorQeeweew <redacted>
Mon, 25 Aug 2025 21:21:22 +0000 (05:21 +0800)
committerGeorgi Gerganov <redacted>
Fri, 5 Sep 2025 09:54:04 +0000 (12:54 +0300)
* CUDA: optimize get_int_from_table_16

* CUDA: use v_perm_b32 to replace byte_perm on AMD GPUs

* revise documentation

---------

Co-authored-by: xix <redacted>
Co-authored-by: Johannes Gäßler <redacted>
src/ggml-cuda/vecdotq.cuh

index d60292b83b1067dd202ba3baa15b1485a9f5ac4f..6baab1176ffe1cf0c85dcc4babf0ceb34bed540e 100644 (file)
@@ -28,7 +28,58 @@ static __device__ __forceinline__ int get_int_b4(const void * x, const int & i32
     return ((const int *) x)[i32]; // assume at least 4 byte alignment
 }
 
+// q4 contains 8 indices with 4 bit each.
+// This function selects those bytes from table that are at those indices and returns them as int2.
+// The first int contains the bytes with even indices in q4, the second int contains the bytes with odd indices in q4.
 static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, const int8_t * table) {
+#if defined(GGML_USE_HIP)
+    // Load the 16-byte table into four 32-bit unsigned integers.
+    const uint32_t *values = (const uint32_t *)table;
+
+    const uint32_t q_even = q4;
+    const uint32_t q_odd  = (q4 >> 4);
+
+    // Perform lookups in the lower half of the table (indices 0-7).
+    uint32_t v_even_low = __builtin_amdgcn_perm(values[1], values[0], q_even & 0x07070707);
+    uint32_t v_odd_low = __builtin_amdgcn_perm(values[1], values[0], q_odd & 0x07070707);
+
+    // Perform lookups in the upper half of the table (indices 8-15).
+    uint32_t v_even_high = __builtin_amdgcn_perm(values[3], values[2], q_even & 0x07070707);
+    uint32_t v_odd_high = __builtin_amdgcn_perm(values[3], values[2], q_odd & 0x07070707);
+
+    // Select between the low and high results based on the MSB of each index nibble.
+    uint32_t mask_even = 0x03020100 | ((q_even & 0x08080808) >> 1);
+    uint32_t res_x = __builtin_amdgcn_perm(v_even_high, v_even_low, mask_even);
+    uint32_t mask_odd = 0x03020100 | ((q_odd & 0x08080808) >> 1);
+    uint32_t res_y = __builtin_amdgcn_perm(v_odd_high, v_odd_low, mask_odd);
+
+    return make_int2(res_x, res_y);
+#elif !defined(GGML_USE_MUSA)
+    // CUDA does not have an instruction for selecting bytes with 4 bit indices.
+    // However, __byte_perm is an instruction that selects bytes with 3 bit indices that can be used instead.
+    const uint32_t * table32 = (const uint32_t *) table;
+
+    // __byte_perm selects bytes based on the lower 16 bits in its third argument.
+    // Therefore, do 2 iterations over the 32 bits in q4 with 0 and 16 shift.
+    // To handle the fourth bit, first call _byte_perm both for the low and the high 64 bit of table, using the low 3 bits.
+    // Then, call __byte_perm again to select from the low and high bytes based on the fourth bit.
+    uint32_t tmp[2];
+    const uint32_t low_high_selection_indices = (0x32103210 | ((q4 & 0x88888888) >> 1));
+#pragma unroll
+    for (uint32_t i = 0; i < 2; ++i) {
+        const uint32_t shift = 16 * i;
+
+        const uint32_t low  = __byte_perm(table32[0], table32[1], q4 >> shift);
+        const uint32_t high = __byte_perm(table32[2], table32[3], q4 >> shift);
+        tmp[i] = __byte_perm(low, high, low_high_selection_indices >> shift);
+    }
+
+    // tmp contains the bytes from tyble in the same order as the 4 bit indices in q4.
+    // However, for the result we need ints with all even/odd 4 bit indices in q4.
+    // Therefore, 2 more calls to __byte_perm to put the bytes in the correct order.
+    return make_int2(__byte_perm(tmp[0], tmp[1], 0x6420), __byte_perm(tmp[0], tmp[1], 0x7531));
+#else
+    // Generic implementation.
     const int      q0_32  = (q4 >> 0) & 0x0F0F0F0F;
     const int8_t * q0_8   = (const int8_t *) &q0_32;
     const char4    val0_8 = make_char4(
@@ -40,6 +91,7 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con
         table[q1_8[0]], table[q1_8[1]], table[q1_8[2]], table[q1_8[3]]);
 
     return make_int2(*((const int *) &val0_8), *((const int *) &val1_8));
+#endif
 }
 
 // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called