]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Fix nasty bug in ggml_compute_forward_mul_mat_f32() and reenable BLAS
authorGeorgi Gerganov <redacted>
Sat, 25 Mar 2023 14:09:54 +0000 (16:09 +0200)
committerGeorgi Gerganov <redacted>
Sat, 25 Mar 2023 14:10:14 +0000 (16:10 +0200)
ggml.c
llama.cpp

diff --git a/ggml.c b/ggml.c
index db68ed144908bb0ac64c6d364ce900e9b324cb7c..625ef600799e04ce56808c06f38bf8b6ec50168e 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -2638,7 +2638,7 @@ static inline int ggml_up(int n, int m) {
 
 // assert that pointer is aligned to GGML_MEM_ALIGN
 #define ggml_assert_aligned(ptr) \
-    assert(((uintptr_t) (ptr))%GGML_MEM_ALIGN == 0)
+    GGML_ASSERT(((uintptr_t) (ptr))%GGML_MEM_ALIGN == 0)
 
 ////////////////////////////////////////////////////////////////////////////////
 
@@ -4566,7 +4566,7 @@ static void ggml_compute_forward_dup_f16(
 
     if (src0->nb[0] == sizeof(ggml_fp16_t)) {
         if (dst->type == GGML_TYPE_F16) {
-            int id = 0;
+            size_t id = 0;
             const size_t rs = ne00*nb00;
 
             for (int i03 = 0; i03 < ne03; i03++) {
@@ -4582,7 +4582,7 @@ static void ggml_compute_forward_dup_f16(
                 }
             }
         } else if (dst->type == GGML_TYPE_F32) {
-            int id = 0;
+            size_t id = 0;
             float * dst_ptr = (float *) dst->data;
 
             for (int i03 = 0; i03 < ne03; i03++) {
@@ -4604,7 +4604,7 @@ static void ggml_compute_forward_dup_f16(
         //printf("%s: this is not optimal - fix me\n", __func__);
 
         if (dst->type == GGML_TYPE_F32) {
-            int id = 0;
+            size_t id = 0;
             float * dst_ptr = (float *) dst->data;
 
             for (int i03 = 0; i03 < ne03; i03++) {
@@ -4620,7 +4620,7 @@ static void ggml_compute_forward_dup_f16(
                 }
             }
         } else if (dst->type == GGML_TYPE_F16) {
-            int id = 0;
+            size_t id = 0;
             ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
 
             for (int i03 = 0; i03 < ne03; i03++) {
@@ -4670,7 +4670,7 @@ static void ggml_compute_forward_dup_f32(
 
     if (src0->nb[0] == sizeof(float)) {
         if (dst->type == GGML_TYPE_F32) {
-            int id = 0;
+            size_t id = 0;
             const size_t rs = ne00*nb00;
 
             for (int i03 = 0; i03 < ne03; i03++) {
@@ -4686,7 +4686,7 @@ static void ggml_compute_forward_dup_f32(
                 }
             }
         } else if (dst->type == GGML_TYPE_F16) {
-            int id = 0;
+            size_t id = 0;
             ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
 
             for (int i03 = 0; i03 < ne03; i03++) {
@@ -4708,7 +4708,7 @@ static void ggml_compute_forward_dup_f32(
         //printf("%s: this is not optimal - fix me\n", __func__);
 
         if (dst->type == GGML_TYPE_F32) {
-            int id = 0;
+            size_t id = 0;
             float * dst_ptr = (float *) dst->data;
 
             for (int i03 = 0; i03 < ne03; i03++) {
@@ -4724,7 +4724,7 @@ static void ggml_compute_forward_dup_f32(
                 }
             }
         } else if (dst->type == GGML_TYPE_F16) {
-            int id = 0;
+            size_t id = 0;
             ggml_fp16_t * dst_ptr = (ggml_fp16_t *) dst->data;
 
             for (int i03 = 0; i03 < ne03; i03++) {
@@ -5854,20 +5854,11 @@ static bool ggml_compute_forward_mul_mat_use_blas(
     const int ne0 = dst->ne[0];
     const int ne1 = dst->ne[1];
 
-    // TMP: disable BLAS for now there is definitely a bug
-    return false;
-
     // TODO: find the optimal values for these
     if (ggml_is_contiguous(src0) &&
         ggml_is_contiguous(src1) && ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32))) {
 
-        // disable BLAS for Q4_0 and Q4_1
-        // there is a bug that has to be fixed before enabling
-        if (src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1) {
-            return false;
-        }
-
-        //printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);
+        /*printf("BLAS: %d %d %d %d %d\n", ne0, ne1, ne10, ne00, ne01);*/
         return true;
     }
 
@@ -5960,19 +5951,17 @@ static void ggml_compute_forward_mul_mat_f32(
 
         for (int i03 = 0; i03 < ne03; i03++) {
             for (int i02 = 0; i02 < ne02; i02++) {
-                const float * x = (float *) (src0->data);
+                const float * x = (float *) ((char *) src0->data + i02*nb02 + i03*nb03);
                 const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
 
                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
 
                 // zT = y * xT
-                {
-                    cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
-                            ne11, ne01, ne10,
-                            1.0f,    y, ne10,
-                                     x, ne10,
-                            0.0f,    d, ne01);
-                }
+                cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
+                        ne11, ne01, ne10,
+                        1.0f,    y, ne10,
+                                 x, ne10,
+                        0.0f,    d, ne01);
             }
         }
 
@@ -6208,7 +6197,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
         for (int i03 = 0; i03 < ne03; i03++) {
             for (int i02 = 0; i02 < ne02; i02++) {
                 {
-                    int id = 0;
+                    size_t id = 0;
                     for (int i01 = 0; i01 < ne01; ++i01) {
                         for (int i00 = 0; i00 < ne00; ++i00) {
                             wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00));
@@ -6219,43 +6208,14 @@ static void ggml_compute_forward_mul_mat_f16_f32(
                 const float * x = wdata;
                 const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
 
-                //      float * z =                          wdata + ne00*ne01;
-
-                // z = x * yT
-                //{
-                //    cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
-                //            ne01, ne11, ne00,
-                //            1.0f, x, ne00,
-                //                  y, ne00,
-                //            0.0f, z, ne11);
-                //}
-
                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
 
-                // transpose z
-                //for (int j = 0; j < ne11; ++j) {
-                //    for (int i = 0; i < ne01; ++i) {
-                //        d[j*ne01 + i] = z[i*ne11 + j];
-                //    }
-                //}
-
-                {
-#if 1
-                    // zT = y * xT
-                    cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
-                            ne11, ne01, ne10,
-                            1.0f,    y, ne00,
-                                     x, ne00,
-                            0.0f,    d, ne01);
-#else
-                    // zT = (xT * y)T
-                    cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans,
-                            ne01, ne11, ne10,
-                            1.0f,    x, ne00,
-                                     y, ne00,
-                            0.0f,    d, ne01);
-#endif
-                }
+                // zT = y * xT
+                cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
+                        ne11, ne01, ne10,
+                        1.0f,    y, ne10,
+                                 x, ne10,
+                        0.0f,    d, ne01);
             }
         }
 
@@ -6269,7 +6229,7 @@ static void ggml_compute_forward_mul_mat_f16_f32(
         if (nb01 >= nb00) {
             ggml_fp16_t * const wdata = params->wdata;
 
-            int id = 0;
+            size_t id = 0;
             for (int i13 = 0; i13 < ne13; ++i13) {
                 for (int i12 = 0; i12 < ne12; ++i12) {
                     for (int i11 = 0; i11 < ne11; ++i11) {
@@ -6514,7 +6474,7 @@ static void ggml_compute_forward_mul_mat_q4_0_f32(
         for (int i03 = 0; i03 < ne03; i03++) {
             for (int i02 = 0; i02 < ne02; i02++) {
                 {
-                    int id = 0;
+                    size_t id = 0;
                     for (int i01 = 0; i01 < ne01; ++i01) {
                         //for (int i00 = 0; i00 < ne00; ++i00) {
                         //    wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00));
@@ -6527,43 +6487,14 @@ static void ggml_compute_forward_mul_mat_q4_0_f32(
                 const float * x = wdata;
                 const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
 
-                //      float * z =                          wdata + ne00*ne01;
-
-                // z = x * yT
-                //{
-                //    cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
-                //            ne01, ne11, ne00,
-                //            1.0f, x, ne00,
-                //                  y, ne00,
-                //            0.0f, z, ne11);
-                //}
-
                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
 
-                // transpose z
-                //for (int j = 0; j < ne11; ++j) {
-                //    for (int i = 0; i < ne01; ++i) {
-                //        d[j*ne01 + i] = z[i*ne11 + j];
-                //    }
-                //}
-
-                {
-#if 1
-                    // zT = y * xT
-                    cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
-                            ne11, ne01, ne10,
-                            1.0f,    y, ne00,
-                                     x, ne00,
-                            0.0f,    d, ne01);
-#else
-                    // zT = (xT * y)T
-                    cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans,
-                            ne01, ne11, ne10,
-                            1.0f,    x, ne00,
-                                     y, ne00,
-                            0.0f,    d, ne01);
-#endif
-                }
+                // zT = y * xT
+                cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
+                        ne11, ne01, ne10,
+                        1.0f,    y, ne10,
+                                 x, ne10,
+                        0.0f,    d, ne01);
             }
         }
 
@@ -6814,7 +6745,7 @@ static void ggml_compute_forward_mul_mat_q4_1_f32(
         for (int i03 = 0; i03 < ne03; i03++) {
             for (int i02 = 0; i02 < ne02; i02++) {
                 {
-                    int id = 0;
+                    size_t id = 0;
                     for (int i01 = 0; i01 < ne01; ++i01) {
                         //for (int i00 = 0; i00 < ne00; ++i00) {
                         //    wdata[id++] = GGML_FP16_TO_FP32(*(ggml_fp16_t *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00));
@@ -6827,43 +6758,14 @@ static void ggml_compute_forward_mul_mat_q4_1_f32(
                 const float * x = wdata;
                 const float * y = (float *) ((char *) src1->data + i02*nb12 + i03*nb13);
 
-                //      float * z =                          wdata + ne00*ne01;
-
-                // z = x * yT
-                //{
-                //    cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
-                //            ne01, ne11, ne00,
-                //            1.0f, x, ne00,
-                //                  y, ne00,
-                //            0.0f, z, ne11);
-                //}
-
                 float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3);
 
-                // transpose z
-                //for (int j = 0; j < ne11; ++j) {
-                //    for (int i = 0; i < ne01; ++i) {
-                //        d[j*ne01 + i] = z[i*ne11 + j];
-                //    }
-                //}
-
-                {
-#if 1
-                    // zT = y * xT
-                    cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
-                            ne11, ne01, ne10,
-                            1.0f,    y, ne00,
-                                     x, ne00,
-                            0.0f,    d, ne01);
-#else
-                    // zT = (xT * y)T
-                    cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans,
-                            ne01, ne11, ne10,
-                            1.0f,    x, ne00,
-                                     y, ne00,
-                            0.0f,    d, ne01);
-#endif
-                }
+                // zT = y * xT
+                cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
+                        ne11, ne01, ne10,
+                        1.0f,    y, ne10,
+                                 x, ne10,
+                        0.0f,    d, ne01);
             }
         }
 
index 14de611a97bc7fad385e56128f05a518bbac835d..bb7bdeadfc51c7201add72eb5c2fb2a2e2a1ec2a 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -917,8 +917,7 @@ static bool llama_eval_internal(
             struct ggml_tensor * KQ_scaled =
                 ggml_scale(ctx0,
                         KQ,
-                        ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head))
-                        );
+                        ggml_new_f32(ctx0, 1.0f/sqrt(float(n_embd)/n_head)));
 
             // KQ_masked = mask_past(KQ_scaled)
             struct ggml_tensor * KQ_masked = ggml_diag_mask_inf(ctx0, KQ_scaled, n_past);
@@ -934,7 +933,7 @@ static bool llama_eval_internal(
                                 ggml_view_1d(ctx0, kv_self.v, (n_past + N)*n_embd, il*n_ctx*ggml_element_size(kv_self.v)*n_embd),
                                 n_embd/n_head, n_head, n_past + N),
                             1, 2, 0, 3),
-                    ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_past + N, n_embd/n_head, n_head));
+                    ggml_new_tensor_3d(ctx0, kv_self.v->type, n_past + N, n_embd/n_head, n_head));
 
             // KQV = transpose(V) * KQ_soft_max
             struct ggml_tensor * KQV = ggml_mul_mat(ctx0, V_trans, KQ_soft_max);