]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : add missing barriers for mul-mat (#2699)
authorShouzheng Liu <redacted>
Tue, 22 Aug 2023 06:18:40 +0000 (02:18 -0400)
committerGitHub <redacted>
Tue, 22 Aug 2023 06:18:40 +0000 (09:18 +0300)
ggml-metal.metal

index 88d48f6c6a2ebddc085d10063714c8f7ca20cad3..ce3541f4bb55f6a06178975ad6b0fa51a2435906 100644 (file)
@@ -1850,6 +1850,7 @@ kernel void kernel_mul_mm(device const  uchar * src0,
         //load data and store to threadgroup memory
         half4x4 temp_a;
         dequantize_func(x, il, temp_a);
+        threadgroup_barrier(mem_flags::mem_threadgroup);
         #pragma unroll(16)
         for (int i = 0; i < 16; i++) {
             *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
@@ -1895,14 +1896,14 @@ kernel void kernel_mul_mm(device const  uchar * src0,
         }
     } else {
         // block is smaller than 64x32, we should avoid writing data outside of the matrix
+        threadgroup_barrier(mem_flags::mem_threadgroup);
         threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
                                       + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
         for (int i = 0; i < 8; i++) {
-            threadgroup_barrier(mem_flags::mem_device);
             simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
         }
 
-        threadgroup_barrier(mem_flags::mem_device);
+        threadgroup_barrier(mem_flags::mem_threadgroup);
         device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
         if (sgitg==0) {
             for (int i = 0; i < n_rows; i++) {