]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
sycl: Remove not needed copy f16->f32 for dnnl mul mat (llama/14125)
authorAnton Mitkov <redacted>
Thu, 12 Jun 2025 13:15:11 +0000 (14:15 +0100)
committerGeorgi Gerganov <redacted>
Wed, 18 Jun 2025 09:40:34 +0000 (12:40 +0300)
ggml/src/ggml-sycl/gemm.hpp
ggml/src/ggml-sycl/ggml-sycl.cpp

index 6cbc7e0f6938cc8767444944c39f2748b3b0c7fd..5efe03d364b1b2d11661827139fbd792849259cd 100644 (file)
@@ -65,6 +65,9 @@ public:
 
         dnnl::primitive_attr primitive_attr;
         primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
+#ifdef GGML_SYCL_F16
+        primitive_attr.set_fpmath_mode(dnnl::fpmath_mode::f16);
+#endif
 
         auto a_mem = dnnl::memory(a_in_md, eng, const_cast<void*>(a));
         auto b_mem = dnnl::memory(b_in_md, eng, const_cast<void*>(b));
index 3693b0a4337a53cad200b6cf6c9bdecd2eb5740b..feb30304fc0927ab643470dcf8eaeacfa36c623e 100644 (file)
@@ -2127,21 +2127,18 @@ inline void ggml_sycl_op_mul_mat_sycl(
         const sycl::half *src1_ptr = src1->type == GGML_TYPE_F16
                 ? (const sycl::half *)src1->data + src1_padded_row_size
                                          : src1_as_f16.get();
-        ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
 
 #if GGML_SYCL_DNNL
         if (!g_ggml_sycl_disable_dnn) {
             DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr,
                                       DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
-                                      dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
-            scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, /*num_src=*/2,
-                                                 " : converting dst to fp32");
-            const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
-            to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
+                                      dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
         }
         else
 #endif
         {
+            ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
+
             const sycl::half alpha_f16 = 1.0f;
             const sycl::half beta_f16  = 0.0f;
             SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(