]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml: implement quantized KV cache for FA (llama/7372)
authorJohannes Gäßler <redacted>
Sun, 19 May 2024 14:46:13 +0000 (16:46 +0200)
committerGeorgi Gerganov <redacted>
Tue, 28 May 2024 11:41:08 +0000 (14:41 +0300)
src/ggml.c

index a04c74ddd2cd6f6f5b5ac5028fe77422f6d20b6d..3a104c486339e5cead080f9405cc898a2076d3f1 100644 (file)
@@ -15882,9 +15882,10 @@ static void ggml_compute_forward_flash_attn_ext_f16(
     GGML_ASSERT(ne0 == D);
     GGML_ASSERT(ne2 == N);
 
-    GGML_ASSERT(nbq0 == sizeof(float));
-    GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
-    GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
+    // input tensor rows must be contiguous
+    GGML_ASSERT(nbq0 == ggml_type_size(q->type));
+    GGML_ASSERT(nbk0 == ggml_type_size(k->type));
+    GGML_ASSERT(nbv0 == ggml_type_size(v->type));
 
     GGML_ASSERT(neq0 == D);
     GGML_ASSERT(nek0 == D);
@@ -15938,6 +15939,11 @@ static void ggml_compute_forward_flash_attn_ext_f16(
     const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
+    enum ggml_type    const k_vec_dot_type = type_traits[k->type].vec_dot_type;
+    ggml_from_float_t const q_to_vec_dot   = type_traits[k_vec_dot_type].from_float;
+    ggml_vec_dot_t    const kq_vec_dot     = type_traits[k->type].vec_dot;
+    ggml_to_float_t   const v_to_float     = type_traits[v->type].to_float;
+
     // loop over n_batch and n_head
     for (int ir = ir0; ir < ir1; ++ir) {
         // q indices
@@ -15945,17 +15951,22 @@ static void ggml_compute_forward_flash_attn_ext_f16(
         const int iq2 = (ir - iq3*neq2*neq1)/neq1;
         const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
 
-        const uint32_t h = iq2; // head
+        const uint32_t h = iq2; // head index
         const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
 
-        float S = 0.0f;
-        float M = -INFINITY;
+        float S = 0.0f;      // sum
+        float M = -INFINITY; // maximum KQ value
 
-        float       * V32 = (float       *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32);
-        ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory
-        ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D);
+        float       * VKQ32 = (float       *) params->wdata + ith*(3*D + CACHE_LINE_SIZE_F32); // FP32 VKQ accumulator
+        float       * V32   =                 (VKQ32 + 1*D); // (temporary) FP32 V buffer
+        ggml_fp16_t * VKQ16 = (ggml_fp16_t *) (VKQ32 + 1*D); // (temporary) FP16 VKQ accumulator
+        ggml_fp16_t * Q_q   = (ggml_fp16_t *) (VKQ32 + 2*D); // (temporary) buffer for Q converted to quantized/FP16
 
-        memset(V16, 0, D*sizeof(ggml_fp16_t));
+        if (v->type == GGML_TYPE_F16) {
+            memset(VKQ16, 0, D*sizeof(ggml_fp16_t));
+        } else {
+            memset(VKQ32, 0, D*sizeof(float));
+        }
 
         const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
 
@@ -15967,6 +15978,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
         const int iv3 = iq3 / rv3;
         const int iv2 = iq2 / rv2;
 
+        const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
+        q_to_vec_dot(pq, Q_q, D);
+
         // online softmax / attention
         // loop over n_kv and n_head_kv
         // ref: https://arxiv.org/pdf/2112.05682.pdf
@@ -15976,52 +15990,67 @@ static void ggml_compute_forward_flash_attn_ext_f16(
                 continue;
             }
 
-            float s;
+            float s; // KQ value
 
-            // convert Q to F16 in V32
-            {
-                const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
+            const char * k_data = (const char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3);
+            kq_vec_dot(D, &s, 0, k_data, 0, Q_q, 0, 1);
 
-                for (int64_t d = 0; d < D; ++d) {
-                    Q16[d] = GGML_FP32_TO_FP16(pq[d]);
-                }
-            }
+            s = s*scale + mv; // scale KQ value and apply mask
 
-            ggml_vec_dot_f16(D,
-                    &s, 0,
-                    (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
-                    Q16, 0, 1);
+            const float Mold = M;
 
-            s = s*scale + mv;
+            float ms = 1.0f; // upon new higher max val, scale VKQ and KQ sum with this value
+            float vs = 1.0f; // post-softmax KQ value, expf(s - M)
 
-            const float Mold = M;
+            const char * v_data = ((const char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
 
-            float ms = 1.0f;
-            float vs = 1.0f;
+            if (v->type== GGML_TYPE_F16) {
+                if (s > M) {
+                    // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
+                    M = s;
+                    ms = expf(Mold - M);
 
-            if (s > M) {
-                M = s;
-                ms = expf(Mold - M);
+                    // V = V*expf(Mold - M)
+                    ggml_vec_scale_f16(D, VKQ16, ms);
+                } else {
+                    // no new maximum, ms == 1.0f, vs != 1.0f
+                    vs = expf(s - M);
+                }
 
-                // V = V*expf(Mold - M)
-                ggml_vec_scale_f16(D, V16, ms);
+                // V += v*expf(s - M)
+                ggml_vec_mad_f16(D, VKQ16, (const ggml_fp16_t *) v_data, vs);
             } else {
-                vs = expf(s - M);
-            }
+                if (s > M) {
+                    // s is new maximum, ms < 1.0f, vs == expf(s - s) == 1.0f
+                    M = s;
+                    ms = expf(Mold - M);
 
-            const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
+                    // V = V*expf(Mold - M)
+                    ggml_vec_scale_f32(D, VKQ32, ms);
+                } else {
+                    // no new maximum, ms == 1.0f, vs != 1.0f
+                    vs = expf(s - M);
+                }
 
-            // V += v*expf(s - M)
-            ggml_vec_mad_f16(D, V16, v16, vs);
+                v_to_float(v_data, V32, D);
 
-            S = S*ms + vs;
+                // V += v*expf(s - M)
+                ggml_vec_mad_f32(D, VKQ32, V32, vs);
+            }
+
+            S = S*ms + vs; // scale and increment sum with partial sum
         }
 
-        // V /= S
-        for (int64_t d = 0; d < D; ++d) {
-            V32[d] = GGML_FP16_TO_FP32(V16[d])/S;
+        if (v->type == GGML_TYPE_F16) {
+            for (int64_t d = 0; d < D; ++d) {
+                VKQ32[d] = GGML_FP16_TO_FP32(VKQ16[d]);
+            }
         }
 
+        // V /= S
+        const float S_inv = 1.0f/S;
+        ggml_vec_scale_f32(D, VKQ32, S_inv);
+
         // dst indices
         const int i1 = iq1;
         const int i2 = iq2;
@@ -16031,7 +16060,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
         //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
 
         // permute(0, 2, 1, 3)
-        memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1);
+        memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
     }
 }
 
@@ -19972,7 +20001,7 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
                 {
                     const int64_t ne00 = node->src[0]->ne[0]; // D
 
-                    cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size
+                    cur = 3*sizeof(float)*ne00*n_tasks; // 3x head size/thread
                 } break;
             case GGML_OP_FLASH_FF:
                 {