typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
// Compare the output of two backends
- GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data);
+ GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node);
// Tensor initialization
GGML_API enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);
}
if (sched->debug > 1) {
ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node);
- GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s]:", i, ggml_op_name(node->op), node->name,
- fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node));
+ GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d:", i, ggml_op_name(node->op), node->name,
+ fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node),
+ graph->use_counts[ggml_hash_find(&graph->visited_hash_set, node)]);
for (int j = 0; j < GGML_MAX_SRC; j++) {
struct ggml_tensor * src = node->src[j];
if (src == NULL) {
ggml_free(copy.ctx_unallocated);
}
-bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data) {
+bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node) {
struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph);
if (copy.buffer == NULL) {
return false;
assert(g1->n_nodes == g2->n_nodes);
- for (int i = 0; i < g1->n_nodes; i++) {
- struct ggml_tensor * t1 = g1->nodes[i];
- struct ggml_tensor * t2 = g2->nodes[i];
+ if (test_node != nullptr) {
+ // Compute the whole graph and only test the output for a specific tensor
+ ggml_backend_graph_compute(backend1, g1);
+ ggml_backend_graph_compute(backend2, g2);
- assert(t1->op == t2->op && ggml_are_same_layout(t1, t2));
+ int test_node_idx = -1;
+ for (int i = 0; i < g1->n_nodes; i++) {
+ struct ggml_tensor * t1 = g1->nodes[i];
+ if (t1 == test_node) {
+ test_node_idx = i;
+ break;
+ }
+ }
+ GGML_ASSERT(test_node_idx != -1);
- struct ggml_cgraph g1v = ggml_graph_view(g1, i, i + 1);
- struct ggml_cgraph g2v = ggml_graph_view(g2, i, i + 1);
+ callback(test_node_idx, g1->nodes[test_node_idx], g2->nodes[test_node_idx], user_data);
+ } else {
+ for (int i = 0; i < g1->n_nodes; i++) {
+ struct ggml_tensor * t1 = g1->nodes[i];
+ struct ggml_tensor * t2 = g2->nodes[i];
- ggml_backend_graph_compute(backend1, &g1v);
- ggml_backend_graph_compute(backend2, &g2v);
+ assert(t1->op == t2->op && ggml_are_same_layout(t1, t2));
- if (ggml_is_view_op(t1->op)) {
- continue;
- }
+ struct ggml_cgraph g1v = ggml_graph_view(g1, i, i + 1);
+ struct ggml_cgraph g2v = ggml_graph_view(g2, i, i + 1);
- // compare results, calculate rms etc
- if (!callback(i, t1, t2, user_data)) {
- break;
+ ggml_backend_graph_compute(backend1, &g1v);
+ ggml_backend_graph_compute(backend2, &g2v);
+
+ if (ggml_is_view_op(t1->op)) {
+ continue;
+ }
+
+ // compare results, calculate rms etc
+ if (!callback(i, t1, t2, user_data)) {
+ break;
+ }
}
}
-
ggml_backend_graph_copy_free(copy);
return true;
struct ggml_tensor ** grads; // the outputs of these tensors are the gradients of the nodes
struct ggml_tensor ** grad_accs; // accumulators for node gradients
struct ggml_tensor ** leafs; // tensors with constant data
+ int32_t * use_counts;// number of uses of each tensor, indexed by hash table slot
struct ggml_hash_set visited_hash_set;
#define GGML_FP32_TO_BF16(x) ggml_compute_fp32_to_bf16(x)
#define GGML_BF16_TO_FP32(x) ggml_compute_bf16_to_fp32(x)
+// return true if the node's results are only used by N other nodes
+// and can be fused into their calculations.
+static inline bool ggml_node_has_n_uses(const struct ggml_cgraph * cgraph, int node_idx, int32_t n_uses) {
+ const struct ggml_tensor * node = cgraph->nodes[node_idx];
+
+ // check the use count against how many we're replacing
+ size_t hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
+ if (!ggml_bitset_get(cgraph->visited_hash_set.used, hash_pos) || cgraph->use_counts[hash_pos] != n_uses) {
+ return false;
+ }
+
+ // if node is a view, some other node might be using the intermediate result
+ // via the view source.
+ if (node->view_src) {
+ return false;
+ }
+
+ // If the user requested output for the node, can't fuse
+ if (node->flags & GGML_TENSOR_FLAG_OUTPUT) {
+ return false;
+ }
+
+ return true;
+}
+
+// Returns true if nodes [i, i+ops.size()) are the sequence of ggml_ops in ops[]
+// and are fusable. Nodes are considered fusable according to this function if:
+// - all nodes except the last have only one use and are not views/outputs (see ggml_node_has_N_uses).
+// - all nodes except the last are a src of the following node.
+// - all nodes are the same shape.
+// TODO: Consider allowing GGML_OP_NONE nodes in between
+static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, const enum ggml_op * ops, int num_ops) {
+ if (node_idx + num_ops > cgraph->n_nodes) {
+ return false;
+ }
+
+ for (int i = 0; i < num_ops; ++i) {
+ struct ggml_tensor * node = cgraph->nodes[node_idx + i];
+ if (node->op != ops[i]) {
+ return false;
+ }
+ if (i < num_ops - 1 && !ggml_node_has_n_uses(cgraph, node_idx + i, 1)) {
+ return false;
+ }
+ if (i > 0) {
+ struct ggml_tensor * prev = cgraph->nodes[node_idx + i - 1];
+ if (node->src[0] != prev && node->src[1] != prev) {
+ return false;
+ }
+ if (!ggml_are_same_shape(node, prev)) {
+ return false;
+ }
+ }
+ }
+ return true;
+}
+
#ifdef __cplusplus
}
#endif
#ifdef __cplusplus
+#include <initializer_list>
#include <vector>
+// nicer C++ syntax for ggml_can_fuse
+inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops) {
+ return ggml_can_fuse(cgraph, node_idx, ops.begin(), (int)ops.size());
+}
+
// expose GGUF internals for test code
GGML_API size_t gguf_type_size(enum gguf_type type);
GGML_API struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_params params);
vk_pipeline pipeline_norm_f32;
vk_pipeline pipeline_group_norm_f32;
vk_pipeline pipeline_rms_norm_f32;
+ vk_pipeline pipeline_rms_norm_mul_f32;
vk_pipeline pipeline_rms_norm_back_f32;
vk_pipeline pipeline_l2_norm_f32;
vk_command_pool compute_cmd_pool;
vk_command_pool transfer_cmd_pool;
+
+ // number of additional consecutive nodes that are being fused with the
+ // node currently being processed
+ uint32_t num_additional_fused_ops {};
};
static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT
ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
- ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1);
+ ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1);
return nullptr;
case GGML_OP_RMS_NORM:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
- return ctx->device->pipeline_rms_norm_f32;
+ return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32;
}
return nullptr;
case GGML_OP_RMS_NORM_BACK:
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun);
}
-static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
+static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
float * op_params = (float *)dst->op_params;
const uint32_t src0_type_size = ggml_type_size(src0->type);
+ const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
- ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_RMS_NORM, {
+ ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, {
(uint32_t)ggml_nelements(src0),
- (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
- (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
+ (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
+ (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
+ (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
- op_params[0], 0.0f,
- 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ op_params[0], 0.0f, 0,
}, dryrun);
}
// Returns true if node has enqueued work into the queue, false otherwise
// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution.
-static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_tensor * node, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){
+static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){
+ ggml_tensor * node = cgraph->nodes[node_idx];
if (ggml_is_empty(node) || !node->buffer) {
return false;
}
break;
case GGML_OP_RMS_NORM:
- ggml_vk_rms_norm(ctx, compute_ctx, src0, node, dryrun);
-
+ if (ctx->num_additional_fused_ops > 0) {
+ // fused rms_norm + mul
+ ggml_tensor *mul = cgraph->nodes[node_idx + 1];
+ ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0];
+ ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, dryrun);
+ } else {
+ ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, dryrun);
+ }
break;
case GGML_OP_RMS_NORM_BACK:
ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun);
uint64_t total_mat_mul_bytes = 0;
for (int i = 0; i < cgraph->n_nodes; i++) {
- ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false, false);
+ if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
+ ctx->num_additional_fused_ops = 1;
+ }
+ ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false);
if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
}
+ i += ctx->num_additional_fused_ops;
+ ctx->num_additional_fused_ops = 0;
}
if (ctx->device->need_compiles) {
ggml_vk_load_shaders(ctx->device);
mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
}
+ if (ggml_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) {
+ ctx->num_additional_fused_ops = 1;
+ }
+
// Signal the almost_ready fence when the graph is mostly complete (< 20% remaining)
bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5;
bool submit = (submitted_nodes >= nodes_per_submit) ||
(mul_mat_bytes >= mul_mat_bytes_per_submit) ||
- (i == last_node) ||
+ (i + ctx->num_additional_fused_ops == last_node) ||
(almost_ready && !ctx->almost_ready_fence_pending);
- bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, almost_ready, submit);
+ bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops == last_node, almost_ready, submit);
if (vk_perf_logger_enabled) {
if (ctx->compute_ctx.expired()) {
} else {
compute_ctx = ctx->compute_ctx.lock();
}
- compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+1);
+ // If there are fused ops, just write out timestamps for all nodes to keep the accounting simple
+ for (int j = 0; j < ctx->num_additional_fused_ops + 1; ++j) {
+ compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+j+1);
+ }
}
if (enqueued) {
}
submit_count++;
}
+ i += ctx->num_additional_fused_ops;
+ ctx->num_additional_fused_ops = 0;
}
if (vk_perf_logger_enabled) {
#version 450
-#include "generic_unary_head.comp"
+#include "generic_binary_head.comp"
#include "types.comp"
#extension GL_EXT_control_flow_attributes : enable
#define BLOCK_SIZE 512
+layout (constant_id = 1) const bool do_multiply = false;
+
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
shared FLOAT_TYPE sum[BLOCK_SIZE];
const uint stride_sample = p.nb03;
uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset();
+ uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset();
uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset();
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
const FLOAT_TYPE mean = sum[0] / FLOAT_TYPE(ncols);
const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1));
- [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
- data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
+ if (do_multiply) {
+ [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
+ data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col]));
+ }
+ } else {
+ [[unroll]] for (uint col = tid; col < ncols; col += BLOCK_SIZE) {
+ data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]));
+ }
}
}
// Norms
string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
- string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
+ string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(src2, cgraph->grads[isrc2]));
}
-static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
+static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
// check if already visited
- if (ggml_hash_insert(&cgraph->visited_hash_set, node) == GGML_HASHSET_ALREADY_EXISTS) {
- return;
+ size_t node_hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node);
+ GGML_ASSERT(node_hash_pos != GGML_HASHSET_FULL);
+ if (!ggml_bitset_get(cgraph->visited_hash_set.used, node_hash_pos)) {
+ // This is the first time we see this node in the current graph.
+ cgraph->visited_hash_set.keys[node_hash_pos] = node;
+ ggml_bitset_set(cgraph->visited_hash_set.used, node_hash_pos);
+ cgraph->use_counts[node_hash_pos] = 0;
+ } else {
+ // already visited
+ return node_hash_pos;
}
for (int i = 0; i < GGML_MAX_SRC; ++i) {
const int k =
(cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i :
(cgraph->order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? (GGML_MAX_SRC-1-i) :
- /* unknown order, just fall back to using i*/ i;
- if (node->src[k]) {
- ggml_visit_parents(cgraph, node->src[k]);
+ /* unknown order, just fall back to using i */ i;
+
+ struct ggml_tensor * src = node->src[k];
+ if (src) {
+ size_t src_hash_pos = ggml_visit_parents(cgraph, src);
+
+ // Update the use count for this operand.
+ cgraph->use_counts[src_hash_pos]++;
}
}
cgraph->nodes[cgraph->n_nodes] = node;
cgraph->n_nodes++;
}
+
+ return node_hash_pos;
}
static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) {
incr_ptr_aligned(&p, sizeof(struct ggml_cgraph), 1);
incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // nodes
incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // leafs
+ incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t)); // use_counts
incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // hash keys
if (grads) {
incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grads
void * p = cgraph + 1;
- struct ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
- struct ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
- struct ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
- struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
- struct ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
+ struct ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
+ struct ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
+ int32_t * use_counts_ptr = incr_ptr_aligned(&p, hash_size * sizeof(int32_t), sizeof(int32_t));
+ struct ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *));
+ struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
+ struct ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL;
ggml_bitset_t * hash_used = incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t));
/*.grads =*/ grads_ptr,
/*.grad_accs =*/ grad_accs_ptr,
/*.leafs =*/ leafs_ptr,
+ /*.use_counts =*/ use_counts_ptr,
/*.hash_table =*/ { hash_size, hash_used, hash_keys_ptr },
/*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT,
};
/*.grads =*/ NULL, // gradients would need visited_hash_set
/*.grad_accs =*/ NULL,
/*.leafs =*/ NULL,
- /*.visited_hash_set =*/ { 0, NULL, NULL },
+ /*.use_counts =*/ cgraph0->use_counts,
+ /*.visited_hash_set =*/ cgraph0->visited_hash_set,
/*.order =*/ cgraph0->order,
};
for (size_t i = 0; i < src->visited_hash_set.size; ++i) {
// copy all hashset keys (tensors) that are in use
if (ggml_bitset_get(src->visited_hash_set.used, i)) {
- ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]);
+ size_t new_hash_pos = ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]);
+ dst->use_counts[new_hash_pos] = src->use_counts[i];
}
}
return 0;
}
+ virtual bool run_whole_graph() { return false; }
+
ggml_cgraph * gf = nullptr;
ggml_cgraph * gb = nullptr;
GGML_UNUSED(index);
};
- const bool cmp_ok = ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud);
+ const bool cmp_ok = ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud, run_whole_graph() ? out : nullptr);
if (!cmp_ok) {
printf("compare failed ");
}
};
+// GGML_OP_RMS_NORM + GGML_OP_MUL
+struct test_rms_norm_mul : public test_case {
+ const ggml_type type;
+ const std::array<int64_t, 4> ne;
+ const float eps;
+
+ std::string op_desc(ggml_tensor * t) override {
+ GGML_UNUSED(t);
+ return "RMS_NORM_MUL";
+ }
+
+ bool run_whole_graph() override { return true; }
+
+ std::string vars() override {
+ return VARS_TO_STR3(type, ne, eps);
+ }
+
+ test_rms_norm_mul(ggml_type type = GGML_TYPE_F32,
+ std::array<int64_t, 4> ne = {64, 5, 4, 3},
+ float eps = 1e-6f)
+ : type(type), ne(ne), eps(eps) {}
+
+ ggml_tensor * build_graph(ggml_context * ctx) override {
+ ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
+ ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data());
+ ggml_set_param(a);
+ ggml_set_name(a, "a");
+ ggml_set_param(b);
+ ggml_set_name(b, "b");
+
+ // Use a and b early, so we don't end up with an OP_NONE between rms_norm and mul
+ a = ggml_add(ctx, a, b);
+ ggml_tensor * out = ggml_mul(ctx, ggml_rms_norm(ctx, a, eps), b);
+ ggml_set_name(out, "out");
+
+ return out;
+ }
+
+ void initialize_tensors(ggml_context * ctx) override {
+ for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+ init_tensor_uniform(t, -10.f, 10.f);
+ }
+ }
+
+ double max_nmse_err() override {
+ return 1e-6;
+ }
+
+ float grad_eps() override {
+ return 1.0f;
+ }
+
+ bool grad_precise() override {
+ return true;
+ }
+};
+
// GGML_OP_SSM_CONV
struct test_ssm_conv : public test_case {
const ggml_type type;
static constexpr float attn_factor = 1.0f;
static constexpr float beta_fast = 32.0f;
static constexpr float beta_slow = 1.0f;
+ bool fused;
std::string op_desc(ggml_tensor * t) override {
GGML_UNUSED(t);
return 2e-3;
}
- test_llama(int n_tokens = 1)
+ bool run_whole_graph() override { return fused; }
+
+ test_llama(int n_tokens = 1, bool fused = false)
: test_llm({
/*n_vocab =*/ 32000,
/*n_embd =*/ 3200,
/*f_norm_eps =*/ 0.f,
/*f_norm_rms_eps =*/ 1e-5f,
/*n_tokens =*/ n_tokens,
- }) {
+ })
+ , fused(fused)
+ {
}
ggml_tensor * build_graph(ggml_context * ctx) override {
test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps));
}
+ for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f}) {
+ test_cases.emplace_back(new test_rms_norm_mul(GGML_TYPE_F32, {64, 5, 4, 3}, eps));
+ }
test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, {64, 5, 4, 3}, 1e-12f));
test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}));
+ test_cases.emplace_back(new test_llama(2, true));
// these tests are disabled to save execution time, but they can be handy for debugging
#if 0
test_cases.emplace_back(new test_llama(1));