]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml-cuda: Add generic NVFP4 MMQ kernel (#21074)
authorMichael Wand <redacted>
Wed, 1 Apr 2026 10:04:58 +0000 (03:04 -0700)
committerGitHub <redacted>
Wed, 1 Apr 2026 10:04:58 +0000 (12:04 +0200)
* Introduced NVFP4 generic MMQ kernel

* Added extra FP8 guard, hope to solve ci HIP failure

* Rename tiles and use HIP_FP8_AVAILABLE

* Removed remaning FP8 straggler and added const int

* Const

* Removed DECL_MMQ_CASE artifact

* Removed newline

* Removed space after else

* Changed HIP FP8 NVFP4 conversion gate

* Added new line to bottom of mmq.cu 270

* Removed extra spaces

* Removed single space in front of else on line 814

* Added NVFP4 to generate cu script so HIP can see it, further tightened logic

* Include generated mmq-instance-nvfp4.cu

* Added NVFP4 mmq to HIP Check ignore list

* Update ggml/src/ggml-cuda/mmq.cuh

Changed to Q3_K tile to read MMQ_MMA_TILE_X_K_NVFP4

Co-authored-by: Johannes Gäßler <redacted>
* Update ggml/src/ggml-cuda/mmq.cuh

Changed to Q3_K tile to read MMQ_MMA_TILE_X_K_NVFP4 in tile assert

Co-authored-by: Johannes Gäßler <redacted>
* Update ggml/src/ggml-cuda/mmq.cuh

Added function name ending for end if

Co-authored-by: Johannes Gäßler <redacted>
* Added function names to closing endif

Co-authored-by: Johannes Gäßler <redacted>
---------

Co-authored-by: Johannes Gäßler <redacted>
ggml/src/ggml-cuda/common.cuh
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/mmq.cu
ggml/src/ggml-cuda/mmq.cuh
ggml/src/ggml-cuda/template-instances/generate_cu_files.py
ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu [new file with mode: 0644]
scripts/hip/gcn-cdna-vgpr-check.py

index 7d7f20af3a04f5e941f060d9052f6e971326b8d3..9affe023403e845e2adf1847496650ff18bb8ea0 100644 (file)
@@ -800,19 +800,32 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
 }
 
 static __device__ __forceinline__ float ggml_cuda_ue4m3_to_fp32(uint8_t x) {
-#ifdef FP8_AVAILABLE
-    const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation.
-#if defined(GGML_USE_HIP) && defined(CDNA3)
-    // ROCm dose not support fp8 in software on devices with fp8 hardware,
+#if defined(GGML_USE_HIP) && defined(CDNA3) && defined(FP8_AVAILABLE) && HIP_VERSION >= 60200000
+    // ROCm does not support fp8 in software on devices with fp8 hardware,
     // but CDNA3 supports only e4m3_fnuz (no inf).
+    const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation.
     const __hip_fp8_e4m3_fnuz xf = *reinterpret_cast<const __hip_fp8_e4m3_fnuz *>(&bits);
+    return static_cast<float>(xf) / 2;
 #else
+#if defined(FP8_AVAILABLE) && !defined(GGML_USE_HIP)
+    const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation.
     const __nv_fp8_e4m3 xf = *reinterpret_cast<const __nv_fp8_e4m3 *>(&bits);
-#endif // defined(GGML_USE_HIP) && defined(GGML_USE_HIP)
     return static_cast<float>(xf) / 2;
 #else
-    NO_DEVICE_CODE;
-#endif // FP8_AVAILABLE
+    if (x == 0 || (x == 0x7F && x != 0xFF)) { // Convert NaN to 0.0f
+        return 0.0f;
+    }
+    const int exp = (x >> 3) & 0xF;
+    const int man = x & 0x7;
+    float raw;
+    if (exp == 0) {
+        raw = ldexpf((float) man, -9);
+    } else {
+        raw = ldexpf(1.0f + (float) man / 8.0f, exp - 7);
+    }
+    return static_cast<float>(raw / 2);
+#endif // defined(FP8_AVAILABLE) && !defined(GGML_USE_HIP)
+#endif // defined(GGML_USE_HIP) && defined(CDNA3) && defined(FP8_AVAILABLE) && HIP_VERSION >= 60200000
 }
 
 __device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) {
index d1239b1c5f72060b454d4221c1fccf73a5900cf8..75b62129adea392c229eb411d457bbe29d11f3a5 100644 (file)
@@ -4791,9 +4791,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                     case GGML_TYPE_Q5_1:
                     case GGML_TYPE_Q8_0:
                     case GGML_TYPE_MXFP4:
-#ifdef FP8_AVAILABLE
                     case GGML_TYPE_NVFP4:
-#endif // FP8_AVAILABLE
                     case GGML_TYPE_Q2_K:
                     case GGML_TYPE_Q3_K:
                     case GGML_TYPE_Q4_K:
index 9a69f41d1598fb5c8326278348d731e4ccb0975c..27b4145ac9abf023bd87780ea20b009a3d7f9852 100644 (file)
@@ -23,6 +23,9 @@ static void ggml_cuda_mul_mat_q_switch_type(ggml_backend_cuda_context & ctx, con
         case GGML_TYPE_MXFP4:
             mul_mat_q_case<GGML_TYPE_MXFP4>(ctx, args, stream);
             break;
+        case GGML_TYPE_NVFP4:
+            mul_mat_q_case<GGML_TYPE_NVFP4>(ctx, args, stream);
+            break;
         case GGML_TYPE_Q2_K:
             mul_mat_q_case<GGML_TYPE_Q2_K>(ctx, args, stream);
             break;
@@ -273,6 +276,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
         case GGML_TYPE_Q5_1:
         case GGML_TYPE_Q8_0:
         case GGML_TYPE_MXFP4:
+        case GGML_TYPE_NVFP4:
         case GGML_TYPE_Q2_K:
         case GGML_TYPE_Q3_K:
         case GGML_TYPE_Q4_K:
@@ -362,5 +366,4 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
     }
 
     return (!GGML_CUDA_CC_IS_CDNA(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
-
 }
index 255e59f6fc68729e1ef0fd51fd16a504145fca45..51e8dad4ce7b69122a8e8df42167b9afb35536e4 100644 (file)
@@ -68,6 +68,8 @@ static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
             return MMQ_Q8_1_DS_LAYOUT_D4;
         case GGML_TYPE_MXFP4:
             return MMQ_Q8_1_DS_LAYOUT_D4;
+        case GGML_TYPE_NVFP4:
+            return MMQ_Q8_1_DS_LAYOUT_D4;
         case GGML_TYPE_Q2_K:
             return MMQ_Q8_1_DS_LAYOUT_D2S6;
         case GGML_TYPE_Q3_K:
@@ -189,6 +191,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
         case GGML_TYPE_Q5_1:    return MMQ_DP4A_TXS_Q8_1;
         case GGML_TYPE_Q8_0:    return MMQ_DP4A_TXS_Q8_0;
         case GGML_TYPE_MXFP4:   return MMQ_DP4A_TXS_Q8_1;
+        case GGML_TYPE_NVFP4:   return MMQ_DP4A_TXS_Q8_0_16;
         case GGML_TYPE_Q2_K:    return MMQ_DP4A_TXS_Q2_K;
         case GGML_TYPE_Q3_K:    return MMQ_DP4A_TXS_Q3_K;
         case GGML_TYPE_Q4_K:    return MMQ_DP4A_TXS_Q4_K;
@@ -206,12 +209,13 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
     }
 }
 
-#define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0                   + 4)
-#define MMQ_MMA_TILE_X_K_FP4  (2*MMQ_TILE_NE_K + 8                                       + 4)
-#define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0                   + 4)
-#define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K                           + 4)
-#define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2                         + 4)
-#define MMQ_MMA_TILE_X_K_Q6_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K   + MMQ_TILE_NE_K/8 + 7)
+#define MMQ_MMA_TILE_X_K_Q8_0  (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0                   + 4)
+#define MMQ_MMA_TILE_X_K_FP4   (2*MMQ_TILE_NE_K + 8                                       + 4) // MXFP4
+#define MMQ_MMA_TILE_X_K_NVFP4 (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2                         + 4) // NVFP4
+#define MMQ_MMA_TILE_X_K_Q8_1  (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0                   + 4)
+#define MMQ_MMA_TILE_X_K_Q2_K  (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K                           + 4)
+#define MMQ_MMA_TILE_X_K_Q3_K  (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2                         + 4)
+#define MMQ_MMA_TILE_X_K_Q6_K  (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI6_K   + MMQ_TILE_NE_K/8 + 7)
 
 static_assert(MMQ_MMA_TILE_X_K_Q8_0 % 8 == 4, "Wrong padding.");
 static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
@@ -220,6 +224,8 @@ static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
 static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
 static_assert(MMQ_MMA_TILE_X_K_FP4  % 8 == 4, "Wrong padding.");
 static_assert(MMQ_MMA_TILE_X_K_FP4 == MMQ_MMA_TILE_X_K_Q8_1, "Wrong tile size for MXFP4");
+static_assert(MMQ_MMA_TILE_X_K_NVFP4 % 8 == 4, "Wrong padding.");
+
 
 static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
     switch (type) {
@@ -230,6 +236,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
         case GGML_TYPE_Q8_0:    return MMQ_MMA_TILE_X_K_Q8_0;
         // tile sizes are the same for Q8_1 and FP4 for blackwell
         case GGML_TYPE_MXFP4:   return MMQ_MMA_TILE_X_K_Q8_1;
+        case GGML_TYPE_NVFP4:   return MMQ_MMA_TILE_X_K_NVFP4;
         case GGML_TYPE_Q2_K:    return MMQ_MMA_TILE_X_K_Q2_K;
         case GGML_TYPE_Q3_K:    return MMQ_MMA_TILE_X_K_Q3_K;
         case GGML_TYPE_Q4_K:    return MMQ_MMA_TILE_X_K_Q8_1;
@@ -826,6 +833,65 @@ static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restr
     }
 }
 
+
+template <int mmq_y, bool need_check>
+static __device__ __forceinline__ void load_tiles_nvfp4(const char * __restrict__ x,
+                                                        int * __restrict__ x_tile,
+                                                        const int kb0,
+                                                        const int i_max,
+                                                        const int stride) {
+    constexpr int nwarps = mmq_get_nwarps_device();
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+    int   * x_qs = (int   *) x_tile;
+    float * x_df = (float *) (x_qs + MMQ_TILE_NE_K*2);
+#else
+    constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_NVFP4, mmq_y);
+    int   * x_qs = (int   *) x_tile;
+    float * x_df = (float *) (x_qs + txs.qs);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+
+    constexpr int threads_per_row = MMQ_ITER_K / QK_NVFP4;
+    constexpr int rows_per_warp = warp_size / threads_per_row;
+    const int kbx = threadIdx.x % threads_per_row;
+    const int row_in_warp = threadIdx.x / threads_per_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
+        int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
+
+        if constexpr (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_nvfp4 * bxi = (const block_nvfp4 *) x + kb0 + i * stride + kbx;
+        const uint32_t * __restrict__ src_qs = reinterpret_cast<const uint32_t *>(bxi->qs);
+        const int kqs = 16 * kbx;
+        const int ksc = 4 * kbx;
+
+#pragma unroll
+        for (int sub = 0; sub < QK_NVFP4 / QK_NVFP4_SUB; ++sub) {
+            const int2 q0 = get_int_from_table_16(src_qs[2 * sub + 0], kvalues_mxfp4);
+            const int2 q1 = get_int_from_table_16(src_qs[2 * sub + 1], kvalues_mxfp4);
+
+#if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+            x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 0] = q0.x;
+            x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 1] = q1.x;
+            x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 2] = q0.y;
+            x_qs[i * MMQ_MMA_TILE_X_K_NVFP4 + kqs + 4 * sub + 3] = q1.y;
+            x_df[i * MMQ_MMA_TILE_X_K_NVFP4 + ksc + sub] = ggml_cuda_ue4m3_to_fp32(bxi->d[sub]);
+#else
+            x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 0] = q0.x;
+            x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 1] = q1.x;
+            x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 2] = q0.y;
+            x_qs[i * (2 * MMQ_TILE_NE_K + 1) + kqs + 4 * sub + 3] = q1.y;
+            x_df[i * (2 * MMQ_TILE_NE_K * 2 / QI_NVFP4) + i / (QK_NVFP4_SUB / QI_NVFP4) + ksc + sub] = ggml_cuda_ue4m3_to_fp32(bxi->d[sub]);
+#endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
+        }
+    }
+}
+
 template <int mmq_x, int mmq_y>
 static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
@@ -1229,7 +1295,7 @@ static __device__ __forceinline__ void vec_dot_q8_1_q8_1_mma(
 #endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 }
 
-// Used for Q3_K, IQ2_S, and IQ2_XS
+// Used for NVFP4, Q3_K, IQ2_S, and IQ2_XS
 template <int mmq_x, int mmq_y>
 static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_dp4a(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
@@ -3261,6 +3327,14 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
     static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
 };
 
+template <int mmq_x, int mmq_y, bool need_check>
+struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_NVFP4> {
+    static constexpr int              vdr          = VDR_NVFP4_Q8_1_MMQ;
+    static constexpr load_tiles_mmq_t load_tiles   = load_tiles_nvfp4<mmq_y, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_16_q8_1_mma<mmq_x, mmq_y>;
+    static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_16_q8_1_dp4a<mmq_x, mmq_y>;
+};
+
 template <int mmq_x, int mmq_y, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q2_K> {
     static constexpr int              vdr          = VDR_Q2_K_Q8_1_MMQ;
@@ -4069,6 +4143,7 @@ extern DECL_MMQ_CASE(GGML_TYPE_Q5_0);
 extern DECL_MMQ_CASE(GGML_TYPE_Q5_1);
 extern DECL_MMQ_CASE(GGML_TYPE_Q8_0);
 extern DECL_MMQ_CASE(GGML_TYPE_MXFP4);
+extern DECL_MMQ_CASE(GGML_TYPE_NVFP4);
 extern DECL_MMQ_CASE(GGML_TYPE_Q2_K);
 extern DECL_MMQ_CASE(GGML_TYPE_Q3_K);
 extern DECL_MMQ_CASE(GGML_TYPE_Q4_K);
index b7b5832293ec5ae043bfa93d69e60ab7c49926fa..40d51f93fa4d6649695bd4b57a7ea18fdb44f26a 100755 (executable)
@@ -35,7 +35,7 @@ TYPES_MMQ = [
     "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
     "GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K",
     "GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S",
-    "GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS", "GGML_TYPE_MXFP4"
+    "GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS", "GGML_TYPE_MXFP4", "GGML_TYPE_NVFP4"
 ]
 
 SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
diff --git a/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu b/ggml/src/ggml-cuda/template-instances/mmq-instance-nvfp4.cu
new file mode 100644 (file)
index 0000000..2cb140d
--- /dev/null
@@ -0,0 +1,5 @@
+// This file has been autogenerated by generate_cu_files.py, do not edit manually.
+
+#include "../mmq.cuh"
+
+DECL_MMQ_CASE(GGML_TYPE_NVFP4);
index 38db47d3d189f521ca3811c82b88a1486bca4737..bbbce52ef39f9ebf2d8f3111f2c1155bda8e233a 100644 (file)
@@ -139,7 +139,11 @@ def main():
         '_ZL18flash_attn_ext_f16ILi96ELi96ELi4ELi8ELb0ELb0EEvPKcS1_S1_S1_S1_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS5_IjLj3EEiiiiiiiiiiiliiliiiiil',
         '_ZL18flash_attn_ext_vecILi128ELi2EL9ggml_type2ELS0_2ELb0EEvPKcS2_S2_S2_S2_PKiPfP15HIP_vector_typeIfLj2EEffffjfiS6_IjLj3EEiiiiiiiiiiiliiliiiiil',
         '_ZL9mul_mat_qIL9ggml_type10ELi16ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
-        '_ZL9mul_mat_qIL9ggml_type12ELi128ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii'
+        '_ZL9mul_mat_qIL9ggml_type12ELi128ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
+        '_ZL9mul_mat_qIL9ggml_type40ELi112ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
+        '_ZL9mul_mat_qIL9ggml_type40ELi112ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
+        '_ZL9mul_mat_qIL9ggml_type40ELi128ELb0EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii',
+        '_ZL9mul_mat_qIL9ggml_type40ELi128ELb1EEvPKcPKiS4_S4_PfS5_iiiiiiiiiiiiiiiii'
     }
 
     functions = parse_log_file(log_file)