]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: Fix OOB accesses in soft_max_back (#15861)
authorJeff Bolz <redacted>
Tue, 9 Sep 2025 12:41:15 +0000 (07:41 -0500)
committerGitHub <redacted>
Tue, 9 Sep 2025 12:41:15 +0000 (14:41 +0200)
ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp
tests/test-backend-ops.cpp

index 6f130d47f2c921d28a62647cdd80aecc70ce1503..c0c2239c2ddb1c4987112f156df49505c61ecca9 100644 (file)
@@ -3387,7 +3387,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
     ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
     ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
     ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
-    ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
+    ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1, true);
 
     ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
@@ -9145,7 +9145,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
 
 static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
     float * op_params = (float *)dst->op_params;
-    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], op_params[1] }, dryrun);
+    ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] }, dryrun);
 }
 
 static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) {
index 29bd77d7e1c886f6d806ffef23fc07015bdba99c..144ea58e6fa8fd0771f1d604fdbeb7706b7582bc 100644 (file)
@@ -20,6 +20,10 @@ void main() {
     const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
     const uint tid = gl_LocalInvocationID.x;
 
+    if (row >= p.KY) {
+        return;
+    }
+
     FLOAT_TYPE scale = p.param1;
 
     // partial sums for thread in warp
index a4776ee7f15dab6d416f81cc5d3f6755159e653a..adf91ab6f9edd822f2b420237d29c3a3ce7f9bcc 100644 (file)
@@ -6391,6 +6391,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
                 for (int64_t ne1 : {16, 1024}) {
                     test_cases.emplace_back(new test_soft_max_back(GGML_TYPE_F32, {ne0,   ne1,   1, 1}, scale, max_bias));
                     test_cases.emplace_back(new test_soft_max_back(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, scale, max_bias));
+                    test_cases.emplace_back(new test_soft_max_back(GGML_TYPE_F32, {ne0,   ne1,   2, 3}, scale, max_bias));
                 }
             }
         }