]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
fix the mul_mat_id ut issues (llama/8427)
authorChen Xi <redacted>
Fri, 12 Jul 2024 00:52:04 +0000 (00:52 +0000)
committerGeorgi Gerganov <redacted>
Sat, 27 Jul 2024 15:26:12 +0000 (18:26 +0300)
* fix part of mul_mat_id

* skip the bfloat 16 sycl ut

Signed-off-by: Chen Xi <redacted>
---------

Signed-off-by: Chen Xi <redacted>
Co-authored-by: Meng, Hengyu <redacted>
Co-authored-by: Chen Xi <redacted>
src/ggml-backend.c
src/ggml-sycl.cpp

index 13c71c310c446d66d99e57167c2588618e8f1026..dbbaa3941febea4ec8e74df17c2e2dcbd808a8de 100644 (file)
@@ -394,7 +394,7 @@ void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event)
 
 // backend registry
 
-#define GGML_REG_MAX_BACKENDS 16
+#define GGML_REG_MAX_BACKENDS 64
 
 struct ggml_backend_reg {
     char name[128];
index 9c419ba896706400b2cef39fa1a6700a501ba344..5a890237f24535b16caa6cc974a70bcae5fa9643 100644 (file)
@@ -3768,37 +3768,13 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten
         stream->memcpy(ids_host.data(), ids_dev, ggml_nbytes(ids))));
     SYCL_CHECK(CHECK_TRY_ERROR(stream->wait()));
 
-    const ggml_tensor_extra_gpu *src0_extra =
-        (const ggml_tensor_extra_gpu *)src0->extra;
-    const ggml_tensor_extra_gpu *src1_extra =
-        (const ggml_tensor_extra_gpu *)src1->extra;
-    const ggml_tensor_extra_gpu *dst_extra =
-        (const ggml_tensor_extra_gpu *)dst->extra;
-
-    ggml_tensor_extra_gpu src0_row_extra;
-    ggml_tensor_extra_gpu src1_row_extra;
-    ggml_tensor_extra_gpu dst_row_extra;
-
     ggml_tensor src0_row = *src0;
     ggml_tensor src1_row = *src1;
     ggml_tensor dst_row = *dst;
 
-    src1_row.backend = GGML_BACKEND_TYPE_GPU;
-    dst_row.backend  = GGML_BACKEND_TYPE_GPU;
-
-    src0_row.extra = &src0_row_extra;
-    src1_row.extra = &src1_row_extra;
-    dst_row.extra = &dst_row_extra;
-
-    char *src0_original = src1->backend == GGML_BACKEND_TYPE_CPU
-                              ? (char *)src0->data
-                              : (char *)src0_extra->data_device[ctx.device];
-    char *src1_original = src1->backend == GGML_BACKEND_TYPE_CPU
-                              ? (char *)src1->data
-                              : (char *)src1_extra->data_device[ctx.device];
-    char *dst_original = dst->backend == GGML_BACKEND_TYPE_CPU
-                             ? (char *)dst->data
-                             : (char *)dst_extra->data_device[ctx.device];
+    char *src0_original = (char *)src0->data;
+    char *src1_original = (char *)src1->data;
+    char *dst_original = (char *)dst->data;
 
     src0_row.ne[2] = 1;
     src0_row.ne[3] = 1;
@@ -3827,12 +3803,9 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten
                 const int64_t i1 = id;
                 const int64_t i2 = i12;
 
-            src0_row_extra.data_device[ctx.device] =
-                src0_original + i02*nb02;
-            src1_row_extra.data_device[ctx.device] =
-                src1_original + + i11*nb11 + i12*nb12;
-            dst_row_extra.data_device[ctx.device] =
-                dst_original + i1*nb1   + i2*nb2;
+            src0_row.data = src0_original + i02*nb02;
+            src1_row.data = src1_original + + i11*nb11 + i12*nb12;
+            dst_row.data = dst_original + i1*nb1   + i2*nb2;
 
             ggml_sycl_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
             }
@@ -3841,8 +3814,8 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten
         ggml_sycl_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
         ggml_sycl_pool_alloc<char>  dst_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(dst));
 
-        src1_row_extra.data_device[ctx.device] = src1_contiguous.get();
-        dst_row_extra.data_device[ctx.device]  =  dst_contiguous.get();
+        src1_row.data = src1_contiguous.get();
+        dst_row.data  =  dst_contiguous.get();
 
         for (int64_t i02 = 0; i02 < n_as; i02++) {
             int64_t num_src1_rows = 0;
@@ -3898,7 +3871,7 @@ static void ggml_sycl_mul_mat_id(ggml_backend_sycl_context & ctx, const ggml_ten
                 });
             }
 
-            src0_row_extra.data_device[ctx.device] = src0_original + i02*nb02;
+            src0_row.data = src0_original + i02*nb02;
 
             GGML_ASSERT(nb11 == sizeof(float)*ne10);
             GGML_ASSERT(nb1 == sizeof(float)*ne0);
@@ -5221,6 +5194,10 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
                         return false;
                     }
                 }
+                ggml_type src0_type = op->src[0]->type;
+                if (src0_type == GGML_TYPE_BF16) {
+                    return false;
+                }
                 return true;
             } break;
         case GGML_OP_GET_ROWS: