]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
model : Add support for Arcee AI's upcoming AFM model (#14185)
authorBartowski <redacted>
Sun, 15 Jun 2025 23:04:06 +0000 (00:04 +0100)
committerGitHub <redacted>
Sun, 15 Jun 2025 23:04:06 +0000 (01:04 +0200)
* Add Arcee AFM support

* Add draft update code

* Fix linter and update URL, may still not be final

* Update src/llama-model.cpp

Co-authored-by: Xuan-Son Nguyen <redacted>
* Remote accidental blank line

---------

Co-authored-by: Xuan-Son Nguyen <redacted>
convert_hf_to_gguf.py
convert_hf_to_gguf_update.py
gguf-py/gguf/constants.py
src/llama-arch.cpp
src/llama-arch.h
src/llama-model.cpp
src/llama-vocab.cpp

index cff72c85fab6971f8916366f8ad35bef78a99142..2232a7d82349ea9a334e3199deab22479989330d 100755 (executable)
@@ -2020,6 +2020,20 @@ class LlamaModel(TextModel):
                 raise ValueError(f"Unprocessed experts: {experts}")
 
 
+@ModelBase.register("ArceeForCausalLM")
+class ArceeModel(LlamaModel):
+    model_arch = gguf.MODEL_ARCH.ARCEE
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        self._try_set_pooling_type()
+        rope_scaling = self.hparams.get("rope_scaling") or {}
+        if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
+            self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
+            self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
+            self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
+
+
 @ModelBase.register(
     "LlavaForConditionalGeneration", # pixtral
     "Mistral3ForConditionalGeneration", # mistral small 3.1
index 2f733f0973686f3f7d57508864348a4629243942..fae4f72605f46926dc60a7dc87de343c74ec13f7 100755 (executable)
@@ -128,6 +128,7 @@ models = [
     {"name": "llama4",           "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct", },
     {"name": "pixtral",          "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistral-community/pixtral-12b", },
     {"name": "seed-coder",       "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ByteDance-Seed/Seed-Coder-8B-Base", },
+    {"name": "arcee",            "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/arcee-ai/AFM-4.5B", }, # TODO confirm final URL
 ]
 
 # some models are known to be broken upstream, so we will skip them as exceptions
index 8de2f7a531967b325a4237879f353ab6e0477ad8..9b2143c7c2eaa40e1ed3232056fbce57c0b1ae15 100644 (file)
@@ -344,6 +344,7 @@ class MODEL_ARCH(IntEnum):
     PLM              = auto()
     BAILINGMOE       = auto()
     DOTS1            = auto()
+    ARCEE            = auto()
 
 
 class VISION_PROJECTOR_TYPE(IntEnum):
@@ -624,7 +625,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.WAVTOKENIZER_DEC: "wavtokenizer-dec",
     MODEL_ARCH.PLM:              "plm",
     MODEL_ARCH.BAILINGMOE:       "bailingmoe",
-    MODEL_ARCH.DOTS1:            "dots1"
+    MODEL_ARCH.DOTS1:            "dots1",
+    MODEL_ARCH.ARCEE:            "arcee",
 }
 
 VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@@ -2070,6 +2072,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_UP_EXP,
         MODEL_TENSOR.FFN_UP_SHEXP,
     ],
+    MODEL_ARCH.ARCEE: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ROPE_FREQS,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.ATTN_ROT_EMBD,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+    ],
     # TODO
 }
 
index f8f76eedd4fa687b6c31eddf4f4bf3a5512d9659..a3e7c861ca02f45e36195389f9bbb5fe58c13927 100644 (file)
@@ -73,6 +73,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_PLM,              "plm"              },
     { LLM_ARCH_BAILINGMOE,       "bailingmoe"       },
     { LLM_ARCH_DOTS1,            "dots1"            },
+    { LLM_ARCH_ARCEE,            "arcee"            },
     { LLM_ARCH_UNKNOWN,          "(unknown)"        },
 };
 
@@ -244,6 +245,24 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_FFN_UP_EXPS,     "blk.%d.ffn_up_exps" },
         },
     },
+    {
+        LLM_ARCH_ARCEE,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ROPE_FREQS,      "rope_freqs" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_ATTN_ROT_EMBD,   "blk.%d.attn_rot_embd" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+        },
+    },
     {
         LLM_ARCH_LLAMA4,
         {
index 18f6d6b94f13716119d3f069e28c24306dd1b4a4..168fdcb401cfda294d6a497c06ea2557f4d530d7 100644 (file)
@@ -77,6 +77,7 @@ enum llm_arch {
     LLM_ARCH_PLM,
     LLM_ARCH_BAILINGMOE,
     LLM_ARCH_DOTS1,
+    LLM_ARCH_ARCEE,
     LLM_ARCH_UNKNOWN,
 };
 
index fdd5fefd6e77863d60510d9ffc68497c7122a9b8..dcc8b0be725633f740e3f58d8dbd8b4dc616b605 100644 (file)
@@ -599,6 +599,16 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     hparams.use_kq_norm = false;
                 }
             } break;
+        case LLM_ARCH_ARCEE:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+
+                // Arcee uses the same structure as Llama
+                switch (hparams.n_layer) {
+                    case 36: type = LLM_TYPE_4B; break;
+                    default: type = LLM_TYPE_UNKNOWN;
+                }
+            } break;
         case LLM_ARCH_DECI:
             {
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
@@ -4190,6 +4200,37 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         }
                     }
                 } break;
+            case LLM_ARCH_ARCEE:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_k_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_v_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), {n_rot/2}, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
+
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {  n_ff, n_embd}, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP,   "weight", i), {n_embd,   n_ff}, 0);
+                    }
+                } break;
             default:
                 throw std::runtime_error("unknown architecture");
         }
@@ -13411,6 +13452,141 @@ struct llm_build_dots1 : public llm_graph_context {
     }
 };
 
+struct llm_build_arcee : public llm_graph_context {
+    llm_build_arcee(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv_unified();
+
+        const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // rope freq factors for llama3; may return nullptr for llama2 and other models
+                ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
+
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+                if (model.layers[il].bq) {
+                    Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                    cb(Qcur, "Qcur", il);
+                }
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+                if (model.layers[il].bk) {
+                    Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                    cb(Kcur, "Kcur", il);
+                }
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+                if (model.layers[il].bv) {
+                    Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                    cb(Vcur, "Vcur", il);
+                }
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, rope_factors,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, rope_factors,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                cur = build_attn(inp_attn, gf,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, nullptr, kq_scale, il);
+                cb(cur, "attn_out", il);
+            }
+
+            if (il == n_layer - 1) {
+                // skip computing output for unused tokens
+                ggml_tensor * inp_out_ids = build_inp_out_ids();
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            // ARCEE uses relu^2 instead of silu
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = build_ffn(cur,
+                    model.layers[il].ffn_up,   NULL, NULL,
+                    NULL,                      NULL, NULL,
+                    model.layers[il].ffn_down, NULL, NULL,
+                    NULL,
+                    LLM_FFN_RELU_SQR, LLM_FFN_SEQ, il);
+            cb(cur, "ffn_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+            cb(cur, "ffn_out", il);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
 llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
     llama_memory_i * res;
 
@@ -13753,6 +13929,10 @@ llm_graph_result_ptr llama_model::build_graph(
             {
                 llm = std::make_unique<llm_build_dots1>(*this, params, gf);
             } break;
+        case LLM_ARCH_ARCEE:
+            {
+                llm = std::make_unique<llm_build_arcee>(*this, params, gf);
+            } break;
         default:
             GGML_ABORT("fatal error");
     }
@@ -13902,6 +14082,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_GRANITE_MOE:
         case LLM_ARCH_CHAMELEON:
         case LLM_ARCH_BAILINGMOE:
+        case LLM_ARCH_ARCEE:
             return LLAMA_ROPE_TYPE_NORM;
 
         // the pairs of head values are offset by n_rot/2
index 905d7c4281d9c6eaf01063fa5e332f50c12eb5a8..dd2251ef3cbefa7bc6d88fecab69cec4bc70b7ad 100644 (file)
@@ -1987,6 +1987,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                     || t.first == "<|eom_id|>"
                     || t.first == "<EOT>"
                     || t.first == "_<EOT>"
+                    || t.first == "<|end_of_text|>"
                ) {
                 special_eog_ids.insert(t.second);
                 if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) {