]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
vulkan: support GGML_OP_SET (llama/19584)
authorJeff Bolz <redacted>
Sat, 14 Feb 2026 05:36:38 +0000 (21:36 -0800)
committerGeorgi Gerganov <redacted>
Sat, 14 Feb 2026 22:20:18 +0000 (00:20 +0200)
src/ggml-vulkan/ggml-vulkan.cpp
src/ggml-vulkan/vulkan-shaders/acc.comp

index e919d2223e77219894ce9af39e64c2e751a1436a..a9f75f0d00dc98b504db2da8125018c5557b5882 100644 (file)
@@ -688,6 +688,7 @@ struct vk_device_struct {
     vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
     vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
     vk_pipeline pipeline_acc_f32;
+    vk_pipeline pipeline_set_f32;
 
     // [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16][dst 0=fp32,1=fp16]
     vk_pipeline pipeline_add[2][2][2];
@@ -4182,7 +4183,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
     ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 1}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_set_f32, "set_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 0}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
@@ -8822,6 +8824,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
             return ctx->device->pipeline_acc_f32;
         }
         return nullptr;
+    case GGML_OP_SET:
+        if (src0->type == src1->type && src0->type == dst->type &&
+            (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32)) {
+            return ctx->device->pipeline_set_f32;
+        }
+        return nullptr;
     case GGML_OP_ADD:
     case GGML_OP_SUB:
     case GGML_OP_MUL:
@@ -9813,7 +9821,7 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const
     int nb3 = dst->op_params[2] / src0_type_size; // 4 bytes of float32
     int offset = dst->op_params[3] / src0_type_size; // offset in bytes
 
-    ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ACC, {
+    ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, {
         (uint32_t)ggml_nelements(src0),
         (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3,
         (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
@@ -12507,6 +12515,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
 
         break;
     case GGML_OP_ACC:
+    case GGML_OP_SET:
         ggml_vk_acc(ctx, compute_ctx, src0, src1, node);
 
         break;
@@ -14967,7 +14976,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
             }
             return op->src[0]->type == GGML_TYPE_F32;
         case GGML_OP_ACC:
-            return op->src[0]->type == GGML_TYPE_F32;
+            return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
+        case GGML_OP_SET:
+            return op->src[0]->type == op->src[1]->type && op->src[0]->type == op->type &&
+                   (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_I32);
         case GGML_OP_CONCAT:
             return ggml_type_size(op->src[0]->type) == ggml_type_size(GGML_TYPE_F32);
         case GGML_OP_ADD1:
@@ -15618,6 +15630,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
             tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]);
         } else if (tensor->op == GGML_OP_ACC) {
             tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
+        } else if (tensor->op == GGML_OP_SET) {
+            tensor_clone = ggml_set(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]);
         } else if (tensor->op == GGML_OP_NORM) {
             tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params);
         } else if (tensor->op == GGML_OP_GROUP_NORM) {
index 3d61168b56f09f97a76042f62370279dfe73a204..6ba3d1d89e0a96421c499b2a9c3791d6c2ee67f2 100644 (file)
@@ -3,6 +3,9 @@
 #include "types.glsl"
 #include "generic_binary_head.glsl"
 
+// false for SET, true for ACC
+layout(constant_id = 1) const bool ACC = true;
+
 layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
 
 void main() {
@@ -23,7 +26,11 @@ void main() {
     uint i00, i01, i02, i03;
 
     if (i0 < p.ne10 && i1 < p.ne11 && i2 < p.ne12 && i3 < p.ne13) {
-        data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)]));
+        if (ACC) {
+            data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)]));
+        } else {
+            data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)]));
+        }
     } else {
         data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]));
     }