import sys
from enum import IntEnum
from pathlib import Path
-from typing import TYPE_CHECKING, Any, ContextManager, Iterator, cast
+from typing import TYPE_CHECKING, Any, ContextManager, Iterator, cast, Optional
import numpy as np
import torch
return PersimmonModel
if model_architecture in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
return StableLMModel
+ if model_architecture == "QWenLMHeadModel":
+ return QwenModel
return Model
def _is_model_safetensors(self) -> bool:
return gguf.MODEL_ARCH.PERSIMMON
if arch in ("StableLMEpochForCausalLM", "LlavaStableLMEpochForCausalLM"):
return gguf.MODEL_ARCH.STABLELM
+ if arch == "QWenLMHeadModel":
+ return gguf.MODEL_ARCH.QWEN
raise NotImplementedError(f'Architecture "{arch}" not supported!')
self.gguf_writer.add_parallel_residual(hparams["use_parallel_residual"] if "use_parallel_residual" in hparams else True)
self.gguf_writer.add_layer_norm_eps(1e-5)
+
+class QwenModel(Model):
+ @staticmethod
+ def token_bytes_to_string(b):
+ from transformers.models.gpt2.tokenization_gpt2 import bytes_to_unicode
+ byte_encoder = bytes_to_unicode()
+ return ''.join([byte_encoder[ord(char)] for char in b.decode('latin-1')])
+
+ @staticmethod
+ def bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: Optional[int] = None) -> list[bytes]:
+ parts = [bytes([b]) for b in token]
+ while True:
+ min_idx = None
+ min_rank = None
+ for i, pair in enumerate(zip(parts[:-1], parts[1:])):
+ rank = mergeable_ranks.get(pair[0] + pair[1])
+ if rank is not None and (min_rank is None or rank < min_rank):
+ min_idx = i
+ min_rank = rank
+ if min_rank is None or (max_rank is not None and min_rank >= max_rank):
+ break
+ assert min_idx is not None
+ parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:]
+ return parts
+
+ def set_vocab(self):
+ dir_model = self.dir_model
+ hparams = self.hparams
+ tokens: list[bytearray] = []
+ toktypes: list[int] = []
+
+ from transformers import AutoTokenizer # type: ignore[attr-defined]
+ tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
+ vocab_size = hparams["vocab_size"]
+ assert max(tokenizer.get_vocab().values()) < vocab_size
+
+ merges = []
+ vocab = {}
+ mergeable_ranks = tokenizer.mergeable_ranks
+ for token, rank in mergeable_ranks.items():
+ vocab[self.token_bytes_to_string(token)] = rank
+ if len(token) == 1:
+ continue
+ merged = QwenModel.bpe(mergeable_ranks, token, max_rank=rank)
+ assert len(merged) == 2
+ merges.append(' '.join(map(self.token_bytes_to_string, merged)))
+
+ reverse_vocab = {id_ : encoded_tok for encoded_tok, id_ in vocab.items()}
+ added_vocab = tokenizer.special_tokens
+
+ for i in range(vocab_size):
+ if i not in reverse_vocab:
+ pad_token = f"[PAD{i}]".encode("utf-8")
+ tokens.append(bytearray(pad_token))
+ toktypes.append(gguf.TokenType.USER_DEFINED)
+ elif reverse_vocab[i] in added_vocab:
+ tokens.append(reverse_vocab[i])
+ toktypes.append(gguf.TokenType.CONTROL)
+ else:
+ tokens.append(reverse_vocab[i])
+ toktypes.append(gguf.TokenType.NORMAL)
+
+ self.gguf_writer.add_tokenizer_model("gpt2")
+ self.gguf_writer.add_token_list(tokens)
+ self.gguf_writer.add_token_types(toktypes)
+
+ special_vocab = gguf.SpecialVocab(dir_model, load_merges=False)
+ special_vocab.merges = merges
+ special_vocab._set_special_token("bos", tokenizer.special_tokens["<|endoftext|>"])
+ special_vocab._set_special_token("eos", tokenizer.special_tokens["<|endoftext|>"])
+ special_vocab._set_special_token("unk", tokenizer.special_tokens["<|endoftext|>"])
+ special_vocab.add_to_gguf(self.gguf_writer)
+
+ def set_gguf_parameters(self):
+ self.gguf_writer.add_name("Qwen")
+ self.gguf_writer.add_context_length(self.hparams["max_position_embeddings"])
+ self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
+ self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
+ self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
+ self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"])
+ self.gguf_writer.add_rope_dimension_count(self.hparams["hidden_size"] // self.hparams["num_attention_heads"])
+ self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
+ self.gguf_writer.add_layer_norm_rms_eps(self.hparams["layer_norm_epsilon"])
+
+ def write_tensors(self):
+ block_count = self.hparams["num_hidden_layers"]
+ model_kv = dict(self.get_tensors())
+ tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
+ for name, data_torch in model_kv.items():
+ # we don't need these
+ if name.endswith(".rotary_emb.inv_freq"):
+ continue
+
+ old_dtype = data_torch.dtype
+
+ # convert any unsupported data types to float32
+ if data_torch.dtype not in (torch.float16, torch.float32):
+ data_torch = data_torch.to(torch.float32)
+
+ data = data_torch.squeeze().numpy()
+
+ # map tensor names
+ new_name = tensor_map.get_name(name, try_suffixes=(".weight", ".bias"))
+ if new_name is None:
+ print(f"Can not map tensor {name!r}")
+ sys.exit()
+
+ n_dims = len(data.shape)
+ data_dtype = data.dtype
+
+ # if f32 desired, convert any float16 to float32
+ if self.ftype == 0 and data_dtype == np.float16:
+ data = data.astype(np.float32)
+
+ # TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
+ if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
+ data = data.astype(np.float32)
+
+ # if f16 desired, convert any float32 2-dim weight tensors to float16
+ if self.ftype == 1 and data_dtype == np.float32 and name.endswith(".weight") and n_dims == 2:
+ data = data.astype(np.float16)
+
+ print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
+ self.gguf_writer.add_tensor(new_name, data)
+
###### CONVERSION LOGIC ######
# Token embeddings
MODEL_TENSOR.TOKEN_EMBD: (
"gpt_neox.embed_in", # gptneox
- "transformer.wte", # gpt2 gpt-j mpt refact
+ "transformer.wte", # gpt2 gpt-j mpt refact qwen
"transformer.word_embeddings", # falcon
"word_embeddings", # bloom
"model.embed_tokens", # llama-hf
# Output
MODEL_TENSOR.OUTPUT: (
"embed_out", # gptneox
- "lm_head", # gpt2 mpt falcon llama-hf baichuan
+ "lm_head", # gpt2 mpt falcon llama-hf baichuan qwen
"output", # llama-pth bloom
"word_embeddings_for_head", # persimmon
),
"norm", # llama-pth
"embeddings.LayerNorm", # bert
"transformer.norm_f", # mpt
- "ln_f", # refact bloom
+ "ln_f", # refact bloom qwen
"language_model.encoder.final_layernorm", # persimmon
),
# Attention norm
MODEL_TENSOR.ATTN_NORM: (
"gpt_neox.layers.{bid}.input_layernorm", # gptneox
- "transformer.h.{bid}.ln_1", # gpt2 gpt-j refact
+ "transformer.h.{bid}.ln_1", # gpt2 gpt-j refact qwen
"transformer.blocks.{bid}.norm_1", # mpt
"transformer.h.{bid}.input_layernorm", # falcon7b
"h.{bid}.input_layernorm", # bloom
# Attention query-key-value
MODEL_TENSOR.ATTN_QKV: (
"gpt_neox.layers.{bid}.attention.query_key_value", # gptneox
- "transformer.h.{bid}.attn.c_attn", # gpt2
+ "transformer.h.{bid}.attn.c_attn", # gpt2 qwen
"transformer.blocks.{bid}.attn.Wqkv", # mpt
"transformer.h.{bid}.self_attention.query_key_value", # falcon
"h.{bid}.self_attention.query_key_value", # bloom
# Attention output
MODEL_TENSOR.ATTN_OUT: (
"gpt_neox.layers.{bid}.attention.dense", # gptneox
- "transformer.h.{bid}.attn.c_proj", # gpt2 refact
+ "transformer.h.{bid}.attn.c_proj", # gpt2 refact qwen
"transformer.blocks.{bid}.attn.out_proj", # mpt
"transformer.h.{bid}.self_attention.dense", # falcon
"h.{bid}.self_attention.dense", # bloom
# Feed-forward norm
MODEL_TENSOR.FFN_NORM: (
"gpt_neox.layers.{bid}.post_attention_layernorm", # gptneox
- "transformer.h.{bid}.ln_2", # gpt2 refact
+ "transformer.h.{bid}.ln_2", # gpt2 refact qwen
"h.{bid}.post_attention_layernorm", # bloom
"transformer.blocks.{bid}.norm_2", # mpt
"model.layers.{bid}.post_attention_layernorm", # llama-hf
"encoder.layer.{bid}.intermediate.dense", # bert
"transformer.h.{bid}.mlp.fc_in", # gpt-j
"language_model.encoder.layers.{bid}.mlp.dense_h_to_4h", # persimmon
+ "transformer.h.{bid}.mlp.w1", # qwen
),
# Feed-forward gate
MODEL_TENSOR.FFN_GATE: (
"model.layers.{bid}.mlp.gate_proj", # llama-hf refact
"layers.{bid}.feed_forward.w1", # llama-pth
+ "transformer.h.{bid}.mlp.w2", # qwen
),
# Feed-forward down
MODEL_TENSOR.FFN_DOWN: (
"gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox
- "transformer.h.{bid}.mlp.c_proj", # gpt2 refact
+ "transformer.h.{bid}.mlp.c_proj", # gpt2 refact qwen
"transformer.blocks.{bid}.ffn.down_proj", # mpt
"transformer.h.{bid}.mlp.dense_4h_to_h", # falcon
"h.{bid}.mlp.dense_4h_to_h", # bloom
LLM_ARCH_REFACT,
LLM_ARCH_BLOOM,
LLM_ARCH_STABLELM,
+ LLM_ARCH_QWEN,
LLM_ARCH_UNKNOWN,
};
{ LLM_ARCH_REFACT, "refact" },
{ LLM_ARCH_BLOOM, "bloom" },
{ LLM_ARCH_STABLELM, "stablelm" },
+ { LLM_ARCH_QWEN, "qwen" },
};
enum llm_kv {
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
},
},
+ {
+ LLM_ARCH_QWEN,
+ {
+ { LLM_TENSOR_TOKEN_EMBD, "token_embd" },
+ { LLM_TENSOR_OUTPUT_NORM, "output_norm" },
+ { LLM_TENSOR_OUTPUT, "output" },
+ { LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
+ { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
+ { LLM_TENSOR_ATTN_QKV, "blk.%d.attn_qkv" },
+ { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
+ { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
+ { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
+ { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
+ { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
+ },
+ },
{
LLM_ARCH_UNKNOWN,
default: model.type = e_model::MODEL_UNKNOWN;
}
} break;
+ case LLM_ARCH_QWEN:
+ {
+ GGUF_GET_KEY(ctx, hparams.f_norm_rms_eps, gguf_get_val_f32, GGUF_TYPE_FLOAT32, true, kv(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS));
+ switch (hparams.n_layer) {
+ case 32: model.type = e_model::MODEL_7B; break;
+ case 40: model.type = e_model::MODEL_13B; break;
+ default: model.type = e_model::MODEL_UNKNOWN;
+ }
+ } break;
default: (void)0;
}
}
}
} break;
+ case LLM_ARCH_QWEN:
+ {
+ model.tok_embd = ml.create_tensor(ctx, tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, GGML_BACKEND_CPU);
+ {
+ ggml_backend_type backend_norm;
+ ggml_backend_type backend_output;
+
+ if (n_gpu_layers > int(n_layer)) {
+ // norm is not performance relevant on its own but keeping it in VRAM reduces data copying
+ // on Windows however this is detrimental unless everything is on the GPU
+#ifndef _WIN32
+ backend_norm = llama_backend_offload;
+#else
+ backend_norm = n_gpu_layers <= (int) n_layer + 2 ? GGML_BACKEND_CPU : llama_backend_offload;
+#endif // _WIN32
+
+ backend_output = llama_backend_offload_split;
+ } else {
+ backend_norm = GGML_BACKEND_CPU;
+ backend_output = GGML_BACKEND_CPU;
+ }
+
+ model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm);
+ model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output);
+
+ if (backend_norm == GGML_BACKEND_GPU) {
+ vram_weights += ggml_nbytes(model.output_norm);
+ }
+ if (backend_output == GGML_BACKEND_GPU_SPLIT) {
+ vram_weights += ggml_nbytes(model.output);
+ }
+ }
+
+ const uint32_t n_ff = hparams.n_ff / 2;
+
+ const int i_gpu_start = n_layer - n_gpu_layers;
+
+ model.layers.resize(n_layer);
+
+ for (uint32_t i = 0; i < n_layer; ++i) {
+ const ggml_backend_type backend = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload; // NOLINT
+ const ggml_backend_type backend_split = int(i) < i_gpu_start ? GGML_BACKEND_CPU : llama_backend_offload_split; // NOLINT
+
+ auto & layer = model.layers[i];
+
+ layer.attn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, backend);
+
+ layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd * 3}, backend_split);
+ layer.bqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd * 3}, backend);
+ layer.wo = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, backend_split);
+
+ layer.ffn_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, backend);
+
+ layer.ffn_gate = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, backend_split);
+ layer.ffn_down = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split);
+ layer.ffn_up = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
+
+ if (backend == GGML_BACKEND_GPU) {
+ vram_weights +=
+ ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.wqkv) + ggml_nbytes(layer.bqkv) +
+ ggml_nbytes(layer.wo) + ggml_nbytes(layer.ffn_norm) + ggml_nbytes(layer.ffn_gate) +
+ ggml_nbytes(layer.ffn_down) + ggml_nbytes(layer.ffn_up);
+ }
+ }
+ } break;
default:
throw std::runtime_error("unknown architecture");
return gf;
}
+
+ struct ggml_cgraph * build_qwen() {
+ struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+ struct ggml_tensor * cur;
+ struct ggml_tensor * inpL;
+
+ inpL = llm_build_inp_embd(ctx0, hparams, batch, model.tok_embd, cb);
+ cb(inpL, "inp_embd", -1);
+
+ // inp_pos - contains the positions
+ struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_tokens);
+ cb(inp_pos, "inp_pos", -1);
+
+ // KQ_scale
+ struct ggml_tensor * KQ_scale = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1);
+ cb(KQ_scale, "KQ_scale", -1);
+
+ // KQ_mask (mask for 1 head, it wil be broadcasted to all heads)
+ struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_kv, n_tokens, 1);
+ cb(KQ_mask, "KQ_mask", -1);
+
+ // shift the entire K-cache if needed
+ if (do_rope_shift) {
+ llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, LLM_ROPE_NEOX, n_ctx, n_embd_head, freq_base, freq_scale, cb);
+ }
+
+ for (int il = 0; il < n_layer; ++il) {
+ struct ggml_tensor * inpSA = inpL;
+
+ cur = llm_build_norm(ctx0, inpL, hparams,
+ model.layers[il].attn_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "attn_norm", il);
+
+ // self-attention
+ {
+ cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+ cb(cur, "wqkv", il);
+
+ cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+ cb(cur, "bqkv", il);
+
+ struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+ struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+ struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd, n_tokens, cur->nb[1], 2*sizeof(float)*(n_embd)));
+
+ cb(Qcur, "Qcur", il);
+ cb(Kcur, "Kcur", il);
+ cb(Vcur, "Vcur", il);
+
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+
+ // using mode = 2 for neox mode
+ Qcur = ggml_rope_custom(
+ ctx0, Qcur, inp_pos, n_embd_head, 2, 0, n_orig_ctx,
+ freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+ );
+ cb(Qcur, "Qcur", il);
+
+ Kcur = ggml_rope_custom(
+ ctx0, Kcur, inp_pos, n_embd_head, 2, 0, n_orig_ctx,
+ freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
+ );
+ cb(Kcur, "Kcur", il);
+
+ llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
+
+ cur = llm_build_kqv(ctx0, hparams, kv_self,
+ model.layers[il].wo, NULL,
+ Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
+ cb(cur, "kqv_out", il);
+ }
+
+ struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+ cb(ffn_inp, "ffn_inp", il);
+
+ // feed-forward forward
+ {
+ cur = llm_build_norm(ctx0, ffn_inp, hparams,
+ model.layers[il].ffn_norm, NULL,
+ LLM_NORM_RMS, cb, il);
+ cb(cur, "ffn_norm", il);
+
+ cur = llm_build_ffn(ctx0, cur,
+ model.layers[il].ffn_up, NULL,
+ model.layers[il].ffn_gate, NULL,
+ model.layers[il].ffn_down, NULL,
+ LLM_FFN_SILU, LLM_FFN_PAR, cb, il);
+ cb(cur, "ffn_out", il);
+ }
+
+ cur = ggml_add(ctx0, cur, ffn_inp);
+ cb(cur, "l_out", il);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ cur = inpL;
+
+ cur = llm_build_norm(ctx0, cur, hparams,
+ model.output_norm, NULL,
+ LLM_NORM_RMS, cb, -1);
+ cb(cur, "result_norm", -1);
+
+ // lm_head
+ cur = ggml_mul_mat(ctx0, model.output, cur);
+ cb(cur, "result_output", -1);
+
+ ggml_build_forward_expand(gf, cur);
+
+ return gf;
+ }
};
//
{
result = llm.build_stablelm();
} break;
+ case LLM_ARCH_QWEN:
+ {
+ result = llm.build_qwen();
+ } break;
default:
GGML_ASSERT(false);
}