]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
llama : support StableLM 2 1.6B (#5052)
authorcompilade <redacted>
Mon, 22 Jan 2024 11:21:52 +0000 (06:21 -0500)
committerGitHub <redacted>
Mon, 22 Jan 2024 11:21:52 +0000 (13:21 +0200)
* llama : support StableLM 2 1.6B

* convert : fix Qwen's set_vocab wrongly naming all special tokens [PAD{id}]

* convert : refactor Qwen's set_vocab to use it for StableLM 2 too

* nix : add tiktoken to llama-python-extra

* convert : use presence of tokenizer.json to determine StableLM tokenizer loader

It's a less arbitrary heuristic than the vocab size.

.devops/nix/package.nix
convert-hf-to-gguf.py
llama.cpp

index 1ef2c6cd9180a14565cf04a11d680b7a32946a49..e8534956f1d96224fdd33ac934f6f08c998f59b0 100644 (file)
@@ -73,6 +73,7 @@ let
     ps: [
       ps.numpy
       ps.sentencepiece
+      ps.tiktoken
       ps.torchWithoutCuda
       ps.transformers
     ]
index 4d995ef786b46f86b1cb10b94411d6ed5ec48a85..7a0a8c3dbaca30020faa1d3e1a9d2e28fa836e8a 100755 (executable)
@@ -289,6 +289,58 @@ class Model:
         special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
         special_vocab.add_to_gguf(self.gguf_writer)
 
+    def _set_vocab_qwen(self):
+        dir_model = self.dir_model
+        hparams = self.hparams
+        tokens: list[bytearray] = []
+        toktypes: list[int] = []
+
+        from transformers import AutoTokenizer
+        tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
+        vocab_size = hparams["vocab_size"]
+        assert max(tokenizer.get_vocab().values()) < vocab_size
+
+        merges = []
+        vocab = {}
+        mergeable_ranks = tokenizer.mergeable_ranks
+        for token, rank in mergeable_ranks.items():
+            vocab[QwenModel.token_bytes_to_string(token)] = rank
+            if len(token) == 1:
+                continue
+            merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank)
+            assert len(merged) == 2
+            merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged)))
+
+        # for this kind of tokenizer, added_vocab is not a subset of vocab, so they need to be combined
+        added_vocab = tokenizer.special_tokens
+        reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in (vocab | added_vocab).items()}
+
+        for i in range(vocab_size):
+            if i not in reverse_vocab:
+                pad_token = f"[PAD{i}]".encode("utf-8")
+                tokens.append(bytearray(pad_token))
+                toktypes.append(gguf.TokenType.USER_DEFINED)
+            elif reverse_vocab[i] in added_vocab:
+                tokens.append(reverse_vocab[i])
+                toktypes.append(gguf.TokenType.CONTROL)
+            else:
+                tokens.append(reverse_vocab[i])
+                toktypes.append(gguf.TokenType.NORMAL)
+
+        self.gguf_writer.add_tokenizer_model("gpt2")
+        self.gguf_writer.add_token_list(tokens)
+        self.gguf_writer.add_token_types(toktypes)
+
+        special_vocab = gguf.SpecialVocab(dir_model, load_merges=False)
+        special_vocab.merges = merges
+        # only add special tokens when they were not already loaded from config.json
+        if len(special_vocab.special_token_ids) == 0:
+            special_vocab._set_special_token("bos", tokenizer.special_tokens["<|endoftext|>"])
+            special_vocab._set_special_token("eos", tokenizer.special_tokens["<|endoftext|>"])
+        # this one is usually not in config.json anyway
+        special_vocab._set_special_token("unk", tokenizer.special_tokens["<|endoftext|>"])
+        special_vocab.add_to_gguf(self.gguf_writer)
+
     def _set_vocab_sentencepiece(self):
         from sentencepiece import SentencePieceProcessor
 
@@ -877,6 +929,13 @@ class PersimmonModel(Model):
 
 
 class StableLMModel(Model):
+    def set_vocab(self):
+        if (self.dir_model / "tokenizer.json").is_file():
+            self._set_vocab_gpt2()
+        else:
+            # StableLM 2 1.6B uses a vocab in a similar format to Qwen's vocab
+            self._set_vocab_qwen()
+
     def set_gguf_parameters(self):
         hparams = self.hparams
         block_count = hparams["num_hidden_layers"]
@@ -922,52 +981,7 @@ class QwenModel(Model):
         return parts
 
     def set_vocab(self):
-        dir_model = self.dir_model
-        hparams = self.hparams
-        tokens: list[bytearray] = []
-        toktypes: list[int] = []
-
-        from transformers import AutoTokenizer
-        tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
-        vocab_size = hparams["vocab_size"]
-        assert max(tokenizer.get_vocab().values()) < vocab_size
-
-        merges = []
-        vocab = {}
-        mergeable_ranks = tokenizer.mergeable_ranks
-        for token, rank in mergeable_ranks.items():
-            vocab[self.token_bytes_to_string(token)] = rank
-            if len(token) == 1:
-                continue
-            merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank)
-            assert len(merged) == 2
-            merges.append(' '.join(map(self.token_bytes_to_string, merged)))
-
-        reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in vocab.items()}
-        added_vocab = tokenizer.special_tokens
-
-        for i in range(vocab_size):
-            if i not in reverse_vocab:
-                pad_token = f"[PAD{i}]".encode("utf-8")
-                tokens.append(bytearray(pad_token))
-                toktypes.append(gguf.TokenType.USER_DEFINED)
-            elif reverse_vocab[i] in added_vocab:
-                tokens.append(reverse_vocab[i])
-                toktypes.append(gguf.TokenType.CONTROL)
-            else:
-                tokens.append(reverse_vocab[i])
-                toktypes.append(gguf.TokenType.NORMAL)
-
-        self.gguf_writer.add_tokenizer_model("gpt2")
-        self.gguf_writer.add_token_list(tokens)
-        self.gguf_writer.add_token_types(toktypes)
-
-        special_vocab = gguf.SpecialVocab(dir_model, load_merges=False)
-        special_vocab.merges = merges
-        special_vocab._set_special_token("bos", tokenizer.special_tokens["<|endoftext|>"])
-        special_vocab._set_special_token("eos", tokenizer.special_tokens["<|endoftext|>"])
-        special_vocab._set_special_token("unk", tokenizer.special_tokens["<|endoftext|>"])
-        special_vocab.add_to_gguf(self.gguf_writer)
+        self._set_vocab_qwen()
 
     def set_gguf_parameters(self):
         self.gguf_writer.add_name("Qwen")
index c56d31163fb52fc49f752ebfd570fc5572871d5b..8c906a22f0ba90092a9477dea8b86a8b8ad16e17 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -2877,6 +2877,7 @@ static void llm_load_hparams(
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
 
                 switch (hparams.n_layer) {
+                    case 24: model.type = e_model::MODEL_1B; break;
                     case 32: model.type = e_model::MODEL_3B; break;
                     default: model.type = e_model::MODEL_UNKNOWN;
                }
@@ -3700,6 +3701,11 @@ static bool llm_load_tensors(
                         layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa});
                         layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
 
+                        // optional bias tensors, present in Stable LM 2 1.6B
+                        layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q,   "bias", i), {n_embd},     false);
+                        layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K,   "bias", i), {n_embd_gqa}, false);
+                        layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V,   "bias", i), {n_embd_gqa}, false);
+
                         layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
                         layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd});
 
@@ -5598,12 +5604,24 @@ struct llm_build_context {
                 // compute Q and K and RoPE them
                 struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
                 cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
 
                 struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
                 cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
 
                 struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
                 cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
 
                 Qcur = ggml_rope_custom(
                     ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos,