]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
vulkan: Add missing bounds checking to scalar/coopmat1 mul_mat_id (llama/15334)
authorJeff Bolz <redacted>
Sat, 16 Aug 2025 08:58:38 +0000 (03:58 -0500)
committerGeorgi Gerganov <redacted>
Mon, 18 Aug 2025 16:15:25 +0000 (19:15 +0300)
src/ggml-vulkan/vulkan-shaders/mul_mm.comp
tests/test-backend-ops.cpp

index 8c5114a79d23cb8529c23d0ea2ec4547b6ed5bb1..a61a464c7bef88593ff7b711190349f5d265b7b1 100644 (file)
@@ -801,7 +801,7 @@ void main() {
             }
 #else
             const uint row_i = ic * BN + loadc_b + l;
-            if (row_i < _ne1) {
+            if (row_i < _ne1 && block + loadr_b < end_k) {
                 const u16vec2 row_idx = row_ids[row_i];
                 buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
             } else {
@@ -875,7 +875,9 @@ void main() {
 
                 const u16vec2 row_idx = row_ids[row_i];
 
-                data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
+                if (dr + cm_row * TM + store_r < p.M) {
+                    data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
+                }
             }
         }
     }
@@ -925,7 +927,9 @@ void main() {
 #endif // MUL_MAT_ID
                 [[unroll]] for (uint cr = 0; cr < TM; cr++) {
 #ifdef MUL_MAT_ID
-                    data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
+                    if (dr_warp + cr < p.M) {
+                        data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
+                    }
 #else
                     if (dr_warp + cr < p.M && dc_warp + cc < p.N) {
                         data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]);
index f4565f9b7123865896d7516f192f3d1c02837332..39547f0649e8e64955185da67ef47a3dcdc17496 100644 (file)
@@ -5824,6 +5824,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, b, 32, 1024, 16));
     }
 
+    test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 1, 1, false, 8, 16, 1));
+
     for (ggml_type type_a : base_types) {
         for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
             for (int n_mats : {4, 8}) {