]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
CUDA: rename macros to avoid conflicts with WinAPI (llama/10736)
authorAndreas Kieslinger <redacted>
Tue, 10 Dec 2024 17:23:24 +0000 (18:23 +0100)
committerGeorgi Gerganov <redacted>
Wed, 18 Dec 2024 10:52:16 +0000 (12:52 +0200)
* Renames NVIDIA GPU-architecture flags to avoid name clashes with WinAPI. (e.g. CC_PASCAL, GPU architecture or WinAPI pascal compiler flag?)

* Reverts erroneous rename in SYCL-code.

* Renames GGML_CUDA_MIN_CC_DP4A to GGML_CUDA_CC_DP4A.

* Renames the rest of the compute capability macros for consistency.

ggml/src/ggml-common.h
ggml/src/ggml-cuda/common.cuh
ggml/src/ggml-cuda/convert.cu
ggml/src/ggml-cuda/fattn.cu
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/mma.cuh
ggml/src/ggml-cuda/mmq.cu
ggml/src/ggml-cuda/mmq.cuh
ggml/src/ggml-cuda/mmvq.cu
ggml/src/ggml-cuda/sum.cu

index 7fd2aadeca1afd1a3db24f348957ef47eb785b02..f13fd4dea6f41c1f89132d126eb615ca86b642c1 100644 (file)
@@ -473,7 +473,7 @@ GGML_TABLE_BEGIN(uint8_t, ksigns_iq2xs, 128)
     240, 113, 114, 243, 116, 245, 246, 119, 120, 249, 250, 123, 252, 125, 126, 255,
 GGML_TABLE_END()
 
-//#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
+//#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A // lowest compute capability for integer intrinsics
 GGML_TABLE_BEGIN(uint64_t, ksigns64, 128)
     0x0000000000000000, 0xff000000000000ff, 0xff0000000000ff00, 0x000000000000ffff,
     0xff00000000ff0000, 0x0000000000ff00ff, 0x0000000000ffff00, 0xff00000000ffffff,
index 535118d87928e15b4abf15e15ba0edba08741908..2c0a562267281c21bca9072e4a56a4e52c7ddd57 100644 (file)
 #define CUDART_HMAX   11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
 #define CUDART_HMASK  12000 // CUDA 12.0, min. ver. for half2 -> uint mask comparisons
 
-#define CC_PASCAL     600
-#define MIN_CC_DP4A   610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
-#define CC_VOLTA      700
-#define CC_TURING     750
-#define CC_AMPERE     800
-#define CC_OFFSET_AMD 1000000
+#define GGML_CUDA_CC_PASCAL     600
+#define GGML_CUDA_CC_DP4A       610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
+#define GGML_CUDA_CC_VOLTA      700
+#define GGML_CUDA_CC_TURING     750
+#define GGML_CUDA_CC_AMPERE     800
+#define GGML_CUDA_CC_OFFSET_AMD 1000000
 
 // GCN/CNDA, wave size is 64
-#define CC_GCN4       (CC_OFFSET_AMD + 803)  // Tonga, Fiji, Polaris, minimum for fast fp16
-#define CC_VEGA       (CC_OFFSET_AMD + 900)  // Vega56/64, minimum for fp16 dual issue
-#define CC_VEGA20     (CC_OFFSET_AMD + 906)  // MI50/Radeon VII, minimum for dp4a
-#define CC_CDNA       (CC_OFFSET_AMD + 908)  // MI100, minimum for MFMA, acc registers
-#define CC_CDNA2      (CC_OFFSET_AMD + 910)  // MI210, minimum acc register renameing
-#define CC_CDNA3      (CC_OFFSET_AMD + 942)  // MI300
+#define GGML_CUDA_CC_GCN4       (GGML_CUDA_CC_OFFSET_AMD + 803)  // Tonga, Fiji, Polaris, minimum for fast fp16
+#define GGML_CUDA_CC_VEGA       (GGML_CUDA_CC_OFFSET_AMD + 900)  // Vega56/64, minimum for fp16 dual issue
+#define GGML_CUDA_CC_VEGA20     (GGML_CUDA_CC_OFFSET_AMD + 906)  // MI50/Radeon VII, minimum for dp4a
+#define GGML_CUDA_CC_CDNA       (GGML_CUDA_CC_OFFSET_AMD + 908)  // MI100, minimum for MFMA, acc registers
+#define GGML_CUDA_CC_CDNA2      (GGML_CUDA_CC_OFFSET_AMD + 910)  // MI210, minimum acc register renameing
+#define GGML_CUDA_CC_CDNA3      (GGML_CUDA_CC_OFFSET_AMD + 942)  // MI300
 
 // RNDA removes MFMA, dp4a, xnack, acc registers, wave size is 32
-#define CC_RDNA1      (CC_OFFSET_AMD + 1010) // RX 5000
-#define CC_RDNA2      (CC_OFFSET_AMD + 1030) // RX 6000, minimum for dp4a
-#define CC_RDNA3      (CC_OFFSET_AMD + 1100) // RX 7000, minimum for WMMA
+#define GGML_CUDA_CC_RDNA1      (GGML_CUDA_CC_OFFSET_AMD + 1010) // RX 5000
+#define GGML_CUDA_CC_RDNA2      (GGML_CUDA_CC_OFFSET_AMD + 1030) // RX 6000, minimum for dp4a
+#define GGML_CUDA_CC_RDNA3      (GGML_CUDA_CC_OFFSET_AMD + 1100) // RX 7000, minimum for WMMA
 
-#define CC_QY1        210
-#define CC_QY2        220
+#define GGML_CUDA_CC_QY1        210
+#define GGML_CUDA_CC_QY2        220
 
 #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
 
@@ -131,36 +131,36 @@ typedef float dfloat; // dequantize float
 typedef float2 dfloat2;
 #endif // GGML_CUDA_F16
 
-#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
+#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
 #define FP16_AVAILABLE
-#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
+#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
 
 #if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
 #define FAST_FP16_AVAILABLE
 #endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
 
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
 #define FP16_MMA_AVAILABLE
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
 
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
 #define INT8_MMA_AVAILABLE
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_TURING
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
 
-#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1)
+#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
 #define FLASH_ATTN_AVAILABLE
-#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= CC_QY1)
+#endif // !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
 
 static constexpr bool fast_fp16_available(const int cc) {
-    return cc >= CC_PASCAL && cc != 610;
+    return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
 }
 
 static constexpr bool fp16_mma_available(const int cc) {
-    return cc < CC_OFFSET_AMD && cc >= CC_VOLTA;
+    return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA;
 }
 
 static constexpr bool int8_mma_available(const int cc) {
-    return cc < CC_OFFSET_AMD && cc >= CC_TURING;
+    return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_TURING;
 }
 
 [[noreturn]]
@@ -187,7 +187,7 @@ static __device__ void no_device_code(
 #endif // __CUDA_ARCH__
 
 static __device__ __forceinline__ int warp_reduce_sum(int x) {
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
     return __reduce_add_sync(0xffffffff, x);
 #else
 #pragma unroll
@@ -195,7 +195,7 @@ static __device__ __forceinline__ int warp_reduce_sum(int x) {
         x += __shfl_xor_sync(0xffffffff, x, offset, 32);
     }
     return x;
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_AMPERE
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 }
 
 static __device__ __forceinline__ float warp_reduce_sum(float x) {
@@ -284,7 +284,7 @@ static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const hal
 }
 
 static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
-#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
+#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
 #pragma unroll
    for (int offset = 16; offset > 0; offset >>= 1) {
        x = ggml_cuda_hmax2(x, __shfl_xor_sync(0xffffffff, x, offset, 32));
@@ -293,7 +293,7 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
 #else
    GGML_UNUSED(x);
    NO_DEVICE_CODE;
-#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
+#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
 }
 
 #if CUDART_VERSION < CUDART_HMASK
@@ -333,13 +333,13 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
 
 #else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
 
-#if __CUDA_ARCH__ >= MIN_CC_DP4A
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
     return __dp4a(a, b, c);
-#else // __CUDA_ARCH__ >= MIN_CC_DP4A
+#else // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
     const int8_t * a8 = (const int8_t *) &a;
     const int8_t * b8 = (const int8_t *) &b;
     return c + a8[0]*b8[0] + a8[1]*b8[1] + a8[2]*b8[2] + a8[3]*b8[3];
-#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_DP4A
 
 #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
 }
index c0a4447075c6ef7a630ca836caa07cb2189cc66b..3896f956d7338f4cc25b7bbc1fe26acbf48d902a 100644 (file)
@@ -26,7 +26,7 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
 
 template <bool need_check>
 static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int64_t k) {
-#if __CUDA_ARCH__ >= CC_PASCAL
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
     constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE;
 
     const int64_t   i0 = CUDA_Q8_0_NE_ALIGN*blockIdx.x;
@@ -64,7 +64,7 @@ static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, h
     GGML_UNUSED(y);
     GGML_UNUSED(k);
     NO_DEVICE_CODE;
-#endif // __CUDA_ARCH__ >= CC_PASCAL
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
 }
 
 template<typename dst_t>
@@ -599,7 +599,7 @@ to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
         case GGML_TYPE_Q5_1:
             return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
         case GGML_TYPE_Q8_0:
-            if (ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= CC_PASCAL) {
+            if (ggml_cuda_info().devices[ggml_cuda_get_device()].cc >= GGML_CUDA_CC_PASCAL) {
                 return dequantize_block_q8_0_f16_cuda;
             }
             return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
index 0e7ebbc5393523a823e48e6b3d16cff88d0b6c83..0b26b0f8e0595c5bc2a7c5faa1e1667bd048664d 100644 (file)
@@ -304,7 +304,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
     const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
 
     // On AMD the tile kernels perform poorly, use the vec kernel instead:
-    if (cc >= CC_OFFSET_AMD) {
+    if (cc >= GGML_CUDA_CC_OFFSET_AMD) {
         if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
             ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
         } else {
index 15fcb2a65cbcf1607631412520d83183dc4b37b7..c180adc84b861a9ca003a7af1047d8b6444e1c96 100644 (file)
@@ -177,7 +177,7 @@ static ggml_cuda_device_info ggml_cuda_init() {
         info.devices[id].smpb  = prop.sharedMemPerBlock;
 #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
         info.devices[id].smpbo = prop.sharedMemPerBlock;
-        info.devices[id].cc = 100*prop.major + 10*prop.minor + CC_OFFSET_AMD;
+        info.devices[id].cc = 100*prop.major + 10*prop.minor + GGML_CUDA_CC_OFFSET_AMD;
 #else
         info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
         info.devices[id].cc = 100*prop.major + 10*prop.minor;
@@ -1081,7 +1081,7 @@ static void ggml_cuda_op_mul_mat_cublas(
 
     const int compute_capability = ggml_cuda_info().devices[id].cc;
 
-    if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
+    if (compute_capability >= GGML_CUDA_CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
         // convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
         ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
         if (src0->type != GGML_TYPE_F16) {
@@ -1108,7 +1108,7 @@ static void ggml_cuda_op_mul_mat_cublas(
         const half beta_f16 = 0.0f;
 
         cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
-        if (ggml_cuda_info().devices[ctx.device].cc == CC_CDNA) {
+        if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) {
             cu_compute_type = CUBLAS_COMPUTE_32F;
         }
 
@@ -1612,7 +1612,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
     cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
     cudaDataType_t      cu_data_type    = CUDA_R_16F;
 
-    if (ggml_cuda_info().devices[ctx.device].cc == CC_CDNA) {
+    if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) {
         cu_compute_type = CUBLAS_COMPUTE_32F;
     }
 
@@ -2357,7 +2357,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
     std::vector<void *> ggml_cuda_cpy_fn_ptrs;
 
     if (cuda_ctx->cuda_graph->graph == nullptr) {
-        if (ggml_cuda_info().devices[cuda_ctx->device].cc < CC_AMPERE) {
+        if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
             cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
 #ifndef NDEBUG
             GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
@@ -3028,7 +3028,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                 return true;
             }
             const int cc = ggml_cuda_info().devices[dev_ctx->device].cc;
-            return cc >= CC_VOLTA && cc < CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
+            return cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16;
         }
         case GGML_OP_CROSS_ENTROPY_LOSS:
         case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
index a452a3cc3d152f850d95b8534c10b883bd3e03e2..7d11540afd2371730a6d88bc880ce905d57c5b23 100644 (file)
@@ -171,7 +171,7 @@ struct mma_int_C_I16J8 {
 
     __device__ __forceinline__ void mma_K4(const mma_int_A_I16K4 & mma_A, const mma_int_B_J8K4 & mma_B) {
 #ifdef INT8_MMA_AVAILABLE
-#if __CUDA_ARCH__ >= CC_AMPERE
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
         asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
             : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
             : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_B.x[0]));
@@ -183,7 +183,7 @@ struct mma_int_C_I16J8 {
         asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
             : "+r"(x[2]), "+r"(x[3])
             : "r"(mma_A.x[1]), "r"(mma_B.x[0]));
-#endif // __CUDA_ARCH__ >= CC_AMPERE
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 #else
         GGML_UNUSED(mma_A);
         GGML_UNUSED(mma_B);
@@ -193,7 +193,7 @@ struct mma_int_C_I16J8 {
 
     __device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) {
 #ifdef INT8_MMA_AVAILABLE
-#if __CUDA_ARCH__ >= CC_AMPERE
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
         asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
             : "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
             : "r"(mma_A.x[0]), "r"(mma_A.x[1]), "r"(mma_A.x[2]), "r"(mma_A.x[3]), "r"(mma_B.x[0]), "r"(mma_B.x[1]));
@@ -211,7 +211,7 @@ struct mma_int_C_I16J8 {
         asm("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0, %1}, {%2}, {%3}, {%0, %1};"
             : "+r"(x[2]), "+r"(x[3])
             : "r"(mma_A.x[3]), "r"(mma_B.x[1]));
-#endif // __CUDA_ARCH__ >= CC_AMPERE
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
 #else
         GGML_UNUSED(mma_A);
         GGML_UNUSED(mma_B);
index 7f7c8c90b6fe2b1ce000df1ee05ccd5843be0d46..270251df4f115d7aaee6dc82556a7e2aeaead23b 100644 (file)
@@ -27,7 +27,7 @@ void ggml_cuda_op_mul_mat_q(
     // The stream-k decomposition is only faster for recent NVIDIA GPUs.
     // Also its fixup needs to allocate a temporary buffer in the memory pool.
     // There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
-    const bool use_stream_k = compute_capability >= CC_VOLTA && compute_capability < CC_OFFSET_AMD && src1_ncols == ne11;
+    const bool use_stream_k = compute_capability >= GGML_CUDA_CC_VOLTA && compute_capability < GGML_CUDA_CC_OFFSET_AMD && src1_ncols == ne11;
     const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst, use_stream_k};
 
     switch (src0->type) {
@@ -136,7 +136,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
         return true;
     }
 
-    if (cc < MIN_CC_DP4A) {
+    if (cc < GGML_CUDA_CC_DP4A) {
         return false;
     }
 
@@ -144,9 +144,9 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
     return true;
 #endif //GGML_CUDA_FORCE_MMQ
 
-    if (cc < CC_OFFSET_AMD) {
-        return cc < CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
+    if (cc < GGML_CUDA_CC_OFFSET_AMD) {
+        return cc < GGML_CUDA_CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
     }
 
-    return (cc < CC_RDNA3 && cc != CC_CDNA && cc != CC_VEGA20) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
+    return (cc < GGML_CUDA_CC_RDNA3 && cc != GGML_CUDA_CC_CDNA && cc != GGML_CUDA_CC_VEGA20) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
 }
index 8d8867121f321c4b68b265106b395b9d668e8d30..3cd508a1d4a1c93db27368afbbe8dbcc0c27dbf7 100644 (file)
@@ -89,9 +89,9 @@ struct tile_x_sizes {
 static constexpr int get_mmq_x_max_host(const int cc) {
     return int8_mma_available(cc) ? 128 :
 #ifdef GGML_CUDA_FORCE_MMQ
-        cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? 128                     : 64;
+        cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? 128                     : 64;
 #else
-        cc >= CC_VOLTA && cc < CC_OFFSET_AMD ? MMQ_DP4A_MAX_BATCH_SIZE : 64;
+        cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ? MMQ_DP4A_MAX_BATCH_SIZE : 64;
 #endif // GGML_CUDA_FORCE_MMQ
 }
 
@@ -104,23 +104,23 @@ static constexpr __device__ int get_mmq_x_max_device() {
     return 128;
 #else // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
 
-#if __CUDA_ARCH__ >= CC_VOLTA
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
 #ifdef GGML_CUDA_FORCE_MMQ
     return MMQ_DP4A_MAX_BATCH_SIZE;
 #else // GGML_CUDA_FORCE_MMQ
     return 128;
 #endif // GGML_CUDA_FORCE_MMQ
-#else // __CUDA_ARCH__ >= CC_VOLTA
+#else // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
 
     return 64;
-#endif // __CUDA_ARCH__ >= CC_VOLTA
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
 
 #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
 #endif // INT8_MMA_AVAILABLE
 }
 
 static constexpr int get_mmq_y_host(const int cc) {
-    return cc >= CC_OFFSET_AMD ? (cc == CC_RDNA1 ? 64 : 128) : (cc >= CC_VOLTA ? 128 : 64);
+    return cc >= GGML_CUDA_CC_OFFSET_AMD ? (cc == GGML_CUDA_CC_RDNA1 ? 64 : 128) : (cc >= GGML_CUDA_CC_VOLTA ? 128 : 64);
 }
 
 static constexpr __device__ int get_mmq_y_device() {
@@ -131,11 +131,11 @@ static constexpr __device__ int get_mmq_y_device() {
     return 128;
 #endif // defined RDNA1
 #else
-#if __CUDA_ARCH__ >= CC_VOLTA
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
     return 128;
 #else
     return 64;
-#endif // __CUDA_ARCH__ >= CC_VOLTA
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
 #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
 }
 
@@ -2574,11 +2574,11 @@ template <ggml_type type, int mmq_x, int nwarps, bool need_check>
     __launch_bounds__(WARP_SIZE*nwarps, 2)
 #endif // defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
 #else
-#if __CUDA_ARCH__ >= CC_VOLTA
+#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
     __launch_bounds__(WARP_SIZE*nwarps, 1)
 #else
     __launch_bounds__(WARP_SIZE*nwarps, 2)
-#endif // __CUDA_ARCH__ >= CC_VOLTA
+#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
 #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
 static __global__ void mul_mat_q(
     const char * __restrict__ x, const char * __restrict__ yc, float * __restrict__ dst, float * __restrict__ tmp_fixup,
@@ -2594,7 +2594,7 @@ static __global__ void mul_mat_q(
     constexpr int mmq_y = get_mmq_y_device();
 
     // On AMD or old CUDA the performance with stream-k was worse, use conventional tiling instead:
-#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
+#if (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
     {
         constexpr bool fixup = false;
         mul_mat_q_process_tile<type, mmq_x, nwarps, need_check, fixup>
@@ -2602,7 +2602,7 @@ static __global__ void mul_mat_q(
                 blockIdx.x, blockIdx.y, 0, ne00/qk);
         return;
     }
-#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < CC_VOLTA
+#endif // (defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ < GGML_CUDA_CC_VOLTA
 
     const     int64_t blocks_per_ne00 = ne00 / qk;
     constexpr int     blocks_per_iter = MMQ_ITER_K / qk;
@@ -2825,7 +2825,7 @@ void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cuda
     const int mmq_x_max = get_mmq_x_max_host(cc);
     const int mmq_y = get_mmq_y_host(cc);
     const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
-    const bool use_stream_k = cc >= CC_VOLTA && cc < CC_OFFSET_AMD;
+    const bool use_stream_k = cc >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD;
 
     int mmq_x_best  = 0;
     int nparts_best = INT_MAX;
index 02d1509836c4090682a4cd3d0e3e69336f3c1169..e3b912d875b5b19434d9f51aa403849e65d71e62 100644 (file)
@@ -142,7 +142,7 @@ static void mul_mat_vec_q_cuda(
     int64_t nwarps = 1;
     int64_t rows_per_cuda_block = 1;
 
-    if (ggml_cuda_info().devices[id].cc < CC_CDNA || ggml_cuda_info().devices[id].cc == CC_RDNA1) { // NVIDIA and AMD older than RDNA2 but not CDNA
+    if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_CDNA || ggml_cuda_info().devices[id].cc == GGML_CUDA_CC_RDNA1) { // NVIDIA and AMD older than RDNA2 but not CDNA
         switch(ncols_y) {
             case 1:
                 nwarps = 4;
index 31cfe53941fe08b28c3ab88ff20a5cc30514bb3c..e0dafc1d2049ddebe367a9c46d535ae10dbebcf6 100644 (file)
@@ -3,8 +3,6 @@
 #endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11700
 
 #ifdef USE_CUB
-// On Windows CUB uses libraries with variables called CC_PASCAL which conflict with the define in common.cuh.
-// For this reason CUB must be included BEFORE anything else.
 #include <cub/cub.cuh>
 using namespace cub;
 #endif // USE_CUB