dnnl::primitive_attr primitive_attr;
primitive_attr.set_scratchpad_mode(dnnl::scratchpad_mode::user);
+#ifdef GGML_SYCL_F16
+ primitive_attr.set_fpmath_mode(dnnl::fpmath_mode::f16);
+#endif
auto a_mem = dnnl::memory(a_in_md, eng, const_cast<void*>(a));
auto b_mem = dnnl::memory(b_in_md, eng, const_cast<void*>(b));
const sycl::half *src1_ptr = src1->type == GGML_TYPE_F16
? (const sycl::half *)src1->data + src1_padded_row_size
: src1_as_f16.get();
- ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
#if GGML_SYCL_DNNL
if (!g_ggml_sycl_disable_dnn) {
DnnlGemmWrapper::row_gemm(ctx, src1_ncols, row_diff, ne10, src1_ptr,
DnnlGemmWrapper::to_dt<sycl::half>(), src0_ptr, DnnlGemmWrapper::to_dt<sycl::half>(),
- dst_f16.get(), DnnlGemmWrapper::to_dt<sycl::half>(), stream);
- scope_op_debug_print scope_dbg_print(__func__, "/to_fp32_sycl", dst, /*num_src=*/2,
- " : converting dst to fp32");
- const to_fp32_sycl_t to_fp32_sycl = ggml_get_to_fp32_sycl(GGML_TYPE_F16, dst);
- to_fp32_sycl(dst_f16.get(), dst_dd_i, row_diff* src1_ncols, stream);
+ dst_dd_i, DnnlGemmWrapper::to_dt<float>(), stream);
}
else
#endif
{
+ ggml_sycl_pool_alloc<sycl::half> dst_f16(ctx.pool(), row_diff * src1_ncols);
+
const sycl::half alpha_f16 = 1.0f;
const sycl::half beta_f16 = 0.0f;
SYCL_CHECK(CHECK_TRY_ERROR(dpct::gemm(