]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
ggml : add NVPL BLAS support (ggml/8329) (llama/8425)
authorNicholai Tukanov <redacted>
Thu, 11 Jul 2024 16:49:15 +0000 (11:49 -0500)
committerGeorgi Gerganov <redacted>
Thu, 8 Aug 2024 19:48:46 +0000 (22:48 +0300)
* ggml : add NVPL BLAS support

* ggml : replace `<BLASLIB>_ENABLE_CBLAS` with `GGML_BLAS_USE_<BLASLIB>`

---------

Co-authored-by: ntukanov <redacted>
ggml/src/ggml-blas.cpp

index d709a357bbf2908ebd9d5123531429b14179d764..a37aa407282b94458805097221cf6aa00fa382f7 100644 (file)
@@ -8,11 +8,12 @@
 #   include <Accelerate/Accelerate.h>
 #elif defined(GGML_BLAS_USE_MKL)
 #   include <mkl.h>
+#elif defined(GGML_BLAS_USE_BLIS)
+#   include <blis.h>
+#elif defined(GGML_BLAS_USE_NVPL)
+#   include <nvpl_blas.h>
 #else
 #   include <cblas.h>
-#   ifdef BLIS_ENABLE_CBLAS
-#       include <blis.h>
-#   endif
 #endif
 
 struct ggml_backend_blas_context {
@@ -140,10 +141,14 @@ static void ggml_backend_blas_mul_mat(ggml_backend_blas_context * ctx, struct gg
     openblas_set_num_threads(ctx->n_threads);
 #endif
 
-#if defined(BLIS_ENABLE_CBLAS)
+#if defined(GGML_BLAS_USE_BLIS)
     bli_thread_set_num_threads(ctx->n_threads);
 #endif
 
+#if defined(GGML_BLAS_USE_NVPL)
+    nvpl_blas_set_num_threads(ctx->n_threads);
+#endif
+
     for (int64_t i13 = 0; i13 < ne13; i13++) {
         for (int64_t i12 = 0; i12 < ne12; i12++) {
             const int64_t i03 = i13/r3;