]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
Add some minimal optimizations for CDNA (llama/10498)
authoruvos <redacted>
Wed, 27 Nov 2024 16:10:08 +0000 (17:10 +0100)
committerGeorgi Gerganov <redacted>
Tue, 3 Dec 2024 19:05:37 +0000 (21:05 +0200)
* Add some minimal optimizations for CDNA

* ggml_cuda: set launch bounds also for GCN as it helps there too

src/ggml-cuda/common.cuh
src/ggml-cuda/ggml-cuda.cu
src/ggml-cuda/mmq.cu
src/ggml-cuda/mmq.cuh
src/ggml-cuda/mmvq.cu
src/ggml-cuda/vendors/hip.h

index b0dd16066b4ba5a46199305f1a2e8d15752ff30c..535118d87928e15b4abf15e15ba0edba08741908 100644 (file)
 #define CC_TURING     750
 #define CC_AMPERE     800
 #define CC_OFFSET_AMD 1000000
-#define CC_RDNA1      (CC_OFFSET_AMD + 1010)
-#define CC_RDNA2      (CC_OFFSET_AMD + 1030)
-#define CC_RDNA3      (CC_OFFSET_AMD + 1100)
+
+// GCN/CNDA, wave size is 64
+#define CC_GCN4       (CC_OFFSET_AMD + 803)  // Tonga, Fiji, Polaris, minimum for fast fp16
+#define CC_VEGA       (CC_OFFSET_AMD + 900)  // Vega56/64, minimum for fp16 dual issue
+#define CC_VEGA20     (CC_OFFSET_AMD + 906)  // MI50/Radeon VII, minimum for dp4a
+#define CC_CDNA       (CC_OFFSET_AMD + 908)  // MI100, minimum for MFMA, acc registers
+#define CC_CDNA2      (CC_OFFSET_AMD + 910)  // MI210, minimum acc register renameing
+#define CC_CDNA3      (CC_OFFSET_AMD + 942)  // MI300
+
+// RNDA removes MFMA, dp4a, xnack, acc registers, wave size is 32
+#define CC_RDNA1      (CC_OFFSET_AMD + 1010) // RX 5000
+#define CC_RDNA2      (CC_OFFSET_AMD + 1030) // RX 6000, minimum for dp4a
+#define CC_RDNA3      (CC_OFFSET_AMD + 1100) // RX 7000, minimum for WMMA
+
 #define CC_QY1        210
 #define CC_QY2        220
 
index 2a78a4393d0f7f30d82bc93d7bd5c30818cbc4d9..d6e4bfdd0d437402795249bee30dd60955192985 100644 (file)
@@ -1107,6 +1107,11 @@ static void ggml_cuda_op_mul_mat_cublas(
         const half alpha_f16 = 1.0f;
         const half beta_f16 = 0.0f;
 
+        cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
+        if (ggml_cuda_info().devices[ctx.device].cc == CC_CDNA) {
+            cu_compute_type = CUBLAS_COMPUTE_32F;
+        }
+
         CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
         CUBLAS_CHECK(
             cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
@@ -1114,7 +1119,7 @@ static void ggml_cuda_op_mul_mat_cublas(
                     &alpha_f16, src0_ptr,       CUDA_R_16F, ne00,
                                 src1_ptr,       CUDA_R_16F, ne10,
                     &beta_f16,   dst_f16.get(), CUDA_R_16F, ldc,
-                    CUBLAS_COMPUTE_16F,
+                    cu_compute_type,
                     CUBLAS_GEMM_DEFAULT_TENSOR_OP));
 
         const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
@@ -1607,6 +1612,10 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
     cublasComputeType_t cu_compute_type = CUBLAS_COMPUTE_16F;
     cudaDataType_t      cu_data_type    = CUDA_R_16F;
 
+    if (ggml_cuda_info().devices[ctx.device].cc == CC_CDNA) {
+        cu_compute_type = CUBLAS_COMPUTE_32F;
+    }
+
     // dst strides
     size_t nbd2 = dst->nb[2];
     size_t nbd3 = dst->nb[3];
index ae5c68ab3512901a902defdb8d019f7c3e4d916e..7f7c8c90b6fe2b1ce000df1ee05ccd5843be0d46 100644 (file)
@@ -148,5 +148,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
         return cc < CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
     }
 
-    return cc < CC_RDNA3 || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
+    return (cc < CC_RDNA3 && cc != CC_CDNA && cc != CC_VEGA20) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
 }
index 425acb20da311f497580aa9690d273039b8efd2d..8d8867121f321c4b68b265106b395b9d668e8d30 100644 (file)
@@ -2570,9 +2570,9 @@ static __device__ void mul_mat_q_process_tile(
 
 template <ggml_type type, int mmq_x, int nwarps, bool need_check>
 #if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
-#if defined(RDNA3) || defined(RDNA2)
+#if defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
     __launch_bounds__(WARP_SIZE*nwarps, 2)
-#endif // defined(RDNA3) || defined(RDNA2)
+#endif // defined(RDNA3) || defined(RDNA2) || defined(CDNA) || defined(GCN)
 #else
 #if __CUDA_ARCH__ >= CC_VOLTA
     __launch_bounds__(WARP_SIZE*nwarps, 1)
index 735975c160dd0d8226c72eade0ec296c033fa24d..02d1509836c4090682a4cd3d0e3e69336f3c1169 100644 (file)
@@ -142,7 +142,7 @@ static void mul_mat_vec_q_cuda(
     int64_t nwarps = 1;
     int64_t rows_per_cuda_block = 1;
 
-    if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2
+    if (ggml_cuda_info().devices[id].cc < CC_CDNA || ggml_cuda_info().devices[id].cc == CC_RDNA1) { // NVIDIA and AMD older than RDNA2 but not CDNA
         switch(ncols_y) {
             case 1:
                 nwarps = 4;
index 1f3c70c2e6934ecc58d8bc227bbc5b67a4fbf56b..3205534d66f10a2176de1b91465252a0553b5af4 100644 (file)
 
 #define __CUDA_ARCH__ 1300
 
+#if defined(__gfx803__) || defined(__gfx900__) || defined(__gfx906__)
+#define GCN
+#endif
+
+#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
+#define CDNA
+#endif
+
 #if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
     defined(__gfx1150__) || defined(__gfx1151__)
 #define RDNA3