]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: fix coopmat2 validation failures (#11284)
authorJeff Bolz <redacted>
Mon, 20 Jan 2025 16:38:32 +0000 (10:38 -0600)
committerGitHub <redacted>
Mon, 20 Jan 2025 16:38:32 +0000 (10:38 -0600)
mul mat and flash attention shaders were loading f32 types directly into
A/B matrices, which happens to work but is technically invalid usage.
For FA, we can load it as an Accumulator matrix and convert and this
is not in the inner loop and is cheap enough. For mul mat, it's more
efficient to do this conversion in a separate pass and have the input(s)
be f16.

coopmat2 requires SPIR-V 1.6 (related using to LocalSizeId). LocalSizeId
requires maintenance4 be enabled, and SPIR-V 1.6 requires Vulkan 1.3.

ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp
ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

index 437e9cdcc35e04ff8cae3db0302f642b1fc9ad12..78c2f5c45ade3bd2247d746f6fbe00a6c673a603 100644 (file)
@@ -29,8 +29,6 @@
 
 #include "ggml-vulkan-shaders.hpp"
 
-#define VK_API_VERSION VK_API_VERSION_1_2
-
 #define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
 
 #define VK_VENDOR_ID_AMD 0x1002
@@ -1614,11 +1612,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
         CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT)   \
         CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT)   \
 
-        CREATE_MM(pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
-        CREATE_MM(pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3)
-
         CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
-        CREATE_MM2(pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3)
         CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
         CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1].f16acc, matmul_q4_1_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
         CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0].f16acc, matmul_q5_0_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
@@ -1631,21 +1625,18 @@ static void ggml_vk_load_shaders(vk_device& device) {
         CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K].f16acc, matmul_q6_k_f16, _f16acc, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3)
         CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
 
-        CREATE_MM(pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
         CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM2(pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
-
-        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
-        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f16acc, matmul_id_q4_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f16acc, matmul_id_q5_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f16acc, matmul_id_q5_1_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f16acc, matmul_id_q8_0_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f16acc, matmul_id_q2_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f16acc, matmul_id_q3_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f16acc, matmul_id_q4_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f16acc, matmul_id_q5_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f16acc, matmul_id_q6_k_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
+        CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
 #undef CREATE_MM
 #undef CREATE_MM2
     } else
@@ -2287,6 +2278,14 @@ static vk_device ggml_vk_get_device(size_t idx) {
         }
 #endif
 
+        VkPhysicalDeviceMaintenance4Features maint4_features {};
+        maint4_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES;
+        if (maintenance4_support) {
+            last_struct->pNext = (VkBaseOutStructure *)&maint4_features;
+            last_struct = (VkBaseOutStructure *)&maint4_features;
+            device_extensions.push_back("VK_KHR_maintenance4");
+        }
+
         vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
 
         device->fp16 = device->fp16 && vk12_features.shaderFloat16;
@@ -2662,7 +2661,14 @@ void ggml_vk_instance_init() {
 
     vk_instance_initialized = true;
 
-    vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, VK_API_VERSION };
+    uint32_t api_version = vk::enumerateInstanceVersion();
+
+    if (api_version < VK_API_VERSION_1_2) {
+        std::cerr << "ggml_vulkan: Error: Vulkan 1.2 required." << std::endl;
+        GGML_ABORT("fatal error");
+    }
+
+    vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, api_version };
 
     const std::vector<vk::ExtensionProperties> instance_extensions = vk::enumerateInstanceExtensionProperties();
     const bool validation_ext = ggml_vk_instance_validation_ext_available(instance_extensions);
@@ -2972,7 +2978,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
         }
     }
 
-    GGML_ASSERT(src1_type == GGML_TYPE_F32);
+    GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16));
 
     switch (src0_type) {
         case GGML_TYPE_Q4_0:
@@ -3812,8 +3818,9 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub
         src1_uma = d_Qy != nullptr;
     }
 
-    const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);
-    // Reformat and convert to fp16 if src1 is non-contiguous, or for coopmat2 for better perf
+    // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf
+    const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
+                              !ggml_vk_dim01_contiguous(src0);
     const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
                               !ggml_vk_dim01_contiguous(src1);
 
@@ -4393,8 +4400,11 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
         ids_uma = d_ids != nullptr;
     }
 
-    const bool x_non_contig = !ggml_vk_dim01_contiguous(src0);
-    const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
+    // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf
+    const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) ||
+                              !ggml_vk_dim01_contiguous(src0);
+    const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) ||
+                              !ggml_vk_dim01_contiguous(src1);
 
     const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig;
 
@@ -4404,7 +4414,8 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
     const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !y_f32_kernel) || y_non_contig;
 
     if (qx_needs_dequant) {
-        GGML_ABORT("fatal error");
+        // Fall back to dequant + f16 mulmat
+        mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, GGML_TYPE_F16, y_f32_kernel ? GGML_TYPE_F32 : GGML_TYPE_F16, (ggml_prec)dst->op_params[0]);
     }
 
     // Not implemented
index ca3a59b8faaab973d7728f0c497a20e39b6a6aa3..3735d0dbbae1d71bb7bd1d2794c3410af23c2976 100644 (file)
@@ -166,7 +166,7 @@ void main() {
     tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
     tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
 
-    coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Q;
+    coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Q;
     coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Qf16;
 
     uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
index cbfa5dce1b091852498e95e4394f7d55b8b185c4..57f9e724513608a94a814c8544f0f7567205a192 100644 (file)
@@ -57,17 +57,13 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
 
 #if QUANT_K > 1
 #define DECODEFUNCA , dequantFuncA
-#define MAT_A_TYPE float16_t
 
 #include "dequant_funcs_cm2.comp"
 
 #else
 #define DECODEFUNCA
-#define MAT_A_TYPE A_TYPE
 #endif
 
-#define MAT_B_TYPE B_TYPE
-
 #ifdef MUL_MAT_ID
 layout (binding = 3) readonly buffer IDS {int data_ids[];};
 
@@ -236,16 +232,13 @@ void main() {
 
         for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
 
-            coopmat<MAT_A_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
-            coopmat<MAT_B_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
+            coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+            coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
 
             coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
-            coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
-
             coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose);
-            coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
 
-            sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
+            sum = coopMatMulAdd(mat_a, mat_b, sum);
         }
     } else
 #endif // !defined(MUL_MAT_ID)
@@ -261,10 +254,8 @@ void main() {
         [[dont_unroll]]
         for (uint block_k = start_k; block_k < end_k; block_k += BK) {
 
-            coopmat<MAT_A_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
-            coopmat<MAT_B_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
-            coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a_ft;
-            coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b_ft;
+            coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
+            coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB> mat_b;
 
             // Clamping is expensive, so detect different code paths for each combination
             // of A and B needing clamping.
@@ -281,16 +272,12 @@ void main() {
 #else
                 coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose);
 #endif
-                mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
-                mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
-                sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
+                sum = coopMatMulAdd(mat_a, mat_b, sum);
             } else if (unclampedA && !unclampedB) {
                 coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, (block_k & ~7), BK) DECODEFUNCA);
                 coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
 
-                mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
-                mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
-                sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
+                sum = coopMatMulAdd(mat_a, mat_b, sum);
             } else if (!unclampedA && unclampedB) {
                 coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
 #ifdef MUL_MAT_ID
@@ -298,16 +285,12 @@ void main() {
 #else
                 coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, (block_k & ~7), BK), tensorViewTranspose);
 #endif
-                mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
-                mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
-                sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
+                sum = coopMatMulAdd(mat_a, mat_b, sum);
             } else if (!unclampedA && !unclampedB) {
                 coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
                 coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose);
 
-                mat_a_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA>(mat_a);
-                mat_b_ft = coopmat<FLOAT_TYPE, gl_ScopeWorkgroup, BK, BN, gl_MatrixUseB>(mat_b);
-                sum = coopMatMulAdd(mat_a_ft, mat_b_ft, sum);
+                sum = coopMatMulAdd(mat_a, mat_b, sum);
             }
         }
     }
index b7890ef1e741bc99edd84a7c27004c0280d357a6..8bcb64101b150428e316257da54ca64414c85f63 100644 (file)
@@ -316,8 +316,11 @@ void matmul_shaders(bool fp16, bool matmul_id, bool coopmat, bool coopmat2, bool
         // For aligned matmul loads
         std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16") ? load_vec : "2";
 
-        string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
-        string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
+        // don't generate f32 variants for coopmat2
+        if (!coopmat2) {
+            string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);
+            string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
+        }
 
         if (tname != "f16" && tname != "f32") {
             string_to_spv(shader_name + "_" + tname + "_f16", source_name,          merge_maps(base_dict, {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned},                           {"B_TYPE", "float16_t"},        {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc);