]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : sync llama.cpp (CLBlast)
authorGeorgi Gerganov <redacted>
Fri, 28 Apr 2023 17:34:38 +0000 (20:34 +0300)
committerGeorgi Gerganov <redacted>
Fri, 28 Apr 2023 17:34:52 +0000 (20:34 +0300)
include/ggml/ggml.h
src/ggml.c

index d9d3d214e84e70f827d59b11ec2a35d47fb9b26f..1bbe2db93f5d1e6f79e29e03d8d293a1bdada76f 100644 (file)
@@ -858,10 +858,11 @@ extern "C" {
     GGML_API int ggml_cpu_has_wasm_simd  (void);
     GGML_API int ggml_cpu_has_blas       (void);
     GGML_API int ggml_cpu_has_cublas     (void);
+    GGML_API int ggml_cpu_has_clblast    (void);
+    GGML_API int ggml_cpu_has_gpublas    (void);
     GGML_API int ggml_cpu_has_sse3       (void);
     GGML_API int ggml_cpu_has_vsx        (void);
 
-
     //
     // Internal types and functions exposed for tests and benchmarks
     //
index b3504d1718e44e3446bbd0b6af62043dc7375ab0..44293dac92668a2825125fa9f81e7f0f37ae3442 100644 (file)
@@ -149,6 +149,8 @@ inline static void* ggml_aligned_malloc(size_t size) {
 #include <cblas.h>
 #elif defined(GGML_USE_CUBLAS)
 #include "ggml-cuda.h"
+#elif defined(GGML_USE_CLBLAST)
+#include "ggml-opencl.h"
 #endif
 
 #undef MIN
@@ -3626,6 +3628,24 @@ static void ggml_vec_dot_q8_0_q8_0(const int n, float * restrict s, const void *
     }
 
     *s = vaddvq_f32(sumv0) + vaddvq_f32(sumv1);
+#elif defined(__AVX2__)
+    // Initialize accumulator with zeros
+    __m256 acc = _mm256_setzero_ps();
+
+    // Main loop
+    for (int i = 0; i < nb; ++i) {
+        // Compute combined scale for the block
+        const __m256 d = _mm256_mul_ps( _mm256_broadcast_ss( &x[i].d ), _mm256_broadcast_ss( &y[i].d ) );
+        __m256i bx = _mm256_loadu_si256((const __m256i *)x[i].qs);
+        __m256i by = _mm256_loadu_si256((const __m256i *)y[i].qs);
+
+        const __m256 q = mul_sum_i8_pairs_float(bx, by);
+
+        // Multiply q with scale and accumulate
+        acc = _mm256_fmadd_ps( d, q, acc );
+    }
+
+    *s = hsum_float_8(acc);
 #else
     // scalar
     float sumf = 0.0;
@@ -4345,6 +4365,8 @@ struct ggml_context * ggml_init(struct ggml_init_params params) {
         // initialize cuBLAS
         #if defined(GGML_USE_CUBLAS)
         ggml_init_cublas();
+        #elif defined(GGML_USE_CLBLAST)
+        ggml_cl_init();
         #endif
 
         is_first_call = false;
@@ -8086,7 +8108,7 @@ static void ggml_compute_forward_rms_norm(
 
 // ggml_compute_forward_mul_mat
 
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
 // helper function to determine if it is better to use BLAS or not
 // for large matrices, BLAS is faster
 static bool ggml_compute_forward_mul_mat_use_blas(
@@ -8111,6 +8133,7 @@ static bool ggml_compute_forward_mul_mat_use_blas(
 
     return false;
 }
+
 #endif
 
 static void ggml_compute_forward_mul_mat_f32(
@@ -8126,7 +8149,7 @@ static void ggml_compute_forward_mul_mat_f32(
     const int64_t ne02 = src0->ne[2];
     const int64_t ne03 = src0->ne[3];
 
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
     const int64_t ne10 = src1->ne[0];
 #endif
     const int64_t ne11 = src1->ne[1];
@@ -8183,7 +8206,7 @@ static void ggml_compute_forward_mul_mat_f32(
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
         if (params->ith != 0) {
             return;
@@ -8232,8 +8255,15 @@ static void ggml_compute_forward_mul_mat_f32(
 
                 // copy data to host
                 CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
-#else
+#elif defined(GGML_USE_CLBLAST)
                 // zT = y * xT
+                ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
+                        ne11, ne01, ne10,
+                        1.0f,    y, ne10,
+                                 x, ne10,
+                        0.0f,    d, ne01,
+                        GGML_TYPE_F32);
+#else
                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
                         ne11, ne01, ne10,
                         1.0f,    y, ne10,
@@ -8377,7 +8407,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
         GGML_ASSERT(nb10 == sizeof(float));
 
@@ -8454,6 +8484,19 @@ static void ggml_compute_forward_mul_mat_f16_f32(
 
                 // copy data to host
                 CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
+#elif defined(GGML_USE_CLBLAST)
+                const float * x = wdata;
+                const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
+
+                float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
+
+                // zT = y * xT
+                ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
+                        ne11, ne01, ne10,
+                        1.0f,    y, ne10,
+                                 x, ne10,
+                        0.0f,    d, ne01,
+                        GGML_TYPE_F32);
 #else
                 const float * x = wdata;
                 const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
@@ -8628,7 +8671,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
     // nb01 >= nb00 - src0 is not transposed
     //   compute by src0 rows
 
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
     if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
         if (params->ith != 0) {
             return;
@@ -8680,7 +8723,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
         else {
             GGML_ASSERT(false);
         }
-#else
+#elif !defined(GGML_USE_CLBLAST)
         float * const wdata = params->wdata;
         dequantize_row_q_t const dequantize_row_q = quantize_fns[type].dequantize_row_q;
 #endif
@@ -8699,6 +8742,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
 
                 dequantize_row_q_cuda(d_Q, d_X, ne01 * ne00, g_cudaStream);
                 CUDA_CHECK(cudaGetLastError());
+#elif defined(GGML_USE_CLBLAST)
+                const void* x = (char *) src0->data + i03*nb03 + i02*nb02;
 #else
                 {
                     size_t id = 0;
@@ -8725,8 +8770,15 @@ static void ggml_compute_forward_mul_mat_q_f32(
 
                 // copy data to host
                 CUDA_CHECK(cudaMemcpyAsync(d, d_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, g_cudaStream));
-#else
+#elif defined(GGML_USE_CLBLAST)
                 // zT = y * xT
+                ggml_cl_sgemm_wrapper(GGML_BLAS_ORDER_ROW_MAJOR, GGML_BLAS_OP_N, GGML_BLAS_OP_T,
+                        ne11, ne01, ne10,
+                        1.0f,    y, ne10,
+                                 x, ne10,
+                        0.0f,    d, ne01,
+                        type);
+#else
                 cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
                         ne11, ne01, ne10,
                         1.0f,    y, ne10,
@@ -11566,7 +11618,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                         size_t cur = 0;
 
                         if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) {
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
                             if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
                                 node->n_tasks = 1; // TODO: this actually is doing nothing
                                                    //       the threads are still spinning
@@ -11583,7 +11635,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
                         } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) {
                             cur = 0;
                         } else if (ggml_is_quantized(node->src0->type) && node->src1->type == GGML_TYPE_F32) {
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
                             if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) {
                                 node->n_tasks = 1;
                                 cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]);
@@ -13083,7 +13135,7 @@ int ggml_cpu_has_wasm_simd(void) {
 }
 
 int ggml_cpu_has_blas(void) {
-#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS)
+#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) || defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST)
     return 1;
 #else
     return 0;
@@ -13098,6 +13150,18 @@ int ggml_cpu_has_cublas(void) {
 #endif
 }
 
+int ggml_cpu_has_clblast(void) {
+#if defined(GGML_USE_CLBLAST)
+    return 1;
+#else
+    return 0;
+#endif
+}
+
+int ggml_cpu_has_gpublas(void) {
+    return ggml_cpu_has_cublas() || ggml_cpu_has_clblast();
+}
+
 int ggml_cpu_has_sse3(void) {
 #if defined(__SSE3__)
     return 1;