]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
vulkan: add noncontiguous GLU support (#21081)
authorRuben Ortlam <redacted>
Sat, 28 Mar 2026 07:44:56 +0000 (08:44 +0100)
committerGitHub <redacted>
Sat, 28 Mar 2026 07:44:56 +0000 (08:44 +0100)
* vulkan: add noncontiguous GLU support

* fix compile issue

ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl
ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl

index 221e6fa04e957e481548ad2723964c33fe2d4fc5..15ed5b2a79df15fa3286ea9af578c74528fc6c0f 100644 (file)
@@ -1112,6 +1112,16 @@ struct vk_op_glu_push_constants {
     uint32_t mode;  // 0: default, 1: swapped, 2: split
     float alpha; // for swiglu_oai
     float limit;
+    uint32_t nb01;
+    uint32_t nb02;
+    uint32_t nb03;
+    uint32_t ne01;
+    uint32_t ne02;
+    uint32_t nb11;
+    uint32_t nb12;
+    uint32_t nb13;
+    uint32_t ne11;
+    uint32_t ne12;
 };
 
 struct vk_op_unary_push_constants {
@@ -5044,7 +5054,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
         } else {
             device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities});
         }
-        vk::DeviceCreateInfo device_create_info;
+        vk::DeviceCreateInfo device_create_info{};
         std::vector<const char *> device_extensions;
         vk::PhysicalDeviceFeatures device_features = device->physical_device.getFeatures();
 
@@ -5413,12 +5423,10 @@ static vk_device ggml_vk_get_device(size_t idx) {
 #endif
         device->name = GGML_VK_NAME + std::to_string(idx);
 
-        device_create_info = {
-            vk::DeviceCreateFlags(),
-            device_queue_create_infos,
-            {},
-            device_extensions
-        };
+        device_create_info
+            .setFlags(vk::DeviceCreateFlags())
+            .setQueueCreateInfos(device_queue_create_infos)
+            .setPEnabledExtensionNames(device_extensions);
         device_create_info.setPNext(&device_features2);
         device->device = device->physical_device.createDevice(device_create_info);
 
@@ -11048,8 +11056,6 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const
     const float alpha = op_params_f[2];
     const float limit = op_params_f[3];
 
-    GGML_ASSERT(ggml_is_contiguous(src0));
-
     if (!split) {
         GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]);
     } else {
@@ -11067,7 +11073,17 @@ static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const
             (uint32_t)dst->ne[0],
             mode,
             alpha,
-            limit
+            limit,
+            (uint32_t)(src0->nb[1] / src0->nb[0]),
+            (uint32_t)(src0->nb[2] / src0->nb[0]),
+            (uint32_t)(src0->nb[3] / src0->nb[0]),
+            (uint32_t)src0->ne[1],
+            (uint32_t)src0->ne[2],
+            (uint32_t)(dst->nb[1] / dst->nb[0]),
+            (uint32_t)(dst->nb[2] / dst->nb[0]),
+            (uint32_t)(dst->nb[3] / dst->nb[0]),
+            (uint32_t)dst->ne[1],
+            (uint32_t)dst->ne[2]
         });
 }
 
@@ -15217,8 +15233,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
                 case GGML_GLU_OP_SWIGLU_OAI:
                 case GGML_GLU_OP_GEGLU_ERF:
                 case GGML_GLU_OP_GEGLU_QUICK:
-                    return ggml_is_contiguous(op->src[0]) &&
-                           (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
+                    return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
                            (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) &&
                            (op->src[0]->type == op->type);
                 default:
index 2168989340b8c0863a54027e2625b92c1f69f822..95298922d83a1f247a1bb57fc61ad7ae670d3b86 100644 (file)
@@ -16,4 +16,14 @@ layout (push_constant) uniform parameter
     uint mode;
     float alpha;
     float limit;
+    uint nb01;
+    uint nb02;
+    uint nb03;
+    uint ne01;
+    uint ne02;
+    uint nb11;
+    uint nb12;
+    uint nb13;
+    uint ne11;
+    uint ne12;
 } p;
index 85cf65a9ecac8d8f1076ec36200a9090581e060b..359461306a5d1b39ed8d665a61eaa786322902f0 100644 (file)
@@ -8,22 +8,32 @@ void main() {
     const uint row = i / p.ne20;
     const uint col = i - row * p.ne20;
 
+    const uint i3 = row / (p.ne01 * p.ne02);
+    const uint i2 = (row % (p.ne01 * p.ne02)) / p.ne01;
+    const uint i1 = row % p.ne01;
+    const uint src_idx = i3 * p.nb03 + i2 * p.nb02 + i1 * p.nb01 + col;
+
+    const uint dst_i3 = row / (p.ne11 * p.ne12);
+    const uint dst_i2 = (row % (p.ne11 * p.ne12)) / p.ne11;
+    const uint dst_i1 = row % p.ne11;
+    const uint dst_idx = dst_i3 * p.nb13 + dst_i2 * p.nb12 + dst_i1 * p.nb11 + col;
+
     if (p.mode == 0) {
         // Default
         const uint offset = p.ne00 / 2;
-        const uint idx = row * p.ne00 + col;
+        const uint idx = src_idx;
 
-        data_d[row * offset + col] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset])));
+        data_d[dst_idx] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset])));
     } else if (p.mode == 1) {
         // Swapped
         const uint offset = p.ne00 / 2;
-        const uint idx = row * p.ne00 + col;
+        const uint idx = src_idx;
 
-        data_d[row * offset + col] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx])));
+        data_d[dst_idx] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx])));
     } else {
         // Split
-        const uint idx = row * p.ne00 + col;
+        const uint idx = src_idx;
 
-        data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx])));
+        data_d[dst_idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx])));
     }
 }