]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : restore 363f0bf and fix reduce in F16_F32 kernels (#2986)
authorGeorgi Gerganov <redacted>
Sun, 3 Sep 2023 10:23:33 +0000 (13:23 +0300)
committerGitHub <redacted>
Sun, 3 Sep 2023 10:23:33 +0000 (13:23 +0300)
ggml-metal.metal

index 1d324e466a1469a7eca0cef368d86334c55944ee..119fcbeb623c11ca25a1c59c08564d7c3da32fba 100644 (file)
@@ -536,14 +536,27 @@ kernel void kernel_mul_mat_f16_f32_1row(
     device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
 
     float sumf = 0;
-    for (int i = tiisg; i < ne00; i += 32) {
-        sumf += (float) x[i] * (float) y[i];
+    if (ne00 < 128) {
+        for (int i = tiisg; i < ne00; i += 32) {
+            sumf += (float) x[i] * (float) y[i];
+        }
+        float all_sum = simd_sum(sumf);
+        if (tiisg == 0) {
+            dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+        }
+    } else {
+        device const half4  * x4 = (device const half4  *) x;
+        device const float4 * y4 = (device const float4 *) y;
+        for (int i = tiisg; i < ne00/4; i += 32) {
+            for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k];
+        }
+        float all_sum = simd_sum(sumf);
+        if (tiisg == 0) {
+            for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
+            dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+        }
     }
 
-    float all_sum = simd_sum(sumf);
-    if (tiisg == 0) {
-        dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
-    }
 }
 
 #define N_F16_F32 4
@@ -570,29 +583,54 @@ kernel void kernel_mul_mat_f16_f32(
         uint tiisg[[thread_index_in_simdgroup]]) {
 
     const int64_t r0 = tgpig.x;
-    const int64_t rb = N_F16_F32*tgpig.y;
+    const int64_t rb = tgpig.y*N_F16_F32;
     const int64_t im = tgpig.z;
 
     device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
 
-    for (int row = 0; row < N_F16_F32; ++row) {
-        int r1 = rb + row;
-        if (r1 >= ne11) {
-            break;
-        }
+    if (ne00 < 128) {
+        for (int row = 0; row < N_F16_F32; ++row) {
+            int r1 = rb + row;
+            if (r1 >= ne11) {
+                break;
+            }
 
-        device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
+            device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
 
-        float sumf = 0;
-        for (int i = tiisg; i < ne00; i += 32) {
-            sumf += (float) x[i] * (float) y[i];
+            float sumf = 0;
+            for (int i = tiisg; i < ne00; i += 32) {
+                sumf += (float) x[i] * (float) y[i];
+            }
+
+            float all_sum = simd_sum(sumf);
+            if (tiisg == 0) {
+                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            }
         }
+    } else {
+        device const half4 * x4 = (device const half4 *)x;
+        for (int row = 0; row < N_F16_F32; ++row) {
+            int r1 = rb + row;
+            if (r1 >= ne11) {
+                break;
+            }
 
-        float all_sum = simd_sum(sumf);
-        if (tiisg == 0) {
-            dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            device const float  * y  = (device const float  *) (src1 + r1*nb11 + im*nb12);
+            device const float4 * y4 = (device const float4 *) y;
+
+            float sumf = 0;
+            for (int i = tiisg; i < ne00/4; i += 32) {
+                for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
+            }
+
+            float all_sum = simd_sum(sumf);
+            if (tiisg == 0) {
+                for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
+                dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+            }
         }
     }
+
 }
 
 kernel void kernel_alibi_f32(