]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: support CPY from any type to itself (#13695)
authorJeff Bolz <redacted>
Fri, 23 May 2025 04:45:02 +0000 (00:45 -0400)
committerGitHub <redacted>
Fri, 23 May 2025 04:45:02 +0000 (06:45 +0200)
Reuse the f16/f32 copy shaders, and just scale the number of elements
according to the type size.

ggml/src/ggml-vulkan/ggml-vulkan.cpp

index 5b62a6e0bec0a5b94b59252f682e0845fa0cdb1f..cf6dd9d44b8a3d2dc89f158c05148b23fb02e0ec 100644 (file)
@@ -4676,6 +4676,19 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
         }
     }
 
+    if (src->type == to) {
+        // Copy two or four bytes at a time, depending on block size.
+        // For quantized types, we scale by block size/type size. But
+        // this path is also used for bf16->bf16 for example, where the
+        // type size must be exactly 2 or 4.
+        GGML_ASSERT(ggml_is_quantized(to) || ggml_type_size(src->type) == 2 || ggml_type_size(src->type) == 4);
+        if ((ggml_type_size(src->type) % 4) == 0) {
+            return ctx->device->pipeline_contig_cpy_f32_f32;
+        } else {
+            return ctx->device->pipeline_contig_cpy_f16_f16;
+        }
+    }
+
     std::cerr << "Missing CPY op for types: " << ggml_type_name(src->type) << " " << ggml_type_name(to) << std::endl;
     GGML_ABORT("fatal error");
 }
@@ -6737,7 +6750,16 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
     case GGML_OP_UNARY:
     case GGML_OP_CONV_2D_DW:
         {
-            const uint32_t ne = ggml_nelements(dst);
+            uint32_t ne = ggml_nelements(dst);
+            if (op == GGML_OP_CPY && ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
+                // Convert from number of logical elements to 2- or 4-byte units.
+                ne /= ggml_blck_size(src0->type);
+                if ((ggml_type_size(src0->type) % 4) == 0) {
+                    ne *= ggml_type_size(src0->type) / 4;
+                } else {
+                    ne *= ggml_type_size(src0->type) / 2;
+                }
+            }
             if (ne > 262144) {
                 elements = { 512, 512, CEIL_DIV(ne, 262144) };
             } else if (ne > 512) {
@@ -7287,8 +7309,19 @@ static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const
     const uint32_t src0_type_size = ggml_type_size(src0->type);
     const uint32_t dst_type_size = ggml_type_size(dst->type);
 
+    uint32_t ne = (uint32_t)ggml_nelements(src0);
+    if (ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) {
+        // Convert from number of logical elements to 2- or 4-byte units.
+        ne /= ggml_blck_size(src0->type);
+        if ((ggml_type_size(src0->type) % 4) == 0) {
+            ne *= ggml_type_size(src0->type) / 4;
+        } else {
+            ne *= ggml_type_size(src0->type) / 2;
+        }
+    }
+
     ggml_vk_op_f32<vk_op_unary_push_constants>(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, {
-        (uint32_t)ggml_nelements(src0),
+        ne,
         (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], (uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
         (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], (uint32_t) dst->nb[0] /  dst_type_size, (uint32_t) dst->nb[1] /  dst_type_size, (uint32_t) dst->nb[2] /  dst_type_size, (uint32_t) dst->nb[3] /  dst_type_size,
         0,
@@ -9872,6 +9905,15 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                 if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
                     return true;
                 }
+
+                // We can handle copying from a type to the same type if it's
+                // contiguous (memcpy). We use f16 or f32 shaders to do the copy,
+                // so the type/block size must be a multiple of 4.
+                if (src0_type == src1_type &&
+                    ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op) &&
+                    (ggml_type_size(src0_type) % 2) == 0) {
+                    return true;
+                }
                 return false;
             } break;
         case GGML_OP_REPEAT: