]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
metal : utilize max shared memory for mul_mat_id (llama/7935)
authorGeorgi Gerganov <redacted>
Fri, 14 Jun 2024 14:14:09 +0000 (17:14 +0300)
committerGeorgi Gerganov <redacted>
Sat, 15 Jun 2024 19:05:47 +0000 (22:05 +0300)
src/ggml-metal.m

index ec9e95302096c4f097d199b8a23c3884f3007b34..f894274cacc93e893614efc700c00c9274e6f326 100644 (file)
@@ -1862,9 +1862,10 @@ static enum ggml_status ggml_metal_graph_compute(
                         // ne21 = n_rows
                         const int dst_rows = ne20*ne21;
                         const int dst_rows_min = n_as;
+                        const int dst_rows_max = (ctx->device.maxThreadgroupMemoryLength - 32 - 8192)/4;
 
                         // max size of the rowids array in the kernel shared buffer
-                        GGML_ASSERT(dst_rows <= 2048);
+                        GGML_ASSERT(dst_rows <= dst_rows_max);
 
                         // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
                         // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel