]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Add missing inference support for GPTNeoXForCausalLM (Pythia and GPT-NeoX base models...
authorfairydreaming <redacted>
Thu, 23 May 2024 09:49:53 +0000 (11:49 +0200)
committerGitHub <redacted>
Thu, 23 May 2024 09:49:53 +0000 (11:49 +0200)
* convert-hf : add conversion of bloom-style qkv tensor to gpt-style qkv (code borrowed from BloomModel)

* llama : add inference support for LLM_ARCH_GPTNEOX

* llama : add model types for every Pythia variant and GPT-NeoX

Co-authored-by: Stanisław Szymczyk <redacted>
convert-hf-to-gguf.py
llama.cpp

index daad1c4fc725512b8a3592007e74e790a2930b8b..5a00a5e89accbd6b0bf749c8d10c05f578b3275b 100755 (executable)
@@ -673,6 +673,44 @@ class GPTNeoXModel(Model):
         self.gguf_writer.add_parallel_residual(self.hparams.get("use_parallel_residual", True))
         self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
 
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        del bid  # unused
+
+        n_head = self.hparams.get("n_head", self.hparams.get("num_attention_heads"))
+        n_embed = self.hparams.get("hidden_size", self.hparams.get("n_embed"))
+
+        tensors: list[tuple[str, Tensor]] = []
+
+        if re.match(r"gpt_neox\.layers\.\d+\.attention\.query_key_value\.weight", name):
+            # Map bloom-style qkv_linear to gpt-style qkv_linear
+            # bloom: https://github.com/huggingface/transformers/blob/main/src/transformers/models/bloom/modeling_bloom.py#L238-L252  # noqa
+            # gpt-2: https://github.com/huggingface/transformers/blob/main/src/transformers/models/gpt2/modeling_gpt2.py#L312  # noqa
+            qkv_weights = data_torch.reshape((n_head, 3, n_embed // n_head, n_embed))
+            data_torch = torch.cat(
+                (
+                    qkv_weights[:, 0, :, :].reshape((-1, n_embed)),
+                    qkv_weights[:, 1, :, :].reshape((-1, n_embed)),
+                    qkv_weights[:, 2, :, :].reshape((-1, n_embed)),
+                ),
+                dim=0,
+            )
+            logger.info("re-format attention.linear_qkv.weight")
+        elif re.match(r"gpt_neox\.layers\.\d+\.attention\.query_key_value\.bias", name):
+            qkv_bias = data_torch.reshape((n_head, 3, n_embed // n_head))
+            data_torch = torch.cat(
+                (
+                    qkv_bias[:, 0, :].reshape((n_embed,)),
+                    qkv_bias[:, 1, :].reshape((n_embed,)),
+                    qkv_bias[:, 2, :].reshape((n_embed,)),
+                ),
+                dim=0,
+            )
+            logger.info("re-format attention.linear_qkv.bias")
+
+        tensors.append((self.map_tensor_name(name), data_torch))
+
+        return tensors
+
 
 @Model.register("BloomForCausalLM")
 class BloomModel(Model):
index 3e09a239000c007b45621f40439a767935507f33..5ff186a579996bf10251d197fd2514d94bb50dc3 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1692,17 +1692,24 @@ static llama_state g_state;
 // available llama models
 enum e_model {
     MODEL_UNKNOWN,
+    MODEL_14M,
     MODEL_17M,
     MODEL_22M,
     MODEL_33M,
+    MODEL_70M,
     MODEL_109M,
     MODEL_137M,
+    MODEL_160M,
     MODEL_335M,
+    MODEL_410M,
     MODEL_0_5B,
     MODEL_1B,
+    MODEL_1_4B,
     MODEL_2B,
+    MODEL_2_8B,
     MODEL_3B,
     MODEL_4B,
+    MODEL_6_9B,
     MODEL_7B,
     MODEL_8B,
     MODEL_12B,
@@ -1734,6 +1741,7 @@ static const size_t GiB = 1024*MiB;
 struct llama_hparams {
     bool vocab_only;
     bool rope_finetuned;
+    bool use_par_res;
 
     uint32_t n_vocab;
     uint32_t n_ctx_train; // context size the model was trained on
@@ -3773,17 +3781,24 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
 
 static const char * llama_model_type_name(e_model type) {
     switch (type) {
+        case MODEL_14M:    return "14M";
         case MODEL_17M:    return "17M";
         case MODEL_22M:    return "22M";
         case MODEL_33M:    return "33M";
+        case MODEL_70M:    return "70M";
         case MODEL_109M:   return "109M";
         case MODEL_137M:   return "137M";
+        case MODEL_160M:   return "160M";
         case MODEL_335M:   return "335M";
+        case MODEL_410M:   return "410M";
         case MODEL_0_5B:   return "0.5B";
         case MODEL_1B:     return "1B";
+        case MODEL_1_4B:   return "1.4B";
         case MODEL_2B:     return "2B";
+        case MODEL_2_8B:   return "2.8B";
         case MODEL_3B:     return "3B";
         case MODEL_4B:     return "4B";
+        case MODEL_6_9B:   return "6.9B";
         case MODEL_7B:     return "7B";
         case MODEL_8B:     return "8B";
         case MODEL_12B:    return "12B";
@@ -4282,6 +4297,52 @@ static void llm_load_hparams(
                     default: model.type = e_model::MODEL_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_GPTNEOX:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps);
+                ml.get_key(LLM_KV_USE_PARALLEL_RESIDUAL, hparams.use_par_res);
+                switch (hparams.n_layer) {
+                    case 6:
+                        switch (hparams.n_ff) {
+                            case 512: model.type = e_model::MODEL_14M; break;
+                            case 2048: model.type = e_model::MODEL_70M; break;
+                            default: model.type = e_model::MODEL_UNKNOWN;
+                        } break;
+                    case 12:
+                        switch (hparams.n_ff) {
+                            case 3072: model.type = e_model::MODEL_160M; break;
+                            default: model.type = e_model::MODEL_UNKNOWN;
+                        } break;
+                    case 16:
+                        switch (hparams.n_ff) {
+                            case 8192: model.type = e_model::MODEL_1B; break;
+                            default: model.type = e_model::MODEL_UNKNOWN;
+                        } break;
+                    case 24:
+                        switch (hparams.n_ff) {
+                            case 4096: model.type = e_model::MODEL_410M; break;
+                            case 8192: model.type = e_model::MODEL_1_4B; break;
+                            default: model.type = e_model::MODEL_UNKNOWN;
+                        } break;
+                    case 32:
+                        switch (hparams.n_ff) {
+                            case 10240: model.type = e_model::MODEL_2_8B; break;
+                            case 16384: model.type = e_model::MODEL_6_9B; break;
+                            default: model.type = e_model::MODEL_UNKNOWN;
+                        } break;
+                    case 36:
+                        switch (hparams.n_ff) {
+                            case 20480: model.type = e_model::MODEL_12B; break;
+                            default: model.type = e_model::MODEL_UNKNOWN;
+                        } break;
+                    case 44:
+                        switch (hparams.n_ff) {
+                            case 24576: model.type = e_model::MODEL_20B; break;
+                            default: model.type = e_model::MODEL_UNKNOWN;
+                        } break;
+                    default: model.type = e_model::MODEL_UNKNOWN;
+                }
+            } break;
         default: (void)0;
     }
 
@@ -6033,6 +6094,41 @@ static bool llm_load_tensors(
                         layer.ffn_up   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff});
                     }
                 } break;
+            case LLM_ARCH_GPTNEOX:
+                {
+                    model.tok_embd   = ml.create_tensor(ctx_input,  tn(LLM_TENSOR_TOKEN_EMBD,      "weight"), {n_embd, n_vocab});
+                    // output
+                    {
+                        model.output_norm   = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd});
+                        model.output_norm_b = ml.create_tensor(ctx_output,       tn(LLM_TENSOR_OUTPUT_NORM, "bias"),   {n_embd});
+                        model.output        = ml.create_tensor(ctx_output_split, tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab});
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        ggml_context * ctx_layer = ctx_for_layer(i);
+                        ggml_context * ctx_split = ctx_for_layer_split(i);
+
+                        auto & layer = model.layers[i];
+
+                        layer.attn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd});
+                        layer.attn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_NORM, "bias", i),   {n_embd});
+
+                        layer.wqkv = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa});
+                        layer.bqkv = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_QKV, "bias", i),   {n_embd + 2*n_embd_gqa});
+
+                        layer.wo   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd});
+                        layer.bo   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_ATTN_OUT, "bias", i),   {n_embd});
+
+                        layer.ffn_norm   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd});
+                        layer.ffn_norm_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_NORM, "bias", i),   {n_embd});
+
+                        layer.ffn_down   = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd});
+                        layer.ffn_down_b = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_DOWN, "bias", i),   {n_embd});
+
+                        layer.ffn_up     = ml.create_tensor(ctx_split, tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd, n_ff});
+                        layer.ffn_up_b   = ml.create_tensor(ctx_layer, tn(LLM_TENSOR_FFN_UP,   "bias", i),   {n_ff});
+                    }
+                } break;
             default:
                 throw std::runtime_error("unknown architecture");
         }
@@ -10560,6 +10656,140 @@ struct llm_build_context {
 
         return gf;
     }
+
+    struct ggml_cgraph * build_gptneox() {
+        struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false);
+
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_embd_gqa  = hparams.n_embd_v_gqa();
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+
+        struct ggml_tensor * cur;
+        struct ggml_tensor * inpL;
+
+        inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb);
+
+        // inp_pos - contains the positions
+        struct ggml_tensor * inp_pos = build_inp_pos();
+
+        // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
+        struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
+
+        for (int il = 0; il < n_layer; ++il) {
+            cur = llm_build_norm(ctx0, inpL, hparams,
+                    model.layers[il].attn_norm,
+                    model.layers[il].attn_norm_b,
+                    LLM_NORM, cb, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                cur = ggml_mul_mat(ctx0, model.layers[il].wqkv, cur);
+                cb(cur, "wqkv", il);
+
+                cur = ggml_add(ctx0, cur, model.layers[il].bqkv);
+                cb(cur, "bqkv", il);
+
+                struct ggml_tensor * Qcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd,     n_tokens, cur->nb[1], 0*sizeof(float)*(n_embd)));
+                struct ggml_tensor * Kcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd)));
+                struct ggml_tensor * Vcur = ggml_cont(ctx0, ggml_view_2d(ctx0, cur, n_embd_gqa, n_tokens, cur->nb[1], 1*sizeof(float)*(n_embd + n_embd_gqa)));
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos, nullptr,
+                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Qcur, "Qcur", il);
+
+                Kcur = ggml_rope_ext(
+                    ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos, nullptr,
+                    n_rot, rope_type, 0, n_orig_ctx, freq_base, freq_scale,
+                    ext_factor, attn_factor, beta_fast, beta_slow
+                );
+                cb(Kcur, "Kcur", il);
+
+                cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                struct ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur  = ggml_get_rows(ctx0,  cur, inp_out_ids);
+                inpL = ggml_get_rows(ctx0, inpL, inp_out_ids);
+            }
+
+            // ffn
+            if (hparams.use_par_res) {
+                // attention and ffn are computed in parallel
+                // x = x + attn(ln1(x)) + ffn(ln2(x))
+
+                struct ggml_tensor * attn_out = cur;
+
+                cur = llm_build_norm(ctx0, inpL, hparams,
+                        model.layers[il].ffn_norm,
+                        model.layers[il].ffn_norm_b,
+                        LLM_NORM, cb, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = llm_build_ffn(ctx0, cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
+                        NULL,                      NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
+                cb(cur, "ffn_out", il);
+
+                cur = ggml_add(ctx0, cur, inpL);
+                cb(cur, "ffn_out", il);
+
+                inpL = ggml_add(ctx0, cur, attn_out);
+                cb(inpL, "l_out", il);
+            } else {
+                // attention and ffn are computed sequentially
+                // x = x + attn(ln1(x))
+                // x = x + ffn(ln2(x))
+
+                struct ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL);
+                cb(ffn_inp, "ffn_inp", il);
+
+                cur = llm_build_norm(ctx0, ffn_inp, hparams,
+                        model.layers[il].ffn_norm,
+                        model.layers[il].ffn_norm_b,
+                        LLM_NORM, cb, il);
+                cb(cur, "ffn_norm", il);
+
+                cur = llm_build_ffn(ctx0, cur,
+                        model.layers[il].ffn_up,   model.layers[il].ffn_up_b,
+                        NULL,                      NULL,
+                        model.layers[il].ffn_down, model.layers[il].ffn_down_b,
+                        NULL,
+                        LLM_FFN_GELU, LLM_FFN_SEQ, cb, il);
+                cb(cur, "ffn_out", il);
+
+                inpL = ggml_add(ctx0, cur, ffn_inp);
+                cb(inpL, "l_out", il);
+            }
+        }
+
+        cur = llm_build_norm(ctx0, inpL, hparams,
+                model.output_norm,
+                model.output_norm_b,
+                LLM_NORM, cb, -1);
+        cb(cur, "result_norm", -1);
+
+        cur = ggml_mul_mat(ctx0, model.output, cur);
+        cb(cur, "result_output", -1);
+
+        ggml_build_forward_expand(gf, cur);
+
+        return gf;
+    }
 };
 
 static struct ggml_cgraph * llama_build_graph_defrag(llama_context & lctx, const std::vector<uint32_t> & ids) {
@@ -10770,6 +11000,10 @@ static struct ggml_cgraph * llama_build_graph(
             {
                 result = llm.build_olmo();
             } break;
+        case LLM_ARCH_GPTNEOX:
+            {
+                result = llm.build_gptneox();
+            } break;
         default:
             GGML_ASSERT(false);
     }
@@ -15762,7 +15996,6 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         // these models do not use RoPE
         case LLM_ARCH_GPT2:
         case LLM_ARCH_GPTJ:
-        case LLM_ARCH_GPTNEOX:
         case LLM_ARCH_MPT:
         case LLM_ARCH_REFACT:
         case LLM_ARCH_BLOOM:
@@ -15798,6 +16031,7 @@ enum llama_rope_type llama_rope_type(const struct llama_model * model) {
         case LLM_ARCH_PHI3:
         case LLM_ARCH_GEMMA:
         case LLM_ARCH_STARCODER2:
+        case LLM_ARCH_GPTNEOX:
             return LLAMA_ROPE_TYPE_NEOX;
 
         // all model arches should be listed explicitly here