]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
metal : relax conditions on fast matrix multiplication kernel (#3168)
authorGeorgi Gerganov <redacted>
Fri, 15 Sep 2023 08:09:24 +0000 (11:09 +0300)
committerGitHub <redacted>
Fri, 15 Sep 2023 08:09:24 +0000 (11:09 +0300)
* metal : relax conditions on fast matrix multiplication kernel

* metal : revert the concurrnecy change because it was wrong

* llama : remove experimental stuff

ggml-metal.m
ggml-metal.metal
ggml.c
llama.cpp

index 4f3f14e24d9d8cf9e37b4bbf0d28edac43cfea76..3e3be98c5d91d112c520f560b7d5e61388eadf56 100644 (file)
@@ -66,6 +66,7 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(soft_max_4);
     GGML_METAL_DECL_KERNEL(diag_mask_inf);
     GGML_METAL_DECL_KERNEL(diag_mask_inf_8);
+    GGML_METAL_DECL_KERNEL(get_rows_f32);
     GGML_METAL_DECL_KERNEL(get_rows_f16);
     GGML_METAL_DECL_KERNEL(get_rows_q4_0);
     GGML_METAL_DECL_KERNEL(get_rows_q4_1);
@@ -145,7 +146,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
     ctx->n_buffers = 0;
     ctx->concur_list_len = 0;
 
-    ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
+    ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
 
 #ifdef GGML_SWIFT
     // load the default.metallib file
@@ -175,7 +176,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
 
         //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
         NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
-        NSString * path = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
+        NSString * path   = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
         metal_printf("%s: loading '%s'\n", __func__, [path UTF8String]);
 
         NSString * src  = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
@@ -224,6 +225,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(soft_max_4);
         GGML_METAL_ADD_KERNEL(diag_mask_inf);
         GGML_METAL_ADD_KERNEL(diag_mask_inf_8);
+        GGML_METAL_ADD_KERNEL(get_rows_f32);
         GGML_METAL_ADD_KERNEL(get_rows_f16);
         GGML_METAL_ADD_KERNEL(get_rows_q4_0);
         GGML_METAL_ADD_KERNEL(get_rows_q4_1);
@@ -293,7 +295,9 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(gelu);
     GGML_METAL_DEL_KERNEL(soft_max);
     GGML_METAL_DEL_KERNEL(soft_max_4);
+    GGML_METAL_DEL_KERNEL(diag_mask_inf);
     GGML_METAL_DEL_KERNEL(diag_mask_inf_8);
+    GGML_METAL_DEL_KERNEL(get_rows_f32);
     GGML_METAL_DEL_KERNEL(get_rows_f16);
     GGML_METAL_DEL_KERNEL(get_rows_q4_0);
     GGML_METAL_DEL_KERNEL(get_rows_q4_1);
@@ -386,6 +390,7 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
     for (int i = 0; i < ctx->n_buffers; ++i) {
         const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
 
+        //metal_printf("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
         if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
             *offs = (size_t) ioffs;
 
@@ -723,6 +728,7 @@ void ggml_metal_graph_compute(
                     case GGML_OP_ADD:
                         {
                             GGML_ASSERT(ggml_is_contiguous(src0));
+                            GGML_ASSERT(ggml_is_contiguous(src1));
 
                             // utilize float4
                             GGML_ASSERT(ne00 % 4 == 0);
@@ -730,6 +736,7 @@ void ggml_metal_graph_compute(
 
                             if (ggml_nelements(src1) == ne10) {
                                 // src1 is a row
+                                GGML_ASSERT(ne11 == 1);
                                 [encoder setComputePipelineState:ctx->pipeline_add_row];
                             } else {
                                 [encoder setComputePipelineState:ctx->pipeline_add];
@@ -746,6 +753,7 @@ void ggml_metal_graph_compute(
                     case GGML_OP_MUL:
                         {
                             GGML_ASSERT(ggml_is_contiguous(src0));
+                            GGML_ASSERT(ggml_is_contiguous(src1));
 
                             // utilize float4
                             GGML_ASSERT(ne00 % 4 == 0);
@@ -753,6 +761,7 @@ void ggml_metal_graph_compute(
 
                             if (ggml_nelements(src1) == ne10) {
                                 // src1 is a row
+                                GGML_ASSERT(ne11 == 1);
                                 [encoder setComputePipelineState:ctx->pipeline_mul_row];
                             } else {
                                 [encoder setComputePipelineState:ctx->pipeline_mul];
@@ -768,6 +777,8 @@ void ggml_metal_graph_compute(
                         } break;
                     case GGML_OP_SCALE:
                         {
+                            GGML_ASSERT(ggml_is_contiguous(src0));
+
                             const float scale = *(const float *) src1->data;
 
                             [encoder setComputePipelineState:ctx->pipeline_scale];
@@ -867,8 +878,8 @@ void ggml_metal_graph_compute(
 
                             // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
                             // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
-                            if (ggml_is_contiguous(src0) &&
-                                ggml_is_contiguous(src1) &&
+                            if (!ggml_is_transposed(src0) &&
+                                !ggml_is_transposed(src1) &&
                                 src1t == GGML_TYPE_F32 &&
                                 [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
                                 ne00%32 == 0 &&
@@ -893,9 +904,12 @@ void ggml_metal_graph_compute(
                                 [encoder setBytes:&nb01    length:sizeof(nb01) atIndex:5];
                                 [encoder setBytes:&nb02    length:sizeof(nb02) atIndex:6];
                                 [encoder setBytes:&ne12    length:sizeof(ne12) atIndex:7];
-                                [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:8];
-                                [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:9];
-                                [encoder setBytes:&gqa     length:sizeof(gqa)  atIndex:10];
+                                [encoder setBytes:&nb10    length:sizeof(nb10) atIndex:8];
+                                [encoder setBytes:&nb11    length:sizeof(nb11) atIndex:9];
+                                [encoder setBytes:&nb12    length:sizeof(nb12) atIndex:10];
+                                [encoder setBytes:&ne0     length:sizeof(ne0)  atIndex:11];
+                                [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:12];
+                                [encoder setBytes:&gqa     length:sizeof(gqa)  atIndex:13];
                                 [encoder setThreadgroupMemoryLength:8192 atIndex:0];
                                 [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
                             } else {
@@ -1045,6 +1059,7 @@ void ggml_metal_graph_compute(
                     case GGML_OP_GET_ROWS:
                         {
                             switch (src0->type) {
+                                case GGML_TYPE_F32:  [encoder setComputePipelineState:ctx->pipeline_get_rows_f32];  break;
                                 case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_get_rows_f16];  break;
                                 case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
                                 case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
@@ -1060,9 +1075,9 @@ void ggml_metal_graph_compute(
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
-                            [encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3];
-                            [encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4];
-                            [encoder setBytes:&(dst->nb[1])  length:sizeof(uint64_t) atIndex:5];
+                            [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
+                            [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
+                            [encoder setBytes:&nb1  length:sizeof(uint64_t) atIndex:5];
 
                             const int64_t n = ggml_nelements(src1);
 
index f45b1490f5b492c2b8eea6010f4b9bf40179da49..ea8b428448b8d1eb8ade086171a772763a393c2c 100644 (file)
@@ -38,7 +38,7 @@ kernel void kernel_add_row(
         device const float4 * src0,
         device const float4 * src1,
         device       float4 * dst,
-        constant   int64_t & nb,
+        constant    int64_t & nb,
         uint tpig[[thread_position_in_grid]]) {
     dst[tpig] = src0[tpig] + src1[tpig % nb];
 }
@@ -1321,7 +1321,6 @@ kernel void kernel_mul_mat_q3_K_f32(
             dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
         }
     }
-
 }
 #else
 kernel void kernel_mul_mat_q3_K_f32(
@@ -1865,6 +1864,15 @@ kernel void kernel_mul_mat_q6_K_f32(
 
 //============================= templates and their specializations =============================
 
+// NOTE: this is not dequantizing - we are simply fitting the template
+template <typename type4x4>
+void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
+    float4x4 temp = *(((device float4x4 *)src));
+    for (int i = 0; i < 16; i++){
+        reg[i/4][i%4] = temp[i/4][i%4];
+    }
+}
+
 template <typename type4x4>
 void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
     half4x4 temp = *(((device half4x4 *)src));
@@ -1875,7 +1883,6 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
 
 template <typename type4x4>
 void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
-
     device const uint16_t * qs = ((device const uint16_t *)xb + 1);
     const float d1 = il ? (xb->d / 16.h) : xb->d;
     const float d2 = d1 / 256.f;
@@ -1887,12 +1894,10 @@ void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg
         reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
         reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
     }
-
 }
 
 template <typename type4x4>
 void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
-
     device const uint16_t * qs = ((device const uint16_t *)xb + 2);
     const float d1 = il ? (xb->d / 16.h) : xb->d;
     const float d2 = d1 / 256.f;
@@ -1964,7 +1969,6 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
     for (int i = 0; i < 16; ++i) {
         reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
     }
-
 #else
     float    kcoef = il&1 ? 1.f/16.f : 1.f;
     uint16_t kmask = il&1 ? 0xF0     : 0x0F;
@@ -2110,22 +2114,25 @@ kernel void kernel_get_rows(
 // each block_q contains 16*nl weights
 template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
 kernel void kernel_mul_mm(device const  uchar * src0,
-                           device const  float * src1,
-                           device        float * dst,
-                           constant    int64_t & ne00,
-                           constant    int64_t & ne02,
-                           constant    int64_t & nb01,
-                           constant    int64_t & nb02,
-                           constant    int64_t & ne12,
-                           constant    int64_t & ne0,
-                           constant    int64_t & ne1,
-                           constant    uint & gqa,
-                           threadgroup   uchar * shared_memory [[threadgroup(0)]],
-                           uint3                 tgpig[[threadgroup_position_in_grid]],
-                           uint                  tiitg[[thread_index_in_threadgroup]],
-                           uint                  sgitg[[simdgroup_index_in_threadgroup]]) {
-
-    threadgroup half * sa = ((threadgroup half *)shared_memory);
+                          device const  uchar * src1,
+                          device        float * dst,
+                          constant    int64_t & ne00,
+                          constant    int64_t & ne02,
+                          constant    int64_t & nb01,
+                          constant    int64_t & nb02,
+                          constant    int64_t & ne12,
+                          constant    int64_t & nb10,
+                          constant    int64_t & nb11,
+                          constant    int64_t & nb12,
+                          constant    int64_t & ne0,
+                          constant    int64_t & ne1,
+                          constant       uint & gqa,
+                          threadgroup   uchar * shared_memory [[threadgroup(0)]],
+                          uint3                 tgpig[[threadgroup_position_in_grid]],
+                          uint                  tiitg[[thread_index_in_threadgroup]],
+                          uint                  sgitg[[simdgroup_index_in_threadgroup]]) {
+
+    threadgroup half  * sa = (threadgroup half  *)(shared_memory);
     threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
 
     const uint r0 = tgpig.y;
@@ -2138,7 +2145,7 @@ kernel void kernel_mul_mm(device const  uchar * src0,
     short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
     short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
 
-    simdgroup_half8x8 ma[4];
+    simdgroup_half8x8  ma[4];
     simdgroup_float8x8 mb[2];
     simdgroup_float8x8 c_res[8];
     for (int i = 0; i < 8; i++){
@@ -2146,10 +2153,15 @@ kernel void kernel_mul_mm(device const  uchar * src0,
     }
 
     short il = (tiitg % THREAD_PER_ROW);
-    uint offset0 = im/gqa*nb02; ushort offset1 = il/nl;
-    device const block_q  * x = (device const block_q  *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
-    device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \
-                             + BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1;
+
+    uint   offset0 = im/gqa*nb02;
+    ushort offset1 = il/nl;
+
+    device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
+    device const float   * y = (device const float   *)(src1
+        + nb12 * im
+        + nb11 * (r1 * BLOCK_SIZE_N + thread_col)
+        + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
 
     for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
         //load data and store to threadgroup memory
@@ -2229,6 +2241,7 @@ kernel void kernel_mul_mm(device const  uchar * src0,
 typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
                           constant uint64_t &, constant uint64_t &, uint, uint, uint);
 
+template [[host_name("kernel_get_rows_f32")]]  kernel get_rows_t kernel_get_rows<float4x4,   1, dequantize_f32>;
 template [[host_name("kernel_get_rows_f16")]]  kernel get_rows_t kernel_get_rows<half4x4,    1, dequantize_f16>;
 template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
 template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
@@ -2239,14 +2252,27 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
 template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
 template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
 
-typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\
-                             constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
-                             constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
-
-template [[host_name("kernel_mul_mm_f16_f32")]]  kernel mat_mm_t kernel_mul_mm<half4x4,    1, dequantize_f16>;
-template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
-template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
-template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
+typedef void (mat_mm_t)(
+        device const  uchar * src0,
+        device const  uchar * src1,
+        device        float * dst,
+        constant    int64_t & ne00,
+        constant    int64_t & ne02,
+        constant    int64_t & nb01,
+        constant    int64_t & nb02,
+        constant    int64_t & ne12,
+        constant    int64_t & nb10,
+        constant    int64_t & nb11,
+        constant    int64_t & nb12,
+        constant    int64_t & ne0,
+        constant    int64_t & ne1,
+        constant       uint & gqa,
+        threadgroup uchar *, uint3, uint, uint);
+
+template [[host_name("kernel_mul_mm_f16_f32")]]  kernel mat_mm_t kernel_mul_mm<half4x4,    1,     dequantize_f16>;
+template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2,     dequantize_q4_0>;
+template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2,     dequantize_q4_1>;
+template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2,     dequantize_q8_0>;
 template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
 template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
 template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
diff --git a/ggml.c b/ggml.c
index a9cffb439a2e6131ad6b46cb7f4cdaceba86bfe1..96edebeb8e340bf59b0cda75212f2647a332bc57 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -4303,10 +4303,21 @@ int64_t ggml_nrows(const struct ggml_tensor * tensor) {
 }
 
 size_t ggml_nbytes(const struct ggml_tensor * tensor) {
-    size_t nbytes = tensor->ne[0]*tensor->nb[0]/ggml_blck_size(tensor->type);
-    for (int i = 1; i < GGML_MAX_DIMS; ++i) {
-        nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
+    size_t nbytes;
+    size_t blck_size = ggml_blck_size(tensor->type);
+    if (blck_size == 1) {
+        nbytes = ggml_type_size(tensor->type);
+        for (int i = 0; i < GGML_MAX_DIMS; ++i) {
+            nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
+        }
     }
+    else {
+        nbytes = tensor->ne[0]*tensor->nb[0]/blck_size;
+        for (int i = 1; i < GGML_MAX_DIMS; ++i) {
+            nbytes += (tensor->ne[i] - 1)*tensor->nb[i];
+        }
+    }
+
     return nbytes;
 }
 
@@ -18340,7 +18351,8 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) {
         GGML_PRINT(" - %3d: [ %5" PRId64 ", %5" PRId64 "] %8s\n",
                 i,
                 node->ne[0], node->ne[1],
-                ggml_op_name(node->op));
+                ggml_op_name(node->op),
+                ggml_get_name(node));
     }
 
     for (int i = 0; i < GGML_OP_COUNT; i++) {
index 30728b7cb5cb5ea310cd6fd1664e51f09d740c86..0cab18093a8485829fd6c3598b6a9d6e28cbbde5 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -3429,10 +3429,6 @@ static bool llama_eval_internal(
     if (lctx.ctx_metal) {
         ggml_metal_set_n_cb     (lctx.ctx_metal, n_threads);
         ggml_metal_graph_compute(lctx.ctx_metal, gf);
-        ggml_metal_get_tensor   (lctx.ctx_metal, res);
-        if (!lctx.embedding.empty()) {
-            ggml_metal_get_tensor(lctx.ctx_metal, embeddings);
-        }
     } else {
         ggml_graph_compute_helper(lctx.work_buffer, gf, n_threads);
     }