]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
talk-llama : sync llama.cpp
authorGeorgi Gerganov <redacted>
Mon, 29 Sep 2025 12:18:41 +0000 (15:18 +0300)
committerGeorgi Gerganov <redacted>
Mon, 29 Sep 2025 12:18:41 +0000 (15:18 +0300)
21 files changed:
examples/talk-llama/llama-arch.cpp
examples/talk-llama/llama-arch.h
examples/talk-llama/llama-context.cpp
examples/talk-llama/llama-context.h
examples/talk-llama/llama-graph.cpp
examples/talk-llama/llama-graph.h
examples/talk-llama/llama-hparams.h
examples/talk-llama/llama-kv-cache-iswa.cpp
examples/talk-llama/llama-kv-cache-iswa.h
examples/talk-llama/llama-kv-cache.cpp
examples/talk-llama/llama-kv-cache.h
examples/talk-llama/llama-memory-hybrid.cpp
examples/talk-llama/llama-memory-hybrid.h
examples/talk-llama/llama-memory-recurrent.cpp
examples/talk-llama/llama-memory-recurrent.h
examples/talk-llama/llama-memory.h
examples/talk-llama/llama-model.cpp
examples/talk-llama/llama-model.h
examples/talk-llama/llama-vocab.cpp
examples/talk-llama/llama.h
examples/talk-llama/unicode.h

index a4d2973ada5dc9a8289eb236b8a79373312b4f09..4e8d54c4193cc107e82a7e00c40d1fdcc0bcbc9c 100644 (file)
@@ -98,6 +98,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_LLADA,            "llada"            },
     { LLM_ARCH_LLADA_MOE,        "llada-moe"        },
     { LLM_ARCH_SEED_OSS,         "seed_oss"         },
+    { LLM_ARCH_GROVEMOE,         "grovemoe"         },
     { LLM_ARCH_UNKNOWN,          "(unknown)"        },
 };
 
@@ -125,6 +126,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_FEED_FORWARD_LENGTH,               "%s.feed_forward_length"               },
     { LLM_KV_EXPERT_FEED_FORWARD_LENGTH,        "%s.expert_feed_forward_length"        },
     { LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, "%s.expert_shared_feed_forward_length" },
+    { LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH,  "%s.expert_chunk_feed_forward_length"  },
     { LLM_KV_USE_PARALLEL_RESIDUAL,             "%s.use_parallel_residual"             },
     { LLM_KV_TENSOR_DATA_LAYOUT,                "%s.tensor_data_layout"                },
     { LLM_KV_EXPERT_COUNT,                      "%s.expert_count"                      },
@@ -133,6 +135,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
     { LLM_KV_EXPERT_WEIGHTS_SCALE,              "%s.expert_weights_scale"              },
     { LLM_KV_EXPERT_WEIGHTS_NORM,               "%s.expert_weights_norm"               },
     { LLM_KV_EXPERT_GATING_FUNC,                "%s.expert_gating_func"                },
+    { LLM_KV_EXPERT_GROUP_SCALE,                "%s.expert_group_scale"                },
+    { LLM_KV_EXPERTS_PER_GROUP,                 "%s.experts_per_group"                 },
     { LLM_KV_MOE_EVERY_N_LAYERS,                "%s.moe_every_n_layers"                },
     { LLM_KV_NEXTN_PREDICT_LAYERS,              "%s.nextn_predict_layers"              },
     { LLM_KV_POOLING_TYPE,                      "%s.pooling_type"                      },
@@ -721,6 +725,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
             { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
             { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_CLS_OUT,         "cls.output" },
             { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
             { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
             { LLM_TENSOR_ATTN_Q_NORM,     "blk.%d.attn_q_norm" },
@@ -2185,6 +2190,29 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
         },
     },
+    {
+        LLM_ARCH_GROVEMOE,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,         "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,        "output_norm" },
+            { LLM_TENSOR_OUTPUT,             "output" },
+            { LLM_TENSOR_ATTN_NORM,          "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,             "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_Q_NORM,        "blk.%d.attn_q_norm" },
+            { LLM_TENSOR_ATTN_K,             "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_K_NORM,        "blk.%d.attn_k_norm" },
+            { LLM_TENSOR_ATTN_V,             "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,           "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_NORM,           "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE_INP,       "blk.%d.ffn_gate_inp" },
+            { LLM_TENSOR_FFN_GATE_EXPS,      "blk.%d.ffn_gate_exps" },
+            { LLM_TENSOR_FFN_DOWN_EXPS,      "blk.%d.ffn_down_exps" },
+            { LLM_TENSOR_FFN_UP_EXPS,        "blk.%d.ffn_up_exps" },
+            { LLM_TENSOR_FFN_GATE_CHEXPS,    "blk.%d.ffn_gate_chexps" },
+            { LLM_TENSOR_FFN_DOWN_CHEXPS,    "blk.%d.ffn_down_chexps" },
+            { LLM_TENSOR_FFN_UP_CHEXPS,      "blk.%d.ffn_up_chexps" },
+        },
+    },
     {
         LLM_ARCH_UNKNOWN,
         {
@@ -2317,6 +2345,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
     {LLM_TENSOR_FFN_DOWN_EXPS,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
     {LLM_TENSOR_FFN_GATE_EXPS,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
     {LLM_TENSOR_FFN_UP_EXPS,                {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
+    {LLM_TENSOR_FFN_DOWN_CHEXPS,            {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
+    {LLM_TENSOR_FFN_GATE_CHEXPS,            {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
+    {LLM_TENSOR_FFN_UP_CHEXPS,              {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
     {LLM_TENSOR_FFN_EXP_PROBS_B,            {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
     // altup / laurel (gemma 3n)
     {LLM_TENSOR_PER_LAYER_TOKEN_EMBD,       {LLM_TENSOR_LAYER_OUTPUT,    GGML_OP_GET_ROWS}},
index d181ce6784ffb9469dcdf1795ec52705320dcf36..b5c6f3d76a62cba5c4c166ace9dceaa06b47c65b 100644 (file)
@@ -102,6 +102,7 @@ enum llm_arch {
     LLM_ARCH_LLADA,
     LLM_ARCH_LLADA_MOE,
     LLM_ARCH_SEED_OSS,
+    LLM_ARCH_GROVEMOE,
     LLM_ARCH_UNKNOWN,
 };
 
@@ -129,6 +130,7 @@ enum llm_kv {
     LLM_KV_FEED_FORWARD_LENGTH,
     LLM_KV_EXPERT_FEED_FORWARD_LENGTH,
     LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH,
+    LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH,
     LLM_KV_USE_PARALLEL_RESIDUAL,
     LLM_KV_TENSOR_DATA_LAYOUT,
     LLM_KV_EXPERT_COUNT,
@@ -137,6 +139,8 @@ enum llm_kv {
     LLM_KV_EXPERT_WEIGHTS_SCALE,
     LLM_KV_EXPERT_WEIGHTS_NORM,
     LLM_KV_EXPERT_GATING_FUNC,
+    LLM_KV_EXPERT_GROUP_SCALE,
+    LLM_KV_EXPERTS_PER_GROUP,
     LLM_KV_MOE_EVERY_N_LAYERS,
     LLM_KV_NEXTN_PREDICT_LAYERS,
     LLM_KV_POOLING_TYPE,
@@ -301,6 +305,9 @@ enum llm_tensor {
     LLM_TENSOR_FFN_DOWN_SHEXP,
     LLM_TENSOR_FFN_GATE_SHEXP,
     LLM_TENSOR_FFN_UP_SHEXP,
+    LLM_TENSOR_FFN_DOWN_CHEXPS,
+    LLM_TENSOR_FFN_GATE_CHEXPS,
+    LLM_TENSOR_FFN_UP_CHEXPS,
     LLM_TENSOR_FFN_EXP_PROBS_B,
     LLM_TENSOR_ATTN_Q_NORM,
     LLM_TENSOR_ATTN_K_NORM,
index e6f76421cf1319702d476e751db6e45b3ca498ae..d8a8b5e647a8508303f2ff548db816141e0e781a 100644 (file)
@@ -2027,6 +2027,21 @@ void llama_context::perf_reset() {
     n_reused    = 0;
 }
 
+std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> llama_context::memory_breakdown() const {
+    std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> ret;
+    for (const auto & buft_size : model.memory_breakdown()) {
+        ret[buft_size.first].model += buft_size.second;
+    }
+    for (const auto & buft_size : memory->memory_breakdown()) {
+        ret[buft_size.first].context += buft_size.second;
+    }
+    for (const auto & backend_ptr : backends) {
+        ggml_backend_t backend = backend_ptr.get();
+        ret[ggml_backend_sched_get_buffer_type(sched.get(), backend)].compute += ggml_backend_sched_get_buffer_size(sched.get(), backend);
+    }
+    return ret;
+}
+
 //
 // training
 //
@@ -2765,6 +2780,142 @@ void llama_perf_context_reset(llama_context * ctx) {
     ctx->perf_reset();
 }
 
+void llama_memory_breakdown_print(const struct llama_context * ctx) {
+    const std::vector<ggml_backend_dev_t> & devices = ctx->get_model().devices;
+
+    std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> memory_breakdown = ctx->memory_breakdown();
+
+    std::vector<std::array<std::string, 9>> table_data;
+    table_data.reserve(devices.size());
+    const std::string template_header = "%s: | %s | %s   %s    %s   %s   %s   %s    %s |\n";
+    const std::string template_gpu    = "%s: | %s | %s = %s + (%s = %s + %s + %s) + %s |\n";
+    const std::string template_other  = "%s: | %s | %s   %s    %s = %s + %s + %s    %s |\n";
+
+    table_data.push_back({template_header, "memory breakdown [MiB]", "total", "free", "self", "model", "context", "compute", "unaccounted"});
+
+    constexpr size_t MiB = 1024 * 1024;
+    const std::vector<std::string> desc_prefixes_strip = {"NVIDIA ", "GeForce ", "Tesla ", "AMD ", "Radeon ", "Instinct "};
+
+    // track seen buffer types to avoid double counting:
+    std::set<ggml_backend_buffer_type_t> seen_buffer_types;
+
+    // accumulative memory breakdown for each device and for host:
+    std::vector<llama_memory_breakdown_data> mb_dev(devices.size());
+    llama_memory_breakdown_data              mb_host;
+
+    for (const auto & buft_mb : memory_breakdown) {
+        ggml_backend_buffer_type_t          buft = buft_mb.first;
+        const llama_memory_breakdown_data & mb   = buft_mb.second;
+        if (ggml_backend_buft_is_host(buft)) {
+            mb_host.model   += mb.model;
+            mb_host.context += mb.context;
+            mb_host.compute += mb.compute;
+            seen_buffer_types.insert(buft);
+            continue;
+        }
+        ggml_backend_dev_t dev = ggml_backend_buft_get_device(buft);
+        if (dev) {
+            int i_dev = -1;
+            for (size_t i = 0; i < devices.size(); i++) {
+                if (devices[i] == dev) {
+                    i_dev = i;
+                    break;
+                }
+            }
+            if (i_dev != -1) {
+                mb_dev[i_dev].model   += mb.model;
+                mb_dev[i_dev].context += mb.context;
+                mb_dev[i_dev].compute += mb.compute;
+                seen_buffer_types.insert(buft);
+                continue;
+            }
+        }
+    }
+
+    // print memory breakdown for each device:
+    for (size_t i = 0; i < devices.size(); i++) {
+        ggml_backend_dev_t          dev = devices[i];
+        llama_memory_breakdown_data mb  = mb_dev[i];
+
+        const std::string name = ggml_backend_dev_name(dev);
+        std::string desc = ggml_backend_dev_description(dev);
+        for (const std::string & prefix : desc_prefixes_strip) {
+            if (desc.length() >= prefix.length() && desc.substr(0, prefix.length()) == prefix) {
+                desc = desc.substr(prefix.length());
+            }
+        }
+
+        size_t free, total;
+        ggml_backend_dev_memory(dev, &free, &total);
+
+        const size_t self = mb.model + mb.context + mb.compute;
+        const size_t unaccounted = total - self - free;
+
+        table_data.push_back({
+            template_gpu,
+            "  - " + name + " (" + desc + ")",
+            std::to_string(total / MiB),
+            std::to_string(free / MiB),
+            std::to_string(self / MiB),
+            std::to_string(mb.model / MiB),
+            std::to_string(mb.context / MiB),
+            std::to_string(mb.compute / MiB),
+            std::to_string(unaccounted / MiB)});
+    }
+
+    // print memory breakdown for host:
+    {
+        const size_t self = mb_host.model + mb_host.context + mb_host.compute;
+        table_data.push_back({
+            template_other,
+            "  - Host",
+            "", // total
+            "", // free
+            std::to_string(self / MiB),
+            std::to_string(mb_host.model / MiB),
+            std::to_string(mb_host.context / MiB),
+            std::to_string(mb_host.compute / MiB),
+            ""}); // unaccounted
+    }
+
+    // print memory breakdown for all remaining buffer types:
+    for (const auto & buft_mb : memory_breakdown) {
+        ggml_backend_buffer_type_t          buft = buft_mb.first;
+        const llama_memory_breakdown_data & mb   = buft_mb.second;
+        if (seen_buffer_types.count(buft) == 1) {
+            continue;
+        }
+        const std::string name = ggml_backend_buft_name(buft);
+        const size_t self = mb.model + mb.context + mb.compute;
+        table_data.push_back({
+            template_other,
+            "  - " + name,
+            "", // total
+            "", // free
+            std::to_string(self / MiB),
+            std::to_string(mb.model / MiB),
+            std::to_string(mb.context / MiB),
+            std::to_string(mb.compute / MiB),
+            ""}); // unaccounted
+        seen_buffer_types.insert(buft);
+    }
+
+    for (size_t j = 1; j < table_data[0].size(); j++) {
+        size_t max_len = 0;
+        for (const auto & td : table_data) {
+            max_len = std::max(max_len, td[j].length());
+        }
+        for (auto & td : table_data) {
+            td[j].insert(j == 1 ? td[j].length() : 0, max_len - td[j].length(), ' ');
+        }
+    }
+    for (const auto & td : table_data) {
+        LLAMA_LOG_INFO(td[0].c_str(),
+            __func__, td[1].c_str(), td[2].c_str(), td[3].c_str(), td[4].c_str(), td[5].c_str(),
+            td[6].c_str(), td[7].c_str(), td[8].c_str());
+    }
+}
+
 //
 // training
 //
index f23aa8ee1368dae4b40218537385b4a4e56851ec..ed6d82cb396f99542abcadc339fd7536df7702c2 100644 (file)
@@ -17,9 +17,17 @@ class llama_batch_allocr;
 class llama_io_read_i;
 class llama_io_write_i;
 
+// "memory" as in abstract memory for the context
 struct llama_memory_i;
 struct llama_memory_context_i;
 
+// "memory" as in physical memory for a buffer type, in bytes
+struct llama_memory_breakdown_data {
+    size_t model   = 0; // memory allocated for the model
+    size_t context = 0; // memory allocated for the context
+    size_t compute = 0; // memory allocated for temporary compute buffers
+};
+
 struct llama_context {
     // init scheduler and compute buffers, reserve worst-case graphs
     llama_context(
@@ -144,6 +152,8 @@ struct llama_context {
     llama_perf_context_data perf_get_data() const;
     void perf_reset();
 
+    std::map<ggml_backend_buffer_type_t, llama_memory_breakdown_data> memory_breakdown() const;
+
     //
     // training
     //
index 9f2e417f1ff4b19be80e6371ff2048b7bad29c7c..90cd885a60a4f6801d61c7abf4ad610ead4f81a7 100644 (file)
@@ -204,7 +204,10 @@ void llm_graph_input_cls::set_input(const llama_ubatch * ubatch) {
         std::vector<int> target_pos(n_seqs_unq, -1);
         std::vector<int> target_row(n_seqs_unq, -1);
 
-        bool last = cparams.pooling_type == LLAMA_POOLING_TYPE_LAST;
+        const bool last = (
+             cparams.pooling_type == LLAMA_POOLING_TYPE_LAST ||
+            (cparams.pooling_type == LLAMA_POOLING_TYPE_RANK && arch == LLM_ARCH_QWEN3) // qwen3 reranking & embedding models use last token
+        );
 
         for (int i = 0; i < n_tokens; ++i) {
             const llama_pos pos = ubatch->pos[i];
@@ -920,15 +923,29 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
         selection_probs = logits;
     }
 
+    if (arch == LLM_ARCH_GROVEMOE) {
+        selection_probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
+        cb(selection_probs, "ffn_moe_probs_biased", il);
+    }
+
     // select experts
     ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
     cb(selected_experts->src[0], "ffn_moe_argsort", il);
     cb(selected_experts, "ffn_moe_topk", il);
 
-    ggml_tensor * weights = ggml_get_rows(ctx0,
-            ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
+    if (arch == LLM_ARCH_GROVEMOE && n_expert != hparams.n_expert) {
+        // TODO: Use scalar div instead when/if implemented
+        ggml_tensor * f_sel = ggml_cast(ctx0, selected_experts, GGML_TYPE_F32);
+        selected_experts = ggml_cast(ctx0, ggml_scale(ctx0, f_sel, 1.0f / float(hparams.n_group_experts)), GGML_TYPE_I32);
+        probs = ggml_reshape_3d(ctx0, probs, 1, hparams.n_expert, n_tokens);
+    } else {
+        probs = ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens);
+    }
+
+    ggml_tensor * weights = ggml_get_rows(ctx0, probs, selected_experts); // [1, n_expert_used, n_tokens]
     cb(weights, "ffn_moe_weights", il);
 
+
     if (gating_op == LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX_WEIGHT) {
         weights = ggml_reshape_2d(ctx0, weights, n_expert_used, n_tokens);
         weights = ggml_soft_max(ctx0, weights); // [n_expert_used, n_tokens]
@@ -952,6 +969,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
         cb(weights, "ffn_moe_weights_scaled", il);
     }
 
+    //call early so that topk-moe can be used
+    ggml_build_forward_expand(gf, weights);
+
     cur = ggml_reshape_3d(ctx0, cur, n_embd, 1, n_tokens);
 
     if (weight_before_ffn) {
@@ -1177,7 +1197,7 @@ ggml_tensor * llm_graph_context::build_inp_mean() const {
 }
 
 ggml_tensor * llm_graph_context::build_inp_cls() const {
-    auto inp = std::make_unique<llm_graph_input_cls>(cparams);
+    auto inp = std::make_unique<llm_graph_input_cls>(cparams, arch);
 
     auto & cur = inp->cls;
 
@@ -1877,34 +1897,32 @@ void llm_graph_context::build_pooling(
         case LLAMA_POOLING_TYPE_RANK:
             {
                 ggml_tensor * inp_cls = build_inp_cls();
-                inp = ggml_get_rows(ctx0, inp, inp_cls);
+                cur = ggml_get_rows(ctx0, inp, inp_cls);
 
+                // classification head
+                // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
                 if (cls) {
-                    // classification head
-                    // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566
-                    cur = ggml_mul_mat(ctx0, cls, inp);
+                    cur = ggml_mul_mat(ctx0, cls, cur);
                     if (cls_b) {
                         cur = ggml_add(ctx0, cur, cls_b);
                     }
                     cur = ggml_tanh(ctx0, cur);
+                }
 
-                    // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
-                    // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
-                    if (cls_out) {
-                        cur = ggml_mul_mat(ctx0, cls_out, cur);
-                        if (cls_out_b) {
-                            cur = ggml_add(ctx0, cur, cls_out_b);
-                        }
-                    }
-                } else if (cls_out) {
-                    // Single layer classification head (direct projection)
-                    // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
-                    cur = ggml_mul_mat(ctx0, cls_out, inp);
+                // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en
+                // https://huggingface.co/jinaai/jina-reranker-v1-tiny-en/blob/cb5347e43979c3084a890e3f99491952603ae1b7/modeling_bert.py#L884-L896
+                // Single layer classification head (direct projection)
+                // https://github.com/huggingface/transformers/blob/f4fc42216cd56ab6b68270bf80d811614d8d59e4/src/transformers/models/bert/modeling_bert.py#L1476
+                if (cls_out) {
+                    cur = ggml_mul_mat(ctx0, cls_out, cur);
                     if (cls_out_b) {
                         cur = ggml_add(ctx0, cur, cls_out_b);
                     }
-                } else {
-                    GGML_ABORT("RANK pooling requires either cls+cls_b or cls_out+cls_out_b");
+                }
+
+                // softmax for qwen3 reranker
+                if (arch == LLM_ARCH_QWEN3) {
+                    cur = ggml_soft_max(ctx0, cur);
                 }
             } break;
         default:
index ca90fdf613f6de0074e465ca414e7b80c070e85a..34b984afeb04379e1e7b5aa4c792931fd3b03b9e 100644 (file)
@@ -206,7 +206,7 @@ public:
 
 class llm_graph_input_cls : public llm_graph_input_i {
 public:
-    llm_graph_input_cls(const llama_cparams & cparams) : cparams(cparams) {}
+    llm_graph_input_cls(const llama_cparams & cparams, const llm_arch arch) : cparams(cparams), arch(arch) {}
     virtual ~llm_graph_input_cls() = default;
 
     void set_input(const llama_ubatch * ubatch) override;
@@ -214,6 +214,7 @@ public:
     ggml_tensor * cls; // I32 [n_batch]
 
     const llama_cparams cparams;
+    const llm_arch arch;
 };
 
 class llm_graph_input_rs : public llm_graph_input_i {
index 202cbbd1b288423d80bc1bfd2399fcc7160e46aa..0fe4b569424056425ad79b9307afec868ea24706 100644 (file)
@@ -69,10 +69,13 @@ struct llama_hparams {
     uint32_t n_lora_kv          = 0;
     uint32_t n_ff_exp           = 0;
     uint32_t n_ff_shexp         = 0;
+    uint32_t n_ff_chexp         = 0;
     uint32_t n_expert_shared    = 0;
     uint32_t n_norm_groups      = 0;
+    uint32_t n_group_experts    = 0;
 
-    float    expert_weights_scale = 0.0;
+    float    expert_group_scale   = 0.05f;
+    float    expert_weights_scale = 0.0f;
     bool     expert_weights_norm  = false;
     uint32_t expert_gating_func   = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE;
     uint32_t moe_every_n_layers   = 0;
index d7342914c6b7cbf17eeb8cae258b12413e16d6e7..827302e6d25bd486562af65c345b21cbeae6ff5a 100644 (file)
@@ -113,6 +113,14 @@ llama_pos llama_kv_cache_iswa::seq_pos_max(llama_seq_id seq_id) const {
     return kv_swa->seq_pos_max(seq_id);
 }
 
+std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache_iswa::memory_breakdown() const {
+    std::map<ggml_backend_buffer_type_t, size_t> mb = kv_base->memory_breakdown();
+    for (const auto & buft_size : kv_swa->memory_breakdown()) {
+        mb[buft_size.first] += buft_size.second;
+    }
+    return mb;
+}
+
 llama_memory_context_ptr llama_kv_cache_iswa::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
     GGML_UNUSED(embd_all);
 
index 5ed134b7958005ab3ccc185d93ba6c505a96f961..70ab22f0d60869673b5c8ad110d0906d3b5d4983 100644 (file)
@@ -56,6 +56,8 @@ public:
     llama_pos seq_pos_min(llama_seq_id seq_id) const override;
     llama_pos seq_pos_max(llama_seq_id seq_id) const override;
 
+    std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
+
     // state write/load
 
     void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
index 885be072a75c8536c02e45521ecb840ebb1e9b6b..816f2d5de592b9949cfab53f071e0acae80ace20 100644 (file)
@@ -473,6 +473,14 @@ llama_pos llama_kv_cache::seq_pos_max(llama_seq_id seq_id) const {
     return cells.seq_pos_max(seq_id);
 }
 
+std::map<ggml_backend_buffer_type_t, size_t> llama_kv_cache::memory_breakdown() const {
+    std::map<ggml_backend_buffer_type_t, size_t> ret;
+    for (const ggml_backend_buffer_ptr & buf_ptr : bufs) {
+        ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get());
+    }
+    return ret;
+}
+
 llama_memory_context_ptr llama_kv_cache::init_batch(
             llama_batch_allocr & balloc,
             uint32_t n_ubatch,
index 30de013f5f7f36b0901716ae63cc10e7bbd907eb..85f0663d8c1d4247b9712100372bf8701cf22df4 100644 (file)
@@ -121,6 +121,8 @@ public:
     llama_pos seq_pos_min(llama_seq_id seq_id) const override;
     llama_pos seq_pos_max(llama_seq_id seq_id) const override;
 
+    std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
+
     // state write/load
 
     void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
index ba61ebaa885feffc62ac012563736f59fa4325eb..abf652483c202971a9e3a36d0d200313bf59bc49 100644 (file)
@@ -166,6 +166,14 @@ llama_pos llama_memory_hybrid::seq_pos_max(llama_seq_id seq_id) const {
     return std::min(mem_attn->seq_pos_max(seq_id), mem_recr->seq_pos_max(seq_id));
 }
 
+std::map<ggml_backend_buffer_type_t, size_t> llama_memory_hybrid::memory_breakdown() const {
+    std::map<ggml_backend_buffer_type_t, size_t> mb = mem_attn->memory_breakdown();
+    for (const auto & buft_size : mem_recr->memory_breakdown()) {
+        mb[buft_size.first] += buft_size.second;
+    }
+    return mb;
+}
+
 void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const {
     GGML_UNUSED(flags);
 
index 11a35651782974023d48be9ddba531c0c79e0a26..558cafdf984c9b6f5b75ee743d4c718716ef44a5 100644 (file)
@@ -68,6 +68,8 @@ public:
     llama_pos seq_pos_min(llama_seq_id seq_id) const override;
     llama_pos seq_pos_max(llama_seq_id seq_id) const override;
 
+    std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
+
     // state write/load
 
     void state_write(llama_io_write_i & io, llama_seq_id seq_id = -1, llama_state_seq_flags flags = 0) const override;
index 08716ed91aed124fcb71c70fec4be6d2ec8e94c7..44645fcdd2d4824edd704e02c57d716db84cbd4d 100644 (file)
@@ -359,6 +359,14 @@ llama_pos llama_memory_recurrent::seq_pos_max(llama_seq_id seq_id) const {
     return result;
 }
 
+std::map<ggml_backend_buffer_type_t, size_t> llama_memory_recurrent::memory_breakdown() const {
+    std::map<ggml_backend_buffer_type_t, size_t> ret;
+    for (const ggml_backend_buffer_ptr & buf_ptr : bufs) {
+        ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get());
+    }
+    return ret;
+}
+
 llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & balloc, uint32_t n_ubatch, bool embd_all) {
     do {
         balloc.split_reset();
index c4daf00495bc2a43192914e4b19dbdf3079708af..077c6e3ce938da5f6dfd8d7d63f842dec2989fca 100644 (file)
@@ -4,6 +4,7 @@
 #include "llama-graph.h"
 #include "llama-memory.h"
 
+#include <map>
 #include <set>
 #include <vector>
 
@@ -50,6 +51,8 @@ public:
     llama_pos seq_pos_min(llama_seq_id seq_id) const override;
     llama_pos seq_pos_max(llama_seq_id seq_id) const override;
 
+    std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const override;
+
     bool prepare(const std::vector<llama_ubatch> & ubatches);
 
     // find a contiguous slot of memory cells and emplace the ubatch there
index ccd1f073b0848c28920dfcd909c753e5482f4c3a..4a157b91fdbdcf6de401032af58a78fe3a8a5c8a 100644 (file)
@@ -2,6 +2,7 @@
 
 #include "llama.h"
 
+#include <map>
 #include <memory>
 #include <functional>
 
@@ -108,6 +109,8 @@ struct llama_memory_i {
     virtual llama_pos seq_pos_min(llama_seq_id seq_id) const = 0;
     virtual llama_pos seq_pos_max(llama_seq_id seq_id) const = 0;
 
+    virtual std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const = 0;
+
     //
     // state write/read
     //
index 981e57083c48d90610fd413496e90e14d858ad7f..ffd9286ef8912f9609ca82d20bac483d5e5fde8d 100644 (file)
@@ -66,6 +66,7 @@ const char * llm_type_name(llm_type type) {
         case LLM_TYPE_1_7B:          return "1.7B";
         case LLM_TYPE_1_8B:          return "1.8B";
         case LLM_TYPE_2B:            return "2B";
+        case LLM_TYPE_2_6B:          return "2.6B";
         case LLM_TYPE_2_8B:          return "2.8B";
         case LLM_TYPE_2_9B:          return "2.9B";
         case LLM_TYPE_3B:            return "3B";
@@ -674,10 +675,17 @@ void llama_model::load_hparams(llama_model_loader & ml) {
             } break;
         case LLM_ARCH_MINICPM:
             {
+                // Backward-compatible defaults for older MiniCPM GGUFs
+                hparams.f_embedding_scale = 12.0f;
+                hparams.f_residual_scale  = 1.4f / sqrtf(float(hparams.n_layer));
+                hparams.f_logit_scale     = hparams.n_embd ? (256.0f / float(hparams.n_embd)) : 1.0f;
+
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
-                ml.get_key(LLM_KV_EMBEDDING_SCALE,             hparams.f_embedding_scale);
-                ml.get_key(LLM_KV_RESIDUAL_SCALE,              hparams.f_residual_scale);
-                ml.get_key(LLM_KV_LOGIT_SCALE,                 hparams.f_logit_scale);
+
+                // Optional KV reads, override defaults if present in newer GGUF exports
+                ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, /*required=*/false);
+                ml.get_key(LLM_KV_RESIDUAL_SCALE, hparams.f_residual_scale, /*required=*/false);
+                ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, /*required=*/false);
 
                 // MiniCPM uses rope by default, unlike Granite which uses it as a switch
                 hparams.rope_finetuned = true;
@@ -1977,10 +1985,11 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                 for (uint32_t il = 0; il < hparams.n_layer; ++il) {
                     hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0;
                 }
-                switch (hparams.n_embd) {
-                    case 1024: type = LLM_TYPE_350M; break;
-                    case 1536: type = LLM_TYPE_700M; break;
-                    case 2048: type = LLM_TYPE_1_2B; break;
+                switch (hparams.n_ff()) {
+                    case  4608: type = LLM_TYPE_350M; break;
+                    case  6912: type = LLM_TYPE_700M; break;
+                    case  8192: type = LLM_TYPE_1_2B; break;
+                    case 10752: type = LLM_TYPE_2_6B; break;
                     default:   type = LLM_TYPE_UNKNOWN;
                 }
             } break;
@@ -2007,6 +2016,19 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     default: type = LLM_TYPE_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_GROVEMOE:
+            {
+                ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH,        hparams.n_ff_exp);
+                ml.get_key(LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH,  hparams.n_ff_chexp);
+                ml.get_key(LLM_KV_EXPERT_GROUP_SCALE,                hparams.expert_group_scale);
+                ml.get_key(LLM_KV_EXPERTS_PER_GROUP,                 hparams.n_group_experts);
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS,       hparams.f_norm_rms_eps);
+
+                switch (hparams.n_layer) {
+                    case 48: type = LLM_TYPE_30B_A3B; break;
+                    default: type = LLM_TYPE_UNKNOWN;
+                }
+            } break;
         default: throw std::runtime_error("unsupported model architecture");
     }
 
@@ -3165,6 +3187,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
                     }
 
+                    // output rerank head
+                    cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED);
+
                     for (int i = 0; i < n_layer; ++i) {
                         auto & layer = layers[i];
 
@@ -5835,6 +5860,53 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
                     }
                 } break;
+            case LLM_ARCH_GROVEMOE:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
+                    output      = create_tensor(tn(LLM_TENSOR_OUTPUT,      "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
+                    }
+
+                    GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for GROVEMOE");
+                    GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for GROVEMOE");
+                    GGML_ASSERT(hparams.n_group_experts > 0 && "n_group_experts must be > 0 for GROVEMOE");
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q,   "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V,   "weight", i), {n_embd, n_embd_gqa}, 0);
+                        layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
+
+                        layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
+                        layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
+
+                        layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
+
+                        // MoE branch
+                        const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
+                        const int64_t n_ff_chexp = hparams.n_ff_chexp ? hparams.n_ff_chexp : n_embd_head_k;
+                        const int64_t n_chunk_expert = n_expert / hparams.n_group_experts;
+
+                        layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
+                        layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp,   n_embd, n_expert}, 0);
+                        layer.ffn_up_exps   = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS,   "weight", i), {  n_embd, n_ff_exp, n_expert}, 0);
+
+                        layer.ffn_gate_chexps = create_tensor(tn(LLM_TENSOR_FFN_GATE_CHEXPS, "weight", i), {  n_embd, n_ff_chexp, n_chunk_expert}, 0);
+                        layer.ffn_down_chexps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_CHEXPS, "weight", i), {n_ff_chexp,   n_embd, n_chunk_expert}, 0);
+                        layer.ffn_up_chexps   = create_tensor(tn(LLM_TENSOR_FFN_UP_CHEXPS,   "weight", i), {  n_embd, n_ff_chexp, n_chunk_expert}, 0);
+                    }
+                } break;
             default:
                 throw std::runtime_error("unknown architecture");
         }
@@ -6003,6 +6075,14 @@ size_t llama_model::n_devices() const {
     return devices.size();
 }
 
+std::map<ggml_backend_buffer_type_t, size_t> llama_model::memory_breakdown() const {
+    std::map<ggml_backend_buffer_type_t, size_t> ret;
+    for (const ggml_backend_buffer_ptr & buf_ptr : pimpl->bufs) {
+        ret[ggml_backend_buffer_get_type(buf_ptr.get())] += ggml_backend_buffer_get_size(buf_ptr.get());
+    }
+    return ret;
+}
+
 uint64_t llama_model::n_elements() const {
     return pimpl->n_elements;
 }
@@ -6166,6 +6246,13 @@ void llama_model::print_info() const {
         LLAMA_LOG_INFO("%s: expert_gating_func   = %s\n",     __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
     }
 
+    if (arch == LLM_ARCH_GROVEMOE) {
+        LLAMA_LOG_INFO("%s: n_ff_exp             = %d\n",     __func__, hparams.n_ff_exp);
+        LLAMA_LOG_INFO("%s: n_ff_chexp           = %d\n",     __func__, hparams.n_ff_chexp);
+        LLAMA_LOG_INFO("%s: n_group_experts      = %d\n",     __func__, hparams.n_group_experts);
+        LLAMA_LOG_INFO("%s: expert_group_scale   = %.2f\n",   __func__, hparams.expert_group_scale);
+    }
+
     vocab.print_info();
 }
 
@@ -18851,6 +18938,156 @@ struct llm_build_smallthinker : public llm_graph_context{
     }
 };
 
+struct llm_build_grovemoe : public llm_graph_context {
+    llm_build_grovemoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+        const int64_t n_chunk_expert = n_expert / hparams.n_group_experts;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_kv();
+
+        ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL,
+                    model.layers[il].attn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self_attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+                Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
+                cb(Qcur, "Qcur_normed", il);
+
+                Qcur = ggml_rope_ext(
+                        ctx0, Qcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+
+                Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
+                cb(Kcur, "Kcur_normed", il);
+
+                Kcur = ggml_rope_ext(
+                        ctx0, Kcur, inp_pos, nullptr,
+                        n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                        ext_factor, attn_factor, beta_fast, beta_slow
+                        );
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                cur = build_attn(inp_attn,
+                        model.layers[il].wo, model.layers[il].bo,
+                        Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1 && inp_out_ids) {
+                cur   = ggml_get_rows(ctx0,   cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // MoE branch
+            cur = build_norm(ffn_inp,
+                    model.layers[il].ffn_norm, NULL,
+                    LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            ggml_tensor * probs = build_lora_mm(model.layers[il].ffn_gate_inp, cur); // [n_expert, n_tokens]
+            cb(probs, "ffn_moe_logits", il);
+
+            ggml_tensor * moe_out =
+                build_moe_ffn(cur,
+                        nullptr,
+                        model.layers[il].ffn_up_exps,
+                        model.layers[il].ffn_gate_exps,
+                        model.layers[il].ffn_down_exps,
+                        nullptr,
+                        n_expert, n_expert_used,
+                        LLM_FFN_SILU, true,
+                        false, 0.0,
+                        LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
+                        il, probs);
+            cb(moe_out, "ffn_moe_out", il);
+            cur = moe_out;
+
+            // TODO: Only do the expert selection and weights once
+            moe_out =
+                build_moe_ffn(cur,
+                        nullptr,
+                        model.layers[il].ffn_up_chexps,
+                        model.layers[il].ffn_gate_chexps,
+                        model.layers[il].ffn_down_chexps,
+                        nullptr,
+                        n_chunk_expert, n_expert_used > n_chunk_expert ? n_chunk_expert : n_expert_used,
+                        LLM_FFN_SILU, true,
+                        false, 0.0,
+                        LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
+                        il, probs);
+            cb(moe_out, "ffn_adj_moe_out", il);
+
+            cur = ggml_add(ctx0, cur, ggml_scale(ctx0, moe_out, hparams.expert_group_scale));
+            cb(cur, "ffn_final_moe_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur,
+                model.output_norm, NULL,
+                LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
 llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
     llama_memory_i * res;
 
@@ -19377,6 +19614,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
                     llm = std::make_unique<llm_build_smallthinker<false>>(*this, params);
                 }
             } break;
+        case LLM_ARCH_GROVEMOE:
+            {
+                llm = std::make_unique<llm_build_grovemoe>(*this, params);
+            } break;
         default:
             GGML_ABORT("fatal error");
     }
@@ -19582,6 +19823,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_SMALLTHINKER:
         case LLM_ARCH_GLM4_MOE:
         case LLM_ARCH_SEED_OSS:
+        case LLM_ARCH_GROVEMOE:
             return LLAMA_ROPE_TYPE_NEOX;
 
         case LLM_ARCH_QWEN2VL:
index b1981978e3acf8753a96dceff561b1f6e5060cdf..d73ce9693230f8662ac42a3ff9e6bfa369afc9cb 100644 (file)
@@ -7,6 +7,7 @@
 #include "llama-memory.h"
 #include "llama-vocab.h"
 
+#include <map>
 #include <memory>
 #include <string>
 #include <unordered_map>
@@ -58,6 +59,7 @@ enum llm_type {
     LLM_TYPE_1_7B,
     LLM_TYPE_1_8B,
     LLM_TYPE_2B,
+    LLM_TYPE_2_6B,
     LLM_TYPE_2_8B,
     LLM_TYPE_2_9B,
     LLM_TYPE_3B,
@@ -273,6 +275,11 @@ struct llama_layer {
     struct ggml_tensor * ffn_down_shexp     = nullptr;
     struct ggml_tensor * ffn_up_shexp       = nullptr;
 
+    // ff adjugate experts (chexps)
+    struct ggml_tensor * ffn_gate_chexps     = nullptr;
+    struct ggml_tensor * ffn_down_chexps     = nullptr;
+    struct ggml_tensor * ffn_up_chexps       = nullptr;
+
     // ff bias
     struct ggml_tensor * ffn_gate_b = nullptr;
     struct ggml_tensor * ffn_down_b = nullptr; // b2
@@ -452,10 +459,12 @@ struct llama_model {
 
     std::string desc() const;
 
-    size_t size() const;
+    size_t size() const; // file size
     size_t n_tensors() const;
     size_t n_devices() const;
 
+    std::map<ggml_backend_buffer_type_t, size_t> memory_breakdown() const;
+
     // total number of parameters in the model
     uint64_t n_elements() const;
 
index 8cb36661a0c968d15a1d6df1a262087b0803d370..da938af03bf080828e23e2aaa1dc4dbb3fd64158 100644 (file)
@@ -1772,7 +1772,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
                 const size_t n_precompiled_charsmap = gguf_get_arr_n(ctx, precompiled_charsmap_keyidx);
                 const char * pc = (const char *) gguf_get_arr_data(ctx, precompiled_charsmap_keyidx);
                 precompiled_charsmap.assign(pc, pc + n_precompiled_charsmap);
-#ifdef IS_BIG_ENDIAN
+#if defined(__BYTE_ORDER__) && defined(__ORDER_BIG_ENDIAN__) && __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
                 // correct endiannes of data in precompiled_charsmap binary blob
                 uint32_t * xcda_blob_size = (uint32_t *) &precompiled_charsmap[0];
                 *xcda_blob_size = __builtin_bswap32(*xcda_blob_size);
index 453190e852b51e87502e813c6d4f91d85cd1f888..452d9ec5bf285425b8935bf209a6262dbd0cbade 100644 (file)
@@ -1329,24 +1329,25 @@ extern "C" {
     //
     // Performance utils
     //
-    // NOTE: Used by llama.cpp examples, avoid using in third-party apps. Instead, do your own performance measurements.
+    // NOTE: Used by llama.cpp examples/tools, avoid using in third-party apps. Instead, do your own performance measurements.
     //
 
     struct llama_perf_context_data {
-        double t_start_ms;
-        double t_load_ms;
-        double t_p_eval_ms;
-        double t_eval_ms;
-
-        int32_t n_p_eval;
-        int32_t n_eval;
-        int32_t n_reused; // number of times a ggml compute graph had been reused
+        // ms == milliseconds
+        double t_start_ms;  // absolute start time
+        double t_load_ms;   // time needed for loading the model
+        double t_p_eval_ms; // time needed for processing the prompt
+        double t_eval_ms;   // time needed for generating tokens
+
+        int32_t n_p_eval;   // number of prompt tokens
+        int32_t n_eval;     // number of generated tokens
+        int32_t n_reused;   // number of times a ggml compute graph had been reused
     };
 
     struct llama_perf_sampler_data {
-        double t_sample_ms;
+        double t_sample_ms; // time needed for sampling in ms
 
-        int32_t n_sample;
+        int32_t n_sample;   // number of sampled tokens
     };
 
     LLAMA_API struct llama_perf_context_data llama_perf_context      (const struct llama_context * ctx);
@@ -1358,6 +1359,9 @@ extern "C" {
     LLAMA_API void                           llama_perf_sampler_print(const struct llama_sampler * chain);
     LLAMA_API void                           llama_perf_sampler_reset(      struct llama_sampler * chain);
 
+    // print a breakdown of per-device memory use via LLAMA_LOG:
+    LLAMA_API void llama_memory_breakdown_print(const struct llama_context * ctx);
+
     //
     // training
     //
index 0a5fa2a78ceff3d98031110d45d495b17f56626e..5bd1362ff41bf76c36dff72973533ac89d14b6d7 100644 (file)
@@ -4,6 +4,7 @@
 #include <string>
 #include <vector>
 
+// TODO: reimplement this structure in endian-independent way
 struct unicode_cpt_flags {
     enum {
         UNDEFINED       = 0x0001,
@@ -15,6 +16,10 @@ struct unicode_cpt_flags {
         SYMBOL          = 0x0040,  // regex: \p{S}
         CONTROL         = 0x0080,  // regex: \p{C}
         MASK_CATEGORIES = 0x00FF,
+        WHITESPACE      = 0x0100,
+        LOWERCASE       = 0x0200,
+        UPPERCASE       = 0x0400,
+        NFD             = 0x0800,
     };
 
     // codepoint type
@@ -34,11 +39,49 @@ struct unicode_cpt_flags {
 
     // decode from uint16
     inline unicode_cpt_flags(const uint16_t flags = 0) {
+#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
         *reinterpret_cast<uint16_t*>(this) = flags;
+#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
+        is_undefined   = (flags & UNDEFINED)   ? 1 : 0;
+        is_number      = (flags & NUMBER)      ? 1 : 0;
+        is_letter      = (flags & LETTER)      ? 1 : 0;
+        is_separator   = (flags & SEPARATOR)   ? 1 : 0;
+        is_accent_mark = (flags & ACCENT_MARK) ? 1 : 0;
+        is_punctuation = (flags & PUNCTUATION) ? 1 : 0;
+        is_symbol      = (flags & SYMBOL)      ? 1 : 0;
+        is_control     = (flags & CONTROL)     ? 1 : 0;
+        is_whitespace  = (flags & WHITESPACE)  ? 1 : 0;
+        is_lowercase   = (flags & LOWERCASE)   ? 1 : 0;
+        is_uppercase   = (flags & UPPERCASE)   ? 1 : 0;
+        is_nfd         = (flags & NFD)         ? 1 : 0;
+#else
+#error Unexpected or undefined __BYTE_ORDER__
+#endif
     }
 
     inline uint16_t as_uint() const {
+#if __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__
         return *reinterpret_cast<const uint16_t*>(this);
+#elif __BYTE_ORDER__ == __ORDER_BIG_ENDIAN__
+        uint16_t result =
+              is_undefined   * UNDEFINED
+            + is_number      * NUMBER
+            + is_letter      * LETTER
+            + is_separator   * SEPARATOR
+            + is_accent_mark * ACCENT_MARK
+            + is_punctuation * PUNCTUATION
+            + is_symbol      * SYMBOL
+            + is_control     * CONTROL
+            + is_whitespace  * WHITESPACE
+            + is_lowercase   * LOWERCASE
+            + is_uppercase   * UPPERCASE
+            + is_nfd         * NFD
+            ;
+
+        return result;
+#else
+#error Unexpected or undefined __BYTE_ORDER__
+#endif
     }
 
     inline uint16_t category_flag() const {