]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA: experimental native mxfp4 support for blackwell (llama/17906)
authorAman Gupta <redacted>
Wed, 24 Dec 2025 14:28:26 +0000 (22:28 +0800)
committerGeorgi Gerganov <redacted>
Wed, 31 Dec 2025 15:52:09 +0000 (17:52 +0200)
* CUDA: experimental native mxfp4 support for blackwell

* optimize load_tiles

* optimize quantize_mxfp4

* cleanup

* first pass review: formatting

* use interleaved layout for mma

* mmq: add assert for size

* use __nv_fp4x4_e2m1

* use iter_k as 512, cleanup

* Use 1200 as blackwell instead of 1000

* address review comments

* mmq: fix stride

* quantize.cu: use reference impl of e8m0 scale

* address review comments

* add 120f-virtual + minor fixes

---------

Co-authored-by: Aman Gupta <aman>
ggml/src/ggml-cuda/CMakeLists.txt
ggml/src/ggml-cuda/common.cuh
ggml/src/ggml-cuda/mma.cuh
ggml/src/ggml-cuda/mmq.cu
ggml/src/ggml-cuda/mmq.cuh
ggml/src/ggml-cuda/quantize.cu
ggml/src/ggml-cuda/quantize.cuh
ggml/src/ggml-cuda/vendors/cuda.h

index 67af1d8ccc1822fb70527a8ac11f9e5a560865c1..f1412e8b196052a71423988f402ee4bdf928d2aa 100644 (file)
@@ -15,6 +15,7 @@ if (CUDAToolkit_FOUND)
         # 80     == Ampere, asynchronous data loading, faster tensor core instructions
         # 86     == RTX 3000, needs CUDA v11.1
         # 89     == RTX 4000, needs CUDA v11.8
+        # 120    == Blackwell, needs CUDA v12.8, FP4 tensor cores
         #
         # XX-virtual == compile CUDA code as PTX, do JIT compilation to binary code on first run
         # XX-real    == compile CUDA code as device code for this specific architecture
@@ -34,6 +35,10 @@ if (CUDAToolkit_FOUND)
             if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.8")
                 list(APPEND CMAKE_CUDA_ARCHITECTURES 89-real)
             endif()
+
+            if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "12.8")
+                list(APPEND CMAKE_CUDA_ARCHITECTURES 120f-virtual)
+            endif()
         endif()
     endif()
     message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
index 9fcb2f9fd2155f432f261a8f51164473ba082858..62e618850bf3e171d0cef0cc37a492e4a4b44163 100644 (file)
 #define GGML_CUDA_CC_TURING          750
 #define GGML_CUDA_CC_AMPERE          800
 #define GGML_CUDA_CC_ADA_LOVELACE    890
+// While BW spans CC 1000, 1100 & 1200, we are integrating Tensor Core instructions available to 1200 family, see
+// https://docs.nvidia.com/cutlass/media/docs/cpp/blackwell_functionality.html#blackwell-sm120-gemms
+#define GGML_CUDA_CC_BLACKWELL       1200
+#define GGML_CUDA_CC_RUBIN           1300
 #define GGML_CUDA_CC_OFFSET_AMD      0x1000000
 #define GGML_CUDA_CC_OFFSET_MTHREADS 0x0100000
 #define GGML_CUDA_CC_IS_NVIDIA(cc)   (cc < GGML_CUDA_CC_OFFSET_MTHREADS)
@@ -246,6 +250,10 @@ static const char * cu_get_error_str(CUresult err) {
 #define AMPERE_MMA_AVAILABLE
 #endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 
+#if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL && __CUDA_ARCH__ < GGML_CUDA_CC_RUBIN
+#    define BLACKWELL_MMA_AVAILABLE
+#endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_BLACKWELL
+
 #if !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 #define CP_ASYNC_AVAILABLE
 #endif // !defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
@@ -316,6 +324,11 @@ static bool cp_async_available(const int cc) {
     return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_AMPERE;
 }
 
+static bool blackwell_mma_available(const int cc) {
+    return GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_BLACKWELL &&
+           ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_RUBIN;
+}
+
 static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
 #if defined(GGML_USE_HIP) && (defined(__GFX9__) || defined(__GFX8__))
     return 64;
@@ -701,6 +714,28 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
 #endif // CUDART_VERSION >= 12050
 }
 
+__device__ __forceinline__ uint8_t ggml_cuda_float_to_fp4_e2m1(float x, float e) {
+    const uint8_t sign_bit = (x < 0.0f) << 3;
+    float         ax       = fabsf(x) * e;
+
+    // Positive LUT
+    static constexpr float pos_lut[8] = { 0.0f, 0.5f, 1.0f, 1.5f, 2.0f, 3.0f, 4.0f, 6.0f };
+
+    int   best_i   = 0;
+    float best_err = fabsf(ax - pos_lut[0]);
+
+#pragma unroll
+    for (int i = 1; i < 8; ++i) {
+        const float err = fabsf(ax - pos_lut[i]);
+        if (err < best_err) {
+            best_err = err;
+            best_i   = i;
+        }
+    }
+
+    return static_cast<uint8_t>(best_i | sign_bit);
+}
+
 // See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
 // Precompute mp (m' in the paper) and L such that division
 // can be computed using a multiply (high 32b of 64b result)
index 3268dadfe83c3a3f272cae336dcbb512147e4a01..df9eed711727353136621172bf301bbe63175bf7 100644 (file)
@@ -900,6 +900,27 @@ namespace ggml_cuda_mma {
 #endif // AMPERE_MMA_AVAILABLE
     }
 
+    static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> &     D,
+                                                            const tile<16, 8, int> & A,
+                                                            const tile<8, 8, int> &  B,
+                                                            uint32_t                 a_scale,
+                                                            uint32_t                 b_scale) {
+#ifdef BLACKWELL_MMA_AVAILABLE
+        const int * Axi = (const int *) A.x;
+        const int * Bxi = (const int *) B.x;
+        float *     Dxi = (float *) D.x;
+
+        asm volatile(
+            "mma.sync.aligned.kind::mxf4.block_scale.scale_vec::2X.m16n8k64.row.col.f32.e2m1.e2m1.f32.ue8m0 "
+            "{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3}, "
+            "%10, {0, 0}, %11, {0, 0};"
+            : "+f"(Dxi[0]), "+f"(Dxi[1]), "+f"(Dxi[2]), "+f"(Dxi[3])
+            : "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]), "r"(a_scale), "r"(b_scale));
+#else
+        GGML_UNUSED_VARS(D, A, B, a_scale, b_scale);
+#endif  // BLACKWELL_MMA_AVAILABLE
+    }
+
     static __device__ __forceinline__ void mma(
             tile<16, 8, float> & D, const tile<16, 8, half2> & A, const tile<8, 8, half2> & B) {
 #ifdef TURING_MMA_AVAILABLE
index f7a2cbca90f2576b8ab2e58cbd795adab865b470..6156dcdae74ac4fd5070ea2f9b17c591724578a8 100644 (file)
@@ -1,3 +1,4 @@
+#include "common.cuh"
 #include "mmq.cuh"
 #include "quantize.cuh"
 #include "mmid.cuh"
@@ -114,6 +115,9 @@ void ggml_cuda_mul_mat_q(
     const bool use_stream_k = (GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA)
                             || GGML_CUDA_CC_IS_CDNA(cc);
 
+    // TODO: tighter pool buffer size vs q8 path
+    const bool use_native_mxfp4 = blackwell_mma_available(cc) && src0->type == GGML_TYPE_MXFP4;
+
     if (!ids) {
         const size_t nbytes_src1_q8_1 = ne13*ne12 * ne11*ne10_padded * sizeof(block_q8_1)/QK8_1 +
             get_mmq_x_max_host(cc)*sizeof(block_q8_1_mmq);
@@ -123,12 +127,24 @@ void ggml_cuda_mul_mat_q(
             const int64_t s11 = src1->nb[1] / ts_src1;
             const int64_t s12 = src1->nb[2] / ts_src1;
             const int64_t s13 = src1->nb[3] / ts_src1;
-            quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type,
-                ne10, s11, s12, s13, ne10_padded, ne11, ne12, ne13, stream);
+            if (use_native_mxfp4) {
+                static_assert(sizeof(block_fp4_mmq) == 4 * sizeof(block_q8_1));
+                quantize_mmq_mxfp4_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
+                                        ne11, ne12, ne13, stream);
+
+            } else {
+                quantize_mmq_q8_1_cuda(src1_d, nullptr, src1_q8_1.get(), src0->type, ne10, s11, s12, s13, ne10_padded,
+                                       ne11, ne12, ne13, stream);
+            }
             CUDA_CHECK(cudaGetLastError());
         }
 
-        const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
+        // Stride depends on quantization format
+        const int64_t s12 = use_native_mxfp4 ?
+                                ne11 * ne10_padded * sizeof(block_fp4_mmq) /
+                                    (8 * QK_MXFP4 * sizeof(int))  // block_fp4_mmq holds 256 values (8 blocks of 32)
+                                :
+                                ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
         const int64_t s13 = ne12*s12;
 
         const mmq_args args = {
@@ -175,12 +191,19 @@ void ggml_cuda_mul_mat_q(
         const int64_t s11 = src1->nb[1] / ts_src1;
         const int64_t s12 = src1->nb[2] / ts_src1;
         const int64_t s13 = src1->nb[2] / ts_src1;
-        quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type,
-            ne10, s11, s12, s13, ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
+
+        if (use_native_mxfp4) {
+            quantize_mmq_mxfp4_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
+                                    ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
+        } else {
+            quantize_mmq_q8_1_cuda(src1_d, ids_src1.get(), src1_q8_1.get(), src0->type, ne10, s11, s12, s13,
+                                   ne10_padded, ne11_flat, ne12_flat, ne13_flat, stream);
+        }
         CUDA_CHECK(cudaGetLastError());
     }
 
-    const int64_t s12 = ne11*ne10_padded * sizeof(block_q8_1)/(QK8_1*sizeof(int));
+    const int64_t s12 = use_native_mxfp4 ? ne11 * ne10_padded * sizeof(block_fp4_mmq) / (8 * QK_MXFP4 * sizeof(int)) :
+                                           ne11 * ne10_padded * sizeof(block_q8_1) / (QK8_1 * sizeof(int));
     const int64_t s13 = ne12*s12;
 
     // Note that ne02 is used instead of ne12 because the number of y channels determines the z dimension of the CUDA grid.
index fa8a72c9c108da5746df08621444a6e8af58eb76..63451ffab7f32678d0ed079a8997e54aff321871 100644 (file)
@@ -11,6 +11,7 @@ using namespace ggml_cuda_mma;
 
 #define MMQ_DP4A_MAX_BATCH_SIZE 64 // Max. batch size to use for dp4a MMQ kernels when FP16 tensor cores are available.
 #define MMQ_ITER_K 256
+#define MMQ_ITER_K_MXFP4_FP4    512
 #define MMQ_NWARPS 8
 
 typedef void (*load_tiles_mmq_t)(const char * __restrict__ x, int * x_tile, const int kbx0, const int i_max, const int stride);
@@ -44,8 +45,15 @@ struct block_q8_1_mmq {
     };
     int8_t qs[4*QK8_1]; // 128 values quantized to 8 bit each
 };
+
+struct block_fp4_mmq {
+    uint32_t d4[4];       // 8 E8M0 scales (1 per 32 values), 2 packed per uint32: d4[0]={s0,s1}, d4[1]={s2,s3}, etc.
+    int8_t   qs[4 * 32];  // 256 FP4 values packed as 4-bit pairs (2 per byte), 8 blocks of 32 values
+};
+
 static_assert(sizeof(block_q8_1_mmq) == 4*QK8_1 + 4*sizeof(half2), "Unexpected block_q8_1_mmq size");
 static_assert(sizeof(block_q8_1_mmq) == 4*sizeof(block_q8_1),      "Unexpected block_q8_1_mmq size");
+static_assert(sizeof(block_fp4_mmq)  == sizeof(block_q8_1_mmq),    "Unexpected block_fp4_mmq size");
 
 static mmq_q8_1_ds_layout mmq_get_q8_1_ds_layout(const ggml_type type_x) {
     switch (type_x) {
@@ -129,6 +137,14 @@ static int get_mmq_y_host(const int cc) {
         ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ? 128 : 64);
 }
 
+static constexpr __device__ int get_iter_k([[maybe_unused]] const ggml_type type) {
+#if defined(BLACKWELL_MMA_AVAILABLE)
+    return type == GGML_TYPE_MXFP4 ? MMQ_ITER_K_MXFP4_FP4 : MMQ_ITER_K;
+#else
+    return MMQ_ITER_K;
+#endif // defined(BLACKWELL_MMA_AVAILABLE)
+}
+
 static constexpr __device__ int get_mmq_y_device() {
 #if defined(GGML_USE_HIP)
 #if defined(RDNA1)
@@ -191,6 +207,7 @@ static constexpr __host__ __device__ tile_x_sizes mmq_get_dp4a_tile_x_sizes(ggml
 }
 
 #define MMQ_MMA_TILE_X_K_Q8_0 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0                   + 4)
+#define MMQ_MMA_TILE_X_K_FP4  (2*MMQ_TILE_NE_K + 8                                       + 4)
 #define MMQ_MMA_TILE_X_K_Q8_1 (2*MMQ_TILE_NE_K + 2*MMQ_TILE_NE_K/QI8_0                   + 4)
 #define MMQ_MMA_TILE_X_K_Q2_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K                           + 4)
 #define MMQ_MMA_TILE_X_K_Q3_K (2*MMQ_TILE_NE_K + MMQ_TILE_NE_K/2                         + 4)
@@ -201,6 +218,8 @@ static_assert(MMQ_MMA_TILE_X_K_Q8_1 % 8 == 4, "Wrong padding.");
 static_assert(MMQ_MMA_TILE_X_K_Q2_K % 8 == 4, "Wrong padding.");
 static_assert(MMQ_MMA_TILE_X_K_Q3_K % 8 == 4, "Wrong padding.");
 static_assert(MMQ_MMA_TILE_X_K_Q6_K % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_FP4  % 8 == 4, "Wrong padding.");
+static_assert(MMQ_MMA_TILE_X_K_FP4 == MMQ_MMA_TILE_X_K_Q8_1, "Wrong tile size for MXFP4");
 
 static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
     switch (type) {
@@ -209,6 +228,7 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
         case GGML_TYPE_Q5_0:    return MMQ_MMA_TILE_X_K_Q8_0;
         case GGML_TYPE_Q5_1:    return MMQ_MMA_TILE_X_K_Q8_1;
         case GGML_TYPE_Q8_0:    return MMQ_MMA_TILE_X_K_Q8_0;
+        // tile sizes are the same for Q8_1 and FP4 for blackwell
         case GGML_TYPE_MXFP4:   return MMQ_MMA_TILE_X_K_Q8_1;
         case GGML_TYPE_Q2_K:    return MMQ_MMA_TILE_X_K_Q2_K;
         case GGML_TYPE_Q3_K:    return MMQ_MMA_TILE_X_K_Q3_K;
@@ -228,7 +248,8 @@ static constexpr __host__ __device__ int mmq_get_mma_tile_x_k(ggml_type type) {
 }
 
 // block_q8_1_mmq has (128 8-bit ints == 32 32-bit ints + 4 32-bit scales)
-#define MMQ_TILE_Y_K (MMQ_TILE_NE_K + MMQ_TILE_NE_K/QI8_1)
+#define MMQ_TILE_Y_K     (MMQ_TILE_NE_K + MMQ_TILE_NE_K / QI8_1)
+#define MMQ_TILE_Y_FP4_K MMQ_TILE_Y_K
 
 static int mmq_get_granularity_host(const int mmq_x, const int cc) {
     if (amd_mfma_available(cc) || amd_wmma_available(cc)) {
@@ -761,6 +782,50 @@ template <int mmq_y, bool need_check> static __device__ __forceinline__ void loa
     }
 }
 
+template <int mmq_y, bool need_check>
+static __device__ __forceinline__ void load_tiles_mxfp4_fp4(const char * __restrict__ x,
+                                                            int * __restrict__ x_tile,
+                                                            const int kbx0,
+                                                            const int i_max,
+                                                            const int stride) {
+    constexpr int nwarps = mmq_get_nwarps_device();
+    constexpr int warp_size = ggml_cuda_get_physical_warp_size();
+
+    int *      x_qs = (int *) x_tile;
+    uint32_t * x_sc = (uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
+
+    const int txi = threadIdx.x;
+
+    constexpr int iter_k = get_iter_k(GGML_TYPE_MXFP4);
+
+    constexpr int threads_per_row = iter_k / QK_MXFP4;  // each thread processes 1 block
+    constexpr int rows_per_warp   = warp_size / threads_per_row;
+    const int     kbx             = txi % threads_per_row;
+    const int     row_in_warp     = txi / threads_per_row;
+
+#pragma unroll
+    for (int i0 = 0; i0 < mmq_y; i0 += rows_per_warp * nwarps) {
+        int i = i0 + threadIdx.y * rows_per_warp + row_in_warp;
+
+        if constexpr (need_check) {
+            i = min(i, i_max);
+        }
+
+        const block_mxfp4 * bxi = (const block_mxfp4 *) x + kbx0 + i * stride + kbx;
+
+        // quantize_mxfp4_mmq permutes nibbles to match the quantized format
+        const int k0 = kbx * 4;
+        memcpy(x_qs + i * MMQ_MMA_TILE_X_K_FP4 + k0, bxi->qs, 16);
+
+        // Load E8M0 scales: pack 2 consecutive scales into one uint32
+        if (kbx % 2 == 0) {
+            uint32_t e = bxi->e;
+            e |= ((bxi + 1)->e << 8);
+            x_sc[i * MMQ_MMA_TILE_X_K_FP4 + kbx / 2] = e;
+        }
+    }
+}
+
 template <int mmq_x, int mmq_y>
 static __device__ __forceinline__ void vec_dot_q8_0_q8_1_dp4a(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
@@ -931,6 +996,78 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
 #endif // defined(AMD_MFMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 }
 
+template <int mmq_x, int mmq_y>
+static __device__ __forceinline__ void vec_dot_mxfp4_mxfp4_mma(const int * __restrict__ x,
+                                                               const int * __restrict__ y,
+                                                               float * __restrict__ sum,
+                                                               const int k00) {
+    typedef tile<16, 8, int>   tile_A;
+    typedef tile<8, 8, int>    tile_B;
+    typedef tile<16, 8, float> tile_C;  // Output is float for native scaled MMA
+
+    constexpr int granularity   = mmq_get_granularity_device(mmq_x);
+    constexpr int rows_per_warp = 2 * granularity;
+    constexpr int ntx           = rows_per_warp / tile_C::I;  // Number of x minitiles per warp.
+
+    y += (threadIdx.y % ntx) * (tile_C::J * MMQ_TILE_Y_FP4_K);
+
+    // Match layout from load_tiles_mxfp4_fp4
+    const int *      x_qs = (const int *) x;
+    const uint32_t * x_sc = (const uint32_t *) (x_qs + 2 * MMQ_TILE_NE_K);
+    const int *      y_qs = (const int *) y + 4;
+    const uint32_t * y_sc = (const uint32_t *) y;
+
+    // tile_A has a length of 64 logical values vs. 32 values in block_mxfp4
+    tile_A   A[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
+    uint32_t scaleA[ntx][MMQ_TILE_NE_K / (2 * QI_MXFP4)];
+
+    // Block scale
+    // Each thread has to point to a 4 byte scale value
+    // https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-block-scaling
+
+    const int i0 = (threadIdx.y / ntx) * rows_per_warp;
+
+#pragma unroll
+    for (int n = 0; n < ntx; ++n) {
+#pragma unroll
+        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
+            const int k0 = k00 + k01;
+
+            load_ldmatrix(A[n][k01 / (2 * QI_MXFP4)], x_qs + (i0 + n * tile_A::I) * MMQ_MMA_TILE_X_K_FP4 + k0,
+                          MMQ_MMA_TILE_X_K_FP4);
+
+            // based on block-scaling document, 2 threads in each quad need to supply to the scale value
+            const int tidx         = threadIdx.x / 4 + (threadIdx.x % 2) * 8;
+            scaleA[n][k01 / (2 * QI_MXFP4)] =
+                *(x_sc + (i0 + n * tile_A::I + tidx) * MMQ_MMA_TILE_X_K_FP4 + k0 / (2 * QI_MXFP4));
+        }
+    }
+
+#pragma unroll
+    for (int j0 = 0; j0 < mmq_x; j0 += ntx * tile_C::J) {
+#pragma unroll
+        for (int k01 = 0; k01 < MMQ_TILE_NE_K; k01 += 2 * QI_MXFP4) {
+            tile_B   B;
+            uint32_t scaleB;  // 2xN scales
+
+            load_generic(B, y_qs + j0 * MMQ_TILE_Y_FP4_K + k01, MMQ_TILE_Y_FP4_K);
+
+            scaleB = y_sc[(j0 + threadIdx.x / 4) * MMQ_TILE_Y_FP4_K + k01 / (2 * QI_MXFP4)];
+
+#pragma unroll
+            for (int n = 0; n < ntx; ++n) {
+                tile_C C;
+
+                mma_block_scaled(C, A[n][k01 / (2 * QI_MXFP4)], B, scaleA[n][k01 / (2 * QI_MXFP4)], scaleB);
+#pragma unroll
+                for (int l = 0; l < tile_C::ne; ++l) {
+                    sum[(j0 / tile_C::J + n) * tile_C::ne + l] += C.x[l];
+                }
+            }
+        }
+    }
+}
+
 template <int mmq_x, int mmq_y>
 static __device__ __forceinline__ void vec_dot_q8_1_q8_1_dp4a(
     const int * __restrict__ x, const int * __restrict__ y, float * __restrict__ sum, const int k00) {
@@ -3109,8 +3246,13 @@ struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_Q8_0> {
 template <int mmq_x, int mmq_y, bool need_check>
 struct mmq_type_traits<mmq_x, mmq_y, need_check, GGML_TYPE_MXFP4> {
     static constexpr int              vdr          = VDR_MXFP4_Q8_1_MMQ;
+#ifdef BLACKWELL_MMA_AVAILABLE
+    static constexpr load_tiles_mmq_t load_tiles  = load_tiles_mxfp4_fp4<mmq_y, need_check>;
+    static constexpr vec_dot_mmq_t    vec_dot_mma = vec_dot_mxfp4_mxfp4_mma<mmq_x, mmq_y>;
+#else
     static constexpr load_tiles_mmq_t load_tiles   = load_tiles_mxfp4<mmq_y, need_check>;
     static constexpr vec_dot_mmq_t    vec_dot_mma  = vec_dot_q8_0_q8_1_mma<mmq_x, mmq_y, MMQ_Q8_1_DS_LAYOUT_D4>;
+#endif // BLACKWELL_MMA_AVAILABLE
     static constexpr vec_dot_mmq_t    vec_dot_dp4a = vec_dot_q8_0_q8_1_dp4a<mmq_x, mmq_y>;
 };
 
@@ -3243,17 +3385,26 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
     constexpr mmq_write_back_t write_back = mmq_write_back_dp4a<mmq_x, mmq_y, need_check>;
 #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE)
 
-    constexpr int blocks_per_iter = MMQ_ITER_K / qk;
+#if defined(BLACKWELL_MMA_AVAILABLE)
+    // FP4 tile stores 8 blocks
+    constexpr int ne_block = (type == GGML_TYPE_MXFP4) ? 8 * QK_MXFP4 : 4 * QK8_1;
+#else
+    constexpr int ne_block = 4 * QK8_1;
+#endif  // defined(BLACKWELL_MMA_AVAILABLE)
+
+    constexpr int ITER_K          = get_iter_k(type);
+    constexpr int blocks_per_iter = ITER_K / qk;
 
     float sum[mmq_x*mmq_y / (nwarps*warp_size)] = {0.0f};
 
+    constexpr int sz = sizeof(block_q8_1_mmq) / sizeof(int);
+
     for (int kb0 = kb0_start; kb0 < kb0_stop; kb0 += blocks_per_iter) {
         load_tiles(x, tile_x, offset_x + kb0, tile_x_max_i, stride_row_x);
-
         {
-            const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 0*sizeof(block_q8_1_mmq)/sizeof(int));
+            const int * by0 = y + ncols_y * (kb0 * qk / ne_block) * sz;
 #pragma unroll
-            for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
+            for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {
                 int l = l0 + threadIdx.y*warp_size + threadIdx.x;
 
                 tile_y[l] = by0[l];
@@ -3267,9 +3418,9 @@ static __device__ __forceinline__ void mul_mat_q_process_tile(
         __syncthreads();
 
         {
-            const int * by0 = y + ncols_y*(kb0*(qk*sizeof(block_q8_1_mmq) / (4*QK8_1*sizeof(int))) + 1*sizeof(block_q8_1_mmq)/sizeof(int));
+            const int * by0 = y + ncols_y * ((kb0 * qk / ne_block) * sz + sz);
 #pragma unroll
-            for (int l0 = 0; l0 < mmq_x*MMQ_TILE_Y_K; l0 += nwarps*warp_size) {
+            for (int l0 = 0; l0 < mmq_x * MMQ_TILE_Y_K; l0 += nwarps * warp_size) {
                 int l = l0 + threadIdx.y*warp_size + threadIdx.x;
 
                 tile_y[l] = by0[l];
@@ -3401,8 +3552,10 @@ static __global__ void mul_mat_q(
     }
 #endif // (defined(GGML_USE_HIP) && !defined(CDNA3)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
 
+    constexpr int ITER_K = get_iter_k(type);
+
     const     int64_t blocks_per_ne00 = ncols_x / qk;
-    constexpr int     blocks_per_iter = MMQ_ITER_K / qk;
+    constexpr int     blocks_per_iter = ITER_K / qk;
 
     // kbc == k block continuous, current index in continuous ijk space.
     int64_t kbc      = (int64_t) blockIdx.x     *nsamples_y*nchannels_y*ntx*nty*blocks_per_ne00 / gridDim.x;
@@ -3463,7 +3616,7 @@ static __global__ void mul_mat_q(
             __syncthreads();
         }
 
-        offset_y   += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
+        offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int));
         offset_dst += it*mmq_y;
 
         const int tile_x_max_i = nrows_x  - it*mmq_y - 1;
@@ -3530,7 +3683,7 @@ static __global__ void mul_mat_q(
         __syncthreads();
     }
 
-    offset_y   += (col_low + jt*mmq_x)*(sizeof(block_q8_1_mmq)/sizeof(int));
+    offset_y += (col_low + jt * mmq_x) * (sizeof(block_q8_1_mmq) / sizeof(int));
     offset_dst += it*mmq_y;
 
     const int tile_x_max_i = nrows_x  - it*mmq_y - 1;
@@ -3553,7 +3706,9 @@ static __global__ void mul_mat_q_stream_k_fixup(
         const int ncols_max) {
     constexpr int     mmq_y           = get_mmq_y_device();
     constexpr int     qk              = ggml_cuda_type_traits<type>::qk;
-    constexpr int     blocks_per_iter = MMQ_ITER_K / qk;
+    constexpr int     ITER_K          = get_iter_k(type);
+
+    constexpr int     blocks_per_iter = ITER_K / qk;
     const     int64_t blocks_per_ne00 = ncols_x / qk;
 
     constexpr int nwarps = mmq_get_nwarps_device();
@@ -3711,7 +3866,7 @@ static size_t mmq_get_nbytes_shared(const int mmq_x, const int mmq_y, const int
     const int mmq_tile_x_k = mmq_get_mma_tile_x_k(type);
     const size_t nbs_ids = mmq_x*sizeof(int);
     const size_t nbs_x = (turing_mma_available(cc) || amd_mfma_available(cc) || amd_wmma_available(cc)) ? mmq_y*mmq_tile_x_k*sizeof(int) : txs.qs*sizeof(int) + txs.dm*sizeof(half2) + txs.sc*sizeof(int);
-    const size_t nbs_y = mmq_x*sizeof(block_q8_1_mmq);
+    const size_t nbs_y = mmq_x * (sizeof(block_q8_1_mmq));
     return nbs_ids + nbs_x + GGML_PAD(nbs_y, nwarps*warp_size*sizeof(int));
 }
 
index 5117f9ffc0ff9cd3bdcdbde39b0ee2d8d209503d..a8c68e44b16ee3f721e867af7f6662d8abc9133a 100644 (file)
@@ -47,6 +47,131 @@ static __global__ void quantize_q8_1(
     y[ib].ds = make_half2(d, sum);
 }
 
+__device__ __forceinline__ uint8_t compute_e8m0_scale(float amax) {
+    if (!(amax > 0.0f)) {
+        return 0;
+    }
+
+    // FP4 E2M1: max exponent (unbiased) is 2.
+    constexpr int FP4_E2M1_EMAX = 2;
+
+    const float e = log2f(amax);
+
+    // "even" -> round-to-nearest integer, ties-to-even
+    const int e_int = __float2int_rn(e);
+
+    const int shared_exp = e_int - FP4_E2M1_EMAX;
+
+    int biased = shared_exp + 127;
+
+    biased = max(biased, 0);
+    biased = min(biased, 254);
+
+    return static_cast<uint8_t>(biased);
+}
+
+// quantize values in the format mxfp4 is stored which is interleaved nibbles
+// i.e. a block a0-a31 is represented as a0a16,a1a17 ...a15a31
+static __global__ void quantize_mmq_mxfp4(const float * __restrict__ x,
+                                          const int32_t * __restrict__ ids,
+                                          void * __restrict__ vy,
+                                          const int64_t ne00,
+                                          const int64_t s01,
+                                          const int64_t s02,
+                                          const int64_t s03,
+                                          const int64_t ne0,
+                                          const int     ne1,
+                                          const int     ne2) {
+    constexpr int vals_per_scale = 32;
+    constexpr int vals_per_warp  = 2 * vals_per_scale;  // Each warp processes 2 blocks of 32 = 64 values
+
+    const int warp_id = threadIdx.y;
+    const int lane_id_32 = threadIdx.x;
+
+    const int nwarps = blockDim.y;
+
+    const int64_t warp_start_offset = (blockIdx.y * nwarps + warp_id) * vals_per_warp;
+
+    if (warp_start_offset >= ne0) {
+        return;
+    }
+
+    const int64_t i1 = blockIdx.x;
+    const int64_t i2 = blockIdx.z % ne2;
+    const int64_t i3 = blockIdx.z / ne2;
+
+    const int64_t i01 = ids ? ids[i1] : i1;
+    const int64_t i02 = i2;
+    const int64_t i03 = i3;
+
+    block_fp4_mmq * y = (block_fp4_mmq *) vy;
+
+    const int64_t block_fp4_mmq_size = 8 * QK_MXFP4;  // 256 values
+    const int64_t ib0                = blockIdx.z * ((int64_t) ne1 * (ne0 / block_fp4_mmq_size));
+    const int64_t ib = ib0 + (warp_start_offset / block_fp4_mmq_size) * ne1 + blockIdx.x;
+    const int64_t quad_idx_in_block  = (warp_start_offset % block_fp4_mmq_size) / vals_per_warp;
+
+    const int group_id = lane_id_32 / 4;
+    const int lane_in_group = lane_id_32 % 4;
+    const int base = group_id * 2;
+    char2 * yqs2 = (char2 *) y[ib].qs;
+
+    const int64_t base_pos = i03 * s03 + i02 * s02 + i01 * s01;
+
+    uint8_t scales[2];
+
+#pragma unroll
+    for (int b = 0; b < 2; ++b) {
+        const int64_t i0 = warp_start_offset + b * vals_per_scale + lane_id_32;
+        const float xi = (i0 < ne00) ? x[base_pos + i0] : 0.0f;
+
+        float amax = fabsf(xi);
+#pragma unroll
+        for (int mask = 16; mask > 0; mask >>= 1) {
+            amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, mask, WARP_SIZE));
+        }
+
+        const uint8_t e = compute_e8m0_scale(amax);
+        scales[b] = e;
+        const float inv_s = (amax == 0.0f) ? 0.0f : __frcp_rn(ggml_cuda_e8m0_to_fp32(e));
+
+#if CUDART_VERSION >= 12080
+        const float scaled_val = xi * inv_s;
+
+        const float val0 = __shfl_sync(0xFFFFFFFF, scaled_val, base, WARP_SIZE);
+        const float val1 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 16, WARP_SIZE);
+        const float val2 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 1, WARP_SIZE);
+        const float val3 = __shfl_sync(0xFFFFFFFF, scaled_val, base + 17, WARP_SIZE);
+
+        if (lane_in_group == 0) {
+            __nv_fp4x4_e2m1 fp4_packed(make_float4(val0, val1, val2, val3));
+
+            yqs2[quad_idx_in_block * 16 + b * 8 + group_id] = *(char2 *) &fp4_packed;
+        }
+#else
+        // Fallback: manual FP4 conversion using LUT
+        const uint8_t q_val = ggml_cuda_float_to_fp4_e2m1(xi, inv_s);
+
+        const uint8_t q_lo_0 = __shfl_sync(0xFFFFFFFF, q_val, base,      WARP_SIZE);
+        const uint8_t q_lo_1 = __shfl_sync(0xFFFFFFFF, q_val, base + 1,  WARP_SIZE);
+        const uint8_t q_hi_0 = __shfl_sync(0xFFFFFFFF, q_val, base + 16, WARP_SIZE);
+        const uint8_t q_hi_1 = __shfl_sync(0xFFFFFFFF, q_val, base + 17, WARP_SIZE);
+
+        if (lane_in_group == 0) {
+            char2 q;
+            q.x = (q_hi_0 << 4) | q_lo_0;
+            q.y = (q_hi_1 << 4) | q_lo_1;
+            yqs2[quad_idx_in_block * 16 + b * 8 + group_id] = q;
+        }
+#endif // CUDART_VERSION >= 12080
+    }
+
+    if (lane_id_32 == 0) {
+        // Store 2 scales packed into 1 uint32
+        y[ib].d4[quad_idx_in_block] = (scales[1] << 8) | scales[0];
+    }
+}
+
 template <mmq_q8_1_ds_layout ds_layout>
 static __global__ void quantize_mmq_q8_1(
         const float * __restrict__ x, const int32_t * __restrict__ ids, void * __restrict__ vy,
@@ -190,3 +315,29 @@ void quantize_mmq_q8_1_cuda(
             break;
     }
 }
+
+void quantize_mmq_mxfp4_cuda(const float *                    x,
+                             const int32_t *                  ids,
+                             void *                           vy,
+                             [[maybe_unused]] const ggml_type type_src0,
+                             const int64_t                    ne00,
+                             const int64_t                    s01,
+                             const int64_t                    s02,
+                             const int64_t                    s03,
+                             const int64_t                    ne0,
+                             const int64_t                    ne1,
+                             const int64_t                    ne2,
+                             const int64_t                    ne3,
+                             cudaStream_t                     stream) {
+    GGML_ASSERT(ne0 % (2 * QK_MXFP4) == 0);
+
+    constexpr int nwarps = 8;
+    constexpr int vals_per_warp  = 2 * QK_MXFP4;
+    constexpr int vals_per_block = nwarps * vals_per_warp;
+
+    const int64_t block_num_y = (ne0 + vals_per_block - 1) / vals_per_block;
+    const dim3    num_blocks(ne1, block_num_y, ne2 * ne3);
+    const dim3    block_size(WARP_SIZE, nwarps, 1);
+
+    quantize_mmq_mxfp4<<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
+}
index 725ab52443c0e9e8720694b7e06d3c6db1c16fd9..6a91df635788ae238f609216360656907ed5e72a 100644 (file)
@@ -25,3 +25,17 @@ void quantize_mmq_q8_1_cuda(
         const float * x, const int32_t * ids, void * vy,
         ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
         int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
+
+void quantize_mmq_mxfp4_cuda(const float *   x,
+                             const int32_t * ids,
+                             void *          vy,
+                             ggml_type       type_src0,
+                             int64_t         ne00,
+                             int64_t         s01,
+                             int64_t         s02,
+                             int64_t         s03,
+                             int64_t         ne0,
+                             int64_t         ne1,
+                             int64_t         ne2,
+                             int64_t         ne3,
+                             cudaStream_t    stream);
index 3b3086778eed8f6773b9baed97d57fa58d19ddc3..ba032cfab4b8d2ad7f199433686380efdad13fb9 100644 (file)
 #include <cuda_fp8.h>
 #endif // CUDART_VERSION >= 12050
 
+#if CUDART_VERSION >= 12080
+#include <cuda_fp4.h>
+#endif // CUDART_VERSION >= 12080
+
 #if CUDART_VERSION < 11020
 #define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
 #define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH