]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
vulkan: Fix multi_add invalid descriptor usage (llama/16899)
authorJeff Bolz <redacted>
Sat, 1 Nov 2025 05:52:14 +0000 (00:52 -0500)
committerGeorgi Gerganov <redacted>
Sun, 9 Nov 2025 16:30:22 +0000 (18:30 +0200)
src/ggml-vulkan/ggml-vulkan.cpp
src/ggml-vulkan/vulkan-shaders/multi_add.comp

index 6a46d0889bdb955025afc84240c710b3ab559aa1..8d1a85c96939b0b06a7a80bdde13edccb861381b 100644 (file)
@@ -4274,8 +4274,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
 
         device->multi_add = vk12_props.shaderRoundingModeRTEFloat16 &&
                             device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_multi_add_push_constants) &&
-                            vk12_features.runtimeDescriptorArray &&
-                            device->vendor_id != VK_VENDOR_ID_INTEL &&
                             getenv("GGML_VK_DISABLE_MULTI_ADD") == nullptr;
 
         device->shader_int64 = device_features2.features.shaderInt64;
index 1e8f694a72470f6094fc6c1f8a1ca80ec6a1b3e3..10cf5202a4a37da22aad66f7ee564067122920bb 100644 (file)
@@ -23,16 +23,100 @@ layout (push_constant) uniform parameter2
     uint rms_partials;
 } p;
 
-// Workaround for MoltenVK Bug, see https://github.com/ggml-org/llama.cpp/issues/15498
-// layout (binding = 0) readonly buffer A {A_TYPE data_a[];} a[];
-// layout (binding = 0) writeonly buffer D {D_TYPE data_d[];} d[];
-layout (binding = 0) buffer A {A_TYPE data_a[];} a[];
-layout (binding = 0) buffer D {D_TYPE data_d[];} d[];
-
-layout (binding = 0, std430) buffer PartialBuf {float partial_sums[];} partials[];
+// No readonly/writeonly decorations. Workaround for MoltenVK Bug, see https://github.com/ggml-org/llama.cpp/issues/15498
+layout (binding = 0)  buffer A0 {A_TYPE data_a[];} a0;
+layout (binding = 1)  buffer A1 {A_TYPE data_a[];} a1;
+layout (binding = 2)  buffer A2 {A_TYPE data_a[];} a2;
+layout (binding = 3)  buffer A3 {A_TYPE data_a[];} a3;
+layout (binding = 4)  buffer A4 {A_TYPE data_a[];} a4;
+layout (binding = 5)  buffer A5 {A_TYPE data_a[];} a5;
+layout (binding = 6)  buffer A6 {A_TYPE data_a[];} a6;
+layout (binding = 7)  buffer A7 {A_TYPE data_a[];} a7;
+layout (binding = 8)  buffer A8 {A_TYPE data_a[];} a8;
+layout (binding = 9)  buffer A9 {A_TYPE data_a[];} a9;
+layout (binding = 10) buffer A10 {A_TYPE data_a[];} a10;
+layout (binding = 11) buffer A11 {A_TYPE data_a[];} a11;
+layout (binding = 0)  buffer D0 {D_TYPE data_d[];} d0;
+layout (binding = 1)  buffer D1 {D_TYPE data_d[];} d1;
+layout (binding = 2)  buffer D2 {D_TYPE data_d[];} d2;
+layout (binding = 3)  buffer D3 {D_TYPE data_d[];} d3;
+layout (binding = 4)  buffer D4 {D_TYPE data_d[];} d4;
+layout (binding = 5)  buffer D5 {D_TYPE data_d[];} d5;
+layout (binding = 6)  buffer D6 {D_TYPE data_d[];} d6;
+layout (binding = 7)  buffer D7 {D_TYPE data_d[];} d7;
+layout (binding = 8)  buffer D8 {D_TYPE data_d[];} d8;
+layout (binding = 9)  buffer D9 {D_TYPE data_d[];} d9;
+layout (binding = 10) buffer D10 {D_TYPE data_d[];} d10;
+layout (binding = 11) buffer D11 {D_TYPE data_d[];} d11;
+layout (binding = 0, std430)  buffer PartialBuf0 {float partial_sums[];} partials0;
+layout (binding = 1, std430)  buffer PartialBuf1 {float partial_sums[];} partials1;
+layout (binding = 2, std430)  buffer PartialBuf2 {float partial_sums[];} partials2;
+layout (binding = 3, std430)  buffer PartialBuf3 {float partial_sums[];} partials3;
+layout (binding = 4, std430)  buffer PartialBuf4 {float partial_sums[];} partials4;
+layout (binding = 5, std430)  buffer PartialBuf5 {float partial_sums[];} partials5;
+layout (binding = 6, std430)  buffer PartialBuf6 {float partial_sums[];} partials6;
+layout (binding = 7, std430)  buffer PartialBuf7 {float partial_sums[];} partials7;
+layout (binding = 8, std430)  buffer PartialBuf8 {float partial_sums[];} partials8;
+layout (binding = 9, std430)  buffer PartialBuf9 {float partial_sums[];} partials9;
+layout (binding = 10, std430) buffer PartialBuf10 {float partial_sums[];} partials10;
+layout (binding = 11, std430) buffer PartialBuf11 {float partial_sums[];} partials11;
 
 layout(constant_id = 0) const uint num_srcs = 2;
 
+FLOAT_TYPE load_a(uint b, uint i) {
+    switch (b) {
+    case 0:  return FLOAT_TYPE(a0.data_a[i]);
+    case 1:  return FLOAT_TYPE(a1.data_a[i]);
+    case 2:  return FLOAT_TYPE(a2.data_a[i]);
+    case 3:  return FLOAT_TYPE(a3.data_a[i]);
+    case 4:  return FLOAT_TYPE(a4.data_a[i]);
+    case 5:  return FLOAT_TYPE(a5.data_a[i]);
+    case 6:  return FLOAT_TYPE(a6.data_a[i]);
+    case 7:  return FLOAT_TYPE(a7.data_a[i]);
+    case 8:  return FLOAT_TYPE(a8.data_a[i]);
+    case 9:  return FLOAT_TYPE(a9.data_a[i]);
+    case 10: return FLOAT_TYPE(a10.data_a[i]);
+    case 11: return FLOAT_TYPE(a11.data_a[i]);
+    default: return FLOAT_TYPE(0);
+    }
+}
+
+void store_d(uint b, uint i, FLOAT_TYPE v) {
+    switch (b) {
+    case 0:  d0.data_d[i] = D_TYPE(v); break;
+    case 1:  d1.data_d[i] = D_TYPE(v); break;
+    case 2:  d2.data_d[i] = D_TYPE(v); break;
+    case 3:  d3.data_d[i] = D_TYPE(v); break;
+    case 4:  d4.data_d[i] = D_TYPE(v); break;
+    case 5:  d5.data_d[i] = D_TYPE(v); break;
+    case 6:  d6.data_d[i] = D_TYPE(v); break;
+    case 7:  d7.data_d[i] = D_TYPE(v); break;
+    case 8:  d8.data_d[i] = D_TYPE(v); break;
+    case 9:  d9.data_d[i] = D_TYPE(v); break;
+    case 10: d10.data_d[i] = D_TYPE(v); break;
+    case 11: d11.data_d[i] = D_TYPE(v); break;
+    default: break;
+    }
+}
+
+void store_partial(uint b, uint i, float v) {
+    switch (b) {
+    case 0:  partials0.partial_sums[i] = v; break;
+    case 1:  partials1.partial_sums[i] = v; break;
+    case 2:  partials2.partial_sums[i] = v; break;
+    case 3:  partials3.partial_sums[i] = v; break;
+    case 4:  partials4.partial_sums[i] = v; break;
+    case 5:  partials5.partial_sums[i] = v; break;
+    case 6:  partials6.partial_sums[i] = v; break;
+    case 7:  partials7.partial_sums[i] = v; break;
+    case 8:  partials8.partial_sums[i] = v; break;
+    case 9:  partials9.partial_sums[i] = v; break;
+    case 10: partials10.partial_sums[i] = v; break;
+    case 11: partials11.partial_sums[i] = v; break;
+    default: break;
+    }
+}
+
 uint src_idx(uint s, uint i00, uint i01, uint i02, uint i03) {
     return i03*p.nb[s][3] + i02*p.nb[s][2] + i01*p.nb[s][1] + i00*p.nb[s][0];
 }
@@ -78,10 +162,10 @@ void main() {
 
         FLOAT_TYPE sum = FLOAT_TYPE(0);
         [[unroll]] for (uint s = 0; s < num_srcs; ++s) {
-            sum += FLOAT_TYPE(a[s].data_a[src_idx(s, i00, i01, i02, i03)]);
+            sum += load_a(s, src_idx(s, i00, i01, i02, i03));
         }
         sum_sq += sum*sum;
-        d[num_srcs].data_d[dst_idx(i00, i01, i02, i03)] = D_TYPE(sum);
+        store_d(num_srcs, dst_idx(i00, i01, i02, i03), sum);
 
         idx += num_threads;
     }
@@ -104,7 +188,7 @@ void main() {
         }
 
         if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) {
-            partials[num_srcs + 1].partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq;
+            store_partial(num_srcs + 1, orig_idx / (num_iter * num_threads), sum_sq);
         }
     }
 #endif