]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
vulkan : support conv_2d_dw with f16 weights (llama/15392)
authorAcly <redacted>
Thu, 21 Aug 2025 15:01:51 +0000 (17:01 +0200)
committerGeorgi Gerganov <redacted>
Fri, 5 Sep 2025 09:54:00 +0000 (12:54 +0300)
src/ggml-vulkan/ggml-vulkan.cpp
src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

index 2d5254dfd2a387cfa12a38076f791fe5fa8f425b..fb18a55cdad2c64c6309dc44e16cf47719746f7c 100644 (file)
@@ -530,8 +530,8 @@ struct vk_device_struct {
     vk_pipeline pipeline_opt_step_sgd_f32;
     vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT];
     vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT];
-    vk_pipeline pipeline_conv2d_dw_whcn_f32;
-    vk_pipeline pipeline_conv2d_dw_cwhn_f32;
+    vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32;
+    vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32;
 
     // [2][2][2] is for {f16acc,f32acc}x{large,small_rows}x{unaligned, aligned}
     vk_pipeline pipeline_flash_attn_f32_f16_cm2[GGML_TYPE_COUNT][FA_HEAD_SIZE_COUNT][2][2][2];
@@ -3257,6 +3257,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
 
     ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1);
 
     for (auto &c : compiles) {
         c.wait();
@@ -7346,6 +7348,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
             } else if (ggml_is_contiguous_channels(src1)) {
                 return ctx->device->pipeline_conv2d_dw_cwhn_f32;
             }
+        } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
+            if (ggml_is_contiguous(src1)) {
+                return ctx->device->pipeline_conv2d_dw_whcn_f16_f32;
+            } else if (ggml_is_contiguous_channels(src1)) {
+                return ctx->device->pipeline_conv2d_dw_cwhn_f16_f32;
+            }
         }
         return nullptr;
     default:
index 4e0bffe6da698460067df6038bcd6e9cd3b9ad60..123ae044914ed0873ca00b04202eff3524e9b5c1 100644 (file)
@@ -680,6 +680,8 @@ void process_shaders() {
 
     string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
     string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
+    string_to_spv("conv2d_dw_whcn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
+    string_to_spv("conv2d_dw_cwhn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));
 
     string_to_spv("roll_f32", "roll.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));