From: Jeff Bolz Date: Sat, 16 Aug 2025 08:58:38 +0000 (-0500) Subject: vulkan: Add missing bounds checking to scalar/coopmat1 mul_mat_id (#15334) X-Git-Tag: upstream/0.0.6199~19 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=2e2b22ba6607414a5d619ac6d2f034b5b02214e5;p=pkg%2Fggml%2Fsources%2Fllama.cpp vulkan: Add missing bounds checking to scalar/coopmat1 mul_mat_id (#15334) --- diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 8c5114a7..a61a464c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -801,7 +801,7 @@ void main() { } #else const uint row_i = ic * BN + loadc_b + l; - if (row_i < _ne1) { + if (row_i < _ne1 && block + loadr_b < end_k) { const u16vec2 row_idx = row_ids[row_i]; buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]); } else { @@ -875,7 +875,9 @@ void main() { const u16vec2 row_idx = row_ids[row_i]; - data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + if (dr + cm_row * TM + store_r < p.M) { + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + } } } } @@ -925,7 +927,9 @@ void main() { #endif // MUL_MAT_ID [[unroll]] for (uint cr = 0; cr < TM; cr++) { #ifdef MUL_MAT_ID - data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); + if (dr_warp + cr < p.M) { + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); + } #else if (dr_warp + cr < p.M && dc_warp + cc < p.N) { data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index f4565f9b..39547f06 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -5824,6 +5824,8 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, b, 32, 1024, 16)); } + test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 1, 1, false, 8, 16, 1)); + for (ggml_type type_a : base_types) { for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) { for (int n_mats : {4, 8}) {