]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
CUDA: use arch list for compatibility check (#11775)
authorJohannes Gäßler <redacted>
Mon, 10 Feb 2025 23:17:22 +0000 (00:17 +0100)
committerGitHub <redacted>
Mon, 10 Feb 2025 23:17:22 +0000 (00:17 +0100)
* CUDA: use arch list for feature availability check

---------

Co-authored-by: Diego Devesa <redacted>
ggml/src/ggml-common.h
ggml/src/ggml-cuda/common.cuh
ggml/src/ggml-cuda/convert.cu
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/mmq.cu
ggml/src/ggml-cuda/mmq.cuh

index f13fd4dea6f41c1f89132d126eb615ca86b642c1..6c02b69ea239a9b7e162b718a731c23b5502d53e 100644 (file)
@@ -473,7 +473,6 @@ GGML_TABLE_BEGIN(uint8_t, ksigns_iq2xs, 128)
     240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
 GGML_TABLE_END()
 
-//#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A // lowest compute capability for integer intrinsics
 GGML_TABLE_BEGIN(uint64_t, ksigns64, 128)
     0x0000000000000000, 0xff000000000000ff, 0xff0000000000ff00, 0x000000000000ffff,
     0xff00000000ff0000, 0x0000000000ff00ff, 0x0000000000ffff00, 0xff00000000ffffff,
@@ -508,7 +507,6 @@ GGML_TABLE_BEGIN(uint64_t, ksigns64, 128)
     0x00ffffffff000000, 0xffffffffff0000ff, 0xffffffffff00ff00, 0x00ffffffff00ffff,
     0xffffffffffff0000, 0x00ffffffffff00ff, 0x00ffffffffffff00, 0xffffffffffffffff,
 GGML_TABLE_END()
-//#endif
 
 
 GGML_TABLE_BEGIN(uint64_t, iq2xxs_grid, 256)
index 174916bc970d7b5d4a41a7cf1b3eb22b8bc63919..2a3244428521b33b037d05c17eb7ffb94a4360bd 100644 (file)
 #define GGML_CUDA_CC_QY1        210
 #define GGML_CUDA_CC_QY2        220
 
+#ifdef __CUDA_ARCH_LIST__
+constexpr bool ggml_cuda_has_arch_impl(int) {
+    return false;
+}
+
+template<class ... Archs>
+constexpr bool ggml_cuda_has_arch_impl(const int arch, const int first, Archs... rest) {
+    return arch == first || ggml_cuda_has_arch_impl(arch, rest...);
+}
+
+constexpr bool ggml_cuda_has_arch(const int arch) {
+    return ggml_cuda_has_arch_impl(arch, __CUDA_ARCH_LIST__);
+}
+
+constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur) {
+    if (cur == 0) {
+        GGML_ABORT("ggml was not compiled with any CUDA arch <= %d", arch);
+    }
+    return cur;
+}
+
+template<class ... Archs>
+constexpr int ggml_cuda_highest_compiled_arch_impl(const int arch, const int cur, const int first, Archs... rest) {
+    if (first <= arch && first > cur) {
+        return ggml_cuda_highest_compiled_arch_impl(arch, first, rest...);
+    } else {
+        return ggml_cuda_highest_compiled_arch_impl(arch, cur, rest...);
+    }
+}
+
+constexpr int ggml_cuda_highest_compiled_arch(const int arch) {
+    return ggml_cuda_highest_compiled_arch_impl(arch, 0, __CUDA_ARCH_LIST__);
+}
+#else
+static int ggml_cuda_highest_compiled_arch(const int arch) {
+    return arch;
+}
+#endif // __CUDA_ARCH_LIST__
+
+// ---------------------------------------------------------------------------------------------------------
+
 #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
 
 #if defined(_MSC_VER)
@@ -162,18 +203,32 @@ typedef float2 dfloat2;
 #define FLASH_ATTN_AVAILABLE
 #endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
 
-static constexpr bool fast_fp16_available(const int cc) {
+static bool fp16_available(const int cc) {
+    return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
+}
+
+static bool fast_fp16_available(const int cc) {
+    return fp16_available(cc) && cc != 610;
+}
+
+// To be used for feature selection of external libraries, e.g. cuBLAS.
+static bool fast_fp16_hardware_available(const int cc) {
     return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
 }
 
-// Any FP16 tensor cores are available.
-static constexpr bool fp16_mma_available(const int cc) {
+// Any FP16 tensor core instructions are available for ggml code.
+static bool fp16_mma_available(const int cc) {
+    return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA;
+}
+
+// To be used for feature selection of external libraries, e.g. cuBLAS.
+static bool fp16_mma_hardware_available(const int cc) {
     return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA;
 }
 
 // Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
-static constexpr bool new_mma_available(const int cc) {
-    return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_TURING;
+static bool new_mma_available(const int cc) {
+    return cc < GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_TURING;
 }
 
 static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
index 5b0dfacefc9dab78c7fe099b6b171ca1c57714f9..795b720d60bcbea9880de7ecccdf64e0adfe6640 100644 (file)
@@ -599,7 +599,7 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
         case GGML_TYPE_Q5_1:
             return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
         case GGML_TYPE_Q8_0:
-            if (ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= GGML_CUDA_CC_PASCAL) {
+            if (fp16_available(ggml_cuda_info().devices[ggml_cuda_get_device()].cc)) {
                 return dequantize_block_q8_0_f16_cuda;
             }
             return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
index 4dbaefdbafdf483b59f97c8c6098af96a3b3bda9..c95728b08bfe856edf2e754a11e2dabb41a12e58 100644 (file)
@@ -1867,14 +1867,14 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
 
             const int cc              = ggml_cuda_info().devices[id].cc;
             use_mul_mat_q             = use_mul_mat_q             && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
-            any_gpus_with_slow_fp16   = any_gpus_with_slow_fp16   || !fast_fp16_available(cc);
-            any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc);
+            any_gpus_with_slow_fp16   = any_gpus_with_slow_fp16   || !fast_fp16_hardware_available(cc);
+            any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
         }
     } else {
         const int cc              = ggml_cuda_info().devices[ctx.device].cc;
         use_mul_mat_q             = use_mul_mat_q             && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]);
-        any_gpus_with_slow_fp16   = any_gpus_with_slow_fp16   || !fast_fp16_available(cc);
-        any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc);
+        any_gpus_with_slow_fp16   = any_gpus_with_slow_fp16   || !fast_fp16_hardware_available(cc);
+        any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_hardware_available(cc);
     }
 
     // debug helpers
@@ -3205,8 +3205,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
             if (op->src[0]->ne[0] == 256 && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16) {
                 return true;
             }
-            const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
-            return cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
+            return fp16_mma_available(ggml_cuda_info().devices[dev_ctx->device].cc) &&
+                op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
         }
         case GGML_OP_CROSS_ENTROPY_LOSS:
         case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
index 45212f66c0098e76cf095e0babef2d01d8a89768..5dacd131ed55c51714e4675fafd0bfd1f709ac79 100644 (file)
@@ -18,7 +18,7 @@ void ggml_cuda_op_mul_mat_q(
     const int64_t stride00 = ne00 / ggml_blck_size(src0->type);
 
     int id = ggml_cuda_get_device();
-    const int compute_capability = ggml_cuda_info().devices[id].cc;
+    const int cc = ggml_cuda_info().devices[id].cc;
 
     // the main device has a larger memory buffer to hold the results from all GPUs
     // nrows_dst == nrows of the matrix that the kernel writes into
@@ -27,7 +27,8 @@ void ggml_cuda_op_mul_mat_q(
     // The stream-k decomposition is only faster for recent NVIDIA GPUs.
     // Also its fixup needs to allocate a temporary buffer in the memory pool.
     // There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
-    const bool use_stream_k = compute_capability >= GGML_CUDA_CC_VOLTA && compute_capability < GGML_CUDA_CC_OFFSET_AMD && src1_ncols == ne11;
+    const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA &&
+        cc < GGML_CUDA_CC_OFFSET_AMD && src1_ncols == ne11;
     const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst, use_stream_k};
 
     switch (src0->type) {
@@ -136,7 +137,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
         return true;
     }
 
-    if (cc < GGML_CUDA_CC_DP4A) {
+    if (ggml_cuda_highest_compiled_arch(cc) < GGML_CUDA_CC_DP4A) {
         return false;
     }
 
@@ -145,7 +146,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
 #endif //GGML_CUDA_FORCE_MMQ
 
     if (cc < GGML_CUDA_CC_OFFSET_AMD) {
-        return cc < GGML_CUDA_CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
+        return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
     }
 
     return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc) && !GGML_CUDA_CC_IS_GCN(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
index 7a2c4d85b799f6085a0a81175e2434f869d62503..5391542086cfd3080386360d0d20a41f4730e6d2 100644 (file)
@@ -86,12 +86,13 @@ struct tile_x_sizes {
     int sc;
 };
 
-static constexpr int get_mmq_x_max_host(const int cc) {
+static int get_mmq_x_max_host(const int cc) {
     return new_mma_available(cc) ? 128 :
+        ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ?
 #ifdef GGML_CUDA_FORCE_MMQ
-        cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? 128                     : 64;
+            128                     : 64;
 #else
-        cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? MMQ_DP4A_MAX_BATCH_SIZE : 64;
+            MMQ_DP4A_MAX_BATCH_SIZE : 64;
 #endif // GGML_CUDA_FORCE_MMQ
 }
 
@@ -119,8 +120,9 @@ static constexpr __device__ int get_mmq_x_max_device() {
 #endif // NEW_MMA_AVAILABLE
 }
 
-static constexpr int get_mmq_y_host(const int cc) {
-    return cc >= GGML_CUDA_CC_OFFSET_AMD ? (GGML_CUDA_CC_IS_RDNA1(cc)  ? 64 : 128) : (cc >= GGML_CUDA_CC_VOLTA ? 128 : 64);
+static int get_mmq_y_host(const int cc) {
+    return cc >= GGML_CUDA_CC_OFFSET_AMD ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) :
+        (ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? 128 : 64);
 }
 
 static constexpr __device__ int get_mmq_y_device() {
@@ -2828,7 +2830,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
     const int mmq_x_max = get_mmq_x_max_host(cc);
     const int mmq_y = get_mmq_y_host(cc);
     const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
-    const bool use_stream_k = cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD;
+    const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD;
 
     int mmq_x_best  = 0;
     int nparts_best = INT_MAX;