]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml-cuda: Add NVFP4 dp4a kernel (#20644)
authorMichael Wand <redacted>
Thu, 26 Mar 2026 08:54:03 +0000 (01:54 -0700)
committerGitHub <redacted>
Thu, 26 Mar 2026 08:54:03 +0000 (09:54 +0100)
Added check for dst_t to cuda_cast template for float
Restored ggml_cuda_ue4m3_to_fp32, changed vecdot ints to int32ts
Added CUDART/HIP Check and HIP/fp8 include
Added NVFP4 to Test-backend-ops
Added hip_fp8_e4m3 to __nv_fp8_e4m3 typedef

---------

Co-authored-by: Johannes Gäßler <redacted>
ggml/src/ggml-cuda/common.cuh
ggml/src/ggml-cuda/convert.cu
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/mmvq.cu
ggml/src/ggml-cuda/vecdotq.cuh
ggml/src/ggml-cuda/vendors/cuda.h
ggml/src/ggml-cuda/vendors/hip.h
tests/test-backend-ops.cpp

index 36d8a3aaab29e2bddd556e1a9acadedd110692bb..9f93c70d21d3504aca467aa5f5cd4a754f4d61d2 100644 (file)
@@ -799,6 +799,16 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
 #endif // CUDART_VERSION >= 12050
 }
 
+static __device__ __forceinline__ float ggml_cuda_ue4m3_to_fp32(uint8_t x) {
+#ifdef FP8_AVAILABLE
+    const uint32_t bits = x * (x != 0x7F && x != 0xFF); // Convert NaN to 0.0f to match CPU implementation.
+    const __nv_fp8_e4m3 xf = *reinterpret_cast<const __nv_fp8_e4m3 *>(&bits);
+    return static_cast<float>(xf) / 2;
+#else
+    NO_DEVICE_CODE;
+#endif // FP8_AVAILABLE
+}
+
 __device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) {
     const uint8_t sign_bit = (x < 0.0f) << 3;
     float         ax       = fabsf(x) * e;
@@ -931,6 +941,13 @@ struct ggml_cuda_type_traits<GGML_TYPE_MXFP4> {
     static constexpr int qi = QI_MXFP4;
 };
 
+template<>
+struct ggml_cuda_type_traits<GGML_TYPE_NVFP4> {
+    static constexpr int qk = QK_NVFP4;
+    static constexpr int qr = QR_NVFP4;
+    static constexpr int qi = QI_NVFP4;
+};
+
 template<>
 struct ggml_cuda_type_traits<GGML_TYPE_Q2_K> {
     static constexpr int qk = QK_K;
index b70492c7d6cf7eb7c24a8b351600da2e1fafb3c7..79ccfe568a234204bfab4b40c9470d497f646015 100644 (file)
@@ -617,6 +617,45 @@ static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t
     dequantize_block_mxfp4<<<nb, 32, 0, stream>>>(vx, y);
 }
 
+template <typename dst_t>
+static __global__ void dequantize_block_nvfp4(
+        const void * __restrict__ vx,
+        dst_t * __restrict__ yy,
+        const int64_t ne) {
+    const int64_t i = blockIdx.x;
+    const int     tid = threadIdx.x;
+
+    const int64_t base = i * QK_NVFP4;
+    if (base >= ne) {
+        return;
+    }
+
+    const block_nvfp4 * x = (const block_nvfp4 *) vx;
+    const block_nvfp4 & xb = x[i];
+
+    const int sub = tid / (QK_NVFP4_SUB / 2);
+    const int j = tid % (QK_NVFP4_SUB / 2);
+
+    const float d = ggml_cuda_ue4m3_to_fp32(xb.d[sub]);
+    const uint8_t q = xb.qs[sub * (QK_NVFP4_SUB / 2) + j];
+
+    const int64_t y0 = base + sub * QK_NVFP4_SUB + j;
+    const int64_t y1 = y0 + QK_NVFP4_SUB / 2;
+
+    yy[y0] = ggml_cuda_cast<dst_t>(d * kvalues_mxfp4[q & 0x0F]);
+    yy[y1] = ggml_cuda_cast<dst_t>(d * kvalues_mxfp4[q >> 4]);
+}
+
+template <typename dst_t>
+static void dequantize_row_nvfp4_cuda(
+        const void * vx,
+        dst_t * y,
+        const int64_t k,
+        cudaStream_t stream) {
+    GGML_ASSERT(k % QK_NVFP4 == 0);
+    const int nb = k / QK_NVFP4;
+    dequantize_block_nvfp4<<<nb, 32, 0, stream>>>(vx, y, k);
+}
 template <typename src_t, typename dst_t>
 static __global__ void convert_unary(
         const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01,
@@ -715,6 +754,8 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
             return dequantize_row_iq3_s_cuda;
         case GGML_TYPE_MXFP4:
             return dequantize_row_mxfp4_cuda;
+        case GGML_TYPE_NVFP4:
+            return dequantize_row_nvfp4_cuda;
         case GGML_TYPE_F32:
             return convert_unary_cont_cuda<float>;
         case GGML_TYPE_BF16:
@@ -766,6 +807,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
             return dequantize_row_iq3_s_cuda;
         case GGML_TYPE_MXFP4:
             return dequantize_row_mxfp4_cuda;
+        case GGML_TYPE_NVFP4:
+            return dequantize_row_nvfp4_cuda;
         case GGML_TYPE_F16:
             return convert_unary_cont_cuda<half>;
         case GGML_TYPE_BF16:
index a31e843e153e87eeb7e8f540004b28aaf116eb68..cc80eb3ffc2a4c1830bcf1e70227f85bcf6972ae 100644 (file)
@@ -1297,7 +1297,12 @@ static void ggml_cuda_op_mul_mat_cublas(
     const bool supports_bf16 = GGML_CUDA_CC_IS_NVIDIA(cc) || GGML_CUDA_CC_IS_AMD(cc) ||
         (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2);
 
-    const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
+    const bool use_fp16 =
+        src0->type != GGML_TYPE_NVFP4 &&
+        (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) &&
+        ggml_is_contiguous(src0) &&
+        row_diff == src0->ne[1] &&
+        dst->op_params[0] == GGML_PREC_DEFAULT;
 
     if (supports_bf16 && src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
         ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));
@@ -4781,6 +4786,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                     case GGML_TYPE_Q5_1:
                     case GGML_TYPE_Q8_0:
                     case GGML_TYPE_MXFP4:
+#ifdef FP8_AVAILABLE
+                    case GGML_TYPE_NVFP4:
+#endif // FP8_AVAILABLE
                     case GGML_TYPE_Q2_K:
                     case GGML_TYPE_Q3_K:
                     case GGML_TYPE_Q4_K:
index 024b3d8cf22b740f010034e635ac67f10900de19..66bd8beeae7b2a702d1799bc3d918ef3ea3bbab7 100644 (file)
@@ -15,6 +15,7 @@ static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type)
         case GGML_TYPE_Q5_1:    return vec_dot_q5_1_q8_1;
         case GGML_TYPE_Q8_0:    return vec_dot_q8_0_q8_1;
         case GGML_TYPE_MXFP4:   return vec_dot_mxfp4_q8_1;
+        case GGML_TYPE_NVFP4:   return vec_dot_nvfp4_q8_1;
         case GGML_TYPE_Q2_K:    return vec_dot_q2_K_q8_1;
         case GGML_TYPE_Q3_K:    return vec_dot_q3_K_q8_1;
         case GGML_TYPE_Q4_K:    return vec_dot_q4_K_q8_1;
@@ -41,6 +42,7 @@ static constexpr __host__ __device__ int get_vdr_mmvq(ggml_type type) {
         case GGML_TYPE_Q5_1:    return VDR_Q5_1_Q8_1_MMVQ;
         case GGML_TYPE_Q8_0:    return VDR_Q8_0_Q8_1_MMVQ;
         case GGML_TYPE_MXFP4:   return VDR_MXFP4_Q8_1_MMVQ;
+        case GGML_TYPE_NVFP4:   return VDR_NVFP4_Q8_1_MMVQ;
         case GGML_TYPE_Q2_K:    return VDR_Q2_K_Q8_1_MMVQ;
         case GGML_TYPE_Q3_K:    return VDR_Q3_K_Q8_1_MMVQ;
         case GGML_TYPE_Q4_K:    return VDR_Q4_K_Q8_1_MMVQ;
@@ -626,6 +628,12 @@ static void mul_mat_vec_q_switch_type(
                  nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
                  nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
             break;
+        case GGML_TYPE_NVFP4:
+            mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_NVFP4>
+                (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
+                 nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
+                 nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
+            break;
         case GGML_TYPE_Q2_K:
             mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
                 (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
index ab803aca21b1d13ac31d3c156b14d766c924fe13..40b2b41e7e82955c6e040e2cc65ff6f3b2b59b3c 100644 (file)
@@ -322,6 +322,38 @@ static __device__ __forceinline__ float vec_dot_mxfp4_q8_1(
     return d * sumi;
 }
 
+#define VDR_NVFP4_Q8_1_MMVQ 4
+#define VDR_NVFP4_Q8_1_MMQ  8
+
+static __device__ __forceinline__ float vec_dot_nvfp4_q8_1(
+                                        const void * __restrict__ vbq,
+                                        const block_q8_1 * __restrict__ bq8_1,
+                                        const int32_t & kbx,
+                                        const int32_t & iqs) {
+
+    const block_nvfp4 * bq4 = (const block_nvfp4 *) vbq + kbx;
+    float sum = 0.0f;
+#pragma unroll
+    for (int i = 0; i < VDR_NVFP4_Q8_1_MMVQ/2; i++) {
+        const int32_t iqs0 = iqs + 2*i;
+        const int32_t iqs1 = iqs0 + 1;
+        const int32_t is = iqs0 >> 1;
+        const int2 v0 = get_int_from_table_16(get_int_b4(bq4->qs, iqs0), kvalues_mxfp4);
+        const int2 v1 = get_int_from_table_16(get_int_b4(bq4->qs, iqs1), kvalues_mxfp4);
+        const block_q8_1 * bq8 = bq8_1 + (is >> 1);
+        const int32_t i8 = ((is & 1) << 2);
+
+        int sumi = ggml_cuda_dp4a(v0.x, get_int_b4(bq8->qs, i8 + 0), 0);
+        sumi = ggml_cuda_dp4a(v0.y, get_int_b4(bq8->qs, i8 + 2), sumi);
+        sumi = ggml_cuda_dp4a(v1.x, get_int_b4(bq8->qs, i8 + 1), sumi);
+        sumi = ggml_cuda_dp4a(v1.y, get_int_b4(bq8->qs, i8 + 3), sumi);
+
+        const float d = ggml_cuda_ue4m3_to_fp32(bq4->d[is]) * __low2float(bq8->ds);
+        sum += d * float(sumi);
+    }
+
+    return sum;
+}
 #define VDR_Q2_K_Q8_1_MMVQ 1
 #define VDR_Q2_K_Q8_1_MMQ  4
 
index ba032cfab4b8d2ad7f199433686380efdad13fb9..07bc47df3b8f72cdf3855e220fbb92366eb56900 100644 (file)
@@ -6,9 +6,10 @@
 #include <cuda_bf16.h>
 #include <cuda_fp16.h>
 
-#if CUDART_VERSION >= 12050
+#if CUDART_VERSION >= 11080
 #include <cuda_fp8.h>
-#endif // CUDART_VERSION >= 12050
+#define FP8_AVAILABLE
+#endif // CUDART_VERSION >= 11080
 
 #if CUDART_VERSION >= 12080
 #include <cuda_fp4.h>
index 35d1e1a06398b42616a1791605813101ae745ece..9d9ba1ee219e656a03faf8dc01201186918f54c8 100644 (file)
 typedef __hip_bfloat16 nv_bfloat16;
 typedef __hip_bfloat162 nv_bfloat162;
 
+#if HIP_VERSION >= 60200000
+#include <hip/hip_fp8.h>
+typedef __hip_fp8_e4m3 __nv_fp8_e4m3;
+#define FP8_AVAILABLE
+#endif // HIP_VERSION >= 60200000
+
 typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
 typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
 static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
index ac284db2b8a4c6d4feb578797d959904ebb0180e..6a4f9b634b222643ab6bf376851a8bf71e8dc46b 100644 (file)
@@ -7284,7 +7284,7 @@ static const ggml_type all_types[] = {
     GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
     GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
     GGML_TYPE_Q8_0,
-    GGML_TYPE_MXFP4,
+    GGML_TYPE_MXFP4, GGML_TYPE_NVFP4,
     GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
     GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
     GGML_TYPE_Q6_K,
@@ -7300,7 +7300,7 @@ static const ggml_type base_types[] = {
     GGML_TYPE_Q4_0,
     GGML_TYPE_Q4_1, // for I8MM tests
     GGML_TYPE_Q4_K,
-    GGML_TYPE_MXFP4, // TODO: or "other"
+    GGML_TYPE_MXFP4, GGML_TYPE_NVFP4, // TODO: or "other"
     GGML_TYPE_IQ2_XXS
 };