]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
Fixed minor bug when enabling FP16 for non intel targets (llama/6464)
authorOuadie EL FAROUKI <redacted>
Fri, 5 Apr 2024 13:35:06 +0000 (14:35 +0100)
committerGeorgi Gerganov <redacted>
Sun, 7 Apr 2024 13:15:57 +0000 (16:15 +0300)
* moved INTEL_MKL guard from gemm_impl to gemm (wrapper)

* Update ggml-sycl.cpp

Co-authored-by: AidanBeltonS <redacted>
---------

Co-authored-by: AidanBeltonS <redacted>
ggml-sycl.cpp

index 2b0e5f54827a5db9479fdb5106a92a63f59a8729..db3c24f60eb896a9d3290494ad502c6b65c8f0f3 100644 (file)
@@ -1664,24 +1664,6 @@ namespace dpct
                               const void *alpha, const void *a, int lda, const void *b,
                               int ldb, const void *beta, void *c, int ldc)
         {
-#ifndef __INTEL_MKL__
-            GGML_UNUSED(q);
-            GGML_UNUSED(a_trans);
-            GGML_UNUSED(b_trans);
-            GGML_UNUSED(m);
-            GGML_UNUSED(n);
-            GGML_UNUSED(k);
-            GGML_UNUSED(alpha);
-            GGML_UNUSED(a);
-            GGML_UNUSED(lda);
-            GGML_UNUSED(b);
-            GGML_UNUSED(ldb);
-            GGML_UNUSED(beta);
-            GGML_UNUSED(c);
-            GGML_UNUSED(ldc);
-            throw std::runtime_error("The oneAPI Math Kernel Library (oneMKL) Interfaces "
-                                     "Project does not support this API.");
-#else
             Ts alpha_value = dpct::get_value(reinterpret_cast<const Ts *>(alpha), q);
             Ts beta_value = dpct::get_value(reinterpret_cast<const Ts *>(beta), q);
             auto data_a = get_memory<const Ta>(a);
@@ -1690,7 +1672,6 @@ namespace dpct
             oneapi::mkl::blas::column_major::gemm(
                 q, a_trans, b_trans, m, n, k, alpha_value, data_a, lda,
                 data_b, ldb, beta_value, data_c, ldc);
-#endif
         }
 
         template <typename VecT, class BinaryOperation, class = void>
@@ -2330,6 +2311,7 @@ namespace dpct
                                           lda, b, ldb, beta, c, ldc);
             break;
         }
+#ifdef __INTEL_MKL__
         case detail::get_type_combination_id(
             library_data_t::real_bfloat16, library_data_t::real_bfloat16,
             library_data_t::real_float, library_data_t::real_float):
@@ -2391,6 +2373,7 @@ namespace dpct
                 q, a_trans, b_trans, m, n, k, &alpha_float, a, lda, b, ldb, &beta_float, c, ldc);
             break;
         }
+#endif // __INTEL_MKL__
         default:
             throw std::runtime_error("the combination of data type is unsupported");
         }