]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
Add LLaDA 8b Diffusion model (#14771)
authorAman Gupta <redacted>
Thu, 31 Jul 2025 11:49:09 +0000 (19:49 +0800)
committerGitHub <redacted>
Thu, 31 Jul 2025 11:49:09 +0000 (19:49 +0800)
* Add support for Llada-8b: diffusion model

* Add README

* Fix README and convert_hf_to_gguf

* convert_hf_to_gguf.py: address review comments

* Make everything in a single example

* Remove model-specific sampling

* Remove unused argmax

* Remove braced initializers, improve README.md a bit

* Add diffusion specific gguf params in set_vocab, remove setting rope_theta and rms_norm_eps

* Remove adding the mask token

* Move add_add_bos_token to set_vocab

* use add_bool in gguf_writer.py

12 files changed:
common/arg.cpp
common/common.h
convert_hf_to_gguf.py
examples/diffusion/README.md [new file with mode: 0644]
examples/diffusion/diffusion-cli.cpp
gguf-py/gguf/constants.py
gguf-py/gguf/gguf_writer.py
gguf-py/gguf/tensor_mapping.py
include/llama.h
src/llama-arch.cpp
src/llama-arch.h
src/llama-model.cpp

index 060053595dbfd6dfda71e32e1c8eb3c41fe65ee8..74137d2db959d4beb87bfb42623a0477b6e9d162 100644 (file)
@@ -3438,12 +3438,18 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
         }
     ).set_examples({LLAMA_EXAMPLE_SERVER}));
 
-    // diffusion parameters
     add_opt(common_arg(
         { "--diffusion-steps" }, "N",
         string_format("number of diffusion steps (default: %d)", params.diffusion.steps),
         [](common_params & params, int value) { params.diffusion.steps = value; }
     ).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
+    add_opt(common_arg(
+        { "--diffusion-visual" },
+        string_format("enable visual diffusion mode (show progressive generation) (default: %s)",
+                      params.diffusion.visual_mode ? "true" : "false"),
+        [](common_params & params) { params.diffusion.visual_mode = true; }
+    ).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
+
     add_opt(common_arg(
         { "--diffusion-eps" }, "F",
         string_format("epsilon for timesteps (default: %.6f)", (double) params.diffusion.eps),
@@ -3451,21 +3457,32 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
     ).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
     add_opt(common_arg(
         { "--diffusion-algorithm" }, "N",
-        string_format("diffusion algorithm: 0=ORIGIN, 1=MASKGIT_PLUS, 2=TOPK_MARGIN, 3=ENTROPY (default: %d)",
+        string_format("diffusion algorithm: 0=ORIGIN, 1=ENTROPY_BASED, 2=MARGIN_BASED, 3=RANDOM, 4=LOW_CONFIDENCE (default: %d)",
                       params.diffusion.algorithm),
         [](common_params & params, int value) { params.diffusion.algorithm = value; }
     ).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
     add_opt(common_arg(
         { "--diffusion-alg-temp" }, "F",
-        string_format("algorithm temperature (default: %.3f)", (double) params.diffusion.alg_temp),
+        string_format("dream algorithm temperature (default: %.3f)", (double) params.diffusion.alg_temp),
         [](common_params & params, const std::string & value) { params.diffusion.alg_temp = std::stof(value); }
     ).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
+
     add_opt(common_arg(
-        { "--diffusion-visual" },
-        string_format("enable visual diffusion mode (show progressive generation) (default: %s)",
-                      params.diffusion.visual_mode ? "true" : "false"),
-        [](common_params & params) { params.diffusion.visual_mode = true; }
+        { "--diffusion-block-length" }, "N",
+        string_format("llada block length for generation (default: %d)", params.diffusion.block_length),
+        [](common_params & params, int value) { params.diffusion.block_length = value; }
+    ).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
+    add_opt(common_arg(
+        { "--diffusion-cfg-scale" }, "F",
+        string_format("llada classifier-free guidance scale (default: %.3f)", (double) params.diffusion.cfg_scale),
+        [](common_params & params, const std::string & value) { params.diffusion.cfg_scale = std::stof(value); }
+    ).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
+    add_opt(common_arg(
+        { "--diffusion-add-gumbel-noise" }, "F",
+        string_format("add gumbel noise to the logits if temp > 0.0 (default: %s)", params.diffusion.add_gumbel_noise ? "true" : "false"),
+        [](common_params & params, const std::string & value) { params.diffusion.add_gumbel_noise = std::stof(value); }
     ).set_examples({ LLAMA_EXAMPLE_DIFFUSION }));
 
+
     return ctx_arg;
 }
index 00f42694eafa8ceba8d5f923767a1d8271b6f0b5..38129b99d511ff4b9f8b9826e9e7fc4fe8ba0c64 100644 (file)
@@ -220,11 +220,17 @@ struct common_params_vocoder {
 };
 
 struct common_params_diffusion {
-    int32_t steps       = 64;     // number of diffusion steps
-    float   eps         = 1e-3f;  // epsilon for timesteps
-    int32_t algorithm   = 0;      // diffusion algorithm (0=ORIGIN, 1=MASKGIT_PLUS, 2=TOPK_MARGIN, 3=ENTROPY)
-    float   alg_temp    = 0.0f;   // algorithm temperature
-    bool    visual_mode = false;  // show progressive diffusion on screen
+    int32_t steps         = 128;
+    bool    visual_mode   = false;
+
+    float   eps           = 0;        // epsilon for timesteps
+    int32_t block_length  = 32;       // block length for generation
+
+    int32_t algorithm     = 4;        // default algorithm: low-confidence
+    float   alg_temp      = 0.0f;     // algorithm temperature
+
+    float   cfg_scale     = 0;        // classifier-free guidance scale
+    bool    add_gumbel_noise = false; // add gumbel noise to the logits if temp > 0.0
 };
 
 enum common_reasoning_format {
index 3f5cefe007cca515d72e6237c5defd1f72f141b0..db4112318d4876e07ee4d2fcb6a88cc34c3e9c39 100755 (executable)
@@ -2904,6 +2904,107 @@ class DreamModel(TextModel):
         yield from super().modify_tensors(data_torch, name, bid)
 
 
+@ModelBase.register("LLaDAModelLM")
+class LLaDAModel(TextModel):
+    model_arch = gguf.MODEL_ARCH.LLADA
+    undo_permute = True
+
+    def get_vocab_base(self) -> tuple[list[str], list[int], str]:
+        tokens: list[str] = []
+        toktypes: list[int] = []
+
+        from transformers import AutoTokenizer
+        tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
+
+        vocab_dict = tokenizer.get_vocab()
+        vocab_size = self.hparams.get("vocab_size", len(vocab_dict))
+        assert max(vocab_dict.values()) < vocab_size
+
+        tokpre = self.get_vocab_base_pre(tokenizer)
+
+        reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in vocab_dict.items()}
+        added_vocab = tokenizer.get_added_vocab()
+
+        for i in range(vocab_size):
+            if i not in reverse_vocab:
+                tokens.append(f"[PAD{i}]")
+                toktypes.append(gguf.TokenType.UNUSED)
+            elif reverse_vocab[i] in added_vocab:
+                tokens.append(reverse_vocab[i])
+                # Check if it's a special token - treat special tokens as CONTROL tokens
+                if hasattr(tokenizer, 'added_tokens_decoder') and i in tokenizer.added_tokens_decoder:
+                    if tokenizer.added_tokens_decoder[i].special:
+                        toktypes.append(gguf.TokenType.CONTROL)
+                    else:
+                        toktypes.append(gguf.TokenType.USER_DEFINED)
+                else:
+                    # Fallback: treat all added vocab as control tokens for special tokens like <|im_start|>
+                    toktypes.append(gguf.TokenType.CONTROL)
+            else:
+                tokens.append(reverse_vocab[i])
+                toktypes.append(gguf.TokenType.NORMAL)
+
+        return tokens, toktypes, tokpre
+
+    def set_vocab(self):
+        self._set_vocab_gpt2()
+
+        # LLaDA specific parameters
+        self.gguf_writer.add_add_bos_token(True)
+
+    def set_gguf_parameters(self):
+        super().set_gguf_parameters()
+        self._try_set_pooling_type()
+
+        # Add parameters similar to LlamaModel
+        hparams = self.hparams
+        self.gguf_writer.add_vocab_size(hparams["vocab_size"])
+
+        if (rope_dim := hparams.get("head_dim")) is None:
+            n_heads = hparams.get("num_attention_heads", hparams.get("n_heads"))
+            rope_dim = hparams.get("hidden_size", hparams.get("d_model")) // n_heads
+        self.gguf_writer.add_rope_dimension_count(rope_dim)
+
+        # Set context length for LLaDA
+        context_length = self.hparams.get("max_sequence_length", 4096)
+        self.gguf_writer.add_context_length(context_length)
+
+        # Set embedding length (dimension size)
+        embedding_length = self.hparams.get("d_model", 4096)
+        self.gguf_writer.add_embedding_length(embedding_length)
+
+        # Set feed forward length (MLP hidden size)
+        feed_forward_length = self.hparams.get("mlp_hidden_size", 12288)
+        self.gguf_writer.add_feed_forward_length(feed_forward_length)
+
+        # LLaDA models use non-causal attention for diffusion, similar to Dream
+        self.gguf_writer.add_causal_attention(False)
+
+        # LLaDA models don't shift their logits
+        self.gguf_writer.add_diffusion_shift_logits(False)
+
+    @staticmethod
+    def permute(weights: Tensor, n_head: int, n_head_kv: int | None):
+        if n_head_kv is not None and n_head != n_head_kv:
+            n_head = n_head_kv
+        return (weights.reshape(n_head, 2, weights.shape[0] // n_head // 2, *weights.shape[1:])
+                .swapaxes(1, 2)
+                .reshape(weights.shape))
+
+    def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
+        n_head = self.hparams.get("num_attention_heads", self.hparams.get("n_heads"))
+        n_kv_head = self.hparams.get("num_key_value_heads", self.hparams.get("n_kv_heads"))
+
+        if self.undo_permute:
+            if name.endswith(("q_proj.weight", "q_proj.bias")):
+                data_torch = LLaDAModel.permute(data_torch, n_head, n_head)
+            if name.endswith(("k_proj.weight", "k_proj.bias")):
+                data_torch = LLaDAModel.permute(data_torch, n_head, n_kv_head)
+
+        # LLaDA model tensors should be mapped directly since it's the base model
+        yield from super().modify_tensors(data_torch, name, bid)
+
+
 @ModelBase.register("Ernie4_5_ForCausalLM")
 class Ernie4_5Model(TextModel):
     model_arch = gguf.MODEL_ARCH.ERNIE4_5
diff --git a/examples/diffusion/README.md b/examples/diffusion/README.md
new file mode 100644 (file)
index 0000000..26de566
--- /dev/null
@@ -0,0 +1,13 @@
+# Diffusion Text Generation
+
+This directory contains implementations for Diffusion LLMs (DLLMs)
+
+More Info:
+- https://github.com/ggml-org/llama.cpp/pull/14644
+- https://github.com/ggml-org/llama.cpp/pull/14771
+
+
+Example of using Dream architechture: `llama-diffusion-cli -m dream7b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-eps 0.001 --diffusion-algorithm 3 --diffusion-steps 256 --diffusion-visual`
+
+Example of using LLaDA architechture: `llama-diffusion-cli -m llada-8b.gguf -p "write code to train MNIST in pytorch" -ub 512 --diffusion-block-length 32 --diffusion-steps 256 --diffusion-visual`
+
index 3e11ce1160b05418778ddef4b563585dc575586d..8431dcea8fe2af1c3339f0eb1d89e060a66fed41 100644 (file)
 #include "log.h"
 
 #include <limits.h>
-#include <string>
-#include <vector>
+
 #include <algorithm>
 #include <cmath>
+#include <cstring>
 #include <limits>
 #include <random>
+#include <string>
+#include <vector>
 
-typedef bool (*diffusion_step_callback_t)(int32_t step,
-                                          int32_t total_steps,
-                                          const llama_token * tokens,
-                                          int32_t n_tokens,
-                                          void * user_data);
-
-enum diffusion_alg {
-    DIFFUSION_ALG_ORIGIN       = 0,
-    DIFFUSION_ALG_MASKGIT_PLUS = 1,
-    DIFFUSION_ALG_TOPK_MARGIN  = 2,
-    DIFFUSION_ALG_ENTROPY      = 3,
+enum diffusion_algorithm { ORIGIN = 0, ENTROPY_BASED = 1, MARGIN_BASED = 2, RANDOM = 3, CONFIDENCE_BASED = 4 };
+
+// Unified transfer scheduling methods
+enum transfer_schedule {
+    TIMESTEP_BASED = 0,  // Dream-style: (1.0 - s/t) * remaining
+    BLOCK_BASED    = 1,  // LLaDA-style: process in blocks with get_num_transfer_tokens
 };
 
+typedef bool (*diffusion_step_callback_t)(int32_t             step,
+                                          int32_t             total_steps,
+                                          const llama_token * tokens,
+                                          int32_t             n_tokens,
+                                          void *              user_data);
+
 struct diffusion_params {
-    int32_t                   steps;
-    float                     eps;
-    float                     temperature;
-    float                     top_p;
-    int32_t                   top_k;
-    llama_token               mask_token_id;
-    enum diffusion_alg        algorithm;
-    float                     alg_temp;
-    diffusion_step_callback_t step_callback;
-    void *                    step_callback_user_data;
-    int32_t                   seed;
+    int32_t                   steps                   = 0;
+    float                     temperature             = 0;
+    llama_token               mask_token_id           = LLAMA_TOKEN_NULL;
+    diffusion_step_callback_t step_callback           = nullptr;
+    void *                    step_callback_user_data = nullptr;
+    int32_t                   seed                    = 0;
+    bool                      visual_mode             = false;
+    bool                      shift_logits            = false;  // Shift logits by -1 after decode
+
+    float   top_p = 0.;
+    int32_t top_k = 0.;
+
+    diffusion_algorithm algorithm = CONFIDENCE_BASED;
+    transfer_schedule   schedule  = TIMESTEP_BASED;
+
+    float   cfg_scale        = 0.;     // Config scale for classifier-free guidance
+    float   eps              = 0.;     // Timestep scheduling
+    int32_t block_length     = 0;      // Block size (for block scheduling)
+    float   alg_temp         = 0;      // algorithm temperature (0.0 = deterministic)
+    bool    add_gumbel_noise = false;  // Add gumbel noise to the logits if temp > 0.0
+
+    int32_t max_length = 0;            // Maximum sequence length
 };
 
+struct callback_data {
+    diffusion_params *  diff_params;
+    const llama_vocab * vocab;
+    int32_t             n_input;
+};
+
+static float calculate_confidence(const llama_token_data_array & cur_p,
+                                  diffusion_algorithm            algorithm,
+                                  std::mt19937 &                 rng) {
+    switch (algorithm) {
+        case CONFIDENCE_BASED:
+            return cur_p.data[cur_p.selected].p;  // Selected token probability
+
+        case ENTROPY_BASED:
+            {
+                float       entropy = 0.0f;
+                const float epsilon = 1e-10f;
+                for (size_t i = 0; i < cur_p.size; i++) {
+                    float prob = cur_p.data[i].p;
+                    entropy += prob * logf(prob + epsilon);
+                }
+                return -entropy;  // Higher entropy = lower confidence
+            }
+
+        case MARGIN_BASED:
+            return (cur_p.size > 1) ? cur_p.data[0].p - cur_p.data[1].p : cur_p.data[0].p;
+
+        case RANDOM:
+            {
+                std::uniform_real_distribution<float> uniform(0.0f, 1.0f);
+                return uniform(rng);  // Random confidence
+            }
+
+        case ORIGIN:
+            return cur_p.data[cur_p.selected].p;
+
+        default:
+            return 0.0f;
+    }
+}
+
+// Unified transfer count calculation function
+static int32_t calculate_transfer_count(int32_t                      step,
+                                        int32_t                      total_steps,
+                                        int32_t                      remaining_masked,
+                                        transfer_schedule            schedule,
+                                        float                        eps,
+                                        const std::vector<int32_t> & num_transfer_tokens = {}) {
+    switch (schedule) {
+        case TIMESTEP_BASED:
+            {
+                float t          = 1.0f - (float) step / total_steps * (1.0f - eps);
+                float s          = 1.0f - (float) (step + 1) / total_steps * (1.0f - eps);
+                float p_transfer = (step < total_steps - 1) ? (1.0f - s / t) : 1.0f;
+                return (int32_t) (remaining_masked * p_transfer);
+            }
+
+        case BLOCK_BASED:
+            if (!num_transfer_tokens.empty() && step < (int32_t) num_transfer_tokens.size()) {
+                return num_transfer_tokens[step];
+            }
+            return remaining_masked / (total_steps - step);  // Fallback
+
+        default:
+            return remaining_masked / (total_steps - step);
+    }
+}
+
+static bool diffusion_step_callback(int32_t             step,
+                                    int32_t             total_steps,
+                                    const llama_token * tokens,
+                                    int32_t             n_tokens,
+                                    void *              user_data) {
+    (void) user_data;
+
+    callback_data * data = static_cast<callback_data *>(user_data);
+
+    auto print_progress_bar = [](int32_t step, int32_t total_steps) {
+        int progress_percent = (step * 100) / total_steps;
+        int progress_bars    = (step * 50) / total_steps;
+        LOG_INF("\rdiffusion step: %d/%d [%s%s] %d%%",
+                step,
+                total_steps,
+                std::string(progress_bars, '=').c_str(),
+                std::string(50 - progress_bars, ' ').c_str(),
+                progress_percent);
+    };
+
+    if (data->diff_params->visual_mode) {
+        // Visual mode: clear
+        LOG_INF("\033[2J\033[H");  // Clear screen and move cursor to top-left
+
+        print_progress_bar(step, total_steps);
+
+        LOG_INF("\n");
+
+        std::string current_text = " ";
+
+        for (int32_t i = data->n_input; i < n_tokens; i++) {
+            std::string token_str;
+            if (tokens[i] != llama_vocab_mask(data->vocab)) {
+                char piece[256];
+                int  n_chars = llama_token_to_piece(data->vocab, tokens[i], piece, sizeof(piece), 0, false);
+                if (n_chars > 0) {
+                    piece[n_chars] = '\0';
+                    token_str      = piece;
+                }
+            } else {
+                token_str = " ";
+            }
+
+            current_text += token_str;
+        }
 
-static diffusion_params diffusion_default_params() {
-    diffusion_params params        = {};
-    params.steps                   = 64;
-    params.eps                     = 1e-3f;
-    params.temperature             = 0.2f;
-    params.top_p                   = 0.95f;
-    params.top_k                   = 0;
-    params.mask_token_id           = LLAMA_TOKEN_NULL;
-    params.algorithm               = DIFFUSION_ALG_ORIGIN;
-    params.alg_temp                = 0.0f;
-    params.step_callback           = nullptr;
-    params.step_callback_user_data = nullptr;
-    params.seed                    = 0;
-    return params;
+        LOG_INF("%s\n", current_text.c_str());
+    } else {
+        print_progress_bar(step, total_steps);
+    }
+
+    return true;
 }
 
-static void diffusion_generate(llama_context * ctx,
-                        const llama_token * input_tokens,
-                        llama_token * output_tokens,
-                        int32_t n_input,
-                        int32_t max_length,
-                        struct diffusion_params params,
-                        int32_t & n_generated) {
+static void add_gumbel_noise(float * logits, int32_t n_vocab, float temperature, std::mt19937 & rng) {
+    if (temperature == 0.0f) {
+        return;
+    }
+
+    std::uniform_real_distribution<double> uniform(0.0, 1.0);
+    for (int32_t i = 0; i < n_vocab; i++) {
+        double noise        = uniform(rng);
+        // Prevent log(0)
+        noise               = std::max(noise, 1e-20);
+        double gumbel_noise = std::pow(-std::log(noise), temperature);
+        logits[i]           = std::exp(logits[i]) / gumbel_noise;
+    }
+}
+
+static std::vector<int32_t> get_num_transfer_tokens(int32_t mask_count, int32_t steps) {
+    std::vector<int32_t> num_transfer_tokens(steps);
+
+    int32_t base      = mask_count / steps;
+    int32_t remainder = mask_count % steps;
+
+    for (int32_t i = 0; i < steps; i++) {
+        num_transfer_tokens[i] = base + (i < remainder ? 1 : 0);
+    }
+
+    return num_transfer_tokens;
+}
 
+static void diffusion_generate(llama_context *          ctx,
+                               const llama_token *      input_tokens,
+                               llama_token *            output_tokens,
+                               int32_t                  n_input,
+                               const diffusion_params & params,
+                               int32_t &                n_generated) {
     n_generated = 0;
-    if (!ctx || !input_tokens || !output_tokens || n_input <= 0 || max_length <= n_input) {
+    if (!ctx || !input_tokens || !output_tokens || n_input <= 0 || params.max_length <= n_input) {
         return;
     }
 
@@ -73,27 +218,21 @@ static void diffusion_generate(llama_context * ctx,
 
     // Initialize with input and pad with mask tokens
     std::copy(input_tokens, input_tokens + n_input, output_tokens);
-    std::fill(output_tokens + n_input, output_tokens + max_length, params.mask_token_id);
+    std::fill(output_tokens + n_input, output_tokens + params.max_length, params.mask_token_id);
 
     std::mt19937 rng(params.seed);
 
-    std::vector<float> timesteps(params.steps + 1);
-    for (int32_t i = 0; i <= params.steps; i++) {
-        timesteps[i] = 1.0f - (float) i / params.steps * (1.0f - params.eps);
-    }
-
     llama_set_causal_attn(ctx, false);
 
     int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
 
     std::vector<llama_token_data> candidates(n_vocab);
-
     std::vector<llama_token_data> conf_candidates;
-    conf_candidates.reserve(max_length);
-
+    conf_candidates.reserve(params.max_length);
     std::vector<int32_t> mask_positions;
-    mask_positions.reserve(max_length);
+    mask_positions.reserve(params.max_length);
 
+    // Setup sampler chain
     struct llama_sampler * sampler = llama_sampler_chain_init(llama_sampler_chain_default_params());
     if (params.top_k > 0) {
         llama_sampler_chain_add(sampler, llama_sampler_init_top_k(params.top_k));
@@ -108,210 +247,269 @@ static void diffusion_generate(llama_context * ctx,
 
     struct llama_sampler * dist_sampler = llama_sampler_init_dist(params.seed);
 
-    llama_batch batch = llama_batch_init(max_length, 0, 1);
-    batch.n_tokens    = max_length;
+    llama_batch batch = llama_batch_init(params.max_length, 0, 1);
+    batch.n_tokens    = params.max_length;
 
-    int64_t total_sampling_time = 0;
-    int64_t total_time = 0;
+    // Pre-allocate buffers for CFG if needed
+    int32_t                  logits_size = n_vocab * params.max_length;
+    std::vector<float>       cond_logits_buffer;
+    std::vector<llama_token> un_x_buffer;
+    if (params.cfg_scale > 0.0f) {
+        cond_logits_buffer.resize(logits_size);
+        un_x_buffer.resize(params.max_length);
+    }
 
-    int64_t time_start = ggml_time_us();
-    for (int32_t step = 0; step < params.steps; step++) {
-        if (params.step_callback) {
-            if (!params.step_callback(step, params.steps, output_tokens, max_length, params.step_callback_user_data)) {
-                break;
-            }
-        }
+    // For block-based processing
+    std::vector<int32_t> num_transfer_tokens;
+    int32_t              num_blocks      = 1;
+    int32_t              steps_per_block = params.steps;
 
-        for (int32_t i = 0; i < max_length; i++) {
-            batch.token[i]     = output_tokens[i];
-            batch.pos[i]       = i;
-            batch.n_seq_id[i]  = 1;
-            batch.seq_id[i][0] = 0;
-            batch.logits[i]    = 1;
-        }
+    if (params.schedule == BLOCK_BASED) {
+        GGML_ASSERT(params.max_length % params.block_length == 0);
+        num_blocks = params.max_length / params.block_length;
+        GGML_ASSERT(params.steps % num_blocks == 0);
+        steps_per_block = params.steps / num_blocks;
+    }
 
-        int ret = llama_decode(ctx, batch);
-        if (ret != 0) {
-            LOG_ERR("%s: failed to decode at step %d, ret = %d\n", __func__, step, ret);
-            break;
-        }
+    std::vector<float> confidence(params.max_length);
 
-        float * raw_logits = llama_get_logits(ctx);
-        if (!raw_logits) {
-            LOG_ERR("%s: failed to get logits at step %d\n", __func__, step);
-            break;
+    int64_t total_sampling_time = 0;
+    int64_t total_time          = 0;
+    int64_t time_start          = ggml_time_us();
+
+    for (int block_num = 0; block_num < num_blocks; block_num++) {
+        int32_t block_start = (params.schedule == BLOCK_BASED) ? n_input + block_num * params.block_length : 0;
+        int32_t block_end   = (params.schedule == BLOCK_BASED) ?
+                                  std::min(n_input + (block_num + 1) * params.block_length, params.max_length) :
+                                  params.max_length;
+
+        // Count masked tokens in current block for block-based processing
+        if (params.schedule == BLOCK_BASED) {
+            int32_t block_mask_count = 0;
+            for (int i = block_start; i < block_end; i++) {
+                if (output_tokens[i] == params.mask_token_id) {
+                    block_mask_count++;
+                }
+            }
+            num_transfer_tokens = get_num_transfer_tokens(block_mask_count, steps_per_block);
         }
 
-        auto get_logits_for_pos = [&](int32_t pos) -> const float * {
-            return pos == 0 ? raw_logits : raw_logits + (pos - 1) * n_vocab;
-        };
-
-        int64_t time_start_sampling = ggml_time_us();
+        for (int32_t step = 0; step < steps_per_block; step++) {
+            int32_t global_step = block_num * steps_per_block + step;
 
-        mask_positions.clear();
-        for (int32_t i = 0; i < max_length; i++) {
-            if (output_tokens[i] == params.mask_token_id) {
-                mask_positions.push_back(i);
+            if (params.step_callback) {
+                if (!params.step_callback(
+                        global_step, params.steps, output_tokens, params.max_length, params.step_callback_user_data)) {
+                    break;
+                }
             }
-        }
 
-        if (mask_positions.empty()) {
-            break;
-        }
+            // Setup batch
+            for (int32_t i = 0; i < params.max_length; i++) {
+                batch.token[i]     = output_tokens[i];
+                batch.pos[i]       = i;
+                batch.n_seq_id[i]  = 1;
+                batch.seq_id[i][0] = 0;
+                batch.logits[i]    = 1;
+            }
 
-        float t = timesteps[step];
-        float s = timesteps[step + 1];
+            float * logits = nullptr;
 
-        if (params.algorithm == DIFFUSION_ALG_ORIGIN) {
-            float p_transfer = (step < params.steps - 1) ? (1.0f - s / t) : 1.0f;
+            if (params.cfg_scale > 0.0f) {
+                int ret = llama_decode(ctx, batch);
+                if (ret != 0) {
+                    LOG_ERR("Failed to generate conditional");
+                    break;
+                }
+                float * cond_logits_ptr = llama_get_logits(ctx);
+                std::memcpy(cond_logits_buffer.data(), cond_logits_ptr, logits_size * sizeof(float));
 
-            for (int32_t pos : mask_positions) {
-                if (std::uniform_real_distribution<float>(0.0f, 1.0f)(rng) < p_transfer) {
-                    const float * pos_logits = get_logits_for_pos(pos);
-                    for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
-                        candidates[token_id].id    = token_id;
-                        candidates[token_id].logit = pos_logits[token_id];
-                        candidates[token_id].p     = 0.0f;
-                    }
+                // Unconditional generation (mask input)
+                std::copy(output_tokens, output_tokens + params.max_length, un_x_buffer.begin());
+                for (int32_t i = 0; i < n_input; i++) {
+                    un_x_buffer[i] = params.mask_token_id;
+                }
 
-                    llama_token_data_array cur_p = {
-                        /* .data       = */ candidates.data(),
-                        /* .size       = */ (size_t) n_vocab,  // Reset size to full vocab
-                        /* .selected   = */ -1,
-                        /* .sorted     = */ false,
-                    };
+                for (int32_t i = 0; i < params.max_length; i++) {
+                    batch.token[i] = un_x_buffer[i];
+                }
+                ret = llama_decode(ctx, batch);
+                if (ret != 0) {
+                    LOG_ERR("Failed to generate unconditional");
+                    break;
+                }
+                float * uncond_logits = llama_get_logits(ctx);
 
-                    llama_sampler_apply(sampler, &cur_p);
-                    output_tokens[pos] = cur_p.data[cur_p.selected].id;
+                // Apply CFG
+                for (int32_t i = 0; i < logits_size; i++) {
+                    cond_logits_buffer[i] =
+                        uncond_logits[i] + (params.cfg_scale + 1.0f) * (cond_logits_buffer[i] - uncond_logits[i]);
                 }
-            }
-        } else {
-            std::vector<std::pair<float, int32_t>> confidences;
-            std::vector<llama_token>               sampled_tokens(mask_positions.size());
-
-            for (size_t i = 0; i < mask_positions.size(); i++) {
-                int32_t       pos        = mask_positions[i];
-                const float * pos_logits = get_logits_for_pos(pos);
-
-                for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
-                    candidates[token_id].logit = pos_logits[token_id];
-                    candidates[token_id].p     = 0.0f;
-                    candidates[token_id].id    = token_id;
+                logits = cond_logits_buffer.data();
+            } else {
+                int ret = llama_decode(ctx, batch);
+                if (ret != 0) {
+                    LOG_ERR("%s: failed to decode at step %d, ret = %d\n", __func__, global_step, ret);
+                    break;
                 }
+                logits = llama_get_logits(ctx);
+            }
 
-                llama_token_data_array cur_p = {
-                    /* .data       = */ candidates.data(),
-                    /* .size       = */ candidates.size(),
-                    /* .selected   = */ -1,
-                    /* .sorted     = */ false,
-                };
+            if (!logits) {
+                LOG_ERR("%s: failed to get logits at step %d\n", __func__, global_step);
+                break;
+            }
 
-                llama_sampler_apply(sampler, &cur_p);
+            auto get_logits_for_pos = [&](int32_t pos) -> const float * {
+                if (params.shift_logits) {
+                    return pos == 0 ? logits : logits + (pos - 1) * n_vocab;
+                }
+                return logits + (pos) *n_vocab;
+            };
 
-                llama_token sampled_token = cur_p.data[cur_p.selected].id;
+            int64_t time_start_sampling = ggml_time_us();
 
-                float confidence = 0.0f;
-                if (params.algorithm == DIFFUSION_ALG_ENTROPY) {
-                    const float epsilon = 1e-10f;
-                    for (size_t j = 0; j < cur_p.size; j++) {
-                        float prob = cur_p.data[j].p;
-                        confidence += prob * logf(prob + epsilon);
+            mask_positions.clear();
+            for (int32_t i = 0; i < params.max_length; i++) {
+                if (output_tokens[i] == params.mask_token_id) {
+                    // For block-based, only consider current block
+                    if (params.schedule != BLOCK_BASED || (i >= block_start && i < block_end)) {
+                        mask_positions.push_back(i);
                     }
-                } else if (params.algorithm == DIFFUSION_ALG_TOPK_MARGIN) {
-                    confidence = cur_p.data[0].p - cur_p.data[1].p;
-                } else {
-                    confidence = cur_p.data[cur_p.selected].p;
                 }
+            }
 
-                sampled_tokens[i] = sampled_token;
-                confidences.emplace_back(confidence, i);
+            if (mask_positions.empty()) {
+                break;
             }
 
-            int32_t num_transfer =
-                (step < params.steps - 1) ? (int32_t) (mask_positions.size() * (1.0f - s / t)) : mask_positions.size();
-
-            if (num_transfer > 0) {
-                if (params.alg_temp == 0.0f) {
-                    std::partial_sort(confidences.begin(), confidences.begin() + num_transfer, confidences.end(),
-                                      [](const std::pair<float, int32_t> & a, const std::pair<float, int32_t> & b) {
-                                          if (a.first != b.first) {
-                                              return a.first > b.first;
-                                          }
-                                          return a.second < b.second;
-                                      });
-                } else {
-                    conf_candidates.clear();
-
-                    for (int32_t pos = 0; pos < max_length; pos++) {
-                        float conf_logit = -std::numeric_limits<float>::infinity();
-
-                        auto it = std::find(mask_positions.begin(), mask_positions.end(), pos);
-                        if (it != mask_positions.end()) {
-                            size_t mask_idx = std::distance(mask_positions.begin(), it);
-                            conf_logit = confidences[mask_idx].first / params.alg_temp;  // Apply temperature scaling
+            if (params.add_gumbel_noise && params.temperature > 0.0f) {
+                add_gumbel_noise(logits, n_vocab, params.temperature, rng);
+            }
+
+            if (params.algorithm == ORIGIN) {
+                int32_t transfer_count = calculate_transfer_count(
+                    step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens);
+                float p_transfer = (float) transfer_count / mask_positions.size();
+
+                for (int32_t pos : mask_positions) {
+                    if (std::uniform_real_distribution<float>(0.0f, 1.0f)(rng) < p_transfer) {
+                        const float * pos_logits = get_logits_for_pos(pos);
+                        for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
+                            candidates[token_id].id    = token_id;
+                            candidates[token_id].logit = pos_logits[token_id];
+                            candidates[token_id].p     = 0.0f;
                         }
 
-                        conf_candidates.emplace_back(llama_token_data{ pos, conf_logit, 0.0f });
+                        llama_token_data_array cur_p = {
+                            candidates.data(),
+                            (size_t) n_vocab,
+                            -1,
+                            false,
+                        };
+
+                        llama_sampler_apply(sampler, &cur_p);
+                        output_tokens[pos] = cur_p.data[cur_p.selected].id;
+                    }
+                }
+            } else {
+                std::vector<std::pair<float, int32_t>> confidences;
+                std::vector<llama_token>               sampled_tokens(mask_positions.size());
+
+                for (size_t i = 0; i < mask_positions.size(); i++) {
+                    int32_t       pos        = mask_positions[i];
+                    const float * pos_logits = get_logits_for_pos(pos);
+
+                    for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
+                        candidates[token_id].logit = pos_logits[token_id];
+                        candidates[token_id].p     = 0.0f;
+                        candidates[token_id].id    = token_id;
                     }
 
-                    llama_token_data_array conf_array = {
-                        /* .data       = */ conf_candidates.data(),
-                        /* .size       = */ conf_candidates.size(),
-                        /* .selected   = */ -1,
-                        /* .sorted     = */ false,
+                    llama_token_data_array cur_p = {
+                        candidates.data(),
+                        candidates.size(),
+                        -1,
+                        false,
                     };
 
-                    for (int32_t i = 0; i < num_transfer; i++) {
-                        // Apply distribution sampler to get selected index
-                        llama_sampler_apply(dist_sampler, &conf_array);
-                        int selected_idx      = conf_array.selected;
-                        confidences[i].second = conf_candidates[selected_idx].id;
+                    llama_sampler_apply(sampler, &cur_p);
+                    llama_token sampled_token = cur_p.data[cur_p.selected].id;
+
+                    float conf = calculate_confidence(cur_p, params.algorithm, rng);
 
-                        conf_candidates[selected_idx].p = 0.0f;
-                        conf_array.selected             = -1;
-                    }
+                    sampled_tokens[i] = sampled_token;
+                    confidences.emplace_back(conf, i);
                 }
 
-                if (params.alg_temp == 0.0f) {
-                    // Deterministic - use confidence order
-                    for (int32_t i = 0; i < num_transfer; i++) {
-                        int32_t     mask_idx = confidences[i].second;
-                        int32_t     pos      = mask_positions[mask_idx];
-                        llama_token token    = sampled_tokens[mask_idx];
-                        output_tokens[pos]   = token;
-                    }
-                } else {
-                    for (int32_t i = 0; i < num_transfer; i++) {
-                        int32_t pos = confidences[i].second;
-                        auto    it  = std::find(mask_positions.begin(), mask_positions.end(), pos);
-                        if (it != mask_positions.end()) {
-                            int32_t mask_idx   = std::distance(mask_positions.begin(), it);
+                int32_t transfer_count = calculate_transfer_count(
+                    step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens);
+
+                if (transfer_count > 0) {
+                    if (params.alg_temp == 0.0f) {
+                        std::partial_sort(confidences.begin(),
+                                          confidences.begin() + std::min(transfer_count, (int32_t) confidences.size()),
+                                          confidences.end(),
+                                          [](const std::pair<float, int32_t> & a, const std::pair<float, int32_t> & b) {
+                                              if (a.first != b.first) {
+                                                  return a.first > b.first;
+                                              }
+                                              return a.second < b.second;
+                                          });
+
+                        for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) {
+                            int32_t mask_idx   = confidences[i].second;
+                            int32_t pos        = mask_positions[mask_idx];
                             output_tokens[pos] = sampled_tokens[mask_idx];
                         }
+                    } else {
+                        conf_candidates.clear();
+                        for (size_t i = 0; i < confidences.size(); i++) {
+                            float conf_logit = confidences[i].first / params.alg_temp;
+                            conf_candidates.emplace_back(llama_token_data{ (int32_t) i, conf_logit, 0.0f });
+                        }
+
+                        llama_token_data_array conf_array = {
+                            conf_candidates.data(),
+                            conf_candidates.size(),
+                            -1,
+                            false,
+                        };
+
+                        for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) {
+                            llama_sampler_apply(dist_sampler, &conf_array);
+                            int32_t selected_idx = conf_array.selected;
+                            int32_t mask_idx     = selected_idx;
+                            int32_t pos          = mask_positions[mask_idx];
+                            output_tokens[pos]   = sampled_tokens[mask_idx];
+
+                            conf_candidates[selected_idx].p = 0.0f;
+                            conf_array.selected             = -1;
+                        }
                     }
                 }
             }
+
+            int64_t time_end_sampling = ggml_time_us();
+            total_sampling_time += time_end_sampling - time_start_sampling;
         }
-        int64_t time_end_sampling = ggml_time_us();
-        total_sampling_time += time_end_sampling - time_start_sampling;
     }
+
     int64_t time_end = ggml_time_us();
     total_time += time_end - time_start;
 
     LOG_INF("\ntotal time: %0.2fms, time per step: %0.2fms, sampling time per step: %0.2fms\n",
-            total_time / 1000.0, total_time / 1000.0 / params.steps, total_sampling_time / 1000.0 / params.steps);
-
+            total_time / 1000.0,
+            total_time / 1000.0 / params.steps,
+            total_sampling_time / 1000.0 / params.steps);
 
     llama_batch_free(batch);
     llama_sampler_free(sampler);
     llama_sampler_free(dist_sampler);
 
-    n_generated = max_length;
+    n_generated = params.max_length;
 }
 
-
-
-
 static std::string format_input_text(const std::string & prompt, bool use_chat_template, llama_model * model) {
     if (!use_chat_template) {
         return prompt;
@@ -331,66 +529,6 @@ static std::string format_input_text(const std::string & prompt, bool use_chat_t
     return result.prompt;
 }
 
-struct callback_data {
-    const common_params_diffusion * diff_params;
-    const llama_vocab *             vocab;
-    int32_t                         n_input;
-};
-
-static bool diffusion_step_callback(int32_t step,
-                                    int32_t total_steps,
-                                    const llama_token * tokens,
-                                    int32_t n_tokens,
-                                    void * user_data) {
-    (void)user_data;
-
-    callback_data * data = static_cast<callback_data *>(user_data);
-
-    auto print_progress_bar = [](int32_t step, int32_t total_steps) {
-        int progress_percent = (step * 100) / total_steps;
-        int progress_bars    = (step * 50) / total_steps;
-        LOG_INF("\rdiffusion step: %d/%d [%s%s] %d%%",
-            step,
-            total_steps,
-            std::string(progress_bars, '=').c_str(),
-            std::string(50 - progress_bars, ' ').c_str(),
-            progress_percent);
-    };
-
-    if (data->diff_params->visual_mode) {
-        // Visual mode: clear
-        LOG_INF("\033[2J\033[H");  // Clear screen and move cursor to top-left
-
-        print_progress_bar(step, total_steps);
-
-        LOG_INF("\n");
-
-        std::string current_text = " ";
-
-        for (int32_t i = data->n_input; i < n_tokens; i++) {
-            std::string token_str;
-            if (tokens[i] != llama_vocab_mask(data->vocab)) {
-                char piece[256];
-                int  n_chars = llama_token_to_piece(data->vocab, tokens[i], piece, sizeof(piece), 0, false);
-                if (n_chars > 0) {
-                    piece[n_chars] = '\0';
-                    token_str      = piece;
-                }
-            } else {
-                token_str = " ";
-            }
-
-            current_text += token_str;
-        }
-
-        LOG_INF("%s\n", current_text.c_str());
-    } else {
-        print_progress_bar(step, total_steps);
-    }
-
-    return true;
-}
-
 int main(int argc, char ** argv) {
     ggml_time_init();
 
@@ -400,11 +538,6 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
-    const char * alg_names[] = { "ORIGIN", "MASKGIT_PLUS", "TOPK_MARGIN", "ENTROPY" };
-    const char * alg_name    = (params.diffusion.algorithm >= 0 && params.diffusion.algorithm <= 3) ?
-                                   alg_names[params.diffusion.algorithm] :
-                                   "UNKNOWN";
-
     common_init();
     llama_backend_init();
 
@@ -421,6 +554,12 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
+    if (!llama_model_is_diffusion(model)) {
+        LOG_ERR("error: unsupported model for diffusion");
+        llama_model_free(model);
+        return 1;
+    }
+
     llama_context_params ctx_params = llama_context_default_params();
     ctx_params.n_ctx                = params.n_ctx;
     ctx_params.n_batch              = params.n_batch;
@@ -442,10 +581,12 @@ int main(int argc, char ** argv) {
     const llama_vocab * vocab            = llama_model_get_vocab(model);
     std::string         formatted_prompt = format_input_text(params.prompt, params.enable_chat_template, model);
 
-    std::vector<llama_token> input_tokens = common_tokenize(vocab, formatted_prompt,
+    std::vector<llama_token> input_tokens = common_tokenize(vocab,
+                                                            formatted_prompt,
                                                             /*add special tokens*/ true,
                                                             /*parse special*/ true);
-    int                      n_input      = input_tokens.size();
+
+    int n_input = input_tokens.size();
 
     if (n_input >= params.n_ctx) {
         LOG_ERR("error: input too long (%d tokens), max context is %d\n", n_input, params.n_ctx);
@@ -454,44 +595,79 @@ int main(int argc, char ** argv) {
         return 1;
     }
 
-    struct diffusion_params ldiff_params = diffusion_default_params();
-    ldiff_params.steps                   = params.diffusion.steps;
-    ldiff_params.eps                     = params.diffusion.eps;
-    ldiff_params.temperature             = params.sampling.temp;
-    ldiff_params.top_p                   = params.sampling.top_p;
-    ldiff_params.top_k                   = params.sampling.top_k;
-    ldiff_params.algorithm               = static_cast<enum diffusion_alg>(params.diffusion.algorithm);
-    ldiff_params.alg_temp                = params.diffusion.alg_temp;
-    ldiff_params.seed                    = params.sampling.seed;
-
     llama_token mask_token_id = llama_vocab_mask(vocab);
     GGML_ASSERT(mask_token_id != LLAMA_TOKEN_NULL);
 
-    LOG_INF("diffusion_params: - %-25s llama_token      = %d\n", "mask_token_id", mask_token_id);
-    LOG_INF("diffusion_params: - %-25s u32              = %d\n", "steps", params.diffusion.steps);
-    LOG_INF("diffusion_params: - %-25s f32              = %.6f\n", "eps", params.diffusion.eps);
-    LOG_INF("diffusion_params: - %-25s u32              = %d (%s)\n", "algorithm", params.diffusion.algorithm,
-            alg_name);
-    LOG_INF("diffusion_params: - %-25s f32              = %.3f\n", "alg_temp", params.diffusion.alg_temp);
+    bool visual_mode = params.diffusion.visual_mode;
 
-    ldiff_params.mask_token_id = mask_token_id;
+    int32_t                  n_generated = 0;
+    std::vector<llama_token> output_tokens(params.n_ubatch);
 
-    callback_data cb_data = { &params.diffusion, vocab, n_input };
+    struct diffusion_params diff_params;
 
-    ldiff_params.step_callback           = diffusion_step_callback;
-    ldiff_params.step_callback_user_data = &cb_data;
+    char shift_logits_str[8];
+    if (llama_model_meta_val_str(model, "diffusion.shift_logits", shift_logits_str, sizeof(shift_logits_str)) >= 0) {
+        diff_params.shift_logits = (strcmp(shift_logits_str, "true") == 0);
+    } else {
+        diff_params.shift_logits = true;
+    }
 
-    int32_t n_generated = 0;
+    //Use either eps or block length, but not both
+    GGML_ASSERT((params.diffusion.eps == 0) ^ (params.diffusion.block_length == 0));
 
-    std::vector<llama_token> output_tokens(params.n_ubatch);
-    diffusion_generate(ctx, input_tokens.data(), output_tokens.data(), n_input, params.n_ubatch,
-                       ldiff_params, n_generated);
+    if (params.diffusion.eps) {
+        diff_params.schedule = TIMESTEP_BASED;
+        diff_params.eps      = params.diffusion.eps;
+    } else if (params.diffusion.block_length) {
+        diff_params.schedule     = BLOCK_BASED;
+        diff_params.block_length = params.diffusion.block_length;
+    }
+
+    diff_params.mask_token_id    = mask_token_id;
+    diff_params.seed             = params.sampling.seed;
+    diff_params.temperature      = params.sampling.temp;
+    diff_params.steps            = params.diffusion.steps;
+    diff_params.algorithm        = static_cast<diffusion_algorithm>(params.diffusion.algorithm);
+    diff_params.max_length       = params.n_ubatch;
+    diff_params.top_p            = params.sampling.top_p;
+    diff_params.top_k            = params.sampling.top_k;
+    diff_params.visual_mode      = params.diffusion.visual_mode;
+    diff_params.add_gumbel_noise = params.diffusion.add_gumbel_noise;
+
+    diff_params.step_callback           = diffusion_step_callback;
+    callback_data cb_data               = { &diff_params, vocab, n_input };
+    diff_params.step_callback_user_data = &cb_data;
+
+    const char * alg_names[]   = { "ORIGIN", "ENTROPY_BASED", "MARGIN_BASED", "RANDOM", "CONFIDENCE_BASED" };
+    const char * sched_names[] = { "TIMESTEP_BASED", "BLOCK_BASED" };
+    const char * alg_name =
+        (diff_params.algorithm >= 0 && diff_params.algorithm <= 4) ? alg_names[diff_params.algorithm] : "UNKNOWN";
+    const char * sched_name =
+        (diff_params.schedule >= 0 && diff_params.schedule <= 1) ? sched_names[diff_params.schedule] : "UNKNOWN";
+
+    LOG_INF("diffusion_params: - %-25s llama_token      = %d\n", "mask_token_id", mask_token_id);
+    LOG_INF("diffusion_params: - %-25s u32              = %d\n", "steps", diff_params.steps);
+    LOG_INF("diffusion_params: - %-25s u32              = %d\n", "max_length", diff_params.max_length);
+    LOG_INF("diffusion_params: - %-25s enum             = %d (%s)\n", "algorithm", diff_params.algorithm, alg_name);
+    LOG_INF("diffusion_params: - %-25s enum             = %d (%s)\n", "schedule", diff_params.schedule, sched_name);
+    LOG_INF("diffusion_params: - %-25s f32              = %.3f\n", "temperature", diff_params.temperature);
+    if (diff_params.schedule == TIMESTEP_BASED) {
+        LOG_INF("diffusion_params: - %-25s f32              = %.6f\n", "eps", diff_params.eps);
+        LOG_INF("diffusion_params: - %-25s f32              = %.3f\n", "alg_temp", diff_params.alg_temp);
+    }
+    if (diff_params.schedule == BLOCK_BASED) {
+        LOG_INF("diffusion_params: - %-25s u32              = %d\n", "block_length", diff_params.block_length);
+        LOG_INF("diffusion_params: - %-25s f32              = %.3f\n", "cfg_scale", diff_params.cfg_scale);
+    }
+
+    diffusion_generate(ctx, input_tokens.data(), output_tokens.data(), n_input, diff_params, n_generated);
 
     if (n_generated > 0) {
-        if (params.diffusion.visual_mode) {
+        if (visual_mode) {
             //clear screen and move cursor to top-left
             LOG_INF("\033[2J\033[H");
         }
+
         output_tokens.erase(output_tokens.begin(), output_tokens.begin() + n_input);
         std::string output_data = common_detokenize(vocab, output_tokens, false);
         LOG_INF("\n%s\n", output_data.c_str());
index c97b61d09c71148859ecdd0bb7686a9c5c86e9da..ef47ea7359eda07bd31d2b064f2afc22ff6c4e1a 100644 (file)
@@ -279,6 +279,9 @@ class Keys:
         class Projector:
             STACK_FACTOR    = "clip.audio.projector.stack_factor"
 
+    class Diffusion:
+        SHIFT_LOGITS        = "diffusion.shift_logits"
+
 #
 # recommended mapping of model tensor names for storage in gguf
 #
@@ -377,6 +380,7 @@ class MODEL_ARCH(IntEnum):
     LFM2             = auto()
     DREAM            = auto()
     SMALLTHINKER     = auto()
+    LLADA            = auto()
 
 
 class VISION_PROJECTOR_TYPE(IntEnum):
@@ -697,6 +701,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
     MODEL_ARCH.LFM2:             "lfm2",
     MODEL_ARCH.DREAM:            "dream",
     MODEL_ARCH.SMALLTHINKER:     "smallthinker",
+    MODEL_ARCH.LLADA:            "llada",
 }
 
 VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
@@ -1318,6 +1323,21 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
         MODEL_TENSOR.FFN_DOWN,
         MODEL_TENSOR.FFN_UP,
     ],
+    MODEL_ARCH.LLADA: [
+        MODEL_TENSOR.TOKEN_EMBD,
+        MODEL_TENSOR.OUTPUT_NORM,
+        MODEL_TENSOR.OUTPUT,
+        MODEL_TENSOR.ROPE_FREQS,
+        MODEL_TENSOR.ATTN_NORM,
+        MODEL_TENSOR.ATTN_Q,
+        MODEL_TENSOR.ATTN_K,
+        MODEL_TENSOR.ATTN_V,
+        MODEL_TENSOR.ATTN_OUT,
+        MODEL_TENSOR.FFN_NORM,
+        MODEL_TENSOR.FFN_GATE,
+        MODEL_TENSOR.FFN_DOWN,
+        MODEL_TENSOR.FFN_UP,
+    ],
     MODEL_ARCH.QWEN2VL: [
         MODEL_TENSOR.TOKEN_EMBD,
         MODEL_TENSOR.OUTPUT_NORM,
index 4f23f9b0246197a7ab026f2566f328ec59fe51f0..f4fd64ad822fae7ab85137b5a06d7c715750b976 100644 (file)
@@ -1047,6 +1047,11 @@ class GGUFWriter:
     def add_audio_stack_factor(self, value: int) -> None:
         self.add_uint32(Keys.ClipAudio.Projector.STACK_FACTOR, value)
 
+    # diffusion models
+
+    def add_diffusion_shift_logits(self, value: bool) -> None:
+        self.add_bool(Keys.Diffusion.SHIFT_LOGITS, value)
+
     def _pack(self, fmt: str, value: Any, skip_pack_prefix: bool = False) -> bytes:
         pack_prefix = ''
         if not skip_pack_prefix:
index bfd4fd37a3f6868a44a5271d607dc7726c5c9bb3..15adbfa7818455a93e67a54a6a6faed472ea4ce2 100644 (file)
@@ -32,6 +32,7 @@ class TensorNameMap:
             "model.word_embeddings",                     # bailingmoe
             "language_model.model.embed_tokens",         # llama4
             "encoder",                                   # neobert
+            "model.transformer.wte",                     # llada
         ),
 
         # Token type embeddings
@@ -71,6 +72,7 @@ class TensorNameMap:
             "head",                      # rwkv
             "head.out",                  # wavtokenizer
             "lm_head",                   # llama4
+            "model.transformer.ff_out",  # llada
         ),
 
         # Output norm
@@ -94,6 +96,7 @@ class TensorNameMap:
             "model.ln_out",                            # rwkv7
             "backbone.final_layer_norm",               # wavtokenizer
             "model.norm",                              # llama4
+            "model.transformer.ln_f",                  # llada
         ),
 
         # Rope frequencies
@@ -139,6 +142,7 @@ class TensorNameMap:
             "model.layers.{bid}.input_layernorm",                   # llama4
             "transformer_encoder.{bid}.attention_norm",             # neobert
             "model.layers.{bid}.operator_norm",                     # lfm2
+            "model.transformer.blocks.{bid}.attn_norm",             # llada
         ),
 
         # Attention norm 2
@@ -183,6 +187,7 @@ class TensorNameMap:
             "transformer.decoder_layer.{bid}.multi_head_attention.query",# Grok
             "transformer.h.{bid}.attn.attention.q_proj",                 # exaone
             "model.layers.{bid}.self_attn.q_proj",                       # llama4
+            "model.transformer.blocks.{bid}.q_proj",                     # llada
         ),
 
         # Attention key
@@ -199,6 +204,7 @@ class TensorNameMap:
             "transformer.decoder_layer.{bid}.multi_head_attention.key",# Grok
             "transformer.h.{bid}.attn.attention.k_proj",               # exaone
             "model.layers.{bid}.self_attn.k_proj",                     # llama4
+            "model.transformer.blocks.{bid}.k_proj",                   # llada
         ),
 
         # Attention value
@@ -214,6 +220,7 @@ class TensorNameMap:
             "transformer.decoder_layer.{bid}.multi_head_attention.value",# Grok
             "transformer.h.{bid}.attn.attention.v_proj",                 # exaone
             "model.layers.{bid}.self_attn.v_proj",                       # llama4
+            "model.transformer.blocks.{bid}.v_proj",                     # llada
         ),
 
         # Attention output
@@ -246,6 +253,7 @@ class TensorNameMap:
             "transformer.h.{bid}.attn.attention.out_proj",                  # exaone
             "model.layers.{bid}.self_attn.o_proj",                          # llama4
             "transformer_encoder.{bid}.wo",                                 # neobert
+            "model.transformer.blocks.{bid}.attn_out",                      # llada
         ),
 
         # Attention output norm
@@ -291,6 +299,7 @@ class TensorNameMap:
             "model.layers.{bid}.post_attention_layernorm",                   # llama4
             "transformer_encoder.{bid}.ffn_norm",                            # neobert
             "model.layers.layers.{bid}.pre_mlp_norm",                        # plamo2
+            "model.transformer.blocks.{bid}.ff_norm",                        # llada
         ),
 
         # Post feed-forward norm
@@ -364,6 +373,7 @@ class TensorNameMap:
             "model.layers.{bid}.feed_forward.up_proj",                # llama4 jamba granite-hybrid
             "transformer_encoder.{bid}.ffn.w12",                      # neobert
             "model.layers.{bid}.block_sparse_moe.up",                 # smallthinker
+            "model.transformer.blocks.{bid}.up_proj",                  # llada
         ),
 
         MODEL_TENSOR.FFN_UP_EXP: (
@@ -405,6 +415,7 @@ class TensorNameMap:
             "transformer.h.{bid}.mlp.c_fc_0",             # exaone
             "model.layers.{bid}.feed_forward.gate_proj",  # llama4 jamba granite-hybrid
             "model.layers.{bid}.block_sparse_moe.gate",   # smallthinker
+            "model.transformer.blocks.{bid}.ff_proj",     # llada
         ),
 
         MODEL_TENSOR.FFN_GATE_EXP: (
@@ -454,6 +465,7 @@ class TensorNameMap:
             "model.layers.{bid}.feed_forward.down_proj",              # llama4 jamba granite-hybrid
             "transformer_encoder.{bid}.ffn.w3",                       # neobert
             "model.layers.{bid}.block_sparse_moe.down",               # smallthinker
+            "model.transformer.blocks.{bid}.ff_out",                   # llada
         ),
 
         MODEL_TENSOR.FFN_DOWN_EXP: (
index 6f454a508a06c80bb92fab68f97f06e6cb95ccd5..1a51e74a8d63f663d0aa70007cb669965aec2c95 100644 (file)
@@ -537,6 +537,9 @@ extern "C" {
     // Returns true if the model is recurrent (like Mamba, RWKV, etc.)
     LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model);
 
+    // Returns true if the model is diffusion-based (like LLaDA, Dream, etc.)
+    LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model);
+
     // Returns 0 on success
     LLAMA_API uint32_t llama_model_quantize(
             const char * fname_inp,
index dbf977443ae85f2870ff2d7eb134ca1b615385e1..15fb9d0b5080932517638ac7391bf557cb3d4b0e 100644 (file)
@@ -89,6 +89,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
     { LLM_ARCH_LFM2,             "lfm2"             },
     { LLM_ARCH_DREAM,            "dream"            },
     { LLM_ARCH_SMALLTHINKER,     "smallthinker"     },
+    { LLM_ARCH_LLADA,            "llada"            },
     { LLM_ARCH_UNKNOWN,          "(unknown)"        },
 };
 
@@ -1972,6 +1973,23 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
             { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
         },
     },
+    {
+        LLM_ARCH_LLADA,
+        {
+            { LLM_TENSOR_TOKEN_EMBD,      "token_embd" },
+            { LLM_TENSOR_OUTPUT_NORM,     "output_norm" },
+            { LLM_TENSOR_OUTPUT,          "output" },
+            { LLM_TENSOR_ATTN_NORM,       "blk.%d.attn_norm" },
+            { LLM_TENSOR_ATTN_Q,          "blk.%d.attn_q" },
+            { LLM_TENSOR_ATTN_K,          "blk.%d.attn_k" },
+            { LLM_TENSOR_ATTN_V,          "blk.%d.attn_v" },
+            { LLM_TENSOR_ATTN_OUT,        "blk.%d.attn_output" },
+            { LLM_TENSOR_FFN_NORM,        "blk.%d.ffn_norm" },
+            { LLM_TENSOR_FFN_GATE,        "blk.%d.ffn_gate" },
+            { LLM_TENSOR_FFN_DOWN,        "blk.%d.ffn_down" },
+            { LLM_TENSOR_FFN_UP,          "blk.%d.ffn_up" },
+        },
+    },
     {
         LLM_ARCH_UNKNOWN,
         {
@@ -2224,6 +2242,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
 bool llm_arch_is_diffusion(const llm_arch & arch) {
     switch (arch) {
         case LLM_ARCH_DREAM:
+        case LLM_ARCH_LLADA:
             return true;
         default:
             return false;
index 8267a8d3aa49128c932dcf44098ff544a1c3a3f2..8ea80806c9c8dc0e6a7a82f3e71deed12527ba30 100644 (file)
@@ -93,6 +93,7 @@ enum llm_arch {
     LLM_ARCH_LFM2,
     LLM_ARCH_DREAM,
     LLM_ARCH_SMALLTHINKER,
+    LLM_ARCH_LLADA,
     LLM_ARCH_UNKNOWN,
 };
 
index e3aa9e6f91af92b7412f09c566f5c9eb8cf5756d..92a7efed3dab3eaefe0f54843a4b051e4ada7967 100644 (file)
@@ -869,6 +869,21 @@ void llama_model::load_hparams(llama_model_loader & ml) {
                 hparams.causal_attn = false;
             }
             break;
+        case LLM_ARCH_LLADA:
+            {
+                ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+                // LLaDA-8B has 32 layers, similar to LLaMA but for diffusion
+                switch (hparams.n_layer) {
+                    case 32:
+                        type = LLM_TYPE_8B;
+                        break;
+                    default:
+                        type = LLM_TYPE_UNKNOWN;
+                }
+                // Set non-causal attention for diffusion models
+                hparams.causal_attn = false;
+            }
+            break;
         case LLM_ARCH_QWEN2MOE:
             {
                 ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH,        hparams.n_ff_exp, false);
@@ -2149,6 +2164,53 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
                         }
                     }
                 } break;
+            case LLM_ARCH_LLADA:
+                {
+                    tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
+
+                    // output
+                    output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
+                    output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
+
+                    // if output is NULL, init from the input tok embed
+                    if (output == NULL) {
+                        output =
+                            create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED);
+                    }
+
+                    for (int i = 0; i < n_layer; ++i) {
+                        auto & layer = layers[i];
+
+                        layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
+
+                        // Use separate Q, K, V projections without bias, matching LLaDALlamaBlock
+                        layer.wq =
+                            create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0);
+                        layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0);
+                        layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0);
+                        // No bias for QKV projections as per config: include_bias=false, include_qkv_bias=false
+                        layer.wo =
+                            create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
+                        layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED);
+
+                        layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0);
+
+                        layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_rot / 2 },
+                                                         TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
+
+                        layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0);
+                        layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
+                        layer.ffn_up   = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0);
+
+                        // optional MLP bias
+                        layer.ffn_gate_b =
+                            create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), { n_ff }, TENSOR_NOT_REQUIRED);
+                        layer.ffn_down_b =
+                            create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED);
+                        layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), { n_ff }, TENSOR_NOT_REQUIRED);
+                    }
+                }
+                break;
             case LLM_ARCH_LLAMA4:
                 {
                     tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
@@ -8042,6 +8104,106 @@ struct llm_build_dream : public llm_graph_context {
     }
 };
 
+struct llm_build_llada : public llm_graph_context {
+    llm_build_llada(const llama_model & model, const llm_graph_params & params) :
+        llm_graph_context(params) {
+        // LLaDA is similar to LLaMA but uses non-causal attention for diffusion
+        const int64_t n_embd_head = hparams.n_embd_head_v;
+
+        GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+        GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+        ggml_tensor * cur;
+        ggml_tensor * inpL;
+
+        inpL = build_inp_embd(model.tok_embd);
+
+        // inp_pos - contains the positions
+        ggml_tensor * inp_pos = build_inp_pos();
+
+        // Non-causal attention for diffusion
+        auto * inp_attn = build_attn_inp_no_cache();
+
+        ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+        for (int il = 0; il < n_layer; ++il) {
+            ggml_tensor * inpSA = inpL;
+
+            // norm
+            cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
+            cb(cur, "attn_norm", il);
+
+            // self-attention
+            {
+                // compute separate Q, K, V projections without bias, matching LLaDALlamaBlock
+                ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+                ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+                ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+                Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+                Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+                Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                                     ext_factor, attn_factor, beta_fast, beta_slow);
+
+                Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+                                     ext_factor, attn_factor, beta_fast, beta_slow);
+
+                cb(Qcur, "Qcur", il);
+                cb(Kcur, "Kcur", il);
+                cb(Vcur, "Vcur", il);
+
+                cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr,
+                                 1.0f / sqrtf(float(n_embd_head)), il);
+            }
+
+            if (il == n_layer - 1 && inp_out_ids) {
+                cur   = ggml_get_rows(ctx0, cur, inp_out_ids);
+                inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+            }
+
+            ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+            cb(ffn_inp, "ffn_inp", il);
+
+            // feed-forward network
+            cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
+            cb(cur, "ffn_norm", il);
+
+            cur = build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL,
+                            model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
+            cb(cur, "ffn_out", il);
+
+            cur = ggml_add(ctx0, cur, ffn_inp);
+
+            cur = build_cvec(cur, il);
+            cb(cur, "l_out", il);
+
+            // input for next layer
+            inpL = cur;
+        }
+
+        cur = inpL;
+
+        cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
+
+        cb(cur, "result_norm", -1);
+        res->t_embd = cur;
+
+        // lm_head
+        cur = build_lora_mm(model.output, cur);
+
+        cb(cur, "result_output", -1);
+        res->t_logits = cur;
+
+        ggml_build_forward_expand(gf, cur);
+    }
+};
+
 struct llm_build_qwen2vl : public llm_graph_context {
     llm_build_qwen2vl(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
         const int64_t n_embd_head = hparams.n_embd_head_v;
@@ -17201,6 +17363,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
         case LLM_ARCH_NEO_BERT:
         case LLM_ARCH_WAVTOKENIZER_DEC:
         case LLM_ARCH_DREAM:
+        case LLM_ARCH_LLADA:
             {
                 res = nullptr;
             } break;
@@ -17367,6 +17530,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
                 llm = std::make_unique<llm_build_dream>(*this, params);
             }
             break;
+        case LLM_ARCH_LLADA:
+            {
+                llm = std::make_unique<llm_build_llada>(*this, params);
+            }
+            break;
         case LLM_ARCH_QWEN2VL:
             {
                 llm = std::make_unique<llm_build_qwen2vl>(*this, params);
@@ -17765,6 +17933,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
 
         // use what we call a normal RoPE, operating on pairs of consecutive head values
         case LLM_ARCH_LLAMA:
+        case LLM_ARCH_LLADA:
         case LLM_ARCH_LLAMA4:
         case LLM_ARCH_DECI:
         case LLM_ARCH_BAICHUAN:
@@ -17943,6 +18112,10 @@ bool llama_model_is_recurrent(const llama_model * model) {
     return llm_arch_is_recurrent(model->arch);
 }
 
+bool llama_model_is_diffusion(const llama_model * model) {
+    return llm_arch_is_diffusion(model->arch);
+}
+
 const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_tensor_map(const llama_model * model) {
     return model->tensors_by_name;
 }