]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
vulkan: skip integer div/mod in get_offsets for batch_idx==0 (llama/10506)
authorJeff Bolz <redacted>
Wed, 27 Nov 2024 07:08:54 +0000 (01:08 -0600)
committerGeorgi Gerganov <redacted>
Tue, 3 Dec 2024 19:05:37 +0000 (21:05 +0200)
src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.comp

index 8d0a5791374d3b88de0e447063d2aa4f995381ca..2ec1af5c75542f9130c4ecdb8cc1b000289024a4 100644 (file)
@@ -52,13 +52,16 @@ void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
 #endif
 
 #ifndef MUL_MAT_ID
-    const uint i13 = batch_idx / p.ne12;
-    const uint i12 = batch_idx % p.ne12;
+    uint batch_idx_a = 0;
+    if (batch_idx != 0) {
+        const uint i13 = batch_idx / p.ne12;
+        const uint i12 = batch_idx % p.ne12;
 
-    const uint i03 = i13 / p.broadcast3;
-    const uint i02 = i12 / p.broadcast2;
+        const uint i03 = i13 / p.broadcast3;
+        const uint i02 = i12 / p.broadcast2;
 
-    const uint batch_idx_a = i03 * p.ne02 + i02;
+        batch_idx_a = i03 * p.ne02 + i02;
+    }
 #else
     const uint expert_id = data_ids[expert_idx];
 #endif