]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
Enabled more data types for oneMKL gemm_batch (llama/8236)
authorOuadie EL FAROUKI <redacted>
Fri, 5 Jul 2024 12:23:25 +0000 (13:23 +0100)
committerGeorgi Gerganov <redacted>
Mon, 8 Jul 2024 10:03:28 +0000 (13:03 +0300)
src/ggml-sycl.cpp
src/ggml-sycl/dpct/helper.hpp

index 053cc950a8a39e5ca5879230c4c630ed0a6979d0..21006cd7bebf7c682c3c55989d2598436a82121d 100644 (file)
@@ -3493,10 +3493,6 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
     SYCL_CHECK(ggml_sycl_set_device(ctx.device));
     queue_ptr main_stream = ctx.stream();;
 
-    bool no_mixed_dtypes = main_stream->get_backend() == sycl::backend::ext_oneapi_cuda ||
-                           main_stream->get_backend() == sycl::backend::ext_oneapi_hip;
-
-
     void * src0_ddq = src0->data;
     sycl::half *src0_as_f16 = (sycl::half *)src0_ddq;
     float * src1_ddf = (float *) src1->data;
@@ -3514,15 +3510,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
     sycl::half *src1_f16 = src1->type == GGML_TYPE_F16 ? (sycl::half *)src1_ddf
                                                        : src1_f16_alloc.get();
 
-    ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool());
     char * dst_t;
 
     dpct::library_data_t cu_compute_type = dpct::library_data_t::real_float;
     dpct::library_data_t cu_data_type = dpct::library_data_t::real_float;
-    if (no_mixed_dtypes) {
-        cu_compute_type = dpct::library_data_t::real_half;
-        cu_data_type = dpct::library_data_t::real_half;
-    }
 
     // dst strides
     size_t nbd2 = dst->nb[2];
@@ -3531,26 +3522,10 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
     const float alpha_f32 = 1.0f;
     const float beta_f32 = 0.0f;
 
-    const sycl::half alpha_f16 = 1.0f;
-    const sycl::half beta_f16 = 0.0f;
-
     const void * alpha = &alpha_f32;
     const void * beta  = &beta_f32;
-    if (no_mixed_dtypes) {
-        alpha = &alpha_f16;
-        beta  = &beta_f16;
-    }
-
-    // TODO: Renable (dst->op_params[0] =! GGML_PREC_DEFAULT) pathway
-    // when oneMKL open source supports half, half, float, float: datatypes
 
     dst_t = (char *) dst_ddf;
-    if (no_mixed_dtypes) {
-        dst_t = (char *) dst_f16.alloc(ne_dst);
-
-        nbd2 /= sizeof(float) / sizeof(sycl::half);
-        nbd3 /= sizeof(float) / sizeof(sycl::half);
-    }
 
     GGML_ASSERT(ne12 % ne02 == 0);
     GGML_ASSERT(ne13 % ne03 == 0);
@@ -3612,11 +3587,6 @@ static void ggml_sycl_mul_mat_batched_sycl(ggml_backend_sycl_context & ctx,
             (void **)(ptrs_dst.get() + 0 * ne23), cu_data_type, ne01, ne23,
             cu_compute_type)));
     }
-
-    if (no_mixed_dtypes) {
-        const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16);
-        to_fp32_sycl(dst_f16.get(), dst_ddf, ne_dst, main_stream);
-    }
 }
 catch (sycl::exception const &exc) {
   std::cerr << exc.what() << "Exception caught at file:" << __FILE__
index 5e98660dc888d2d38a6d3abf5d8f12e4b0d331f9..31df1cb9e2cf489b47ade99c93524ee455256609 100644 (file)
@@ -2426,6 +2426,7 @@ namespace dpct
                                            b, ldb, beta, c, ldc, batch_size);
             break;
         }
+#endif
         case detail::get_type_combination_id(
             library_data_t::real_int8, library_data_t::real_int8,
             library_data_t::real_int32, library_data_t::real_int32):
@@ -2458,7 +2459,6 @@ namespace dpct
                 batch_size);
             break;
         }
-#endif
         case detail::get_type_combination_id(
             library_data_t::real_half, library_data_t::real_half,
             library_data_t::real_half, library_data_t::real_float):
@@ -2595,6 +2595,7 @@ namespace dpct
                                            stride_c, batch_size);
             break;
         }
+#endif
         case detail::get_type_combination_id(
             library_data_t::real_int8, library_data_t::real_int8,
             library_data_t::real_int32, library_data_t::real_int32):
@@ -2623,7 +2624,6 @@ namespace dpct
                 beta, c, ldc, stride_c, batch_size);
             break;
         }
-#endif
         case detail::get_type_combination_id(
             library_data_t::real_half, library_data_t::real_half,
             library_data_t::real_half, library_data_t::real_float):