#include "log.h"
#include <limits.h>
-#include <string>
-#include <vector>
+
#include <algorithm>
#include <cmath>
+#include <cstring>
#include <limits>
#include <random>
+#include <string>
+#include <vector>
-typedef bool (*diffusion_step_callback_t)(int32_t step,
- int32_t total_steps,
- const llama_token * tokens,
- int32_t n_tokens,
- void * user_data);
-
-enum diffusion_alg {
- DIFFUSION_ALG_ORIGIN = 0,
- DIFFUSION_ALG_MASKGIT_PLUS = 1,
- DIFFUSION_ALG_TOPK_MARGIN = 2,
- DIFFUSION_ALG_ENTROPY = 3,
+enum diffusion_algorithm { ORIGIN = 0, ENTROPY_BASED = 1, MARGIN_BASED = 2, RANDOM = 3, CONFIDENCE_BASED = 4 };
+
+// Unified transfer scheduling methods
+enum transfer_schedule {
+ TIMESTEP_BASED = 0, // Dream-style: (1.0 - s/t) * remaining
+ BLOCK_BASED = 1, // LLaDA-style: process in blocks with get_num_transfer_tokens
};
+typedef bool (*diffusion_step_callback_t)(int32_t step,
+ int32_t total_steps,
+ const llama_token * tokens,
+ int32_t n_tokens,
+ void * user_data);
+
struct diffusion_params {
- int32_t steps;
- float eps;
- float temperature;
- float top_p;
- int32_t top_k;
- llama_token mask_token_id;
- enum diffusion_alg algorithm;
- float alg_temp;
- diffusion_step_callback_t step_callback;
- void * step_callback_user_data;
- int32_t seed;
+ int32_t steps = 0;
+ float temperature = 0;
+ llama_token mask_token_id = LLAMA_TOKEN_NULL;
+ diffusion_step_callback_t step_callback = nullptr;
+ void * step_callback_user_data = nullptr;
+ int32_t seed = 0;
+ bool visual_mode = false;
+ bool shift_logits = false; // Shift logits by -1 after decode
+
+ float top_p = 0.;
+ int32_t top_k = 0.;
+
+ diffusion_algorithm algorithm = CONFIDENCE_BASED;
+ transfer_schedule schedule = TIMESTEP_BASED;
+
+ float cfg_scale = 0.; // Config scale for classifier-free guidance
+ float eps = 0.; // Timestep scheduling
+ int32_t block_length = 0; // Block size (for block scheduling)
+ float alg_temp = 0; // algorithm temperature (0.0 = deterministic)
+ bool add_gumbel_noise = false; // Add gumbel noise to the logits if temp > 0.0
+
+ int32_t max_length = 0; // Maximum sequence length
};
+struct callback_data {
+ diffusion_params * diff_params;
+ const llama_vocab * vocab;
+ int32_t n_input;
+};
+
+static float calculate_confidence(const llama_token_data_array & cur_p,
+ diffusion_algorithm algorithm,
+ std::mt19937 & rng) {
+ switch (algorithm) {
+ case CONFIDENCE_BASED:
+ return cur_p.data[cur_p.selected].p; // Selected token probability
+
+ case ENTROPY_BASED:
+ {
+ float entropy = 0.0f;
+ const float epsilon = 1e-10f;
+ for (size_t i = 0; i < cur_p.size; i++) {
+ float prob = cur_p.data[i].p;
+ entropy += prob * logf(prob + epsilon);
+ }
+ return -entropy; // Higher entropy = lower confidence
+ }
+
+ case MARGIN_BASED:
+ return (cur_p.size > 1) ? cur_p.data[0].p - cur_p.data[1].p : cur_p.data[0].p;
+
+ case RANDOM:
+ {
+ std::uniform_real_distribution<float> uniform(0.0f, 1.0f);
+ return uniform(rng); // Random confidence
+ }
+
+ case ORIGIN:
+ return cur_p.data[cur_p.selected].p;
+
+ default:
+ return 0.0f;
+ }
+}
+
+// Unified transfer count calculation function
+static int32_t calculate_transfer_count(int32_t step,
+ int32_t total_steps,
+ int32_t remaining_masked,
+ transfer_schedule schedule,
+ float eps,
+ const std::vector<int32_t> & num_transfer_tokens = {}) {
+ switch (schedule) {
+ case TIMESTEP_BASED:
+ {
+ float t = 1.0f - (float) step / total_steps * (1.0f - eps);
+ float s = 1.0f - (float) (step + 1) / total_steps * (1.0f - eps);
+ float p_transfer = (step < total_steps - 1) ? (1.0f - s / t) : 1.0f;
+ return (int32_t) (remaining_masked * p_transfer);
+ }
+
+ case BLOCK_BASED:
+ if (!num_transfer_tokens.empty() && step < (int32_t) num_transfer_tokens.size()) {
+ return num_transfer_tokens[step];
+ }
+ return remaining_masked / (total_steps - step); // Fallback
+
+ default:
+ return remaining_masked / (total_steps - step);
+ }
+}
+
+static bool diffusion_step_callback(int32_t step,
+ int32_t total_steps,
+ const llama_token * tokens,
+ int32_t n_tokens,
+ void * user_data) {
+ (void) user_data;
+
+ callback_data * data = static_cast<callback_data *>(user_data);
+
+ auto print_progress_bar = [](int32_t step, int32_t total_steps) {
+ int progress_percent = (step * 100) / total_steps;
+ int progress_bars = (step * 50) / total_steps;
+ LOG_INF("\rdiffusion step: %d/%d [%s%s] %d%%",
+ step,
+ total_steps,
+ std::string(progress_bars, '=').c_str(),
+ std::string(50 - progress_bars, ' ').c_str(),
+ progress_percent);
+ };
+
+ if (data->diff_params->visual_mode) {
+ // Visual mode: clear
+ LOG_INF("\033[2J\033[H"); // Clear screen and move cursor to top-left
+
+ print_progress_bar(step, total_steps);
+
+ LOG_INF("\n");
+
+ std::string current_text = " ";
+
+ for (int32_t i = data->n_input; i < n_tokens; i++) {
+ std::string token_str;
+ if (tokens[i] != llama_vocab_mask(data->vocab)) {
+ char piece[256];
+ int n_chars = llama_token_to_piece(data->vocab, tokens[i], piece, sizeof(piece), 0, false);
+ if (n_chars > 0) {
+ piece[n_chars] = '\0';
+ token_str = piece;
+ }
+ } else {
+ token_str = " ";
+ }
+
+ current_text += token_str;
+ }
-static diffusion_params diffusion_default_params() {
- diffusion_params params = {};
- params.steps = 64;
- params.eps = 1e-3f;
- params.temperature = 0.2f;
- params.top_p = 0.95f;
- params.top_k = 0;
- params.mask_token_id = LLAMA_TOKEN_NULL;
- params.algorithm = DIFFUSION_ALG_ORIGIN;
- params.alg_temp = 0.0f;
- params.step_callback = nullptr;
- params.step_callback_user_data = nullptr;
- params.seed = 0;
- return params;
+ LOG_INF("%s\n", current_text.c_str());
+ } else {
+ print_progress_bar(step, total_steps);
+ }
+
+ return true;
}
-static void diffusion_generate(llama_context * ctx,
- const llama_token * input_tokens,
- llama_token * output_tokens,
- int32_t n_input,
- int32_t max_length,
- struct diffusion_params params,
- int32_t & n_generated) {
+static void add_gumbel_noise(float * logits, int32_t n_vocab, float temperature, std::mt19937 & rng) {
+ if (temperature == 0.0f) {
+ return;
+ }
+
+ std::uniform_real_distribution<double> uniform(0.0, 1.0);
+ for (int32_t i = 0; i < n_vocab; i++) {
+ double noise = uniform(rng);
+ // Prevent log(0)
+ noise = std::max(noise, 1e-20);
+ double gumbel_noise = std::pow(-std::log(noise), temperature);
+ logits[i] = std::exp(logits[i]) / gumbel_noise;
+ }
+}
+
+static std::vector<int32_t> get_num_transfer_tokens(int32_t mask_count, int32_t steps) {
+ std::vector<int32_t> num_transfer_tokens(steps);
+
+ int32_t base = mask_count / steps;
+ int32_t remainder = mask_count % steps;
+
+ for (int32_t i = 0; i < steps; i++) {
+ num_transfer_tokens[i] = base + (i < remainder ? 1 : 0);
+ }
+
+ return num_transfer_tokens;
+}
+static void diffusion_generate(llama_context * ctx,
+ const llama_token * input_tokens,
+ llama_token * output_tokens,
+ int32_t n_input,
+ const diffusion_params & params,
+ int32_t & n_generated) {
n_generated = 0;
- if (!ctx || !input_tokens || !output_tokens || n_input <= 0 || max_length <= n_input) {
+ if (!ctx || !input_tokens || !output_tokens || n_input <= 0 || params.max_length <= n_input) {
return;
}
// Initialize with input and pad with mask tokens
std::copy(input_tokens, input_tokens + n_input, output_tokens);
- std::fill(output_tokens + n_input, output_tokens + max_length, params.mask_token_id);
+ std::fill(output_tokens + n_input, output_tokens + params.max_length, params.mask_token_id);
std::mt19937 rng(params.seed);
- std::vector<float> timesteps(params.steps + 1);
- for (int32_t i = 0; i <= params.steps; i++) {
- timesteps[i] = 1.0f - (float) i / params.steps * (1.0f - params.eps);
- }
-
llama_set_causal_attn(ctx, false);
int32_t n_vocab = llama_vocab_n_tokens(llama_model_get_vocab(model));
std::vector<llama_token_data> candidates(n_vocab);
-
std::vector<llama_token_data> conf_candidates;
- conf_candidates.reserve(max_length);
-
+ conf_candidates.reserve(params.max_length);
std::vector<int32_t> mask_positions;
- mask_positions.reserve(max_length);
+ mask_positions.reserve(params.max_length);
+ // Setup sampler chain
struct llama_sampler * sampler = llama_sampler_chain_init(llama_sampler_chain_default_params());
if (params.top_k > 0) {
llama_sampler_chain_add(sampler, llama_sampler_init_top_k(params.top_k));
struct llama_sampler * dist_sampler = llama_sampler_init_dist(params.seed);
- llama_batch batch = llama_batch_init(max_length, 0, 1);
- batch.n_tokens = max_length;
+ llama_batch batch = llama_batch_init(params.max_length, 0, 1);
+ batch.n_tokens = params.max_length;
- int64_t total_sampling_time = 0;
- int64_t total_time = 0;
+ // Pre-allocate buffers for CFG if needed
+ int32_t logits_size = n_vocab * params.max_length;
+ std::vector<float> cond_logits_buffer;
+ std::vector<llama_token> un_x_buffer;
+ if (params.cfg_scale > 0.0f) {
+ cond_logits_buffer.resize(logits_size);
+ un_x_buffer.resize(params.max_length);
+ }
- int64_t time_start = ggml_time_us();
- for (int32_t step = 0; step < params.steps; step++) {
- if (params.step_callback) {
- if (!params.step_callback(step, params.steps, output_tokens, max_length, params.step_callback_user_data)) {
- break;
- }
- }
+ // For block-based processing
+ std::vector<int32_t> num_transfer_tokens;
+ int32_t num_blocks = 1;
+ int32_t steps_per_block = params.steps;
- for (int32_t i = 0; i < max_length; i++) {
- batch.token[i] = output_tokens[i];
- batch.pos[i] = i;
- batch.n_seq_id[i] = 1;
- batch.seq_id[i][0] = 0;
- batch.logits[i] = 1;
- }
+ if (params.schedule == BLOCK_BASED) {
+ GGML_ASSERT(params.max_length % params.block_length == 0);
+ num_blocks = params.max_length / params.block_length;
+ GGML_ASSERT(params.steps % num_blocks == 0);
+ steps_per_block = params.steps / num_blocks;
+ }
- int ret = llama_decode(ctx, batch);
- if (ret != 0) {
- LOG_ERR("%s: failed to decode at step %d, ret = %d\n", __func__, step, ret);
- break;
- }
+ std::vector<float> confidence(params.max_length);
- float * raw_logits = llama_get_logits(ctx);
- if (!raw_logits) {
- LOG_ERR("%s: failed to get logits at step %d\n", __func__, step);
- break;
+ int64_t total_sampling_time = 0;
+ int64_t total_time = 0;
+ int64_t time_start = ggml_time_us();
+
+ for (int block_num = 0; block_num < num_blocks; block_num++) {
+ int32_t block_start = (params.schedule == BLOCK_BASED) ? n_input + block_num * params.block_length : 0;
+ int32_t block_end = (params.schedule == BLOCK_BASED) ?
+ std::min(n_input + (block_num + 1) * params.block_length, params.max_length) :
+ params.max_length;
+
+ // Count masked tokens in current block for block-based processing
+ if (params.schedule == BLOCK_BASED) {
+ int32_t block_mask_count = 0;
+ for (int i = block_start; i < block_end; i++) {
+ if (output_tokens[i] == params.mask_token_id) {
+ block_mask_count++;
+ }
+ }
+ num_transfer_tokens = get_num_transfer_tokens(block_mask_count, steps_per_block);
}
- auto get_logits_for_pos = [&](int32_t pos) -> const float * {
- return pos == 0 ? raw_logits : raw_logits + (pos - 1) * n_vocab;
- };
-
- int64_t time_start_sampling = ggml_time_us();
+ for (int32_t step = 0; step < steps_per_block; step++) {
+ int32_t global_step = block_num * steps_per_block + step;
- mask_positions.clear();
- for (int32_t i = 0; i < max_length; i++) {
- if (output_tokens[i] == params.mask_token_id) {
- mask_positions.push_back(i);
+ if (params.step_callback) {
+ if (!params.step_callback(
+ global_step, params.steps, output_tokens, params.max_length, params.step_callback_user_data)) {
+ break;
+ }
}
- }
- if (mask_positions.empty()) {
- break;
- }
+ // Setup batch
+ for (int32_t i = 0; i < params.max_length; i++) {
+ batch.token[i] = output_tokens[i];
+ batch.pos[i] = i;
+ batch.n_seq_id[i] = 1;
+ batch.seq_id[i][0] = 0;
+ batch.logits[i] = 1;
+ }
- float t = timesteps[step];
- float s = timesteps[step + 1];
+ float * logits = nullptr;
- if (params.algorithm == DIFFUSION_ALG_ORIGIN) {
- float p_transfer = (step < params.steps - 1) ? (1.0f - s / t) : 1.0f;
+ if (params.cfg_scale > 0.0f) {
+ int ret = llama_decode(ctx, batch);
+ if (ret != 0) {
+ LOG_ERR("Failed to generate conditional");
+ break;
+ }
+ float * cond_logits_ptr = llama_get_logits(ctx);
+ std::memcpy(cond_logits_buffer.data(), cond_logits_ptr, logits_size * sizeof(float));
- for (int32_t pos : mask_positions) {
- if (std::uniform_real_distribution<float>(0.0f, 1.0f)(rng) < p_transfer) {
- const float * pos_logits = get_logits_for_pos(pos);
- for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
- candidates[token_id].id = token_id;
- candidates[token_id].logit = pos_logits[token_id];
- candidates[token_id].p = 0.0f;
- }
+ // Unconditional generation (mask input)
+ std::copy(output_tokens, output_tokens + params.max_length, un_x_buffer.begin());
+ for (int32_t i = 0; i < n_input; i++) {
+ un_x_buffer[i] = params.mask_token_id;
+ }
- llama_token_data_array cur_p = {
- /* .data = */ candidates.data(),
- /* .size = */ (size_t) n_vocab, // Reset size to full vocab
- /* .selected = */ -1,
- /* .sorted = */ false,
- };
+ for (int32_t i = 0; i < params.max_length; i++) {
+ batch.token[i] = un_x_buffer[i];
+ }
+ ret = llama_decode(ctx, batch);
+ if (ret != 0) {
+ LOG_ERR("Failed to generate unconditional");
+ break;
+ }
+ float * uncond_logits = llama_get_logits(ctx);
- llama_sampler_apply(sampler, &cur_p);
- output_tokens[pos] = cur_p.data[cur_p.selected].id;
+ // Apply CFG
+ for (int32_t i = 0; i < logits_size; i++) {
+ cond_logits_buffer[i] =
+ uncond_logits[i] + (params.cfg_scale + 1.0f) * (cond_logits_buffer[i] - uncond_logits[i]);
}
- }
- } else {
- std::vector<std::pair<float, int32_t>> confidences;
- std::vector<llama_token> sampled_tokens(mask_positions.size());
-
- for (size_t i = 0; i < mask_positions.size(); i++) {
- int32_t pos = mask_positions[i];
- const float * pos_logits = get_logits_for_pos(pos);
-
- for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
- candidates[token_id].logit = pos_logits[token_id];
- candidates[token_id].p = 0.0f;
- candidates[token_id].id = token_id;
+ logits = cond_logits_buffer.data();
+ } else {
+ int ret = llama_decode(ctx, batch);
+ if (ret != 0) {
+ LOG_ERR("%s: failed to decode at step %d, ret = %d\n", __func__, global_step, ret);
+ break;
}
+ logits = llama_get_logits(ctx);
+ }
- llama_token_data_array cur_p = {
- /* .data = */ candidates.data(),
- /* .size = */ candidates.size(),
- /* .selected = */ -1,
- /* .sorted = */ false,
- };
+ if (!logits) {
+ LOG_ERR("%s: failed to get logits at step %d\n", __func__, global_step);
+ break;
+ }
- llama_sampler_apply(sampler, &cur_p);
+ auto get_logits_for_pos = [&](int32_t pos) -> const float * {
+ if (params.shift_logits) {
+ return pos == 0 ? logits : logits + (pos - 1) * n_vocab;
+ }
+ return logits + (pos) *n_vocab;
+ };
- llama_token sampled_token = cur_p.data[cur_p.selected].id;
+ int64_t time_start_sampling = ggml_time_us();
- float confidence = 0.0f;
- if (params.algorithm == DIFFUSION_ALG_ENTROPY) {
- const float epsilon = 1e-10f;
- for (size_t j = 0; j < cur_p.size; j++) {
- float prob = cur_p.data[j].p;
- confidence += prob * logf(prob + epsilon);
+ mask_positions.clear();
+ for (int32_t i = 0; i < params.max_length; i++) {
+ if (output_tokens[i] == params.mask_token_id) {
+ // For block-based, only consider current block
+ if (params.schedule != BLOCK_BASED || (i >= block_start && i < block_end)) {
+ mask_positions.push_back(i);
}
- } else if (params.algorithm == DIFFUSION_ALG_TOPK_MARGIN) {
- confidence = cur_p.data[0].p - cur_p.data[1].p;
- } else {
- confidence = cur_p.data[cur_p.selected].p;
}
+ }
- sampled_tokens[i] = sampled_token;
- confidences.emplace_back(confidence, i);
+ if (mask_positions.empty()) {
+ break;
}
- int32_t num_transfer =
- (step < params.steps - 1) ? (int32_t) (mask_positions.size() * (1.0f - s / t)) : mask_positions.size();
-
- if (num_transfer > 0) {
- if (params.alg_temp == 0.0f) {
- std::partial_sort(confidences.begin(), confidences.begin() + num_transfer, confidences.end(),
- [](const std::pair<float, int32_t> & a, const std::pair<float, int32_t> & b) {
- if (a.first != b.first) {
- return a.first > b.first;
- }
- return a.second < b.second;
- });
- } else {
- conf_candidates.clear();
-
- for (int32_t pos = 0; pos < max_length; pos++) {
- float conf_logit = -std::numeric_limits<float>::infinity();
-
- auto it = std::find(mask_positions.begin(), mask_positions.end(), pos);
- if (it != mask_positions.end()) {
- size_t mask_idx = std::distance(mask_positions.begin(), it);
- conf_logit = confidences[mask_idx].first / params.alg_temp; // Apply temperature scaling
+ if (params.add_gumbel_noise && params.temperature > 0.0f) {
+ add_gumbel_noise(logits, n_vocab, params.temperature, rng);
+ }
+
+ if (params.algorithm == ORIGIN) {
+ int32_t transfer_count = calculate_transfer_count(
+ step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens);
+ float p_transfer = (float) transfer_count / mask_positions.size();
+
+ for (int32_t pos : mask_positions) {
+ if (std::uniform_real_distribution<float>(0.0f, 1.0f)(rng) < p_transfer) {
+ const float * pos_logits = get_logits_for_pos(pos);
+ for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
+ candidates[token_id].id = token_id;
+ candidates[token_id].logit = pos_logits[token_id];
+ candidates[token_id].p = 0.0f;
}
- conf_candidates.emplace_back(llama_token_data{ pos, conf_logit, 0.0f });
+ llama_token_data_array cur_p = {
+ candidates.data(),
+ (size_t) n_vocab,
+ -1,
+ false,
+ };
+
+ llama_sampler_apply(sampler, &cur_p);
+ output_tokens[pos] = cur_p.data[cur_p.selected].id;
+ }
+ }
+ } else {
+ std::vector<std::pair<float, int32_t>> confidences;
+ std::vector<llama_token> sampled_tokens(mask_positions.size());
+
+ for (size_t i = 0; i < mask_positions.size(); i++) {
+ int32_t pos = mask_positions[i];
+ const float * pos_logits = get_logits_for_pos(pos);
+
+ for (int32_t token_id = 0; token_id < n_vocab; token_id++) {
+ candidates[token_id].logit = pos_logits[token_id];
+ candidates[token_id].p = 0.0f;
+ candidates[token_id].id = token_id;
}
- llama_token_data_array conf_array = {
- /* .data = */ conf_candidates.data(),
- /* .size = */ conf_candidates.size(),
- /* .selected = */ -1,
- /* .sorted = */ false,
+ llama_token_data_array cur_p = {
+ candidates.data(),
+ candidates.size(),
+ -1,
+ false,
};
- for (int32_t i = 0; i < num_transfer; i++) {
- // Apply distribution sampler to get selected index
- llama_sampler_apply(dist_sampler, &conf_array);
- int selected_idx = conf_array.selected;
- confidences[i].second = conf_candidates[selected_idx].id;
+ llama_sampler_apply(sampler, &cur_p);
+ llama_token sampled_token = cur_p.data[cur_p.selected].id;
+
+ float conf = calculate_confidence(cur_p, params.algorithm, rng);
- conf_candidates[selected_idx].p = 0.0f;
- conf_array.selected = -1;
- }
+ sampled_tokens[i] = sampled_token;
+ confidences.emplace_back(conf, i);
}
- if (params.alg_temp == 0.0f) {
- // Deterministic - use confidence order
- for (int32_t i = 0; i < num_transfer; i++) {
- int32_t mask_idx = confidences[i].second;
- int32_t pos = mask_positions[mask_idx];
- llama_token token = sampled_tokens[mask_idx];
- output_tokens[pos] = token;
- }
- } else {
- for (int32_t i = 0; i < num_transfer; i++) {
- int32_t pos = confidences[i].second;
- auto it = std::find(mask_positions.begin(), mask_positions.end(), pos);
- if (it != mask_positions.end()) {
- int32_t mask_idx = std::distance(mask_positions.begin(), it);
+ int32_t transfer_count = calculate_transfer_count(
+ step, steps_per_block, mask_positions.size(), params.schedule, params.eps, num_transfer_tokens);
+
+ if (transfer_count > 0) {
+ if (params.alg_temp == 0.0f) {
+ std::partial_sort(confidences.begin(),
+ confidences.begin() + std::min(transfer_count, (int32_t) confidences.size()),
+ confidences.end(),
+ [](const std::pair<float, int32_t> & a, const std::pair<float, int32_t> & b) {
+ if (a.first != b.first) {
+ return a.first > b.first;
+ }
+ return a.second < b.second;
+ });
+
+ for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) {
+ int32_t mask_idx = confidences[i].second;
+ int32_t pos = mask_positions[mask_idx];
output_tokens[pos] = sampled_tokens[mask_idx];
}
+ } else {
+ conf_candidates.clear();
+ for (size_t i = 0; i < confidences.size(); i++) {
+ float conf_logit = confidences[i].first / params.alg_temp;
+ conf_candidates.emplace_back(llama_token_data{ (int32_t) i, conf_logit, 0.0f });
+ }
+
+ llama_token_data_array conf_array = {
+ conf_candidates.data(),
+ conf_candidates.size(),
+ -1,
+ false,
+ };
+
+ for (int32_t i = 0; i < std::min(transfer_count, (int32_t) confidences.size()); i++) {
+ llama_sampler_apply(dist_sampler, &conf_array);
+ int32_t selected_idx = conf_array.selected;
+ int32_t mask_idx = selected_idx;
+ int32_t pos = mask_positions[mask_idx];
+ output_tokens[pos] = sampled_tokens[mask_idx];
+
+ conf_candidates[selected_idx].p = 0.0f;
+ conf_array.selected = -1;
+ }
}
}
}
+
+ int64_t time_end_sampling = ggml_time_us();
+ total_sampling_time += time_end_sampling - time_start_sampling;
}
- int64_t time_end_sampling = ggml_time_us();
- total_sampling_time += time_end_sampling - time_start_sampling;
}
+
int64_t time_end = ggml_time_us();
total_time += time_end - time_start;
LOG_INF("\ntotal time: %0.2fms, time per step: %0.2fms, sampling time per step: %0.2fms\n",
- total_time / 1000.0, total_time / 1000.0 / params.steps, total_sampling_time / 1000.0 / params.steps);
-
+ total_time / 1000.0,
+ total_time / 1000.0 / params.steps,
+ total_sampling_time / 1000.0 / params.steps);
llama_batch_free(batch);
llama_sampler_free(sampler);
llama_sampler_free(dist_sampler);
- n_generated = max_length;
+ n_generated = params.max_length;
}
-
-
-
static std::string format_input_text(const std::string & prompt, bool use_chat_template, llama_model * model) {
if (!use_chat_template) {
return prompt;
return result.prompt;
}
-struct callback_data {
- const common_params_diffusion * diff_params;
- const llama_vocab * vocab;
- int32_t n_input;
-};
-
-static bool diffusion_step_callback(int32_t step,
- int32_t total_steps,
- const llama_token * tokens,
- int32_t n_tokens,
- void * user_data) {
- (void)user_data;
-
- callback_data * data = static_cast<callback_data *>(user_data);
-
- auto print_progress_bar = [](int32_t step, int32_t total_steps) {
- int progress_percent = (step * 100) / total_steps;
- int progress_bars = (step * 50) / total_steps;
- LOG_INF("\rdiffusion step: %d/%d [%s%s] %d%%",
- step,
- total_steps,
- std::string(progress_bars, '=').c_str(),
- std::string(50 - progress_bars, ' ').c_str(),
- progress_percent);
- };
-
- if (data->diff_params->visual_mode) {
- // Visual mode: clear
- LOG_INF("\033[2J\033[H"); // Clear screen and move cursor to top-left
-
- print_progress_bar(step, total_steps);
-
- LOG_INF("\n");
-
- std::string current_text = " ";
-
- for (int32_t i = data->n_input; i < n_tokens; i++) {
- std::string token_str;
- if (tokens[i] != llama_vocab_mask(data->vocab)) {
- char piece[256];
- int n_chars = llama_token_to_piece(data->vocab, tokens[i], piece, sizeof(piece), 0, false);
- if (n_chars > 0) {
- piece[n_chars] = '\0';
- token_str = piece;
- }
- } else {
- token_str = " ";
- }
-
- current_text += token_str;
- }
-
- LOG_INF("%s\n", current_text.c_str());
- } else {
- print_progress_bar(step, total_steps);
- }
-
- return true;
-}
-
int main(int argc, char ** argv) {
ggml_time_init();
return 1;
}
- const char * alg_names[] = { "ORIGIN", "MASKGIT_PLUS", "TOPK_MARGIN", "ENTROPY" };
- const char * alg_name = (params.diffusion.algorithm >= 0 && params.diffusion.algorithm <= 3) ?
- alg_names[params.diffusion.algorithm] :
- "UNKNOWN";
-
common_init();
llama_backend_init();
return 1;
}
+ if (!llama_model_is_diffusion(model)) {
+ LOG_ERR("error: unsupported model for diffusion");
+ llama_model_free(model);
+ return 1;
+ }
+
llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = params.n_ctx;
ctx_params.n_batch = params.n_batch;
const llama_vocab * vocab = llama_model_get_vocab(model);
std::string formatted_prompt = format_input_text(params.prompt, params.enable_chat_template, model);
- std::vector<llama_token> input_tokens = common_tokenize(vocab, formatted_prompt,
+ std::vector<llama_token> input_tokens = common_tokenize(vocab,
+ formatted_prompt,
/*add special tokens*/ true,
/*parse special*/ true);
- int n_input = input_tokens.size();
+
+ int n_input = input_tokens.size();
if (n_input >= params.n_ctx) {
LOG_ERR("error: input too long (%d tokens), max context is %d\n", n_input, params.n_ctx);
return 1;
}
- struct diffusion_params ldiff_params = diffusion_default_params();
- ldiff_params.steps = params.diffusion.steps;
- ldiff_params.eps = params.diffusion.eps;
- ldiff_params.temperature = params.sampling.temp;
- ldiff_params.top_p = params.sampling.top_p;
- ldiff_params.top_k = params.sampling.top_k;
- ldiff_params.algorithm = static_cast<enum diffusion_alg>(params.diffusion.algorithm);
- ldiff_params.alg_temp = params.diffusion.alg_temp;
- ldiff_params.seed = params.sampling.seed;
-
llama_token mask_token_id = llama_vocab_mask(vocab);
GGML_ASSERT(mask_token_id != LLAMA_TOKEN_NULL);
- LOG_INF("diffusion_params: - %-25s llama_token = %d\n", "mask_token_id", mask_token_id);
- LOG_INF("diffusion_params: - %-25s u32 = %d\n", "steps", params.diffusion.steps);
- LOG_INF("diffusion_params: - %-25s f32 = %.6f\n", "eps", params.diffusion.eps);
- LOG_INF("diffusion_params: - %-25s u32 = %d (%s)\n", "algorithm", params.diffusion.algorithm,
- alg_name);
- LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "alg_temp", params.diffusion.alg_temp);
+ bool visual_mode = params.diffusion.visual_mode;
- ldiff_params.mask_token_id = mask_token_id;
+ int32_t n_generated = 0;
+ std::vector<llama_token> output_tokens(params.n_ubatch);
- callback_data cb_data = { ¶ms.diffusion, vocab, n_input };
+ struct diffusion_params diff_params;
- ldiff_params.step_callback = diffusion_step_callback;
- ldiff_params.step_callback_user_data = &cb_data;
+ char shift_logits_str[8];
+ if (llama_model_meta_val_str(model, "diffusion.shift_logits", shift_logits_str, sizeof(shift_logits_str)) >= 0) {
+ diff_params.shift_logits = (strcmp(shift_logits_str, "true") == 0);
+ } else {
+ diff_params.shift_logits = true;
+ }
- int32_t n_generated = 0;
+ //Use either eps or block length, but not both
+ GGML_ASSERT((params.diffusion.eps == 0) ^ (params.diffusion.block_length == 0));
- std::vector<llama_token> output_tokens(params.n_ubatch);
- diffusion_generate(ctx, input_tokens.data(), output_tokens.data(), n_input, params.n_ubatch,
- ldiff_params, n_generated);
+ if (params.diffusion.eps) {
+ diff_params.schedule = TIMESTEP_BASED;
+ diff_params.eps = params.diffusion.eps;
+ } else if (params.diffusion.block_length) {
+ diff_params.schedule = BLOCK_BASED;
+ diff_params.block_length = params.diffusion.block_length;
+ }
+
+ diff_params.mask_token_id = mask_token_id;
+ diff_params.seed = params.sampling.seed;
+ diff_params.temperature = params.sampling.temp;
+ diff_params.steps = params.diffusion.steps;
+ diff_params.algorithm = static_cast<diffusion_algorithm>(params.diffusion.algorithm);
+ diff_params.max_length = params.n_ubatch;
+ diff_params.top_p = params.sampling.top_p;
+ diff_params.top_k = params.sampling.top_k;
+ diff_params.visual_mode = params.diffusion.visual_mode;
+ diff_params.add_gumbel_noise = params.diffusion.add_gumbel_noise;
+
+ diff_params.step_callback = diffusion_step_callback;
+ callback_data cb_data = { &diff_params, vocab, n_input };
+ diff_params.step_callback_user_data = &cb_data;
+
+ const char * alg_names[] = { "ORIGIN", "ENTROPY_BASED", "MARGIN_BASED", "RANDOM", "CONFIDENCE_BASED" };
+ const char * sched_names[] = { "TIMESTEP_BASED", "BLOCK_BASED" };
+ const char * alg_name =
+ (diff_params.algorithm >= 0 && diff_params.algorithm <= 4) ? alg_names[diff_params.algorithm] : "UNKNOWN";
+ const char * sched_name =
+ (diff_params.schedule >= 0 && diff_params.schedule <= 1) ? sched_names[diff_params.schedule] : "UNKNOWN";
+
+ LOG_INF("diffusion_params: - %-25s llama_token = %d\n", "mask_token_id", mask_token_id);
+ LOG_INF("diffusion_params: - %-25s u32 = %d\n", "steps", diff_params.steps);
+ LOG_INF("diffusion_params: - %-25s u32 = %d\n", "max_length", diff_params.max_length);
+ LOG_INF("diffusion_params: - %-25s enum = %d (%s)\n", "algorithm", diff_params.algorithm, alg_name);
+ LOG_INF("diffusion_params: - %-25s enum = %d (%s)\n", "schedule", diff_params.schedule, sched_name);
+ LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "temperature", diff_params.temperature);
+ if (diff_params.schedule == TIMESTEP_BASED) {
+ LOG_INF("diffusion_params: - %-25s f32 = %.6f\n", "eps", diff_params.eps);
+ LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "alg_temp", diff_params.alg_temp);
+ }
+ if (diff_params.schedule == BLOCK_BASED) {
+ LOG_INF("diffusion_params: - %-25s u32 = %d\n", "block_length", diff_params.block_length);
+ LOG_INF("diffusion_params: - %-25s f32 = %.3f\n", "cfg_scale", diff_params.cfg_scale);
+ }
+
+ diffusion_generate(ctx, input_tokens.data(), output_tokens.data(), n_input, diff_params, n_generated);
if (n_generated > 0) {
- if (params.diffusion.visual_mode) {
+ if (visual_mode) {
//clear screen and move cursor to top-left
LOG_INF("\033[2J\033[H");
}
+
output_tokens.erase(output_tokens.begin(), output_tokens.begin() + n_input);
std::string output_data = common_detokenize(vocab, output_tokens, false);
LOG_INF("\n%s\n", output_data.c_str());
hparams.causal_attn = false;
}
break;
+ case LLM_ARCH_LLADA:
+ {
+ ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
+ // LLaDA-8B has 32 layers, similar to LLaMA but for diffusion
+ switch (hparams.n_layer) {
+ case 32:
+ type = LLM_TYPE_8B;
+ break;
+ default:
+ type = LLM_TYPE_UNKNOWN;
+ }
+ // Set non-causal attention for diffusion models
+ hparams.causal_attn = false;
+ }
+ break;
case LLM_ARCH_QWEN2MOE:
{
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
}
}
} break;
+ case LLM_ARCH_LLADA:
+ {
+ tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
+
+ // output
+ output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
+ output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
+
+ // if output is NULL, init from the input tok embed
+ if (output == NULL) {
+ output =
+ create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED);
+ }
+
+ for (int i = 0; i < n_layer; ++i) {
+ auto & layer = layers[i];
+
+ layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
+
+ // Use separate Q, K, V projections without bias, matching LLaDALlamaBlock
+ layer.wq =
+ create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0);
+ layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0);
+ layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0);
+ // No bias for QKV projections as per config: include_bias=false, include_qkv_bias=false
+ layer.wo =
+ create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
+ layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED);
+
+ layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0);
+
+ layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_rot / 2 },
+ TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0));
+
+ layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), { n_embd, n_ff }, 0);
+ layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0);
+ layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0);
+
+ // optional MLP bias
+ layer.ffn_gate_b =
+ create_tensor(tn(LLM_TENSOR_FFN_GATE, "bias", i), { n_ff }, TENSOR_NOT_REQUIRED);
+ layer.ffn_down_b =
+ create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED);
+ layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), { n_ff }, TENSOR_NOT_REQUIRED);
+ }
+ }
+ break;
case LLM_ARCH_LLAMA4:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
}
};
+struct llm_build_llada : public llm_graph_context {
+ llm_build_llada(const llama_model & model, const llm_graph_params & params) :
+ llm_graph_context(params) {
+ // LLaDA is similar to LLaMA but uses non-causal attention for diffusion
+ const int64_t n_embd_head = hparams.n_embd_head_v;
+
+ GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
+ GGML_ASSERT(n_embd_head == hparams.n_rot);
+
+ ggml_tensor * cur;
+ ggml_tensor * inpL;
+
+ inpL = build_inp_embd(model.tok_embd);
+
+ // inp_pos - contains the positions
+ ggml_tensor * inp_pos = build_inp_pos();
+
+ // Non-causal attention for diffusion
+ auto * inp_attn = build_attn_inp_no_cache();
+
+ ggml_tensor * inp_out_ids = build_inp_out_ids();
+
+ for (int il = 0; il < n_layer; ++il) {
+ ggml_tensor * inpSA = inpL;
+
+ // norm
+ cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
+ cb(cur, "attn_norm", il);
+
+ // self-attention
+ {
+ // compute separate Q, K, V projections without bias, matching LLaDALlamaBlock
+ ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
+ ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
+ ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
+
+ cb(Qcur, "Qcur", il);
+ cb(Kcur, "Kcur", il);
+ cb(Vcur, "Vcur", il);
+
+ Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
+ Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
+ Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
+
+ Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow);
+
+ Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
+ ext_factor, attn_factor, beta_fast, beta_slow);
+
+ cb(Qcur, "Qcur", il);
+ cb(Kcur, "Kcur", il);
+ cb(Vcur, "Vcur", il);
+
+ cur = build_attn(inp_attn, model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr,
+ 1.0f / sqrtf(float(n_embd_head)), il);
+ }
+
+ if (il == n_layer - 1 && inp_out_ids) {
+ cur = ggml_get_rows(ctx0, cur, inp_out_ids);
+ inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
+ }
+
+ ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
+ cb(ffn_inp, "ffn_inp", il);
+
+ // feed-forward network
+ cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
+ cb(cur, "ffn_norm", il);
+
+ cur = build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL,
+ model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
+ cb(cur, "ffn_out", il);
+
+ cur = ggml_add(ctx0, cur, ffn_inp);
+
+ cur = build_cvec(cur, il);
+ cb(cur, "l_out", il);
+
+ // input for next layer
+ inpL = cur;
+ }
+
+ cur = inpL;
+
+ cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
+
+ cb(cur, "result_norm", -1);
+ res->t_embd = cur;
+
+ // lm_head
+ cur = build_lora_mm(model.output, cur);
+
+ cb(cur, "result_output", -1);
+ res->t_logits = cur;
+
+ ggml_build_forward_expand(gf, cur);
+ }
+};
+
struct llm_build_qwen2vl : public llm_graph_context {
llm_build_qwen2vl(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
const int64_t n_embd_head = hparams.n_embd_head_v;
case LLM_ARCH_NEO_BERT:
case LLM_ARCH_WAVTOKENIZER_DEC:
case LLM_ARCH_DREAM:
+ case LLM_ARCH_LLADA:
{
res = nullptr;
} break;
llm = std::make_unique<llm_build_dream>(*this, params);
}
break;
+ case LLM_ARCH_LLADA:
+ {
+ llm = std::make_unique<llm_build_llada>(*this, params);
+ }
+ break;
case LLM_ARCH_QWEN2VL:
{
llm = std::make_unique<llm_build_qwen2vl>(*this, params);
// use what we call a normal RoPE, operating on pairs of consecutive head values
case LLM_ARCH_LLAMA:
+ case LLM_ARCH_LLADA:
case LLM_ARCH_LLAMA4:
case LLM_ARCH_DECI:
case LLM_ARCH_BAICHUAN:
return llm_arch_is_recurrent(model->arch);
}
+bool llama_model_is_diffusion(const llama_model * model) {
+ return llm_arch_is_diffusion(model->arch);
+}
+
const std::vector<std::pair<std::string, ggml_tensor *>> & llama_internal_get_tensor_map(const llama_model * model) {
return model->tensors_by_name;
}