]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
Use batched mul_mat pathway (llama/5591)
authorAidanBeltonS <redacted>
Fri, 1 Mar 2024 07:36:47 +0000 (07:36 +0000)
committerGeorgi Gerganov <redacted>
Fri, 8 Mar 2024 09:38:31 +0000 (11:38 +0200)
* Use batched mul_mat pathway

* rm extra line

* Explicitly state scaled data type

---------

Co-authored-by: Abhilash Majumder <redacted>
ggml-sycl.cpp

index a054ec8b92bac708fb4a3a75e9bda578fb24dd8c..6f391b0c650ad5d72cd6666fcb5f30a4dc2cbd3c 100644 (file)
@@ -12726,6 +12726,7 @@ static void ggml_sycl_op_mul_mat(const ggml_tensor *src0,
 
     GGML_ASSERT(dst->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
     GGML_ASSERT(src1->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));
 
     GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);
 
@@ -13269,31 +13270,23 @@ static void k_compute_batched_ptrs(const sycl::half *src0_as_f16,
     int64_t i03 = i13 / r3;
     int64_t i02 = i12 / r2;
 
-    ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02   + i03*nb03;
-    ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12/2 + i13*nb13/2;
-    ptrs_dst[0*ne23 + i12 + i13*ne12] = (      char *)         dst + i12*nbd2   + i13*nbd3;
+    ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_as_f16 + i02*nb02 + i03*nb03;
+    ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_as_f16 + i12*nb12 + i13*nb13;
+    ptrs_dst[0*ne23 + i12 + i13*ne12] = (      char *)         dst + i12*nbd2 + i13*nbd3;
 }
 
-static void ggml_sycl_mul_mat_mat_batched_sycl(const ggml_tensor *src0,
-                                                 const ggml_tensor *src1,
-                                                 ggml_tensor *dst) try {
+static void ggml_sycl_mul_mat_batched_sycl(const ggml_tensor *src0,
+                                             const ggml_tensor *src1,
+                                             ggml_tensor *dst) try {
     GGML_ASSERT(!ggml_is_transposed(src0));
     GGML_ASSERT(!ggml_is_transposed(src1));
 
     GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
     GGML_ASSERT(src0->type == GGML_TYPE_F16);
-    GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
-    GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne);
-
-    GGML_TENSOR_LOCALS(int64_t, nb0, src0, nb);
 
-    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne);
-
-    GGML_TENSOR_LOCALS(int64_t, nb1, src1, nb);
+    GGML_TENSOR_BINARY_OP_LOCALS
 
-    const int64_t ne1 = ggml_nelements(src1);
-    const int64_t ne  = ggml_nelements(dst);
+    const int64_t ne_dst  = ggml_nelements(dst);
 
     SYCL_CHECK(ggml_sycl_set_device(g_main_device));
     dpct::queue_ptr main_stream = g_syclStreams[g_main_device_index][0];
@@ -13312,11 +13305,16 @@ static void ggml_sycl_mul_mat_mat_batched_sycl(const ggml_tensor *src0,
     float * dst_ddf = (float *) dst_extra->data_device[g_main_device_index];
 
     // convert src1 to fp16
-    const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
-    GGML_ASSERT(to_fp16_sycl != nullptr);
-
-    sycl_pool_alloc<sycl::half> src1_as_f16(ne1);
-    to_fp16_sycl(src1_ddf, src1_as_f16.get(), ne1, main_stream);
+    sycl_pool_alloc<sycl::half> src1_f16_alloc;
+    if (src1->type != GGML_TYPE_F16) {
+      const to_fp16_sycl_t to_fp16_sycl = ggml_get_to_fp16_sycl(src1->type);
+      const int64_t ne_src1 = ggml_nelements(src1);
+      src1_f16_alloc.alloc(ne_src1);
+      GGML_ASSERT(to_fp16_sycl != nullptr);
+      to_fp16_sycl(src1_ddf, src1_f16_alloc.get(), ne_src1, main_stream);
+    }
+    sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf
+                                                       : src1_f16_alloc.get();
 
     sycl_pool_alloc<sycl::half> dst_f16;
     char * dst_t;
@@ -13337,20 +13335,12 @@ static void ggml_sycl_mul_mat_mat_batched_sycl(const ggml_tensor *src0,
     const void * alpha = &alpha_f16;
     const void * beta  = &beta_f16;
 
-    if (dst->op_params[0] == GGML_PREC_DEFAULT) {
-        dst_t = (char *) dst_f16.alloc(ne);
-
-        nbd2 /= sizeof(float) / sizeof(sycl::half);
-        nbd3 /= sizeof(float) / sizeof(sycl::half);
-    } else {
-        dst_t = (char *) dst_ddf;
-
-        cu_compute_type = dpct::library_data_t::real_float;
-        cu_data_type = dpct::library_data_t::real_float;
+    // TODO: Renable (dst->op_params[0] =! GGML_PREC_DEFAULT) pathway
+    // once oneMKL open source supports half, half, float, float: datatypes
+    dst_t = (char *) dst_f16.alloc(ne_dst);
 
-        alpha = &alpha_f32;
-        beta  = &beta_f32;
-    }
+    nbd2 /= sizeof(float) / sizeof(sycl::half);
+    nbd3 /= sizeof(float) / sizeof(sycl::half);
 
     GGML_ASSERT(ne12 % ne02 == 0);
     GGML_ASSERT(ne13 % ne03 == 0);
@@ -13386,10 +13376,10 @@ static void ggml_sycl_mul_mat_mat_batched_sycl(const ggml_tensor *src0,
             *g_sycl_handles[g_main_device_index], oneapi::mkl::transpose::trans,
             oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
             (const char *)src0_as_f16, dpct::library_data_t::real_half,
-            nb01 / sizeof(sycl::half), src0->nb[2] / sizeof(sycl::half),
-            (const char *)src1_as_f16.get(), dpct::library_data_t::real_half,
-            nb11 / sizeof(float), src1->nb[2] / sizeof(float), beta,
-            (char *)dst_t, cu_data_type, ne01, dst->nb[2] / sizeof(float),
+            nb01 / nb00, nb02 / nb00,
+            (const char *)src1_f16, dpct::library_data_t::real_half,
+            nb11 / nb10, nb12 / nb10, beta,
+            (char *)dst_t, cu_data_type, ne01, nb2 / nb0,
             ne12 * ne13, cu_compute_type)));
     } else {
         // use syclGemmBatchedEx
@@ -13409,44 +13399,35 @@ static void ggml_sycl_mul_mat_mat_batched_sycl(const ggml_tensor *src0,
                                          {sycl::aspect::fp16});
 
             main_stream->submit([&](sycl::handler &cgh) {
-                const sycl::half *src1_as_f16_get_ct1 = src1_as_f16.get();
-                const void **ptrs_src_get_ct3 = ptrs_src.get();
-                void **ptrs_dst_get_ct4 = ptrs_dst.get();
-
+                const void **ptrs_src_get = ptrs_src.get();
+                void **ptrs_dst_get = ptrs_dst.get();
+                size_t nb12_scaled = src1->type == GGML_TYPE_F16 ? nb12 : nb12 / 2;
+                size_t nb13_scaled = src1->type == GGML_TYPE_F16 ? nb13 : nb13 / 2;
                 cgh.parallel_for(sycl::nd_range<3>(block_dims, block_dims),
                                  [=](sycl::nd_item<3> item_ct1) {
                                      k_compute_batched_ptrs(
-                                         src0_as_f16, src1_as_f16_get_ct1,
-                                         dst_t, ptrs_src_get_ct3,
-                                         ptrs_dst_get_ct4, ne12, ne13, ne23,
-                                         nb02, nb03, nb12, nb13, nbd2, nbd3, r2,
-                                         r3, item_ct1);
+                                         src0_as_f16, src1_f16,
+                                         dst_t, ptrs_src_get,
+                                         ptrs_dst_get, ne12, ne13, ne23,
+                                         nb02, nb03, nb12_scaled, nb13_scaled,
+                                         nbd2, nbd3, r2, r3, item_ct1);
                                  });
             });
         }
-        /*
-        DPCT1010:95: SYCL uses exceptions to report errors and does not use the
-        error codes. The call was replaced with 0. You need to rewrite this
-        code.
-        */
-        SYCL_CHECK(0);
-
         SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm_batch(
             *g_sycl_handles[g_main_device_index], oneapi::mkl::transpose::trans,
             oneapi::mkl::transpose::nontrans, ne01, ne11, ne10, alpha,
             (const void **)(ptrs_src.get() + 0 * ne23),
-            dpct::library_data_t::real_half, nb01 / sizeof(sycl::half),
+            dpct::library_data_t::real_half, nb01 / nb00,
             (const void **)(ptrs_src.get() + 1 * ne23),
-            dpct::library_data_t::real_half, nb11 / sizeof(float), beta,
+            dpct::library_data_t::real_half, nb11 / nb10, beta,
             (void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23,
             cu_compute_type)));
     }
 #endif
 
-    if (dst->op_params[0] == GGML_PREC_DEFAULT) {
-        const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
-        to_fp32_sycl(dst_f16.get(), dst_ddf, ne, main_stream);
-    }
+    const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
+    to_fp32_sycl(dst_f16.get(), dst_ddf, ne_dst, main_stream);
 }
 catch (sycl::exception const &exc) {
   std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -13491,10 +13472,10 @@ static void ggml_sycl_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
         // KQV single-batch
         // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_vec_nc\n");
         ggml_sycl_mul_mat_vec_nc(src0, src1, dst);
-    } else if (!split && all_on_device && use_xmx && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
+    } else if (!split && all_on_device && use_xmx && src0->type == GGML_TYPE_F16 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
         // KQ + KQV multi-batch
-        // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_mat_batched_sycl\n");
-        ggml_sycl_mul_mat_mat_batched_sycl(src0, src1, dst);
+        // GGML_SYCL_DEBUG("ggml_sycl_mul_mat_batched_sycl\n");
+        ggml_sycl_mul_mat_batched_sycl(src0, src1, dst);
     } else if (src0->type == GGML_TYPE_F32) {
         // GGML_SYCL_DEBUG("ggml_sycl_op_mul_mat\n");
         ggml_sycl_op_mul_mat(src0, src1, dst, ggml_sycl_op_mul_mat_sycl, false);