VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)");
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
+ uint64_t total_mat_mul_bytes = 0;
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_vk_build_graph(ctx, cgraph->nodes[i], i, nullptr, 0, true, false, false);
+ if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
+ total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
+ }
}
if (ctx->device->need_compiles) {
ggml_vk_load_shaders(ctx->device);
bool first_node_in_batch = true; // true if next node will be first node in a batch
int submit_node_idx = 0; // index to first node in a batch
- // Submit work every nodes_per_submit nodes to overlap CPU cmdbuffer generation with GPU execution.
- // Start with a smaller count to get work submitted right away, and increase it after each submit.
- int nodes_per_submit = 20;
+ // Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution.
+ // Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB
+ // (and scaled down based on model size, so smaller models submit earlier).
+ // Also submit at least every 100 nodes, in case there are workloads without as much matmul.
+ int nodes_per_submit = 100;
int submitted_nodes = 0;
int submit_count = 0;
+ uint64_t mul_mat_bytes = 0;
+ uint64_t mul_mat_bytes_per_submit = std::min(uint64_t(100*1000*1000), total_mat_mul_bytes / 40u);
for (int i = 0; i < cgraph->n_nodes; i++) {
if (first_node_in_batch) {
submit_node_idx = i;
}
- bool submit = (submitted_nodes >= nodes_per_submit) || (i == last_node);
+ if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) {
+ mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]);
+ }
+
+ bool submit = (submitted_nodes >= nodes_per_submit) ||
+ (mul_mat_bytes >= mul_mat_bytes_per_submit) ||
+ (i == last_node);
bool enqueued = ggml_vk_build_graph(ctx, cgraph->nodes[i], i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i == last_node, submit);
if (submit) {
first_node_in_batch = true;
submitted_nodes = 0;
- switch (submit_count) {
- case 0:
- nodes_per_submit = 50;
- break;
- default:
- nodes_per_submit = 100;
- break;
+ mul_mat_bytes = 0;
+ if (submit_count < 3) {
+ mul_mat_bytes_per_submit *= 2;
}
submit_count++;
}