special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
special_vocab.add_to_gguf(self.gguf_writer)
+ def _set_vocab_qwen(self):
+ dir_model = self.dir_model
+ hparams = self.hparams
+ tokens: list[bytearray] = []
+ toktypes: list[int] = []
+
+ from transformers import AutoTokenizer
+ tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
+ vocab_size = hparams["vocab_size"]
+ assert max(tokenizer.get_vocab().values()) < vocab_size
+
+ merges = []
+ vocab = {}
+ mergeable_ranks = tokenizer.mergeable_ranks
+ for token, rank in mergeable_ranks.items():
+ vocab[QwenModel.token_bytes_to_string(token)] = rank
+ if len(token) == 1:
+ continue
+ merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank)
+ assert len(merged) == 2
+ merges.append(' '.join(map(QwenModel.token_bytes_to_string, merged)))
+
+ # for this kind of tokenizer, added_vocab is not a subset of vocab, so they need to be combined
+ added_vocab = tokenizer.special_tokens
+ reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in (vocab | added_vocab).items()}
+
+ for i in range(vocab_size):
+ if i not in reverse_vocab:
+ pad_token = f"[PAD{i}]".encode("utf-8")
+ tokens.append(bytearray(pad_token))
+ toktypes.append(gguf.TokenType.USER_DEFINED)
+ elif reverse_vocab[i] in added_vocab:
+ tokens.append(reverse_vocab[i])
+ toktypes.append(gguf.TokenType.CONTROL)
+ else:
+ tokens.append(reverse_vocab[i])
+ toktypes.append(gguf.TokenType.NORMAL)
+
+ self.gguf_writer.add_tokenizer_model("gpt2")
+ self.gguf_writer.add_token_list(tokens)
+ self.gguf_writer.add_token_types(toktypes)
+
+ special_vocab = gguf.SpecialVocab(dir_model, load_merges=False)
+ special_vocab.merges = merges
+ # only add special tokens when they were not already loaded from config.json
+ if len(special_vocab.special_token_ids) == 0:
+ special_vocab._set_special_token("bos", tokenizer.special_tokens["<|endoftext|>"])
+ special_vocab._set_special_token("eos", tokenizer.special_tokens["<|endoftext|>"])
+ # this one is usually not in config.json anyway
+ special_vocab._set_special_token("unk", tokenizer.special_tokens["<|endoftext|>"])
+ special_vocab.add_to_gguf(self.gguf_writer)
+
def _set_vocab_sentencepiece(self):
from sentencepiece import SentencePieceProcessor
class StableLMModel(Model):
+ def set_vocab(self):
+ if (self.dir_model / "tokenizer.json").is_file():
+ self._set_vocab_gpt2()
+ else:
+ # StableLM 2 1.6B uses a vocab in a similar format to Qwen's vocab
+ self._set_vocab_qwen()
+
def set_gguf_parameters(self):
hparams = self.hparams
block_count = hparams["num_hidden_layers"]
return parts
def set_vocab(self):
- dir_model = self.dir_model
- hparams = self.hparams
- tokens: list[bytearray] = []
- toktypes: list[int] = []
-
- from transformers import AutoTokenizer
- tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
- vocab_size = hparams["vocab_size"]
- assert max(tokenizer.get_vocab().values()) < vocab_size
-
- merges = []
- vocab = {}
- mergeable_ranks = tokenizer.mergeable_ranks
- for token, rank in mergeable_ranks.items():
- vocab[self.token_bytes_to_string(token)] = rank
- if len(token) == 1:
- continue
- merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank)
- assert len(merged) == 2
- merges.append(' '.join(map(self.token_bytes_to_string, merged)))
-
- reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in vocab.items()}
- added_vocab = tokenizer.special_tokens
-
- for i in range(vocab_size):
- if i not in reverse_vocab:
- pad_token = f"[PAD{i}]".encode("utf-8")
- tokens.append(bytearray(pad_token))
- toktypes.append(gguf.TokenType.USER_DEFINED)
- elif reverse_vocab[i] in added_vocab:
- tokens.append(reverse_vocab[i])
- toktypes.append(gguf.TokenType.CONTROL)
- else:
- tokens.append(reverse_vocab[i])
- toktypes.append(gguf.TokenType.NORMAL)
-
- self.gguf_writer.add_tokenizer_model("gpt2")
- self.gguf_writer.add_token_list(tokens)
- self.gguf_writer.add_token_types(toktypes)
-
- special_vocab = gguf.SpecialVocab(dir_model, load_merges=False)
- special_vocab.merges = merges
- special_vocab._set_special_token("bos", tokenizer.special_tokens["<|endoftext|>"])
- special_vocab._set_special_token("eos", tokenizer.special_tokens["<|endoftext|>"])
- special_vocab._set_special_token("unk", tokenizer.special_tokens["<|endoftext|>"])
- special_vocab.add_to_gguf(self.gguf_writer)
+ self._set_vocab_qwen()
def set_gguf_parameters(self):
self.gguf_writer.add_name("Qwen")
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
switch (hparams.n_layer) {
+ case 24: model.type = e_model::MODEL_1B; break;
case 32: model.type = e_model::MODEL_3B; break;
default: model.type = e_model::MODEL_UNKNOWN;
}
layer.wv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa});
layer.wo = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+ // optional bias tensors, present in Stable LM 2 1.6B
+ layer.bq = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, false);
+ layer.bk = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, false);
+ layer.bv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, false);
+
layer.ffn_norm = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd});
// compute Q and K and RoPE them
struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
cb(Qcur, "Qcur", il);
+ if (model.layers[il].bq) {
+ Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+ cb(Qcur, "Qcur", il);
+ }
struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
cb(Kcur, "Kcur", il);
+ if (model.layers[il].bk) {
+ Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+ cb(Kcur, "Kcur", il);
+ }
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
cb(Vcur, "Vcur", il);
+ if (model.layers[il].bv) {
+ Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+ cb(Vcur, "Vcur", il);
+ }
Qcur = ggml_rope_custom(
ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,