]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
Align GEMM dispatch (llama/7566)
authorMeng, Hengyu <redacted>
Tue, 28 May 2024 23:00:24 +0000 (07:00 +0800)
committerGeorgi Gerganov <redacted>
Wed, 29 May 2024 10:16:38 +0000 (13:16 +0300)
* align GEMM dispatch

src/ggml-sycl.cpp

index dccfe9eb407af7df33f24da034a251c74ec06406..a73448136a4d8d90b253b283e487d7f6e754281b 100644 (file)
@@ -3022,20 +3022,19 @@ static int g_work_group_size = 0;
 // typedef sycl::half ggml_fp16_t;
 
 #define __SYCL_ARCH__ DPCT_COMPATIBILITY_TEMP
-#define VER_4VEC   610          //todo for hardward optimize.
+#define VER_4VEC   130          //todo for hardward optimize.
 #define VER_GEN9      700       //todo for hardward optimize.
 #define VER_GEN12 1000000       //todo for hardward optimize.
 #define VER_GEN13      (VER_GEN12 + 1030)   //todo for hardward optimize.
 
 #define GGML_SYCL_MAX_NODES 8192 //TODO: adapt to hardwares
 
-
-//define for XMX in Intel GPU
-//TODO: currently, it's not used for XMX really.
-#define SYCL_USE_XMX
+#if !defined(GGML_SYCL_FORCE_MMQ)
+    #define SYCL_USE_XMX
+#endif
 
 // max batch size to use MMQ kernels when tensor cores are available
-#define XMX_MAX_BATCH_SIZE 32
+#define MMQ_MAX_BATCH_SIZE 32
 
 
 #if defined(_MSC_VER)
@@ -15249,6 +15248,29 @@ catch (sycl::exception const &exc) {
   std::exit(1);
 }
 
+inline bool ggml_sycl_supports_mmq(enum ggml_type type) {
+    // TODO: accuracy issues in MMQ
+    return false;
+}
+
+bool ggml_sycl_supports_dmmv(enum ggml_type type) {
+    switch (type) {
+        case GGML_TYPE_Q4_0:
+        case GGML_TYPE_Q4_1:
+        case GGML_TYPE_Q5_0:
+        case GGML_TYPE_Q5_1:
+        case GGML_TYPE_Q8_0:
+        case GGML_TYPE_Q2_K:
+        case GGML_TYPE_Q3_K:
+        case GGML_TYPE_Q4_K:
+        case GGML_TYPE_Q5_K:
+        case GGML_TYPE_Q6_K:
+        case GGML_TYPE_F16:
+            return true;
+        default:
+            return false;
+    }
+}
 
 static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     const bool all_on_device =
@@ -15265,76 +15287,42 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
         }
     }
 
-#ifdef SYCL_USE_XMX
-    const bool use_xmx = true;
-#else
-    const bool use_xmx = false;
-#endif
+    // check data types and tensor shapes for custom matrix multiplication kernels:
+    bool use_dequantize_mul_mat_vec = ggml_sycl_supports_dmmv(src0->type)
+        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
+        && src0->ne[0] % GGML_SYCL_DMMV_X == 0 && src1->ne[1] == 1;
 
-    // debug helpers
-    //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
-    //printf("      %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
-    //printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]);
-    //printf("      %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]);
-    //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
-    //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
+    bool use_mul_mat_vec_q =  ggml_is_quantized(src0->type)
+        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32
+        && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE;
+
+    bool use_mul_mat_q =  ggml_sycl_supports_mmq(src0->type)
+        && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32;
+
+    // mmvq and mmq need the __dp4a instruction which is available for gen12+
+    // Workaround in https://github.com/ggerganov/llama.cpp/commit/95f84d5ce8b449a9b16009434aca800df504a02e
+    use_mul_mat_q = use_mul_mat_q && (src0->type != GGML_TYPE_IQ2_XXS);
+#ifdef SYCL_USE_XMX
+    use_mul_mat_q = use_mul_mat_q && (src1->ne[1] <= MMQ_MAX_BATCH_SIZE);
+#endif // SYCL_USE_XMX
 
-    if (!split && all_on_device && !use_xmx && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
+    if (!split && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
         // KQ single-batch
-        // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_vec_p021\n");
         ggml_sycl_mul_mat_vec_p021(src0, src1, dst);
-    } else if (!split && all_on_device && !use_xmx && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
+    } else if (!split && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
         // KQV single-batch
-        // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_vec_nc\n");
         ggml_sycl_mul_mat_vec_nc(src0, src1, dst);
-    } else if (!split && all_on_device && use_xmx && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
+    } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
         // KQ + KQV multi-batch
-        // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_batched_sycl\n");
         ggml_sycl_mul_mat_batched_sycl(src0, src1, dst);
-    } else if (src0->type == GGML_TYPE_F32) {
-        // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat\n");
-        ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
-    } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
-        // GGML_SYCL_DEBUG("ggml_is_quantized or GGML_TYPE_F16\n");
-        if (src1->ne[1] == 1 && src0->ne[0] % GGML_SYCL_DMMV_X == 0) {
-#ifdef GGML_SYCL_FORCE_DMMV
-            const bool use_mul_mat_vec_q = false;
-#else
-            bool use_mul_mat_vec_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type);
-            use_mul_mat_vec_q = use_mul_mat_vec_q ||
-                (src0->type == GGML_TYPE_IQ2_XXS) || (src0->type == GGML_TYPE_IQ2_XS) || (src0->type == GGML_TYPE_IQ2_S) ||
-                (src0->type == GGML_TYPE_IQ3_XXS) || (src0->type == GGML_TYPE_IQ3_S) ||
-                (src0->type == GGML_TYPE_IQ4_NL) || (src0->type == GGML_TYPE_IQ4_XS) ||
-                (src0->type == GGML_TYPE_IQ1_S) || (src0->type == GGML_TYPE_IQ1_M);
-
-
-#endif // GGML_SYCL_FORCE_DMMV
-
-            if (use_mul_mat_vec_q) {
-                // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_vec_q path\n");
-                ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
-            } else {
-                // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_dequantize_mul_mat_vec path\n");
-                ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
-            }
-        } else {
-            bool use_mul_mat_q = min_compute_capability >= VER_4VEC && ggml_is_quantized(src0->type);
-            use_mul_mat_q = use_mul_mat_q && (src0->type != GGML_TYPE_IQ2_XXS);
-
-            if (use_xmx && min_compute_capability >= VER_GEN9 && src1->ne[1] > XMX_MAX_BATCH_SIZE) {
-                use_mul_mat_q = false;
-            }
-
-            if (use_mul_mat_q) {
-                // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_q path\n");
-                ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
-            } else {
-                // GGML_SYCL_DEBUG("ggml_sycl_mul_mat ggml_sycl_op_mul_mat_sycl path\n");
-                ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
-            }
-        }
+    } else if (use_dequantize_mul_mat_vec) {
+        ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_dequantize_mul_mat_vec, false);
+    } else if (use_mul_mat_vec_q) {
+        ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_vec_q, true);
+    } else if (use_mul_mat_q) {
+        ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_q, true);
     } else {
-        GGML_ASSERT(false);
+        ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);
     }
 }