]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : reorder write loop in mul mat kernel + style (#10231)
authorGeorgi Gerganov <redacted>
Sat, 9 Nov 2024 09:53:13 +0000 (11:53 +0200)
committerGitHub <redacted>
Sat, 9 Nov 2024 09:53:13 +0000 (11:53 +0200)
* metal : reorder write loop

* metal : int -> short, style

ggml-ci

ggml/src/ggml-metal.metal

index 779f459681fa1069dedb63df56c72de30b968052..413661c8a5d4280ef7166cdc94bf037ffb1cd9f3 100644 (file)
@@ -6318,8 +6318,8 @@ kernel void kernel_mul_mm(device const  uchar * src0,
     const uint im = tgpig.z;
 
     // if this block is of 64x32 shape or smaller
-    short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
-    short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
+    short n_rows = (ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
+    short n_cols = (ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
 
     // a thread shouldn't load data outside of the matrix
     short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
@@ -6327,9 +6327,10 @@ kernel void kernel_mul_mm(device const  uchar * src0,
 
     simdgroup_T8x8     ma[4];
     simdgroup_float8x8 mb[2];
-    simdgroup_float8x8 c_res[8];
-    for (int i = 0; i < 8; i++){
-        c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
+    simdgroup_float8x8 mc[8];
+
+    for (short i = 0; i < 8; i++){
+        mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
     }
 
     short il = (tiitg % THREAD_PER_ROW);
@@ -6340,7 +6341,7 @@ kernel void kernel_mul_mm(device const  uchar * src0,
     uint   offset0 = (i12/r2)*nb02 + (i13/r3)*nb03;
     ushort offset1 = il/nl;
 
-    device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
+    device const block_q * x = (device const block_q *)(src0 + (r0*BLOCK_SIZE_M + thread_row)*nb01 + offset0) + offset1;
     device const float   * y = (device const float   *)(src1
         + nb13 * i13
         + nb12 * i12
@@ -6354,13 +6355,13 @@ kernel void kernel_mul_mm(device const  uchar * src0,
         threadgroup_barrier(mem_flags::mem_threadgroup);
 
         #pragma unroll(16)
-        for (int i = 0; i < 16; i++) {
-            *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
-            +                     (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
-            +                     (tiitg / THREAD_PER_ROW) % 8  + (i & 7) * 8) = temp_a[i/4][i%4];
+        for (short i = 0; i < 16; i++) {
+            *(sa + SG_MAT_SIZE * ((tiitg/THREAD_PER_ROW/8) \
+            +                     (tiitg%THREAD_PER_ROW)*16 + (i/8)*8) \
+            +                     (tiitg/THREAD_PER_ROW)%8  + (i&7)*8) = temp_a[i/4][i%4];
         }
 
-        *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
+        *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL)*8*32 + 8*(tiitg/THREAD_PER_COL)) = *((device float2x4 *) y);
 
         il = (il + 2 < nl) ? il + 2 : il % 2;
         x  = (il < 2) ? x + (2+nl-1)/nl : x;
@@ -6369,27 +6370,27 @@ kernel void kernel_mul_mm(device const  uchar * src0,
         threadgroup_barrier(mem_flags::mem_threadgroup);
 
         // load matrices from threadgroup memory and conduct outer products
-        threadgroup T     * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
-        threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
+        threadgroup T     * lsma = (sa + THREAD_MAT_M*SG_MAT_SIZE*(sgitg%2));
+        threadgroup float * lsmb = (sb + THREAD_MAT_N*SG_MAT_SIZE*(sgitg/2));
 
         #pragma unroll(4)
-        for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
+        for (short ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
             #pragma unroll(4)
-            for (int i = 0; i < 4; i++) {
-                simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
+            for (short i = 0; i < 4; i++) {
+                simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
             }
             simdgroup_barrier(mem_flags::mem_none);
             #pragma unroll(2)
-            for (int i = 0; i < 2; i++) {
-                simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
+            for (short i = 0; i < 2; i++) {
+                simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
             }
 
-            lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
-            lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
+            lsma += BLOCK_SIZE_M/SG_MAT_ROW * SG_MAT_SIZE;
+            lsmb += BLOCK_SIZE_N/SG_MAT_ROW * SG_MAT_SIZE;
 
             #pragma unroll(8)
-            for (int i = 0; i < 8; i++){
-                simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
+            for (short i = 0; i < 8; i++){
+                simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
             }
         }
     }
@@ -6397,25 +6398,36 @@ kernel void kernel_mul_mm(device const  uchar * src0,
     if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
         device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg &  1)) \
                                + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
-        for (int i = 0; i < 8; i++) {
-            simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
+        for (short i = 0; i < 8; i++) {
+            simdgroup_store(mc[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
         }
     } else {
         // block is smaller than 64x32, we should avoid writing data outside of the matrix
         threadgroup_barrier(mem_flags::mem_threadgroup);
-        threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
-                                      + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
-        for (int i = 0; i < 8; i++) {
-            simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
+        threadgroup float * temp_str = ((threadgroup float *) shared_memory) \
+                                      + 32 * (sgitg&1) + (16 * (sgitg>>1))*BLOCK_SIZE_M;
+        for (short i = 0; i < 8; i++) {
+            simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
         }
 
         threadgroup_barrier(mem_flags::mem_threadgroup);
 
-        device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
         if (sgitg == 0) {
-            for (int i = 0; i < n_rows; i++) {
-                for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
-                    *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
+            for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
+                device float  * D  = dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*ne0 + im*ne1*ne0;
+                device float4 * D4 = (device float4 *) D;
+
+                threadgroup float  * C  = temp_str + (j*BLOCK_SIZE_M);
+                threadgroup float4 * C4 = (threadgroup float4 *) C;
+
+                int i = 0;
+                for (; i < n_rows/4; i++) {
+                    *(D4 + i) = *(C4 + i);
+                }
+
+                i *= 4;
+                for (; i < n_rows; i++) {
+                    *(D + i) = *(C + i);
                 }
             }
         }