]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
fix ggml_sycl_mul_mat_id() to match the change of api (llama/7436)
authorNeo Zhang <redacted>
Tue, 28 May 2024 09:53:37 +0000 (17:53 +0800)
committerGeorgi Gerganov <redacted>
Wed, 29 May 2024 10:16:38 +0000 (13:16 +0300)
* fix mul_mat_id to match the change of api

* rm comment

* rm unused or duplicated code, rename as review comment

src/ggml-sycl.cpp

index d5384b2e065a863d4cffede3a0a362987cec8eb9..022a52aeb6b786d0c09ceddda819191b9c5e54dc 100644 (file)
@@ -2944,6 +2944,57 @@ namespace dpct
     using shared_memory = detail::device_memory<T, shared, Dimension>;
 
 
+    template <typename T,
+            sycl::access::address_space addressSpace =
+                sycl::access::address_space::global_space,
+            sycl::memory_order memoryOrder = sycl::memory_order::relaxed,
+            sycl::memory_scope memoryScope = sycl::memory_scope::device>
+    inline T atomic_fetch_add(T *addr, T operand) {
+    auto atm =
+        sycl::atomic_ref<T, memoryOrder, memoryScope, addressSpace>(addr[0]);
+    return atm.fetch_add(operand);
+    }
+
+    template <sycl::access::address_space addressSpace =
+                sycl::access::address_space::global_space,
+            sycl::memory_order memoryOrder = sycl::memory_order::relaxed,
+            sycl::memory_scope memoryScope = sycl::memory_scope::device,
+            typename T1, typename T2>
+    inline T1 atomic_fetch_add(T1 *addr, T2 operand) {
+    auto atm =
+        sycl::atomic_ref<T1, memoryOrder, memoryScope, addressSpace>(addr[0]);
+    return atm.fetch_add(operand);
+    }
+
+    template <typename T, sycl::access::address_space addressSpace =
+                            sycl::access::address_space::global_space>
+    inline T atomic_fetch_add(T *addr, T operand,
+                            sycl::memory_order memoryOrder) {
+    switch (memoryOrder) {
+        case sycl::memory_order::relaxed:
+            return atomic_fetch_add<T, addressSpace, sycl::memory_order::relaxed,
+                                    sycl::memory_scope::device>(addr, operand);
+        case sycl::memory_order::acq_rel:
+            return atomic_fetch_add<T, addressSpace, sycl::memory_order::acq_rel,
+                                    sycl::memory_scope::device>(addr, operand);
+        case sycl::memory_order::seq_cst:
+            return atomic_fetch_add<T, addressSpace, sycl::memory_order::seq_cst,
+                                    sycl::memory_scope::device>(addr, operand);
+        default:
+            assert(false && "Invalid memory_order for atomics. Valid memory_order for "
+                            "atomics are: sycl::memory_order::relaxed, "
+                            "sycl::memory_order::acq_rel, sycl::memory_order::seq_cst!");
+        }
+    }
+
+    template <sycl::access::address_space addressSpace =
+                sycl::access::address_space::global_space,
+            typename T1, typename T2>
+    inline T1 atomic_fetch_add(T1 *addr, T2 operand,
+                            sycl::memory_order memoryOrder) {
+    atomic_fetch_add<T1, addressSpace>(addr, operand, memoryOrder);
+    }
+
 } // COPY from DPCT head files
 
 #define GGML_COMMON_DECL_SYCL
@@ -3060,6 +3111,7 @@ void   ggml_sycl_get_device_description(int device, char * description, size_t d
 bool   ggml_backend_is_sycl(ggml_backend_t backend);
 int    ggml_backend_sycl_get_device(ggml_backend_t backend);
 int    get_main_device();
+static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer);
 void   print_ggml_tensor(const char*name, struct ggml_tensor *src);
 void   log_tensor_with_cnt(const char* name, struct ggml_tensor * src, int stop_cnt);
 
@@ -15459,22 +15511,86 @@ static void ggml_sycl_mul_mat_id_sycl(ggml_tensor * dst) {
 }
 #endif
 
+struct mmid_row_mapping {
+    int32_t i1;
+    int32_t i2;
+};
+
+__dpct_inline__ static void k_copy_src1_to_contiguous(
+    const char *__restrict__ src1_original, char *__restrict__ src1_contiguous,
+    int *__restrict__ cur_src1_row, mmid_row_mapping *__restrict__ row_mapping,
+    const char *__restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
+    int64_t ne11, int64_t ne10, size_t nb11, size_t nb12,
+    const sycl::nd_item<3> &item_ct1, int &src1_row) {
+    int32_t iid1 = item_ct1.get_group(2);
+    int32_t id = item_ct1.get_group(1);
+
+    const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
+
+    if (row_id_i != i02) {
+        return;
+    }
+
+    const int64_t i11 = id % ne11;
+    const int64_t i12 = iid1;
+
+    if (item_ct1.get_local_id(2) == 0) {
+        src1_row =
+            dpct::atomic_fetch_add<sycl::access::address_space::generic_space>(
+                cur_src1_row, 1);
+        row_mapping[src1_row] = {id, iid1};
+    }
+    /*
+    DPCT1065:194: Consider replacing sycl::nd_item::barrier() with
+    sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better
+    performance if there is no access to global memory.
+    */
+    item_ct1.barrier();
+
+    const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
+    float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
+
+#pragma unroll
+    for (int i = item_ct1.get_local_id(2); i < ne10;
+         i += item_ct1.get_local_range(2)) {
+        src1_row_contiguous[i] = src1_row_original[i];
+    }
+}
+
+__dpct_inline__ static void k_copy_dst_from_contiguous(
+    char *__restrict__ dst_original, const char *__restrict__ dst_contiguous,
+    const mmid_row_mapping *__restrict__ row_mapping, int64_t ne0, size_t nb1,
+    size_t nb2, const sycl::nd_item<3> &item_ct1) {
+    int32_t i = item_ct1.get_group(2);
+
+    const int32_t i1 = row_mapping[i].i1;
+    const int32_t i2 = row_mapping[i].i2;
+
+    const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1);
+    float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2);
+
+#pragma unroll
+    for (int j = item_ct1.get_local_id(2); j < ne0;
+         j += item_ct1.get_local_range(2)) {
+        dst_row_original[j] = dst_row_contiguous[j];
+    }
+}
+
 static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
                                  const ggml_tensor *src1,
                                  ggml_tensor *dst) try {
-    GGML_ASSERT(src0->backend != GGML_BACKEND_TYPE_GPU_SPLIT &&
-                "mul_mat_id does not support split buffers");
+    GGML_ASSERT(!ggml_backend_buffer_is_sycl_split(src0->buffer) && "mul_mat_id does not support split buffers");
+
     const ggml_tensor *ids = dst->src[2];
-    const dpct::queue_ptr stream = g_syclStreams[g_main_device][0];
+    GGML_TENSOR_BINARY_OP_LOCALS
 
-    const size_t nb11 = src1->nb[1];
-    const size_t nb1 = dst->nb[1];
+    const dpct::queue_ptr stream = g_syclStreams[g_main_device][0];
 
-    const int32_t id = ((int32_t *)dst->op_params)[0];
-    const int32_t n_as = src0->ne[2];
+    const int64_t n_as = ne02;
+    const int64_t n_ids = ids->ne[0];
 
     std::vector<char> ids_host(ggml_nbytes(ids));
-    const char *ids_dev = (const char *)ids->data;
+    const char * ids_dev = (const char *) ids->data;
 
     SYCL_CHECK(CHECK_TRY_ERROR(
         stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));
@@ -15514,24 +15630,40 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
 
     src0_row.ne[2] = 1;
     src0_row.ne[3] = 1;
-    src0_row.nb[3] = src0->nb[2];
-
-    if (src1->ne[1] == 1) {
-        for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
-            const int32_t row_id =
-                *(const int32_t *)(ids_host.data() + i01 * ids->nb[1] +
-                                   id * ids->nb[0]);
-
-            GGML_ASSERT(row_id >= 0 && row_id < n_as);
+    src0_row.nb[3] = nb02;
+
+    src1_row.ne[1] = 1;
+    src1_row.ne[2] = 1;
+    src1_row.ne[3] = 1;
+    src1_row.nb[2] = nb11;
+    src1_row.nb[3] = nb11;
+
+    dst_row.ne[1] = 1;
+    dst_row.ne[2] = 1;
+    dst_row.ne[3] = 1;
+    dst_row.nb[2] = nb1;
+    dst_row.nb[3] = nb1;
+    if (ne12 == 1) {
+        for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
+            for (int64_t id = 0; id < n_ids; id++) {
+                const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
+                GGML_ASSERT(i02 >= 0 && i02 < n_as);
+
+                const int64_t i11 = id % ne11;
+                const int64_t i12 = iid1;
+
+                const int64_t i1 = id;
+                const int64_t i2 = i12;
 
             src0_row_extra.data_device[g_main_device] =
-                src0_original + row_id * src0->nb[2];
+                src0_original + i02*nb02;
             src1_row_extra.data_device[g_main_device] =
-                src1_original + i01 * src1->nb[1];
+                src1_original + + i11*nb11 + i12*nb12;
             dst_row_extra.data_device[g_main_device] =
-                dst_original + i01 * dst->nb[1];
+                dst_original + i1*nb1   + i2*nb2;
 
             ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row);
+            }
         }
     } else {
         sycl_pool_alloc<char> src1_contiguous(sizeof(float)*ggml_nelements(src1));
@@ -15540,64 +15672,98 @@ static void ggml_sycl_mul_mat_id(const ggml_tensor *src0,
         src1_row_extra.data_device[g_main_device] = src1_contiguous.get();
         dst_row_extra.data_device[g_main_device]  =  dst_contiguous.get();
 
-        for (int32_t row_id = 0; row_id < n_as; ++row_id) {
+        for (int64_t i02 = 0; i02 < n_as; i02++) {
             int64_t num_src1_rows = 0;
-            for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
-                const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
+            for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
+                for (int64_t id = 0; id < n_ids; id++) {
+                    const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
 
-                if (row_id_i != row_id) {
-                    continue;
-                }
+                    GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
 
-                GGML_ASSERT(row_id >= 0 && row_id < n_as);
+                    if (row_id_i != i02) {
+                        continue;
+                    }
 
-                SYCL_CHECK(CHECK_TRY_ERROR(
-                    stream->memcpy(src1_contiguous.get() + num_src1_rows * nb11,
-                                   src1_original + i01 * nb11, nb11)));
-                num_src1_rows++;
+                    num_src1_rows++;
+                }
             }
 
             if (num_src1_rows == 0) {
                 continue;
             }
 
-            src0_row_extra.data_device[g_main_device] =
-                src0_original + row_id * src0->nb[2];
 
+            sycl_pool_alloc<int> dev_cur_src1_row(1);
+            sycl_pool_alloc<mmid_row_mapping> dev_row_mapping(num_src1_rows);
+            SYCL_CHECK(CHECK_TRY_ERROR(
+                stream->memset(dev_cur_src1_row.get(), 0, sizeof(int))));
+
+            {
+                sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne10, 768u));
+                sycl::range<3> grid_dims(1, n_ids, ids->ne[1]);
+                stream->submit([&](sycl::handler &cgh) {
+                    sycl::local_accessor<int, 0> src1_row_acc(cgh);
+
+                    char *__restrict src1_contiguous_get =
+                        src1_contiguous.get();
+                    int *__restrict dev_cur_src1_row_get =
+                        dev_cur_src1_row.get();
+                    mmid_row_mapping *__restrict dev_row_mapping_get =
+                        dev_row_mapping.get();
+                    size_t ids_nb_ct6 = ids->nb[1];
+                    size_t ids_nb_ct7 = ids->nb[0];
+
+                    cgh.parallel_for(
+                        sycl::nd_range<3>(grid_dims * block_dims, block_dims),
+                        [=](sycl::nd_item<3> item_ct1) {
+                            k_copy_src1_to_contiguous(
+                                src1_original, src1_contiguous_get,
+                                dev_cur_src1_row_get,
+                                dev_row_mapping_get, ids_dev, i02,
+                                ids_nb_ct6, ids_nb_ct7, ne11, ne10, nb11, nb12,
+                                item_ct1, src1_row_acc);
+                        });
+                });
+            }
+
+            src0_row_extra.data_device[g_main_device] = src0_original + i02*nb02;
+
+            GGML_ASSERT(nb11 == sizeof(float)*ne10);
+            GGML_ASSERT(nb1 == sizeof(float)*ne0);
             src1_row.ne[1] = num_src1_rows;
-            dst_row.ne[1] = num_src1_rows;
 
             src1_row.nb[1] = nb11;
             src1_row.nb[2] = num_src1_rows*nb11;
             src1_row.nb[3] = num_src1_rows*nb11;
 
+            dst_row.ne[1] = num_src1_rows;
             dst_row.nb[1] = nb1;
             dst_row.nb[2] = num_src1_rows*nb1;
             dst_row.nb[3] = num_src1_rows*nb1;
 
             ggml_sycl_mul_mat(&src0_row, &src1_row, &dst_row);
 
-            num_src1_rows = 0;
-            for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
-                const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
-
-                if (row_id_i != row_id) {
-                    continue;
-                }
-
-                GGML_ASSERT(row_id >= 0 && row_id < n_as);
-
-                SYCL_CHECK(CHECK_TRY_ERROR(stream->memcpy(
-                    dst_original + i01 * nb1,
-                    dst_contiguous.get() + num_src1_rows * nb1, nb1)));
-                num_src1_rows++;
+            {
+                sycl::range<3> block_dims(1, 1, std::min((unsigned int)ne0, 768u));
+                sycl::range<3> grid_dims(1, 1, num_src1_rows);
+                stream->submit([&](sycl::handler &cgh) {
+                    const char *__restrict dst_contiguous_get =
+                        dst_contiguous.get();
+                    const mmid_row_mapping *__restrict dev_row_mapping_get =
+                        dev_row_mapping.get();
+
+                    cgh.parallel_for(
+                        sycl::nd_range<3>(grid_dims * block_dims, block_dims),
+                        [=](sycl::nd_item<3> item_ct1) {
+                            k_copy_dst_from_contiguous(dst_original,
+                                                       dst_contiguous_get,
+                                                       dev_row_mapping_get,
+                                                       ne0, nb1, nb2, item_ct1);
+                        });
+                });
             }
         }
     }
-
-    if (dst->backend == GGML_BACKEND_TYPE_CPU) {
-        SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
-    }
 }
 catch (sycl::exception const &exc) {
   std::cerr << exc.what() << "Exception caught at file:" << __FILE__
@@ -16580,10 +16746,9 @@ GGML_CALL static const char * ggml_backend_sycl_split_buffer_get_name(ggml_backe
     UNUSED(buffer);
 }
 
-// unused at the moment
-//static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer) {
-//    return buffer->iface.get_name == ggml_backend_sycl_split_buffer_get_name;
-//}
+static bool ggml_backend_buffer_is_sycl_split(ggml_backend_buffer_t buffer) {
+   return buffer->iface.get_name == ggml_backend_sycl_split_buffer_get_name;
+}
 
 GGML_CALL static void ggml_backend_sycl_split_buffer_free_buffer(ggml_backend_buffer_t buffer) {
     ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context;