]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: Use coopmat2 for conv2d (#14982)
authorJeff Bolz <redacted>
Sun, 3 Aug 2025 12:23:57 +0000 (07:23 -0500)
committerGitHub <redacted>
Sun, 3 Aug 2025 12:23:57 +0000 (14:23 +0200)
ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp

index e095b26a48471caa67c2ec7cf20ea65dff6ff590..3682ee3804784b058941cfa874bceadbc3065324 100644 (file)
@@ -3096,6 +3096,10 @@ static void ggml_vk_load_shaders(vk_device& device) {
         uint32_t conv2d_SHMEM_PAD = 4;
         bool conv2d_UNROLL = true;
 
+        if (device->coopmat2) {
+            conv2d_SHMEM_PAD = 8; // 8 float16_t
+        }
+
         if (device->vendor_id == VK_VENDOR_ID_INTEL) {
             conv2d_SHMEM_PAD = 0;
             conv2d_UNROLL = false;
@@ -3154,7 +3158,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
         std::array<uint32_t, 3> wg_denoms = { conv2d_BS_K, conv2d_BS_NPQ, 1 };
         std::vector<uint32_t> spec_constants = { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD };
 
-        if (conv2d_UNROLL) {
+        if (device->coopmat2) {
+            ggml_vk_create_pipeline(
+                device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_cm2_len, conv2d_f32_cm2_data, "main", 3,
+                sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
+            ggml_vk_create_pipeline(
+                device, device->pipeline_conv2d_f16_f32[s], "conv2d_f16_f32", conv2d_f16_f32_cm2_len, conv2d_f16_f32_cm2_data, "main", 3,
+                sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
+        } else if (conv2d_UNROLL) {
             ggml_vk_create_pipeline(
                 device, device->pipeline_conv2d_f32[s], "conv2d_f32", conv2d_f32_unroll_len, conv2d_f32_unroll_data, "main", 3,
                 sizeof(vk_op_conv2d_push_constants), wg_denoms, spec_constants, 1, true, use_collectives);
index 04a10c012f4fe9f1ca19df327661bb24dae66480..86bafba4a4398cc0d1446f3f63368ff4d3126df8 100644 (file)
@@ -1,6 +1,11 @@
 #version 450
 
 #extension GL_EXT_control_flow_attributes : enable
+#ifdef COOPMAT2
+#extension GL_NV_cooperative_matrix2 : enable
+#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
+#extension GL_KHR_memory_scope_semantics : enable
+#endif
 
 #ifdef USE_COLLECTIVES
 #    extension GL_KHR_shader_subgroup_shuffle : enable
@@ -91,6 +96,12 @@ uint32_t n_elems_out = K * NPQ;
 // Number of blocktiles per input
 uint32_t NB_CRS = splitWork(CRS, BS_CRS);
 
+#ifdef COOPMAT2
+#define SHMEM_TYPE float16_t
+#else
+#define SHMEM_TYPE float
+#endif
+
 const uint32_t Ash_stride = BS_CRS + SHMEM_PAD;
 const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD;
 
@@ -100,8 +111,8 @@ const uint32_t Bsh_numel = BS_CRS * BS_NPQ;
 const uint32_t Ash_len = BS_K * Ash_stride;
 const uint32_t Bsh_len = BS_CRS * Bsh_stride;
 
-shared float Ash[Ash_len];  // K x CRS
-shared float Bsh[Bsh_len];  // CRS x NPQ
+shared SHMEM_TYPE Ash[Ash_len];  // K x CRS
+shared SHMEM_TYPE Bsh[Bsh_len];  // CRS x NPQ
 
 // Threadtile sizes
 const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K;
@@ -110,10 +121,6 @@ const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K;
 const uint32_t NT_K   = BS_K / TS_K;
 const uint32_t NT_NPQ = BS_NPQ / TS_NPQ;
 
-float regA[TS_K];
-float regB[TS_NPQ];
-float regC[TS_K][TS_NPQ];
-
 /*
 Compute
 KxCRS @ CRSxNPQ = K x NPQ
@@ -145,12 +152,36 @@ uint fastdiv(uint n, uint mp, uint L) {
     return (msbs + n) >> L;
 }
 
+#ifdef COOPMAT2
+#define ACC_TYPE float16_t
+
+ACC_TYPE perElemOpStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem)
+{
+    uint32_t K_idx   = B_idx_K * BS_K + r;
+    uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + c;
+    uint32_t N_idx   = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW;
+    uint32_t OH_idx  = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW;
+    uint32_t OW_idx  = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW;
+    uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3;
+    if (K_idx < K && NPQ_idx < NPQ) {
+        dst_data[dst_idx] = D_TYPE(elem);
+    }
+    return elem;
+}
+#endif
+
 void main() {
+#ifdef COOPMAT2
+    coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator> matC;
+    matC = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BS_K, BS_NPQ, gl_MatrixUseAccumulator>(0.0);
+#else
+    float regC[TS_K][TS_NPQ];
     for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
         for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
             regC[T_ly][T_lx] = 0.0;
         }
     }
+#endif
     /* Advance block in CRS dim */
     for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) {
         uint32_t CRS_idx_a;
@@ -199,7 +230,7 @@ void main() {
             if (K_idx >= K || CRS_idx_a >= CRS) {
                 val = 0.0;
             }
-            Ash[B_ly * Ash_stride + B_lx] = val;
+            Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val);
         }
         /* Load input to B_block: (BS_CRS x BS_NPQ) */
         UNROLL for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) {
@@ -244,11 +275,21 @@ void main() {
             if (CRS_idx_b >= CRS || NPQ_idx >= NPQ || H_idx < 0 || H_idx >= p.H || W_idx < 0 || W_idx >= p.W) {
                 val = 0.0;
             }
-            Bsh[B_ly * Bsh_stride + B_lx] = val;
+            Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val);
         }
         barrier();
+#ifdef COOPMAT2
+        coopmat<float16_t, gl_ScopeWorkgroup, BS_K, BS_CRS, gl_MatrixUseA> matA;
+        coopmat<float16_t, gl_ScopeWorkgroup, BS_CRS, BS_NPQ, gl_MatrixUseB> matB;
+
+        coopMatLoad(matA, Ash, 0, Ash_stride, gl_CooperativeMatrixLayoutRowMajor);
+        coopMatLoad(matB, Bsh, 0, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor);
+        matC = coopMatMulAdd(matA, matB, matC);
+#else
         if (T_y * TS_K < K) {
             UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) {
+                float regA[TS_K];
+                float regB[TS_NPQ];
                 for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
                     regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx];
                 }
@@ -262,9 +303,13 @@ void main() {
                 }
             }
         }
+#endif
         barrier();
     }
     /* Save C* */
+#ifdef COOPMAT2
+    coopMatPerElementNV(matC, matC, perElemOpStore);
+#else
     if (T_y * TS_K < K) {
         for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) {
             for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) {
@@ -280,4 +325,5 @@ void main() {
             }
         }
     }
+#endif
 }
index b634e52d64d376b1a5a11ef9198f28caba116228..83e4a7c723d32f5c22fc17e944b99ed0d3a0eb25 100644 (file)
@@ -661,6 +661,9 @@ void process_shaders() {
     string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", ""}});
     string_to_spv("conv2d_f16_f32", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", ""}});
 
+    string_to_spv("conv2d_f32", "conv2d_mm.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}, {"COOPMAT2", "1"}}, true, false, true);
+    string_to_spv("conv2d_f16_f32", "conv2d_mm.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"USE_COLLECTIVES", "1"}, {"UNROLL", "[[unroll]]"}, {"COOPMAT2", "1"}}, true, false, true);
+
     string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}}));
     string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}}));