]> git.djapps.eu Git - pkg/ggml/sources/whisper.cpp/commitdiff
vulkan: Additional type support for unary, binary, and copy (llama/13266)
authorJeff Bolz <redacted>
Sun, 4 May 2025 05:17:16 +0000 (00:17 -0500)
committerGeorgi Gerganov <redacted>
Wed, 7 May 2025 18:00:32 +0000 (21:00 +0300)
Support f16->f32 copy.
Support f16->f16 and f32->f32 unary ops.
Support all combinations of f16/f32 for src0/src1/dst for add/sub/mul/div.

ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/relu.comp
ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp
ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

index c4df47abae4ea392a69bb2424b8c72dbf07919f0..c61a8cf0af4e9bbecfb503963c08f30830d63045 100644 (file)
@@ -340,11 +340,17 @@ struct vk_device_struct {
     vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT];
     vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT];
     vk_pipeline pipeline_acc_f32;
-    vk_pipeline pipeline_add_f32, pipeline_add_f32_norepeat;
-    vk_pipeline pipeline_add_f16_f32_f16, pipeline_add_f16_f32_f16_norepeat;
-    vk_pipeline pipeline_sub_f32, pipeline_sub_f32_norepeat;
-    vk_pipeline pipeline_mul_f32, pipeline_mul_f32_norepeat;
-    vk_pipeline pipeline_div_f32, pipeline_div_f32_norepeat;
+
+    // [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16][dst 0=fp32,1=fp16]
+    vk_pipeline pipeline_add[2][2][2];
+    vk_pipeline pipeline_add_norepeat[2][2][2];
+    vk_pipeline pipeline_sub[2][2][2];
+    vk_pipeline pipeline_sub_norepeat[2][2][2];
+    vk_pipeline pipeline_mul[2][2][2];
+    vk_pipeline pipeline_mul_norepeat[2][2][2];
+    vk_pipeline pipeline_div[2][2][2];
+    vk_pipeline pipeline_div_norepeat[2][2][2];
+
     vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32;
     vk_pipeline pipeline_upscale_f32;
     vk_pipeline pipeline_scale_f32;
@@ -354,8 +360,8 @@ struct vk_device_struct {
     vk_pipeline pipeline_clamp_f32;
     vk_pipeline pipeline_pad_f32;
     vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
-    vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f32_bf16;
-    vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f32_bf16;
+    vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16;
+    vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16;
     vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
     vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
     vk_pipeline pipeline_norm_f32;
@@ -363,14 +369,17 @@ struct vk_device_struct {
     vk_pipeline pipeline_rms_norm_f32;
     vk_pipeline pipeline_rms_norm_back_f32;
     vk_pipeline pipeline_l2_norm_f32;
-    vk_pipeline pipeline_gelu_f32;
-    vk_pipeline pipeline_gelu_quick_f32;
-    vk_pipeline pipeline_silu_f32;
-    vk_pipeline pipeline_silu_back_f32;
-    vk_pipeline pipeline_relu_f32;
+
+    // [src/dst 0=fp32,1=fp16]
+    vk_pipeline pipeline_gelu[2];
+    vk_pipeline pipeline_gelu_quick[2];
+    vk_pipeline pipeline_silu[2];
+    vk_pipeline pipeline_relu[2];
+    vk_pipeline pipeline_tanh[2];
+    vk_pipeline pipeline_sigmoid[2];
+
     vk_pipeline pipeline_leaky_relu_f32;
-    vk_pipeline pipeline_tanh_f32;
-    vk_pipeline pipeline_sigmoid_f32;
+    vk_pipeline pipeline_silu_back_f32;
     vk_pipeline pipeline_diag_mask_inf_f32;
     vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16;
     vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512;
@@ -2508,11 +2517,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f32, "cpy_f16_f32", cpy_f16_f32_len, cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f32, "contig_cpy_f16_f32", contig_cpy_f16_f32_len, contig_cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
     if (device->float_controls_rte_fp16) {
@@ -2538,19 +2549,31 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q8_0], "cpy_q8_0_f32", cpy_q8_0_f32_len, cpy_q8_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_IQ4_NL], "cpy_iq4_nl_f32", cpy_iq4_nl_f32_len, cpy_iq4_nl_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(device, device->pipeline_add_f32, "add_f32", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_add_f32_norepeat, "add_f32_norepeat", add_f32_len, add_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16, "add_f16_f32_f16", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_add_f16_f32_f16_norepeat, "add_f16_f32_f16_norepeat", add_f16_f32_f16_len, add_f16_f32_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
+    auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) {
+        std::string s;
+        s += std::string(src0_f16 ? "_f16" : "_f32");
+        s += std::string(src1_f16 ? "_f16" : "_f32");
+        s += std::string(dst_f16 ? "_f16" : "_f32");
+        return s;
+    };
 
-    ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
+#define CREATE_BINARY(name, namemod, spec) \
+    for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \
+        ggml_vk_create_pipeline(device, device->pipeline_ ## name ## namemod[s0][s1][d], \
+                                #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d], name ## _data[s0][s1][d], \
+                                "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1);
+
+    CREATE_BINARY(add, , {0})
+    CREATE_BINARY(add, _norepeat, {1})
+    CREATE_BINARY(sub, , {0})
+    CREATE_BINARY(sub, _norepeat, {1})
+    CREATE_BINARY(mul, , {0})
+    CREATE_BINARY(mul, _norepeat, {1})
+    CREATE_BINARY(div, , {0})
+    CREATE_BINARY(div, _norepeat, {1})
+#undef CREATE_BINARY
 
-    ggml_vk_create_pipeline(device, device->pipeline_sub_f32, "sub_f32", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_sub_f32_norepeat, "sub_f32_norepeat", sub_f32_len, sub_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_mul_f32, "mul_f32", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_mul_f32_norepeat, "mul_f32_norepeat", mul_f32_len, mul_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_div_f32, "div_f32", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_div_f32_norepeat, "div_f32_norepeat", div_f32_len, div_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {1}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1);
@@ -2571,14 +2594,20 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
 
-    ggml_vk_create_pipeline(device, device->pipeline_gelu_f32, "gelu_f32", gelu_f32_len, gelu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_gelu_quick_f32, "gelu_quick_f32", gelu_quick_f32_len, gelu_quick_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_silu_f32, "silu_f32", silu_f32_len, silu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_relu_f32, "relu_f32", relu_f32_len, relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+#define CREATE_UNARY(name)  \
+    ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);  \
+    ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+
+    CREATE_UNARY(gelu)
+    CREATE_UNARY(gelu_quick)
+    CREATE_UNARY(silu)
+    CREATE_UNARY(relu)
+    CREATE_UNARY(tanh)
+    CREATE_UNARY(sigmoid)
+#undef CREATE_UNARY
+
     ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_tanh_f32, "tanh_f32", tanh_f32_len, tanh_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_sigmoid_f32, "sigmoid_f32", sigmoid_f32_len, sigmoid_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1);
 
     ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true);
 
@@ -4504,6 +4533,13 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
             return ctx->device->pipeline_cpy_f16_f16;
         }
     }
+    if (src->type == GGML_TYPE_F16 && to == GGML_TYPE_F32) {
+        if (contig) {
+            return ctx->device->pipeline_contig_cpy_f16_f32;
+        } else {
+            return ctx->device->pipeline_cpy_f16_f32;
+        }
+    }
     if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_BF16) {
         if (contig) {
             return ctx->device->pipeline_contig_cpy_f32_bf16;
@@ -5894,26 +5930,37 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
         }
         return nullptr;
     case GGML_OP_ADD:
-        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f32_norepeat : ctx->device->pipeline_add_f32;
+    case GGML_OP_SUB:
+    case GGML_OP_MUL:
+    case GGML_OP_DIV:
+        if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
+            (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) ||
+            (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16)) {
+            return nullptr;
         }
-        if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
-            return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_f16_f32_f16_norepeat : ctx->device->pipeline_add_f16_f32_f16;
+        switch (op) {
+        case GGML_OP_ADD:
+        {
+            auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add;
+            return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
         }
-        return nullptr;
-    case GGML_OP_SUB:
-        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_sub_f32_norepeat : ctx->device->pipeline_sub_f32;
+        case GGML_OP_SUB:
+        {
+            auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_sub_norepeat : ctx->device->pipeline_sub;
+            return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
         }
-        return nullptr;
-    case GGML_OP_MUL:
-        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_f32_norepeat : ctx->device->pipeline_mul_f32;
+        case GGML_OP_MUL:
+        {
+            auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_norepeat : ctx->device->pipeline_mul;
+            return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
         }
-        return nullptr;
-    case GGML_OP_DIV:
-        if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-            return ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_f32_norepeat : ctx->device->pipeline_div_f32;
+        case GGML_OP_DIV:
+        {
+            auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_norepeat : ctx->device->pipeline_div;
+            return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16];
+        }
+        default:
+            break;
         }
         return nullptr;
     case GGML_OP_CONCAT:
@@ -6007,37 +6054,25 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
         }
         return nullptr;
     case GGML_OP_UNARY:
+        if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) ||
+            (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) ||
+            (src0->type != dst->type)) {
+            return nullptr;
+        }
+
         switch (ggml_get_unary_op(dst)) {
             case GGML_UNARY_OP_SILU:
-                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-                    return ctx->device->pipeline_silu_f32;
-                }
-                break;
+                return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16];
             case GGML_UNARY_OP_GELU:
-                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-                    return ctx->device->pipeline_gelu_f32;
-                }
-                break;
+                return ctx->device->pipeline_gelu[dst->type == GGML_TYPE_F16];
             case GGML_UNARY_OP_GELU_QUICK:
-                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-                    return ctx->device->pipeline_gelu_quick_f32;
-                }
-                break;
+                return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16];
             case GGML_UNARY_OP_RELU:
-                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-                    return ctx->device->pipeline_relu_f32;
-                }
-                break;
+                return ctx->device->pipeline_relu[dst->type == GGML_TYPE_F16];
             case GGML_UNARY_OP_TANH:
-                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-                    return ctx->device->pipeline_tanh_f32;
-                }
-                break;
+                return ctx->device->pipeline_tanh[dst->type == GGML_TYPE_F16];
             case GGML_UNARY_OP_SIGMOID:
-                if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
-                    return ctx->device->pipeline_sigmoid_f32;
-                }
-                break;
+                return ctx->device->pipeline_sigmoid[dst->type == GGML_TYPE_F16];
             default:
                 break;
         }
@@ -9423,7 +9458,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                 case GGML_UNARY_OP_RELU:
                 case GGML_UNARY_OP_TANH:
                 case GGML_UNARY_OP_SIGMOID:
-                    return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
+                    return ggml_is_contiguous(op->src[0]) &&
+                           (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
+                           (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
+                           (op->src[0]->type == op->type);
                 default:
                     return false;
             }
@@ -9603,6 +9641,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                 }
                 if (src1_type == GGML_TYPE_F32) {
                     switch (src0_type) {
+                    case GGML_TYPE_F16:
                     case GGML_TYPE_Q4_0:
                     case GGML_TYPE_Q4_1:
                     case GGML_TYPE_Q5_0:
@@ -9641,6 +9680,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
         case GGML_OP_SUB:
         case GGML_OP_MUL:
         case GGML_OP_DIV:
+            return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
+                   (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) &&
+                   (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
         case GGML_OP_SILU_BACK:
         case GGML_OP_RMS_NORM_BACK:
         case GGML_OP_SQR:
index 52a19b62a67db71e58a5eb47a4b4fbb3705c78c7..4f806270c7799a5e62f28db338f8752ecb0beac3 100644 (file)
@@ -17,5 +17,5 @@ void main() {
         return;
     }
 
-    data_d[i] = max(float(data_a[i]), 0);
+    data_d[i] = D_TYPE(max(float(data_a[i]), 0));
 }
index 776581e2c4e2694c112d196498a4d5f2e5efad9d..5c9e5c350323b2e306005bcefeed16d5c79fc703 100644 (file)
@@ -16,5 +16,5 @@ void main() {
     if (i >= p.KX) {
         return;
     }
-    data_d[i] = D_TYPE(1. / (1 + exp(-1. *data_a[i])));
+    data_d[i] = D_TYPE(1. / (1 + exp(-1. * float(data_a[i]))));
 }
index 495f966bdc6e5a556baffd6d3c5d49531ac434e4..8a6f868f58a7c82234780abe78fb8d30a9525d9f 100644 (file)
@@ -16,5 +16,5 @@ void main() {
     if (i >= p.KX) {
         return;
     }
-    data_d[i] = D_TYPE(1. - 2. / (exp(2.*data_a[i]) + 1.));
+    data_d[i] = D_TYPE(1. - 2. / (exp(2.*float(data_a[i])) + 1.));
 }
index 759112b94629a90051451cc6ec39a829d1b94baa..382f190f287bb6e3ed30bc2cb97c13f61741de2b 100644 (file)
@@ -485,10 +485,12 @@ void process_shaders() {
     string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
     string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
     string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
+    string_to_spv("cpy_f16_f32", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
     string_to_spv("cpy_f32_bf16","copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
     string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
     string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
     string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
+    string_to_spv("contig_cpy_f16_f32", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
     string_to_spv("contig_cpy_f32_bf16","contig_copy.comp",{{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
 
     for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
@@ -497,8 +499,26 @@ void process_shaders() {
         string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
     }
 
-    string_to_spv("add_f32", "add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
-    string_to_spv("add_f16_f32_f16", "add.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"FLOAT_TYPE", "float"}});
+    auto get_type_str = [](bool f16) {
+        return f16 ? "float16_t" : "float";
+    };
+    auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) {
+        std::string s;
+        s += std::string(src0_f16 ? "_f16" : "_f32");
+        s += std::string(src1_f16 ? "_f16" : "_f32");
+        s += std::string(dst_f16 ? "_f16" : "_f32");
+        return s;
+    };
+    for (std::string op : {"add", "sub", "mul", "div"}) {
+    for (auto src0_f16 : {false, true}) {
+    for (auto src1_f16 : {false, true}) {
+    for (auto dst_f16  : {false, true}) {
+        auto name = op + get_suffix(src0_f16, src1_f16, dst_f16);
+        string_to_spv(name.c_str(), op + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}});
+    }
+    }
+    }
+    }
 
     string_to_spv("sub_f32", "sub.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
 
@@ -533,14 +553,21 @@ void process_shaders() {
 
     string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
 
-    string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
-    string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
-    string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
-    string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
-    string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
-    string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
-    string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
-    string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+    string_to_spv("gelu_f16",       "gelu.comp",        {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
+    string_to_spv("gelu_f32",       "gelu.comp",        {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
+    string_to_spv("gelu_quick_f16", "gelu_quick.comp",  {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
+    string_to_spv("gelu_quick_f32", "gelu_quick.comp",  {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
+    string_to_spv("silu_f16",       "silu.comp",        {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
+    string_to_spv("silu_f32",       "silu.comp",        {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
+    string_to_spv("relu_f16",       "relu.comp",        {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
+    string_to_spv("relu_f32",       "relu.comp",        {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
+    string_to_spv("tanh_f16",       "tanh.comp",        {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
+    string_to_spv("tanh_f32",       "tanh.comp",        {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
+    string_to_spv("sigmoid_f16",    "sigmoid.comp",     {{"A_TYPE", "float16_t"},   {"D_TYPE", "float16_t"}});
+    string_to_spv("sigmoid_f32",    "sigmoid.comp",     {{"A_TYPE", "float"},       {"D_TYPE", "float"}});
+
+    string_to_spv("leaky_relu_f32", "leaky_relu.comp",  {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
+    string_to_spv("silu_back_f32",  "silu_back.comp",   {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});
 
     string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
 
@@ -641,7 +668,12 @@ void write_output_files() {
             std::remove(path.c_str());
         }
     }
-
+    for (const char *op : {"add", "sub", "mul", "div"}) {
+        fprintf(hdr, "extern unsigned char *%s_data[2][2][2];\n", op);
+        fprintf(hdr, "extern uint64_t %s_len[2][2][2];\n", op);
+        fprintf(src, "unsigned char *%s_data[2][2][2] = {{{%s_f32_f32_f32_data, %s_f32_f32_f16_data}, {%s_f32_f16_f32_data, %s_f32_f16_f16_data}}, {{%s_f16_f32_f32_data, %s_f16_f32_f16_data}, {%s_f16_f16_f32_data, %s_f16_f16_f16_data}}};\n", op, op, op, op, op, op, op, op, op);
+        fprintf(src, "uint64_t %s_len[2][2][2] = {{{%s_f32_f32_f32_len, %s_f32_f32_f16_len}, {%s_f32_f16_f32_len, %s_f32_f16_f16_len}}, {{%s_f16_f32_f32_len, %s_f16_f32_f16_len}, {%s_f16_f16_f32_len, %s_f16_f16_f16_len}}};\n", op, op, op, op, op, op, op, op, op);
+    }
     fclose(hdr);
     fclose(src);
 }