]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
refact : fix convert script + zero out KV cache to avoid nans (#3523)
authorGeorgi Gerganov <redacted>
Mon, 9 Oct 2023 11:32:17 +0000 (14:32 +0300)
committerGitHub <redacted>
Mon, 9 Oct 2023 11:32:17 +0000 (14:32 +0300)
* refact : fix convert script + zero out KV cache to avoid nans

* ggml : silu(-inf) should never happen

* metal : assert various kernel requirements

convert-refact-hf-to-gguf.py
examples/parallel/parallel.cpp
ggml-metal.m
ggml-metal.metal
ggml.c
llama.cpp

index e0cd417dbbbc4297aeec8a9979a5148f7190c9c5..bfeabc0825ba7d48f12ac7a89b6f4b737f520fd2 100755 (executable)
@@ -17,33 +17,6 @@ if "NO_LOCAL_GGUF" not in os.environ:
     sys.path.insert(1, str(Path(__file__).parent / "gguf-py" / "gguf"))
 import gguf
 
-
-def bytes_to_unicode():
-    # ref: https://github.com/openai/gpt-2/blob/master/src/encoder.py
-    """
-    Returns list of utf-8 byte and a corresponding list of unicode strings.
-    The reversible bpe codes work on unicode strings.
-    This means you need a large # of unicode characters in your vocab if you want to avoid UNKs.
-    When you're at something like a 10B token dataset you end up needing around 5K for decent coverage.
-    This is a significant percentage of your normal, say, 32K bpe vocab.
-    To avoid that, we want lookup tables between utf-8 bytes and unicode strings.
-    And avoids mapping to whitespace/control characters the bpe code barfs on.
-    """
-    bs = (
-        list(range(ord("!"), ord("~") + 1))
-        + list(range(ord("¡"), ord("¬") + 1))
-        + list(range(ord("®"), ord("ÿ") + 1))
-    )
-    cs = bs[:]
-    n = 0
-    for b in range(2**8):
-        if b not in bs:
-            bs.append(b)
-            cs.append(2**8 + n)
-            n += 1
-    return dict(zip(bs, (chr(n) for n in cs)))
-
-
 def count_model_parts(dir_model: Path) -> int:
     num_parts = 0
     for filename in os.listdir(dir_model):
@@ -153,53 +126,25 @@ tokens: list[bytearray] = []
 scores: list[float] = []
 toktypes: list[int] = []
 
-tokenizer_json_file = dir_model / "tokenizer.json"
-if not tokenizer_json_file.is_file():
-    print(f"Error: Missing {tokenizer_json_file}", file=sys.stderr)
-    sys.exit(1)
-
 # gpt2 tokenizer
 gguf_writer.add_tokenizer_model("gpt2")
 
-with open(tokenizer_json_file, "r", encoding="utf-8") as f:
-    tokenizer_json = json.load(f)
-
 print("gguf: get gpt2 tokenizer vocab")
 
+# ref: https://github.com/cmp-nct/ggllm.cpp/blob/master/falcon_convert.py
+tokenizer = AutoTokenizer.from_pretrained(dir_model)
+
 # The number of tokens in tokenizer.json can differ from the expected vocab size.
 # This causes downstream issues with mismatched tensor sizes when running the inference
-vocab_size = (
-    hparams["vocab_size"]
-    if "vocab_size" in hparams
-    else len(tokenizer_json["model"]["vocab"])
-)
-
-tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
+vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
+assert max(tokenizer.vocab.values()) < vocab_size
 
 reverse_vocab = {id: encoded_tok for encoded_tok, id in tokenizer.vocab.items()}
-byte_encoder = bytes_to_unicode()
-byte_decoder = {v: k for k, v in byte_encoder.items()}
 
 for i in range(vocab_size):
-    if i in reverse_vocab:
-        text = reverse_vocab[i]
-        try:
-            text = bytearray([byte_decoder[c] for c in reverse_vocab[i]])
-        except KeyError:
-            text = bytearray()
-            for c in reverse_vocab[i]:
-                if ord(c) < 256:  # single byte character
-                    text.append(byte_decoder[ord(c)])
-                else:  # multibyte special token character
-                    text.extend(c.encode("utf-8"))
-    else:
-        print(f"Key {i} not in tokenizer vocabulary. Padding with an arbitrary token.")
-        pad_token = f"[PAD{i}]".encode("utf8")
-        text = bytearray(pad_token)
-
-    tokens.append(text)
-    scores.append(0.0)  # dymmy
-    toktypes.append(gguf.TokenType.NORMAL)  # dummy
+    tokens.append(reverse_vocab[i] if i in reverse_vocab else f"[PAD{i}]")
+    scores.append(0.0) # dummy
+    toktypes.append(gguf.TokenType.NORMAL)
 
 gguf_writer.add_token_list(tokens)
 gguf_writer.add_token_scores(scores)
index 721888da7de9491e6b85c15f26cbe6e8f6c8475e..04f1e45b955dd8a029159d3ee76a921e4d70a8ed 100644 (file)
@@ -167,7 +167,7 @@ int main(int argc, char ** argv) {
 
     // the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
     // users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
-    llama_batch batch = llama_batch_init(params.n_ctx, 0);
+    llama_batch batch = llama_batch_init(n_ctx, 0);
 
     int32_t n_total_prompt = 0;
     int32_t n_total_gen    = 0;
index 7d67db90f078b20782d490d0361196163976e457..5a23144d0c89133d70ed01ae8d44287ca4350c9f 100644 (file)
@@ -779,8 +779,8 @@ void ggml_metal_graph_compute(
                         } break;
                     case GGML_OP_CONCAT:
                         {
+                            const int64_t nb = ne00;
 
-                            int64_t nb = ne00;
                             [encoder setComputePipelineState:ctx->pipeline_concat];
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -812,6 +812,7 @@ void ggml_metal_graph_compute(
                             [encoder setBytes:&nb   length:sizeof(nb)   atIndex:27];
 
                             const int nth = MIN(1024, ne0);
+
                             [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
                     case GGML_OP_ADD:
@@ -909,9 +910,10 @@ void ggml_metal_graph_compute(
                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
                             [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
 
-                            const int64_t n = ggml_nelements(dst)/4;
+                            const int64_t n = ggml_nelements(dst);
+                            GGML_ASSERT(n % 4 == 0);
 
-                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                            [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                         } break;
                     case GGML_OP_UNARY:
                         switch (ggml_get_unary_op(gf->nodes[i])) {
@@ -921,9 +923,10 @@ void ggml_metal_graph_compute(
                                     [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                                     [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
 
-                                    const int64_t n = ggml_nelements(dst)/4;
+                                    const int64_t n = ggml_nelements(dst);
+                                    GGML_ASSERT(n % 4 == 0);
 
-                                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                                 } break;
                             case GGML_UNARY_OP_RELU:
                                 {
@@ -941,9 +944,10 @@ void ggml_metal_graph_compute(
                                     [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                                     [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
 
-                                    const int64_t n = ggml_nelements(dst)/4;
+                                    const int64_t n = ggml_nelements(dst);
+                                    GGML_ASSERT(n % 4 == 0);
 
-                                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                                 } break;
                             default:
                                 {
@@ -1251,6 +1255,8 @@ void ggml_metal_graph_compute(
                         } break;
                     case GGML_OP_RMS_NORM:
                         {
+                            GGML_ASSERT(ne00 % 4 == 0);
+
                             float eps;
                             memcpy(&eps, dst->op_params, sizeof(float));
 
index b6288db28660dc9e8902cb7a27369e9af30db695..99b9fd7a748bad355c15e5fb50036b5d2c826af0 100644 (file)
@@ -345,10 +345,11 @@ kernel void kernel_rms_norm(
         uint sgitg[[simdgroup_index_in_threadgroup]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint   ntg[[threads_per_threadgroup]]) {
-    device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
-    device const float * x_scalar = (device const float *) x;
-    float4 sumf=0;
-    float all_sum=0;
+    device const float4 * x        = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
+    device const float  * x_scalar = (device const float  *) x;
+
+    float4 sumf = 0;
+    float all_sum = 0;
 
     // parallel sum
     for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
@@ -361,6 +362,7 @@ kernel void kernel_rms_norm(
     }
 
     threadgroup_barrier(mem_flags::mem_threadgroup);
+
     // broadcast, simd group number is ntg / 32
     for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
        if (tpitg < i) {
@@ -368,7 +370,9 @@ kernel void kernel_rms_norm(
        }
     }
     if (tpitg == 0) {
-        for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
+        for (int i = 4 * (ne00 / 4); i < ne00; i++) {
+            sum[0] += x_scalar[i];
+        }
         sum[0] /= ne00;
     }
 
@@ -383,7 +387,9 @@ kernel void kernel_rms_norm(
         y[i00] = x[i00] * scale;
     }
     if (tpitg == 0) {
-        for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
+        for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
+            y_scalar[i00] = x_scalar[i00] * scale;
+        }
     }
 }
 
diff --git a/ggml.c b/ggml.c
index 6d1776ca46741f2d436a0c667e4609dd04c28a38..5bb1da31ba624d4e54ce083925ab8456d2c8e40c 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -11233,7 +11233,7 @@ static void ggml_compute_forward_silu_f32(
 
 #ifndef NDEBUG
         for (int k = 0; k < nc; k++) {
-            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k];
             UNUSED(x);
             assert(!isnan(x));
             assert(!isinf(x));
@@ -13066,17 +13066,17 @@ static void ggml_compute_forward_alibi_f32(
 
     assert(n_past >= 0);
 
-    const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
-    const int ne1 = src0->ne[1]; // seq_len_without_past
-    const int ne2 = src0->ne[2]; // n_head -> this is k
-    //const int ne3 = src0->ne[3]; // 1 -> bsz
+    const int64_t ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
+    const int64_t ne1 = src0->ne[1]; // seq_len_without_past
+    const int64_t ne2 = src0->ne[2]; // n_head -> this is k
+    //const int64_t ne3 = src0->ne[3]; // 1 -> bsz
 
-    const int n  = ggml_nrows(src0);
-    const int ne2_ne3 = n/ne1; // ne2*ne3
+    const int64_t n  = ggml_nrows(src0);
+    const int64_t ne2_ne3 = n/ne1; // ne2*ne3
 
-    const int nb0 = src0->nb[0];
-    const int nb1 = src0->nb[1];
-    const int nb2 = src0->nb[2];
+    const size_t nb0 = src0->nb[0];
+    const size_t nb1 = src0->nb[1];
+    const size_t nb2 = src0->nb[2];
     //const int nb3 = src0->nb[3];
 
     GGML_ASSERT(nb0 == sizeof(float));
@@ -13088,9 +13088,9 @@ static void ggml_compute_forward_alibi_f32(
     const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
 
-    for (int i = 0; i < ne0; i++) {
-        for (int j = 0; j < ne1; j++) {
-            for (int k = 0; k < ne2_ne3; k++) {
+    for (int64_t i = 0; i < ne0; i++) {
+        for (int64_t j = 0; j < ne1; j++) {
+            for (int64_t k = 0; k < ne2_ne3; k++) {
                 float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
                 float *      pdst = (float *)((char *)  dst->data + i*nb0 + j*nb1 + k*nb2);
 
@@ -13105,7 +13105,6 @@ static void ggml_compute_forward_alibi_f32(
                 }
 
                 pdst[0] = i * m_k + src[0];
-
             }
         }
     }
index 77f7fa7c1055555f7722f5f3d30fe4ab835ce2d0..24f07daca6181a28c12ed528d3cf33735c55e221 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1325,7 +1325,11 @@ static bool llama_kv_cache_init(
     cache.cells.clear();
     cache.cells.resize(n_ctx);
 
+    // TODO: this should be:
+    //       cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*ggml_tensor_overhead());
+    //       change it and test that it works
     cache.buf.resize(2u*n_elements*ggml_type_size(wtype) + 2u*MB);
+    memset(cache.buf.data, 0, cache.buf.size);
 
     struct ggml_init_params params;
     params.mem_size   = cache.buf.size;