]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : alibi for arbitrary number of heads (#3426)
authorJiahao Li <redacted>
Tue, 3 Oct 2023 16:55:21 +0000 (00:55 +0800)
committerGitHub <redacted>
Tue, 3 Oct 2023 16:55:21 +0000 (19:55 +0300)
ggml-metal.m
ggml-metal.metal

index b3c463f03ad3d2438a7cf0f4aea8cd4246276d64..866fed43442a8c487bec4e117dd955e01ee92e1d 100644 (file)
@@ -1213,12 +1213,9 @@ void ggml_metal_graph_compute(
                             float max_bias;
                             memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
 
-                            if (__builtin_popcount(n_head) != 1) {
-                                GGML_ASSERT(false && "only power-of-two n_head implemented");
-                            }
-
                             const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
                             const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
+                            const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
 
                             [encoder setComputePipelineState:ctx->pipeline_alibi_f32];
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@@ -1239,7 +1236,9 @@ void ggml_metal_graph_compute(
                             [encoder setBytes:&nb1  length:sizeof(uint64_t) atIndex:15];
                             [encoder setBytes:&nb2  length:sizeof(uint64_t) atIndex:16];
                             [encoder setBytes:&nb3  length:sizeof(uint64_t) atIndex:17];
-                            [encoder setBytes:&m0  length:sizeof(    float) atIndex:18];
+                            [encoder setBytes:&m0   length:sizeof(   float) atIndex:18];
+                            [encoder setBytes:&m1   length:sizeof(   float) atIndex:19];
+                            [encoder setBytes:&n_heads_log2_floor   length:sizeof(int) atIndex:20];
 
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
index 5e1af6a092aed940f3c14fbaca618d348181e8e7..5a860098f157c49c131101f9d2b36d96d60d6e60 100644 (file)
@@ -830,7 +830,9 @@ kernel void kernel_alibi_f32(
         constant  uint64_t & nb1,
         constant  uint64_t & nb2,
         constant  uint64_t & nb3,
-        constant      float & m0,
+        constant     float & m0,
+        constant     float & m1,
+        constant       int & n_heads_log2_floor,
         uint3 tgpig[[threadgroup_position_in_grid]],
         uint3 tpitg[[thread_position_in_threadgroup]],
         uint3   ntg[[threads_per_threadgroup]]) {
@@ -846,7 +848,12 @@ kernel void kernel_alibi_f32(
     const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
 
     device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
-    float m_k = pow(m0, i2 + 1);
+    float m_k;
+    if (i2 < n_heads_log2_floor) {
+        m_k = pow(m0, i2 + 1);
+    } else {
+        m_k = pow(m1, 2 * (i2 - n_heads_log2_floor) + 1);
+    }
     for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
         device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
         dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);