]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
metal : move mm_id indices to shared mem (llama/5982)
authorGeorgi Gerganov <redacted>
Sun, 10 Mar 2024 21:12:48 +0000 (23:12 +0200)
committerGeorgi Gerganov <redacted>
Thu, 14 Mar 2024 16:46:58 +0000 (18:46 +0200)
src/ggml-metal.m
src/ggml-metal.metal

index 00df2283821b2d89ca6908253c39f9aed80dd45c..3cf80de7bf2e075dc5f4d914633188f99cdafa97 100644 (file)
@@ -1642,8 +1642,8 @@ static enum ggml_status ggml_metal_graph_compute(
                         // TODO: make this more general
                         GGML_ASSERT(n_as <= 8);
 
-                        // max size of the src1ids array in the kernel stack
-                        GGML_ASSERT(ne11 <= 512);
+                        // max size of the src1ids array in the kernel shared buffer
+                        GGML_ASSERT(ne11 <= 4096);
 
                         const int64_t  ne20 = src2 ? src2->ne[0] : 0;
                         const int64_t  ne21 = src2 ? src2->ne[1] : 0;
@@ -1741,7 +1741,7 @@ static enum ggml_status ggml_metal_graph_compute(
                                 [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
                             }
 
-                            [encoder setThreadgroupMemoryLength:8192 atIndex:0];
+                            [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0];
 
                             [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
                         } else {
index 6ebbbd195e7cef09242f40562069c6fb8bb26719..50185ae4dea09567aaad97ddc4f0f1a20a1002db 100644 (file)
@@ -5386,7 +5386,7 @@ template<typename block_q, short nl, void (*dequantize_func)(device const block_
 void kernel_mul_mm_id_impl(
         device const  uchar * src0,
         device const  uchar * src1,
-        thread        short * src1ids,
+        threadgroup   short * src1ids,
         device        float * dst,
         constant    int64_t & ne00,
         constant    int64_t & ne02,
@@ -5589,9 +5589,9 @@ kernel void kernel_mul_mm_id(
     tgpig.z = tgpig.z%(ne12*ne13);
 
     // row indices of src1 for expert id
-    int64_t _ne1 = 0;
-    short src1ids[512];
+    threadgroup short * src1ids = (threadgroup short *)(shared_memory + 8192);
 
+    int64_t _ne1 = 0;
     for (int64_t i1 = 0; i1 < ne1; i1++) {
         if (((device int32_t *) (ids + i1*nbi1))[idx] == id) {
             src1ids[_ne1++] = i1;