]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Support diffusion models: Add Dream 7B (#14644)
authorAman Gupta <redacted>
Wed, 16 Jul 2025 12:03:51 +0000 (20:03 +0800)
committerGitHub <redacted>
Wed, 16 Jul 2025 12:03:51 +0000 (20:03 +0800)
* Support diffusion models: Add Dream 7B

* Move diffusion to examples

* Move stuff to examples. Add patch to not use kv-cache

* Address review comments

* Make sampling fast

* llama: remove diffusion functions

* Add basic timings + cleanup

* More cleanup

* Review comments: better formating, use LOG instead std::cerr, re-use batch, use ubatch instead of max_length

* fixup!

* Review: move everything to diffusion-cli for now

13 files changed:
common/arg.cpp
common/common.h
convert_hf_to_gguf.py
examples/CMakeLists.txt
examples/diffusion/CMakeLists.txt [new file with mode: 0644]
examples/diffusion/diffusion-cli.cpp [new file with mode: 0644]
gguf-py/gguf/constants.py
include/llama.h
src/llama-arch.cpp
src/llama-arch.h
src/llama-model.cpp
src/llama-vocab.cpp
src/llama-vocab.h

index 56827a65908beccfb084f678a4d68017492ceb27..4c86f58f2cc33f4775193e4898a25879227a8a72 100644 (file)
@@ -3423,5 +3423,34 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         }
     ).set_examples({LLAMA_EXAMPLE_SERVER}));
 
+    // diffusion parameters
+    add_opt(common_arg(
+        { "--diffusion-steps" }, "N",
+        string_format("number of diffusion steps (default: %d)", params.diffusion.steps),
+        [](common_params & params, int value) { params.diffusion.steps = value; }
+    ).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
+    add_opt(common_arg(
+        { "--diffusion-eps" }, "F",
+        string_format("epsilon for timesteps (default: %.6f)", (double) params.diffusion.eps),
+        [](common_params & params, const std::string & value) { params.diffusion.eps = std::stof(value); }
+    ).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
+    add_opt(common_arg(
+        { "--diffusion-algorithm" }, "N",
+        string_format("diffusion algorithm: 0=ORIGIN, 1=MASKGIT_PLUS, 2=TOPK_MARGIN, 3=ENTROPY (default: %d)",
+                      params.diffusion.algorithm),
+        [](common_params & params, int value) { params.diffusion.algorithm = value; }
+    ).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
+    add_opt(common_arg(
+        { "--diffusion-alg-temp" }, "F",
+        string_format("algorithm temperature (default: %.3f)", (double) params.diffusion.alg_temp),
+        [](common_params & params, const std::string & value) { params.diffusion.alg_temp = std::stof(value); }
+    ).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
+    add_opt(common_arg(
+        { "--diffusion-visual" },
+        string_format("enable visual diffusion mode (show progressive generation) (default: %s)",
+                      params.diffusion.visual_mode ? "true" : "false"),
+        [](common_params & params) { params.diffusion.visual_mode = true; }
+    ).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
+
     return ctx_arg;
 }
index 248e82d8732a7e4ecbfa188865394953eb04fd36..e1f272318df7694af1a297708b18914324d6a90a 100644 (file)
@@ -81,6 +81,7 @@ enum llama_example {
     LLAMA_EXAMPLE_LOOKUP,
     LLAMA_EXAMPLE_PARALLEL,
     LLAMA_EXAMPLE_TTS,
+    LLAMA_EXAMPLE_DIFFUSION,
 
     LLAMA_EXAMPLE_COUNT,
 };
@@ -218,6 +219,14 @@ struct common_params_vocoder {
     bool use_guide_tokens = false; // enable guide tokens to improve TTS accuracy            // NOLINT
 };
 
+struct common_params_diffusion {
+    int32_t steps       = 64;     // number of diffusion steps
+    float   eps         = 1e-3f;  // epsilon for timesteps
+    int32_t algorithm   = 0;      // diffusion algorithm (0=ORIGIN, 1=MASKGIT_PLUS, 2=TOPK_MARGIN, 3=ENTROPY)
+    float   alg_temp    = 0.0f;   // algorithm temperature
+    bool    visual_mode = false;  // show progressive diffusion on screen
+};
+
 enum common_reasoning_format {
     COMMON_REASONING_FORMAT_NONE,
     COMMON_REASONING_FORMAT_DEEPSEEK_LEGACY, // Extract thinking tag contents and return as `message.reasoning_content`, or leave inline in <think> tags in stream mode
@@ -269,6 +278,7 @@ struct common_params {
     struct common_params_sampling    sampling;
     struct common_params_speculative speculative;
     struct common_params_vocoder     vocoder;
+    struct common_params_diffusion   diffusion;
 
     struct common_params_model model;
 
index 764163c438a15e41fff5183caaf2f29da733118d..d802524bba4a0ca5fed1d232b4d10027d0cdd6c1 100755 (executable)
@@ -2778,6 +2778,76 @@ class Qwen2Model(TextModel):
         yield from super().modify_tensors(data_torch, name, bid)
 
 
+@ModelBase.register("DreamModel")
+class DreamModel(TextModel):
+    model_arch = gguf.MODEL_ARCH.DREAM
+
+    def get_vocab_base(self) -> tuple[list[str], list[int], str]:
+        tokens: list[str] = []
+        toktypes: list[int] = []
+
+        from transformers import AutoTokenizer
+        tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
+
+        vocab_dict = tokenizer.get_vocab()
+        vocab_size = self.hparams.get("vocab_size", len(vocab_dict))
+        assert max(vocab_dict.values()) < vocab_size
+
+        tokpre = self.get_vocab_base_pre(tokenizer)
+
+        reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab_dict.items()}
+        added_vocab = tokenizer.get_added_vocab()
+
+        for i in range(vocab_size):
+            if i not in reverse_vocab:
+                tokens.append(f"[PAD{i}]")
+                toktypes.append(gguf.TokenType.UNUSED)
+            elif reverse_vocab[i] in added_vocab:
+                tokens.append(reverse_vocab[i])
+                # Check if it's a special token - treat special tokens as CONTROL tokens
+                if hasattr(tokenizer, 'added_tokens_decoder') and i in tokenizer.added_tokens_decoder:
+                    if tokenizer.added_tokens_decoder[i].special:
+                        toktypes.append(gguf.TokenType.CONTROL)
+                    else:
+                        toktypes.append(gguf.TokenType.USER_DEFINED)
+                else:
+                    # Fallback: treat all added vocab as control tokens for special tokens like <|im_start|>
+                    toktypes.append(gguf.TokenType.CONTROL)
+            else:
+                tokens.append(reverse_vocab[i])
+                toktypes.append(gguf.TokenType.NORMAL)
+
+        return tokens, toktypes, tokpre
+
+    def set_vocab(self):
+        try:
+            self._set_vocab_sentencepiece()
+        except FileNotFoundError:
+            self._set_vocab_gpt2()
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        self._try_set_pooling_type()
+
+        # Dream models use non-causal attention for diffusion
+        self.gguf_writer.add_causal_attention(False)
+        # Handle RoPE scaling similar to Qwen2
+        rope_scaling = self.hparams.get("rope_scaling") or {}
+        if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
+            self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
+            self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
+            self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
+
+        # Add Dream-specific parameters
+        mask_token_id = self.hparams.get("mask_token_id")
+        if mask_token_id is not None:
+            self.gguf_writer.add_mask_token_id(mask_token_id)
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        # Dream model tensors should be mapped directly since it's the base model
+        yield from super().modify_tensors(data_torch, name, bid)
+
+
 @ModelBase.register("Ernie4_5_ForCausalLM")
 class Ernie4_5Model(TextModel):
     model_arch = gguf.MODEL_ARCH.ERNIE4_5
index 49e4d2cf8c198b71434b41454d8bbe143e5cc00a..11ff38762b848954e8449e6c5edb601430430fd2 100644 (file)
@@ -33,6 +33,7 @@ else()
     add_subdirectory(speculative-simple)
     add_subdirectory(gen-docs)
     add_subdirectory(training)
+    add_subdirectory(diffusion)
     if (NOT GGML_BACKEND_DL)
         add_subdirectory(convert-llama2c-to-ggml)
         # these examples use the backends directly and cannot be built with dynamic loading
diff --git a/examples/diffusion/CMakeLists.txt b/examples/diffusion/CMakeLists.txt
new file mode 100644 (file)
index 0000000..396549c
--- /dev/null
@@ -0,0 +1,5 @@
+set(TARGET llama-diffusion-cli)
+add_executable(${TARGET} diffusion-cli.cpp)
+install(TARGETS ${TARGET} RUNTIME)
+target_link_libraries(${TARGET} PRIVATE llama common ${CMAKE_THREAD_LIBS_INIT})
+target_compile_features(${TARGET} PRIVATE cxx_std_17)
diff --git a/examples/diffusion/diffusion-cli.cpp b/examples/diffusion/diffusion-cli.cpp
new file mode 100644 (file)
index 0000000..3e11ce1
--- /dev/null
@@ -0,0 +1,507 @@
+#include "arg.h"
+#include "chat.h"
+#include "common.h"
+#include "llama.h"
+#include "log.h"
+
+#include <limits.h>
+#include <string>
+#include <vector>
+#include <algorithm>
+#include <cmath>
+#include <limits>
+#include <random>
+
+typedef bool (*diffusion_step_callback_t)(int32_t step,
+                                          int32_t total_steps,
+                                          const llama_token * tokens,
+                                          int32_t n_tokens,
+                                          void * user_data);
+
+enum diffusion_alg {
+    DIFFUSION_ALG_ORIGIN       = 0,
+    DIFFUSION_ALG_MASKGIT_PLUS = 1,
+    DIFFUSION_ALG_TOPK_MARGIN  = 2,
+    DIFFUSION_ALG_ENTROPY      = 3,
+};
+
+struct diffusion_params {
+    int32_t                   steps;
+    float                     eps;
+    float                     temperature;
+    float                     top_p;
+    int32_t                   top_k;
+    llama_token               mask_token_id;
+    enum diffusion_alg        algorithm;
+    float                     alg_temp;
+    diffusion_step_callback_t step_callback;
+    void *                    step_callback_user_data;
+    int32_t                   seed;
+};
+
+
+static diffusion_params diffusion_default_params() {
+    diffusion_params params        = {};
+    params.steps                   = 64;
+    params.eps                     = 1e-3f;
+    params.temperature             = 0.2f;
+    params.top_p                   = 0.95f;
+    params.top_k                   = 0;
+    params.mask_token_id           = LLAMA_TOKEN_NULL;
+    params.algorithm               = DIFFUSION_ALG_ORIGIN;
+    params.alg_temp                = 0.0f;
+    params.step_callback           = nullptr;
+    params.step_callback_user_data = nullptr;
+    params.seed                    = 0;
+    return params;
+}
+
+static void diffusion_generate(llama_context * ctx,
+                        const llama_token * input_tokens,
+                        llama_token * output_tokens,
+                        int32_t n_input,
+                        int32_t max_length,
+                        struct diffusion_params params,
+                        int32_t & n_generated) {
+
+    n_generated = 0;
+    if (!ctx || !input_tokens || !output_tokens || n_input <= 0 || max_length <= n_input) {
+        return;
+    }
+
+    const llama_model * model = llama_get_model(ctx);
+
+    // Initialize with input and pad with mask tokens
+    std::copy(input_tokens, input_tokens + n_input, output_tokens);
+    std::fill(output_tokens + n_input, output_tokens + max_length, params.mask_token_id);
+
+    std::mt19937 rng(params.seed);
+
+    std::vector<float> timesteps(params.steps + 1);
+    for (int32_t i = 0; i <= params.steps; i++) {
+        timesteps[i] = 1.0f - (float) i / params.steps * (1.0f - params.eps);
+    }
+
+    llama_set_causal_attn(ctx, false);
+
+    int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
+
+    std::vector<llama_token_data> candidates(n_vocab);
+
+    std::vector<llama_token_data> conf_candidates;
+    conf_candidates.reserve(max_length);
+
+    std::vector<int32_t> mask_positions;
+    mask_positions.reserve(max_length);
+
+    struct llama_sampler * sampler = llama_sampler_chain_init(llama_sampler_chain_default_params());
+    if (params.top_k > 0) {
+        llama_sampler_chain_add(sampler, llama_sampler_init_top_k(params.top_k));
+    }
+    if (params.top_p < 1.0f) {
+        llama_sampler_chain_add(sampler, llama_sampler_init_top_p(params.top_p, 1));
+    }
+    if (params.temperature > 0.0f) {
+        llama_sampler_chain_add(sampler, llama_sampler_init_temp(params.temperature));
+    }
+    llama_sampler_chain_add(sampler, llama_sampler_init_dist(params.seed));
+
+    struct llama_sampler * dist_sampler = llama_sampler_init_dist(params.seed);
+
+    llama_batch batch = llama_batch_init(max_length, 0, 1);
+    batch.n_tokens    = max_length;
+
+    int64_t total_sampling_time = 0;
+    int64_t total_time = 0;
+
+    int64_t time_start = ggml_time_us();
+    for (int32_t step = 0; step < params.steps; step++) {
+        if (params.step_callback) {
+            if (!params.step_callback(step, params.steps, output_tokens, max_length, params.step_callback_user_data)) {
+                break;
+            }
+        }
+
+        for (int32_t i = 0; i < max_length; i++) {
+            batch.token[i]     = output_tokens[i];
+            batch.pos[i]       = i;
+            batch.n_seq_id[i]  = 1;
+            batch.seq_id[i][0] = 0;
+            batch.logits[i]    = 1;
+        }
+
+        int ret = llama_decode(ctx, batch);
+        if (ret != 0) {
+            LOG_ERR("%s: failed to decode at step %d, ret = %d\n", __func__, step, ret);
+            break;
+        }
+
+        float * raw_logits = llama_get_logits(ctx);
+        if (!raw_logits) {
+            LOG_ERR("%s: failed to get logits at step %d\n", __func__, step);
+            break;
+        }
+
+        auto get_logits_for_pos = [&](int32_t pos) -> const float * {
+            return pos == 0 ? raw_logits : raw_logits + (pos - 1) * n_vocab;
+        };
+
+        int64_t time_start_sampling = ggml_time_us();
+
+        mask_positions.clear();
+        for (int32_t i = 0; i < max_length; i++) {
+            if (output_tokens[i] == params.mask_token_id) {
+                mask_positions.push_back(i);
+            }
+        }
+
+        if (mask_positions.empty()) {
+            break;
+        }
+
+        float t = timesteps[step];
+        float s = timesteps[step + 1];
+
+        if (params.algorithm == DIFFUSION_ALG_ORIGIN) {
+            float p_transfer = (step < params.steps - 1) ? (1.0f - s / t) : 1.0f;
+
+            for (int32_t pos : mask_positions) {
+                if (std::uniform_real_distribution<float>(0.0f, 1.0f)(rng) < p_transfer) {
+                    const float * pos_logits = get_logits_for_pos(pos);
+                    for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
+                        candidates[token_id].id    = token_id;
+                        candidates[token_id].logit = pos_logits[token_id];
+                        candidates[token_id].p     = 0.0f;
+                    }
+
+                    llama_token_data_array cur_p = {
+                        /* .data       = */ candidates.data(),
+                        /* .size       = */ (size_t) n_vocab,  // Reset size to full vocab
+                        /* .selected   = */ -1,
+                        /* .sorted     = */ false,
+                    };
+
+                    llama_sampler_apply(sampler, &cur_p);
+                    output_tokens[pos] = cur_p.data[cur_p.selected].id;
+                }
+            }
+        } else {
+            std::vector<std::pair<float, int32_t>> confidences;
+            std::vector<llama_token>               sampled_tokens(mask_positions.size());
+
+            for (size_t i = 0; i < mask_positions.size(); i++) {
+                int32_t       pos        = mask_positions[i];
+                const float * pos_logits = get_logits_for_pos(pos);
+
+                for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
+                    candidates[token_id].logit = pos_logits[token_id];
+                    candidates[token_id].p     = 0.0f;
+                    candidates[token_id].id    = token_id;
+                }
+
+                llama_token_data_array cur_p = {
+                    /* .data       = */ candidates.data(),
+                    /* .size       = */ candidates.size(),
+                    /* .selected   = */ -1,
+                    /* .sorted     = */ false,
+                };
+
+                llama_sampler_apply(sampler, &cur_p);
+
+                llama_token sampled_token = cur_p.data[cur_p.selected].id;
+
+                float confidence = 0.0f;
+                if (params.algorithm == DIFFUSION_ALG_ENTROPY) {
+                    const float epsilon = 1e-10f;
+                    for (size_t j = 0; j < cur_p.size; j++) {
+                        float prob = cur_p.data[j].p;
+                        confidence += prob * logf(prob + epsilon);
+                    }
+                } else if (params.algorithm == DIFFUSION_ALG_TOPK_MARGIN) {
+                    confidence = cur_p.data[0].p - cur_p.data[1].p;
+                } else {
+                    confidence = cur_p.data[cur_p.selected].p;
+                }
+
+                sampled_tokens[i] = sampled_token;
+                confidences.emplace_back(confidence, i);
+            }
+
+            int32_t num_transfer =
+                (step < params.steps - 1) ? (int32_t) (mask_positions.size() * (1.0f - s / t)) : mask_positions.size();
+
+            if (num_transfer > 0) {
+                if (params.alg_temp == 0.0f) {
+                    std::partial_sort(confidences.begin(), confidences.begin() + num_transfer, confidences.end(),
+                                      [](const std::pair<float, int32_t> & a, const std::pair<float, int32_t> & b) {
+                                          if (a.first != b.first) {
+                                              return a.first > b.first;
+                                          }
+                                          return a.second < b.second;
+                                      });
+                } else {
+                    conf_candidates.clear();
+
+                    for (int32_t pos = 0; pos < max_length; pos++) {
+                        float conf_logit = -std::numeric_limits<float>::infinity();
+
+                        auto it = std::find(mask_positions.begin(), mask_positions.end(), pos);
+                        if (it != mask_positions.end()) {
+                            size_t mask_idx = std::distance(mask_positions.begin(), it);
+                            conf_logit = confidences[mask_idx].first / params.alg_temp;  // Apply temperature scaling
+                        }
+
+                        conf_candidates.emplace_back(llama_token_data{ pos, conf_logit, 0.0f });
+                    }
+
+                    llama_token_data_array conf_array = {
+                        /* .data       = */ conf_candidates.data(),
+                        /* .size       = */ conf_candidates.size(),
+                        /* .selected   = */ -1,
+                        /* .sorted     = */ false,
+                    };
+
+                    for (int32_t i = 0; i < num_transfer; i++) {
+                        // Apply distribution sampler to get selected index
+                        llama_sampler_apply(dist_sampler, &conf_array);
+                        int selected_idx      = conf_array.selected;
+                        confidences[i].second = conf_candidates[selected_idx].id;
+
+                        conf_candidates[selected_idx].p = 0.0f;
+                        conf_array.selected             = -1;
+                    }
+                }
+
+                if (params.alg_temp == 0.0f) {
+                    // Deterministic - use confidence order
+                    for (int32_t i = 0; i < num_transfer; i++) {
+                        int32_t     mask_idx = confidences[i].second;
+                        int32_t     pos      = mask_positions[mask_idx];
+                        llama_token token    = sampled_tokens[mask_idx];
+                        output_tokens[pos]   = token;
+                    }
+                } else {
+                    for (int32_t i = 0; i < num_transfer; i++) {
+                        int32_t pos = confidences[i].second;
+                        auto    it  = std::find(mask_positions.begin(), mask_positions.end(), pos);
+                        if (it != mask_positions.end()) {
+                            int32_t mask_idx   = std::distance(mask_positions.begin(), it);
+                            output_tokens[pos] = sampled_tokens[mask_idx];
+                        }
+                    }
+                }
+            }
+        }
+        int64_t time_end_sampling = ggml_time_us();
+        total_sampling_time += time_end_sampling - time_start_sampling;
+    }
+    int64_t time_end = ggml_time_us();
+    total_time += time_end - time_start;
+
+    LOG_INF("\ntotal time: %0.2fms, time per step: %0.2fms, sampling time per step: %0.2fms\n",
+            total_time / 1000.0, total_time / 1000.0 / params.steps, total_sampling_time / 1000.0 / params.steps);
+
+
+    llama_batch_free(batch);
+    llama_sampler_free(sampler);
+    llama_sampler_free(dist_sampler);
+
+    n_generated = max_length;
+}
+
+
+
+
+static std::string format_input_text(const std::string & prompt, bool use_chat_template, llama_model * model) {
+    if (!use_chat_template) {
+        return prompt;
+    }
+
+    auto chat_templates = common_chat_templates_init(model, "");
+
+    common_chat_templates_inputs inputs;
+    common_chat_msg              user_msg;
+    user_msg.role                = "user";
+    user_msg.content             = prompt;
+    inputs.add_generation_prompt = true;
+    inputs.messages.push_back(user_msg);
+
+    auto result = common_chat_templates_apply(chat_templates.get(), inputs);
+
+    return result.prompt;
+}
+
+struct callback_data {
+    const common_params_diffusion * diff_params;
+    const llama_vocab *             vocab;
+    int32_t                         n_input;
+};
+
+static bool diffusion_step_callback(int32_t step,
+                                    int32_t total_steps,
+                                    const llama_token * tokens,
+                                    int32_t n_tokens,
+                                    void * user_data) {
+    (void)user_data;
+
+    callback_data * data = static_cast<callback_data *>(user_data);
+
+    auto print_progress_bar = [](int32_t step, int32_t total_steps) {
+        int progress_percent = (step * 100) / total_steps;
+        int progress_bars    = (step * 50) / total_steps;
+        LOG_INF("\rdiffusion step: %d/%d [%s%s] %d%%",
+            step,
+            total_steps,
+            std::string(progress_bars, '=').c_str(),
+            std::string(50 - progress_bars, ' ').c_str(),
+            progress_percent);
+    };
+
+    if (data->diff_params->visual_mode) {
+        // Visual mode: clear
+        LOG_INF("\033[2J\033[H");  // Clear screen and move cursor to top-left
+
+        print_progress_bar(step, total_steps);
+
+        LOG_INF("\n");
+
+        std::string current_text = " ";
+
+        for (int32_t i = data->n_input; i < n_tokens; i++) {
+            std::string token_str;
+            if (tokens[i] != llama_vocab_mask(data->vocab)) {
+                char piece[256];
+                int  n_chars = llama_token_to_piece(data->vocab, tokens[i], piece, sizeof(piece), 0, false);
+                if (n_chars > 0) {
+                    piece[n_chars] = '\0';
+                    token_str      = piece;
+                }
+            } else {
+                token_str = " ";
+            }
+
+            current_text += token_str;
+        }
+
+        LOG_INF("%s\n", current_text.c_str());
+    } else {
+        print_progress_bar(step, total_steps);
+    }
+
+    return true;
+}
+
+int main(int argc, char ** argv) {
+    ggml_time_init();
+
+    common_params params;
+
+    if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_DIFFUSION)) {
+        return 1;
+    }
+
+    const char * alg_names[] = { "ORIGIN", "MASKGIT_PLUS", "TOPK_MARGIN", "ENTROPY" };
+    const char * alg_name    = (params.diffusion.algorithm >= 0 && params.diffusion.algorithm <= 3) ?
+                                   alg_names[params.diffusion.algorithm] :
+                                   "UNKNOWN";
+
+    common_init();
+    llama_backend_init();
+
+    llama_model_params model_params = llama_model_default_params();
+    model_params.n_gpu_layers       = params.n_gpu_layers;
+    model_params.devices            = params.devices.data();
+    model_params.use_mmap           = params.use_mmap;
+    model_params.use_mlock          = params.use_mlock;
+    model_params.check_tensors      = params.check_tensors;
+
+    llama_model * model = llama_model_load_from_file(params.model.path.c_str(), model_params);
+    if (!model) {
+        LOG_ERR("error: failed to load model '%s'\n", params.model.path.c_str());
+        return 1;
+    }
+
+    llama_context_params ctx_params = llama_context_default_params();
+    ctx_params.n_ctx                = params.n_ctx;
+    ctx_params.n_batch              = params.n_batch;
+    ctx_params.n_ubatch             = params.n_ubatch;
+    ctx_params.flash_attn           = params.flash_attn;
+    ctx_params.no_perf              = params.no_perf;
+    ctx_params.type_k               = params.cache_type_k;
+    ctx_params.type_v               = params.cache_type_v;
+
+    llama_context * ctx = llama_init_from_model(model, ctx_params);
+    if (!ctx) {
+        LOG_ERR("error: failed to create context\n");
+        llama_model_free(model);
+        return 1;
+    }
+
+    llama_set_n_threads(ctx, params.cpuparams.n_threads, params.cpuparams_batch.n_threads);
+
+    const llama_vocab * vocab            = llama_model_get_vocab(model);
+    std::string         formatted_prompt = format_input_text(params.prompt, params.enable_chat_template, model);
+
+    std::vector<llama_token> input_tokens = common_tokenize(vocab, formatted_prompt,
+                                                            /*add special tokens*/ true,
+                                                            /*parse special*/ true);
+    int                      n_input      = input_tokens.size();
+
+    if (n_input >= params.n_ctx) {
+        LOG_ERR("error: input too long (%d tokens), max context is %d\n", n_input, params.n_ctx);
+        llama_free(ctx);
+        llama_model_free(model);
+        return 1;
+    }
+
+    struct diffusion_params ldiff_params = diffusion_default_params();
+    ldiff_params.steps                   = params.diffusion.steps;
+    ldiff_params.eps                     = params.diffusion.eps;
+    ldiff_params.temperature             = params.sampling.temp;
+    ldiff_params.top_p                   = params.sampling.top_p;
+    ldiff_params.top_k                   = params.sampling.top_k;
+    ldiff_params.algorithm               = static_cast<enum diffusion_alg>(params.diffusion.algorithm);
+    ldiff_params.alg_temp                = params.diffusion.alg_temp;
+    ldiff_params.seed                    = params.sampling.seed;
+
+    llama_token mask_token_id = llama_vocab_mask(vocab);
+    GGML_ASSERT(mask_token_id != LLAMA_TOKEN_NULL);
+
+    LOG_INF("diffusion_params: - %-25s llama_token      = %d\n", "mask_token_id", mask_token_id);
+    LOG_INF("diffusion_params: - %-25s u32              = %d\n", "steps", params.diffusion.steps);
+    LOG_INF("diffusion_params: - %-25s f32              = %.6f\n", "eps", params.diffusion.eps);
+    LOG_INF("diffusion_params: - %-25s u32              = %d (%s)\n", "algorithm", params.diffusion.algorithm,
+            alg_name);
+    LOG_INF("diffusion_params: - %-25s f32              = %.3f\n", "alg_temp", params.diffusion.alg_temp);
+
+    ldiff_params.mask_token_id = mask_token_id;
+
+    callback_data cb_data = { &params.diffusion, vocab, n_input };
+
+    ldiff_params.step_callback           = diffusion_step_callback;
+    ldiff_params.step_callback_user_data = &cb_data;
+
+    int32_t n_generated = 0;
+
+    std::vector<llama_token> output_tokens(params.n_ubatch);
+    diffusion_generate(ctx, input_tokens.data(), output_tokens.data(), n_input, params.n_ubatch,
+                       ldiff_params, n_generated);
+
+    if (n_generated > 0) {
+        if (params.diffusion.visual_mode) {
+            //clear screen and move cursor to top-left
+            LOG_INF("\033[2J\033[H");
+        }
+        output_tokens.erase(output_tokens.begin(), output_tokens.begin() + n_input);
+        std::string output_data = common_detokenize(vocab, output_tokens, false);
+        LOG_INF("\n%s\n", output_data.c_str());
+    } else {
+        LOG_INF("Error: diffusion generation failed\n");
+    }
+
+    llama_free(ctx);
+    llama_model_free(model);
+    llama_backend_free();
+
+    return 0;
+}
index 486a165b68b72a49df00f39190e58f158bd0f845..d8afe7696d24384c024cd1ae67f5c4092510b185 100644 (file)
@@ -367,6 +367,7 @@ class MODEL_ARCH(IntEnum):
     HUNYUAN_MOE      = auto()
     SMOLLM3          = auto()
     LFM2             = auto()
+    DREAM            = auto()
 
 
 class VISION_PROJECTOR_TYPE(IntEnum):
@@ -683,6 +684,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.HUNYUAN_MOE:      "hunyuan-moe",
     MODEL_ARCH.SMOLLM3:          "smollm3",
     MODEL_ARCH.LFM2:             "lfm2",
+    MODEL_ARCH.DREAM:            "dream",
 }
 
 VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@@ -1289,6 +1291,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_DOWN,
         MODEL_TENSOR.FFN_UP,
     ],
+    MODEL_ARCH.DREAM: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ROPE_FREQS,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_GATE,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+    ],
     MODEL_ARCH.QWEN2VL: [
         MODEL_TENSOR.TOKEN_EMBD,
         MODEL_TENSOR.OUTPUT_NORM,
index 28e84d4d7e27eb89082f4eb3dfe98685420d1a9b..bbe4f8dbfae666efbca5b0b7830abbb7dcef3ff1 100644 (file)
@@ -1005,6 +1005,7 @@ extern "C" {
     LLAMA_API llama_token llama_vocab_sep(const struct llama_vocab * vocab); // sentence separator
     LLAMA_API llama_token llama_vocab_nl (const struct llama_vocab * vocab); // next-line
     LLAMA_API llama_token llama_vocab_pad(const struct llama_vocab * vocab); // padding
+    LLAMA_API llama_token llama_vocab_mask(const struct llama_vocab * vocab); // mask
 
     LLAMA_API bool llama_vocab_get_add_bos(const struct llama_vocab * vocab);
     LLAMA_API bool llama_vocab_get_add_eos(const struct llama_vocab * vocab);
index 5c7a0d087ce528f124b55393eb9179e771dcf576..9454d04e538018d8c5be1f20be8a430e2ed5d5bf 100644 (file)
@@ -85,6 +85,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_HUNYUAN_MOE,      "hunyuan-moe"      },
     { LLM_ARCH_SMOLLM3,          "smollm3"          },
     { LLM_ARCH_LFM2,             "lfm2"             },
+    { LLM_ARCH_DREAM,            "dream"            },
     { LLM_ARCH_UNKNOWN,          "(unknown)"        },
 };
 
@@ -1891,6 +1892,23 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
         },
     },
+    {
+        LLM_ARCH_DREAM,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+        },
+    },
 };
 
 static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
@@ -2133,3 +2151,12 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
             return false;
     }
 }
+
+bool llm_arch_is_diffusion(const llm_arch & arch) {
+    switch (arch) {
+        case LLM_ARCH_DREAM:
+            return true;
+        default:
+            return false;
+    }
+}
index d4a2dea9ec33dcdc41be2ebcde9879e12a3177de..0ead0d6cdb11be87c5ef8660c288e5111c6e31dd 100644 (file)
@@ -89,6 +89,7 @@ enum llm_arch {
     LLM_ARCH_HUNYUAN_MOE,
     LLM_ARCH_SMOLLM3,
     LLM_ARCH_LFM2,
+    LLM_ARCH_DREAM,
     LLM_ARCH_UNKNOWN,
 };
 
@@ -479,3 +480,4 @@ const llm_tensor_info & llm_tensor_info_for(llm_tensor tensor);
 
 bool llm_arch_is_recurrent(const llm_arch & arch);
 bool llm_arch_is_hybrid   (const llm_arch & arch);
+bool llm_arch_is_diffusion(const llm_arch & arch);
index 1c437d55caded29b1318463d1d61aeed2d7dc888..82ddc5cef67651530bd9776f55a3b9d3935cb604 100644 (file)
@@ -849,6 +849,21 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                     default: type = LLM_TYPE_UNKNOWN;
                 }
             } break;
+        case LLM_ARCH_DREAM:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+                // Dream models are primarily 7B with 28 layers
+                switch (hparams.n_layer) {
+                    case 28:
+                        type = LLM_TYPE_7B;
+                        break;
+                    default:
+                        type = LLM_TYPE_UNKNOWN;
+                }
+                // Set non-causal attention for diffusion models
+                hparams.causal_attn = false;
+            }
+            break;
         case LLM_ARCH_QWEN2MOE:
             {
                 ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH,        hparams.n_ff_exp, false);
@@ -2670,6 +2685,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                 } break;
             case LLM_ARCH_QWEN2:
             case LLM_ARCH_QWEN2VL:
+            case LLM_ARCH_DREAM:
                 {
                     tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
 
@@ -7756,6 +7772,109 @@ struct llm_build_qwen2 : public llm_graph_context {
     }
 };
 
+struct llm_build_dream : public llm_graph_context {
+    llm_build_dream(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) :
+        llm_graph_context(params) {
+        //copied from qwen2
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        auto * inp_attn = build_attn_inp_no_cache();
+
+        ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute Q and K and RoPE them
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                Qcur               = ggml_add(ctx0, Qcur, model.layers[il].bq);
+                cb(Qcur, "Qcur", il);
+
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                Kcur               = ggml_add(ctx0, Kcur, model.layers[il].bk);
+                cb(Kcur, "Kcur", il);
+
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+                Vcur               = ggml_add(ctx0, Vcur, model.layers[il].bv);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+                Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                                     ext_factor, attn_factor, beta_fast, beta_slow);
+
+                Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                                     ext_factor, attn_factor, beta_fast, beta_slow);
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                cur = build_attn(inp_attn, gf, model.layers[il].wo, model.layers[il].bo, Qcur, Kcur, Vcur, nullptr,
+                                 nullptr, 1.0f / sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1 && inp_out_ids) {
+                cur   = ggml_get_rows(ctx0, cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL,
+                            model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
 struct llm_build_qwen2vl : public llm_graph_context {
     llm_build_qwen2vl(const llama_model & model, const llm_graph_params & params, ggml_cgraph * gf) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -16487,6 +16606,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
         case LLM_ARCH_NOMIC_BERT_MOE:
         case LLM_ARCH_NEO_BERT:
         case LLM_ARCH_WAVTOKENIZER_DEC:
+        case LLM_ARCH_DREAM:
             {
                 res = nullptr;
             } break;
@@ -16638,6 +16758,11 @@ llm_graph_result_ptr llama_model::build_graph(
             {
                 llm = std::make_unique<llm_build_qwen2>(*this, params, gf);
             } break;
+        case LLM_ARCH_DREAM:
+            {
+                llm = std::make_unique<llm_build_dream>(*this, params, gf);
+            }
+            break;
         case LLM_ARCH_QWEN2VL:
             {
                 llm = std::make_unique<llm_build_qwen2vl>(*this, params, gf);
@@ -17055,6 +17180,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
         case LLM_ARCH_BITNET:
         case LLM_ARCH_QWEN:
         case LLM_ARCH_QWEN2:
+        case LLM_ARCH_DREAM:
         case LLM_ARCH_QWEN2MOE:
         case LLM_ARCH_QWEN3:
         case LLM_ARCH_QWEN3MOE:
index 8d5c3b1448d61e105e5a13e45790ca8c58dcf8fb..2181c01e31a875b74befdfb63c6943e241bd7495 100644 (file)
@@ -3354,6 +3354,10 @@ llama_token llama_vocab::token_fim_sep() const {
     return pimpl->special_fim_sep_id;
 }
 
+llama_token llama_vocab::token_mask() const {
+    return pimpl->special_mask_id;
+}
+
 bool llama_vocab::get_add_space_prefix() const {
     return pimpl->add_space_prefix;
 }
@@ -3594,6 +3598,10 @@ llama_token llama_vocab_fim_sep(const struct llama_vocab * vocab) {
     return vocab->token_fim_sep();
 }
 
+llama_token llama_vocab_mask(const struct llama_vocab* vocab) {
+    return vocab->token_mask();
+}
+
 // deprecated
 const char * llama_token_get_text(const struct llama_vocab * vocab, llama_token token) {
     return llama_vocab_get_text(vocab, token);
index 1ce8fd307e2d3cb81ef3d31fcc1f699588aec527..842b129e86171dd9f1def51e464e1cac6ab06828 100644 (file)
@@ -101,6 +101,7 @@ struct llama_vocab {
     llama_token token_sep() const;
     llama_token token_nl () const;
     llama_token token_pad() const;
+    llama_token token_mask() const;
 
     llama_token token_prefix() const;
     llama_token token_middle() const;