]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
sync : llama.cpp (Metal + OpenCL + minor alibi) (#558)
authorGeorgi Gerganov <redacted>
Sun, 8 Oct 2023 16:44:14 +0000 (19:44 +0300)
committerGitHub <redacted>
Sun, 8 Oct 2023 16:44:14 +0000 (19:44 +0300)
ggml-ci

scripts/sync-llama.sh
src/ggml-metal.m
src/ggml-metal.metal
src/ggml-opencl.cpp
src/ggml.c

index db7ee49a653431fca0e9250acd9c9af981daf663..b9b7aed14f1a52da3cf3a05c1c50f9574bbe51dc 100755 (executable)
@@ -2,6 +2,7 @@
 
 cp -rpv ../llama.cpp/ggml.c           src/ggml.c
 cp -rpv ../llama.cpp/ggml-alloc.c     src/ggml-alloc.c
+cp -rpv ../llama.cpp/ggml-backend.c   src/ggml-backend.c
 cp -rpv ../llama.cpp/ggml-cuda.h      src/ggml-cuda.h
 cp -rpv ../llama.cpp/ggml-cuda.cu     src/ggml-cuda.cu
 cp -rpv ../llama.cpp/ggml-opencl.h    src/ggml-opencl.h
@@ -11,6 +12,7 @@ cp -rpv ../llama.cpp/ggml-metal.m     src/ggml-metal.m
 cp -rpv ../llama.cpp/ggml-metal.metal src/ggml-metal.metal
 cp -rpv ../llama.cpp/ggml.h           include/ggml/ggml.h
 cp -rpv ../llama.cpp/ggml-alloc.h     include/ggml/ggml-alloc.h
+cp -rpv ../llama.cpp/ggml-backend.h   include/ggml/ggml-backend.h
 
 cp -rpv ../llama.cpp/tests/test-opt.cpp           tests/test-opt.cpp
 cp -rpv ../llama.cpp/tests/test-grad0.cpp         tests/test-grad0.cpp
index 47b37819be8dcd0f7ad4993acdbb8f13376cede9..29cb3c922daeba4e912f1b2defcca6fbc30e18f4 100644 (file)
@@ -81,18 +81,18 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(get_rows_q6_K);
     GGML_METAL_DECL_KERNEL(rms_norm);
     GGML_METAL_DECL_KERNEL(norm);
-    GGML_METAL_DECL_KERNEL(mul_mat_f32_f32);
-    GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
-    GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
-    GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
-    GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
-    GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
-    GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
-    GGML_METAL_DECL_KERNEL(mul_mat_q2_K_f32);
-    GGML_METAL_DECL_KERNEL(mul_mat_q3_K_f32);
-    GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
-    GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
-    GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_f16_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_1row);
+    GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
+    GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
@@ -109,6 +109,8 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(cpy_f32_f16);
     GGML_METAL_DECL_KERNEL(cpy_f32_f32);
     GGML_METAL_DECL_KERNEL(cpy_f16_f16);
+    GGML_METAL_DECL_KERNEL(concat);
+    GGML_METAL_DECL_KERNEL(sqr);
 
 #undef GGML_METAL_DECL_KERNEL
 };
@@ -151,6 +153,8 @@ static void ggml_metal_log(enum ggml_log_level level, const char* format, ...){
     }
 }
 
+
+
 struct ggml_metal_context * ggml_metal_init(int n_cb) {
     GGML_METAL_LOG_INFO("%s: allocating\n", __func__);
 
@@ -181,56 +185,44 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
 
     ctx->d_queue = dispatch_queue_create("ggml-metal", DISPATCH_QUEUE_CONCURRENT);
 
-#ifdef GGML_SWIFT
-    // load the default.metallib file
+    // load library
     {
-        NSError * error = nil;
-
-        NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
-        NSString * llamaBundlePath = [bundle pathForResource:@"llama_llama" ofType:@"bundle"];
-        NSBundle * llamaBundle = [NSBundle bundleWithPath:llamaBundlePath];
-        NSString * libPath = [llamaBundle pathForResource:@"default" ofType:@"metallib"];
-        NSURL * libURL = [NSURL fileURLWithPath:libPath];
-
-        // Load the metallib file into a Metal library
-        ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
-
-        if (error) {
-            GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
-            return NULL;
-        }
-    }
+        NSBundle * bundle = nil;
+#ifdef SWIFT_PACKAGE
+        bundle = SWIFTPM_MODULE_BUNDLE;
 #else
-    UNUSED(msl_library_source);
-
-    // read the source from "ggml-metal.metal" into a string and use newLibraryWithSource
-    {
+        bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
+#endif
         NSError * error = nil;
+        NSString * libPath = [bundle pathForResource:@"default" ofType:@"metallib"];
+        if (libPath != nil) {
+            NSURL * libURL = [NSURL fileURLWithPath:libPath];
+            GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [libPath UTF8String]);
+            ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
+        } else {
+            GGML_METAL_LOG_INFO("%s: default.metallib not found, loading from source\n", __func__);
+
+            NSString * sourcePath = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
+            GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [sourcePath UTF8String]);
+            NSString * src = [NSString stringWithContentsOfFile:sourcePath encoding:NSUTF8StringEncoding error:&error];
+            if (error) {
+                GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
+                return NULL;
+            }
 
-        //NSString * path = [[NSBundle mainBundle] pathForResource:@"../../examples/metal/metal" ofType:@"metal"];
-        NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
-        NSString * path   = [bundle pathForResource:@"ggml-metal" ofType:@"metal"];
-        GGML_METAL_LOG_INFO("%s: loading '%s'\n", __func__, [path UTF8String]);
-
-        NSString * src  = [NSString stringWithContentsOfFile:path encoding:NSUTF8StringEncoding error:&error];
-        if (error) {
-            GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
-            return NULL;
-        }
-
+            MTLCompileOptions* options = nil;
 #ifdef GGML_QKK_64
-        MTLCompileOptions* options = [MTLCompileOptions new];
-        options.preprocessorMacros = @{ @"QK_K" : @(64) };
-        ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
-#else
-        ctx->library = [ctx->device newLibraryWithSource:src options:nil error:&error];
+            options = [MTLCompileOptions new];
+            options.preprocessorMacros = @{ @"QK_K" : @(64) };
 #endif
+            ctx->library = [ctx->device newLibraryWithSource:src options:options error:&error];
+        }
+
         if (error) {
             GGML_METAL_LOG_ERROR("%s: error: %s\n", __func__, [[error description] UTF8String]);
             return NULL;
         }
     }
-#endif
 
     // load kernels
     {
@@ -270,40 +262,57 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(get_rows_q6_K);
         GGML_METAL_ADD_KERNEL(rms_norm);
         GGML_METAL_ADD_KERNEL(norm);
-        GGML_METAL_ADD_KERNEL(mul_mat_f32_f32);
-        GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
-        GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
-        GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
-        GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
-        GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
-        GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
-        GGML_METAL_ADD_KERNEL(mul_mat_q2_K_f32);
-        GGML_METAL_ADD_KERNEL(mul_mat_q3_K_f32);
-        GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
-        GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
-        GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
-        GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
-        GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
-        GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
-        GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
-        GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
-        GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
-        GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
-        GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
-        GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
-        GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_f16_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_1row);
+        GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
+        GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
+        if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
+            GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
+        }
         GGML_METAL_ADD_KERNEL(rope_f32);
         GGML_METAL_ADD_KERNEL(rope_f16);
         GGML_METAL_ADD_KERNEL(alibi_f32);
         GGML_METAL_ADD_KERNEL(cpy_f32_f16);
         GGML_METAL_ADD_KERNEL(cpy_f32_f32);
         GGML_METAL_ADD_KERNEL(cpy_f16_f16);
+        GGML_METAL_ADD_KERNEL(concat);
+        GGML_METAL_ADD_KERNEL(sqr);
 
 #undef GGML_METAL_ADD_KERNEL
     }
 
-    GGML_METAL_LOG_INFO("%s: hasUnifiedMemory              = %s\n",       __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
 #if TARGET_OS_OSX
+    // print MTL GPU family:
+    GGML_METAL_LOG_INFO("%s: GPU name:   %s\n", __func__, [[ctx->device name] UTF8String]);
+
+    // determine max supported GPU family
+    // https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf
+    // https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf
+    for (int i = MTLGPUFamilyApple1 + 20; i >= MTLGPUFamilyApple1; --i) {
+        if ([ctx->device supportsFamily:i]) {
+            GGML_METAL_LOG_INFO("%s: GPU family: MTLGPUFamilyApple%d (%d)\n", __func__, i - MTLGPUFamilyApple1 + 1, i);
+            break;
+        }
+    }
+
+    GGML_METAL_LOG_INFO("%s: hasUnifiedMemory              = %s\n",       __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
     GGML_METAL_LOG_INFO("%s: recommendedMaxWorkingSetSize  = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
     if (ctx->device.maxTransferRate != 0) {
         GGML_METAL_LOG_INFO("%s: maxTransferRate               = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
@@ -345,34 +354,38 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(get_rows_q6_K);
     GGML_METAL_DEL_KERNEL(rms_norm);
     GGML_METAL_DEL_KERNEL(norm);
-    GGML_METAL_DEL_KERNEL(mul_mat_f32_f32);
-    GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
-    GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
-    GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
-    GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
-    GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
-    GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
-    GGML_METAL_DEL_KERNEL(mul_mat_q2_K_f32);
-    GGML_METAL_DEL_KERNEL(mul_mat_q3_K_f32);
-    GGML_METAL_DEL_KERNEL(mul_mat_q4_K_f32);
-    GGML_METAL_DEL_KERNEL(mul_mat_q5_K_f32);
-    GGML_METAL_DEL_KERNEL(mul_mat_q6_K_f32);
-    GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
-    GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
-    GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
-    GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
-    GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
-    GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
-    GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
-    GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
-    GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
-    GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_f16_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_1row);
+    GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
+    GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
+    if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
+        GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32);
+    }
     GGML_METAL_DEL_KERNEL(rope_f32);
     GGML_METAL_DEL_KERNEL(rope_f16);
     GGML_METAL_DEL_KERNEL(alibi_f32);
     GGML_METAL_DEL_KERNEL(cpy_f32_f16);
     GGML_METAL_DEL_KERNEL(cpy_f32_f32);
     GGML_METAL_DEL_KERNEL(cpy_f16_f16);
+    GGML_METAL_DEL_KERNEL(concat);
+    GGML_METAL_DEL_KERNEL(sqr);
 
 #undef GGML_METAL_DEL_KERNEL
 
@@ -429,7 +442,7 @@ static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, stru
     for (int i = 0; i < ctx->n_buffers; ++i) {
         const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
 
-        //metal_printf("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
+        //GGML_METAL_LOG_INFO("ioffs = %10ld, tsize = %10ld, sum = %10ld, ctx->buffers[%d].size = %10ld, name = %s\n", ioffs, tsize, ioffs + tsize, i, ctx->buffers[i].size, ctx->buffers[i].name);
         if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
             *offs = (size_t) ioffs;
 
@@ -764,6 +777,43 @@ void ggml_metal_graph_compute(
                         {
                             // noop
                         } break;
+                    case GGML_OP_CONCAT:
+                        {
+
+                            int64_t nb = ne00;
+                            [encoder setComputePipelineState:ctx->pipeline_concat];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:2];
+                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+                            [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
+                            [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
+                            [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
+                            [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
+                            [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
+                            [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
+                            [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
+                            [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
+                            [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
+                            [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
+                            [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
+                            [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
+                            [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
+                            [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
+                            [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
+                            [encoder setBytes:&ne0  length:sizeof(ne0)  atIndex:19];
+                            [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:20];
+                            [encoder setBytes:&ne2  length:sizeof(ne2)  atIndex:21];
+                            [encoder setBytes:&ne3  length:sizeof(ne3)  atIndex:22];
+                            [encoder setBytes:&nb0  length:sizeof(nb0)  atIndex:23];
+                            [encoder setBytes:&nb1  length:sizeof(nb1)  atIndex:24];
+                            [encoder setBytes:&nb2  length:sizeof(nb2)  atIndex:25];
+                            [encoder setBytes:&nb3  length:sizeof(nb3)  atIndex:26];
+                            [encoder setBytes:&nb   length:sizeof(nb)   atIndex:27];
+
+                            const int nth = MIN(1024, ne0);
+                            [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+                        } break;
                     case GGML_OP_ADD:
                         {
                             GGML_ASSERT(ggml_is_contiguous(src0));
@@ -901,6 +951,17 @@ void ggml_metal_graph_compute(
                                     GGML_ASSERT(false);
                                 }
                         } break;
+                    case GGML_OP_SQR:
+                        {
+                            GGML_ASSERT(ggml_is_contiguous(src0));
+
+                            [encoder setComputePipelineState:ctx->pipeline_sqr];
+                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst atIndex:1];
+
+                            const int64_t n = ggml_nelements(dst);
+                            [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                        } break;
                     case GGML_OP_SOFT_MAX:
                         {
                             const int nth = MIN(32, ne00);
@@ -942,21 +1003,46 @@ void ggml_metal_graph_compute(
                         } break;
                     case GGML_OP_MUL_MAT:
                         {
-                            // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
-
                             GGML_ASSERT(ne00 == ne10);
-                            // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
-                            uint gqa = ne12/ne02;
                             GGML_ASSERT(ne03 == ne13);
 
+                            const uint gqa = ne12/ne02;
+
+                            // find the break-even point where the matrix-matrix kernel becomes more efficient compared
+                            // to the matrix-vector kernel
+                            int ne11_mm_min = 1;
+
+#if 0
+                            // the numbers below are measured on M2 Ultra for 7B and 13B models
+                            // these numbers do not translate to other devices or model sizes
+                            // TODO: need to find a better approach
+                            if ([ctx->device.name isEqualToString:@"Apple M2 Ultra"]) {
+                                switch (src0t) {
+                                    case GGML_TYPE_F16:  ne11_mm_min = 2;  break;
+                                    case GGML_TYPE_Q8_0: ne11_mm_min = 7;  break;
+                                    case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
+                                    case GGML_TYPE_Q3_K: ne11_mm_min = 7;  break;
+                                    case GGML_TYPE_Q4_0:
+                                    case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
+                                    case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
+                                    case GGML_TYPE_Q5_0:                          // not tested yet
+                                    case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
+                                    case GGML_TYPE_Q5_K: ne11_mm_min = 7;  break;
+                                    case GGML_TYPE_Q6_K: ne11_mm_min = 7;  break;
+                                    default:             ne11_mm_min = 1;  break;
+                                }
+                            }
+#endif
+
                             // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
                             // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
-                            if (!ggml_is_transposed(src0) &&
+                            if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
+                                !ggml_is_transposed(src0) &&
                                 !ggml_is_transposed(src1) &&
                                 src1t == GGML_TYPE_F32 &&
-                                [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
-                                ne00%32 == 0 &&
-                                ne11 > 2) {
+                                ne00 % 32 == 0 &&
+                                ne11 > ne11_mm_min) {
+                                //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
                                 switch (src0->type) {
                                     case GGML_TYPE_F32:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_f32_f32];  break;
                                     case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32];  break;
@@ -985,17 +1071,18 @@ void ggml_metal_graph_compute(
                                 [encoder setBytes:&ne1     length:sizeof(ne1)  atIndex:12];
                                 [encoder setBytes:&gqa     length:sizeof(gqa)  atIndex:13];
                                 [encoder setThreadgroupMemoryLength:8192 atIndex:0];
-                                [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+                                [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
                             } else {
                                 int nth0 = 32;
                                 int nth1 = 1;
                                 int nrows = 1;
+                                //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
 
                                 // use custom matrix x vector kernel
                                 switch (src0t) {
                                     case GGML_TYPE_F32:
                                         {
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_f32_f32];
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_f32_f32];
                                             nrows = 4;
                                         } break;
                                     case GGML_TYPE_F16:
@@ -1003,12 +1090,12 @@ void ggml_metal_graph_compute(
                                             nth0 = 32;
                                             nth1 = 1;
                                             if (ne11 * ne12 < 4) {
-                                                [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
+                                                [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_1row];
                                             } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
-                                                [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
+                                                [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32_l4];
                                                 nrows = ne11;
                                             } else {
-                                                [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
+                                                [encoder setComputePipelineState:ctx->pipeline_mul_mv_f16_f32];
                                                 nrows = 4;
                                             }
                                         } break;
@@ -1019,7 +1106,7 @@ void ggml_metal_graph_compute(
 
                                             nth0 = 8;
                                             nth1 = 8;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32];
                                         } break;
                                     case GGML_TYPE_Q4_1:
                                         {
@@ -1028,7 +1115,7 @@ void ggml_metal_graph_compute(
 
                                             nth0 = 8;
                                             nth1 = 8;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
                                         } break;
                                     case GGML_TYPE_Q8_0:
                                         {
@@ -1037,7 +1124,7 @@ void ggml_metal_graph_compute(
 
                                             nth0 = 8;
                                             nth1 = 8;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q8_0_f32];
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32];
                                         } break;
                                     case GGML_TYPE_Q2_K:
                                         {
@@ -1046,7 +1133,7 @@ void ggml_metal_graph_compute(
 
                                             nth0 = 2;
                                             nth1 = 32;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32];
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32];
                                         } break;
                                     case GGML_TYPE_Q3_K:
                                         {
@@ -1055,7 +1142,7 @@ void ggml_metal_graph_compute(
 
                                             nth0 = 2;
                                             nth1 = 32;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_K_f32];
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32];
                                         } break;
                                     case GGML_TYPE_Q4_K:
                                         {
@@ -1064,7 +1151,7 @@ void ggml_metal_graph_compute(
 
                                             nth0 = 4; //1;
                                             nth1 = 8; //32;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32];
                                         } break;
                                     case GGML_TYPE_Q5_K:
                                         {
@@ -1073,7 +1160,7 @@ void ggml_metal_graph_compute(
 
                                             nth0 = 2;
                                             nth1 = 32;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32];
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32];
                                         } break;
                                     case GGML_TYPE_Q6_K:
                                         {
@@ -1082,7 +1169,7 @@ void ggml_metal_graph_compute(
 
                                             nth0 = 2;
                                             nth1 = 32;
-                                            [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32];
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32];
                                         } break;
                                     default:
                                         {
@@ -1111,7 +1198,7 @@ void ggml_metal_graph_compute(
                                 [encoder setBytes:&gqa  length:sizeof(gqa)  atIndex:17];
 
                                 if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
-                                    src0t == GGML_TYPE_Q2_K) {// || src0t == GGML_TYPE_Q4_K) {
+                                    src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
                                 else if (src0t == GGML_TYPE_Q4_K) {
index 5a860098f157c49c131101f9d2b36d96d60d6e60..b6288db28660dc9e8902cb7a27369e9af30db695 100644 (file)
@@ -13,8 +13,8 @@ typedef struct {
 
 #define QK4_1 32
 typedef struct {
-    half d;          // delta
-    half m;          // min
+    half d;                 // delta
+    half m;                 // min
     uint8_t qs[QK4_1 / 2];  // nibbles / quants
 } block_q4_1;
 
@@ -132,6 +132,13 @@ kernel void kernel_relu(
     dst[tpig] = max(0.0f, src0[tpig]);
 }
 
+kernel void kernel_sqr(
+        device const float * src0,
+        device       float * dst,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] * src0[tpig];
+}
+
 constant float GELU_COEF_A    = 0.044715f;
 constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
 
@@ -416,8 +423,8 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
 }
 
 // putting them in the kernel cause a significant performance penalty
-#define N_DST 4 // each SIMD group works on 4 rows
-#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
+#define N_DST 4        // each SIMD group works on 4 rows
+#define N_SIMDGROUP 2  // number of SIMD groups in a thread group
 #define N_SIMDWIDTH 32 // assuming SIMD group size is 32
 //Note: This is a template, but strictly speaking it only applies to
 //      quantizations where the block size is 32. It also does not
@@ -428,18 +435,23 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
                     int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
                     uint3 tgpig, uint tiisg, uint sgitg) {
     const int nb = ne00/QK4_0;
+
     const int r0 = tgpig.x;
     const int r1 = tgpig.y;
     const int im = tgpig.z;
+
     const int first_row = (r0 * nsg + sgitg) * nr;
+
     const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
+
     device const block_q_type * x = (device const block_q_type *) src0 + offset0;
     device const float        * y = (device const float        *) src1 + r1*ne10 + im*ne00*ne1;
-    float yl[16];       // src1 vector cache
-    float sumf[nr]={0.f};
 
-    const int ix = tiisg/2;
-    const int il = 8*(tiisg%2);
+    float yl[16]; // src1 vector cache
+    float sumf[nr] = {0.f};
+
+    const int ix = (tiisg/2);
+    const int il = (tiisg%2)*8;
 
     device const float * yb = y + ix * QK4_0 + il;
 
@@ -450,6 +462,7 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
             sumy += yb[i] + yb[i+1];
             yl[i+0] = yb[i+ 0];
             yl[i+1] = yb[i+ 1]/256.f;
+
             sumy += yb[i+16] + yb[i+17];
             yl[i+8] = yb[i+16]/16.f;
             yl[i+9] = yb[i+17]/4096.f;
@@ -465,12 +478,12 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device
     for (int row = 0; row < nr; ++row) {
         const float tot = simd_sum(sumf[row]);
         if (tiisg == 0 && first_row + row < ne01) {
-            dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
+            dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot;
         }
     }
 }
 
-kernel void kernel_mul_mat_q4_0_f32(
+kernel void kernel_mul_mv_q4_0_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -483,12 +496,12 @@ kernel void kernel_mul_mat_q4_0_f32(
         constant   int64_t & ne1[[buffer(16)]],
         constant   uint    & gqa[[buffer(17)]],
         uint3 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]],
-        uint sgitg[[simdgroup_index_in_threadgroup]]) {
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
     mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
 }
 
-kernel void kernel_mul_mat_q4_1_f32(
+kernel void kernel_mul_mv_q4_1_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -508,7 +521,7 @@ kernel void kernel_mul_mat_q4_1_f32(
 
 #define NB_Q8_0 8
 
-kernel void kernel_mul_mat_q8_0_f32(
+kernel void kernel_mul_mv_q8_0_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -572,7 +585,7 @@ kernel void kernel_mul_mat_q8_0_f32(
 
 #define N_F32_F32 4
 
-kernel void kernel_mul_mat_f32_f32(
+kernel void kernel_mul_mv_f32_f32(
         device const  char * src0,
         device const  char * src1,
         device       float * dst,
@@ -643,7 +656,7 @@ kernel void kernel_mul_mat_f32_f32(
     }
 }
 
-kernel void kernel_mul_mat_f16_f32_1row(
+kernel void kernel_mul_mv_f16_f32_1row(
         device const  char * src0,
         device const  char * src1,
         device       float * dst,
@@ -662,7 +675,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         uint3 tgpig[[threadgroup_position_in_grid]],
-        uint tiisg[[thread_index_in_simdgroup]]) {
+        uint  tiisg[[thread_index_in_simdgroup]]) {
 
     const int64_t r0 = tgpig.x;
     const int64_t r1 = tgpig.y;
@@ -697,7 +710,7 @@ kernel void kernel_mul_mat_f16_f32_1row(
 
 #define N_F16_F32 4
 
-kernel void kernel_mul_mat_f16_f32(
+kernel void kernel_mul_mv_f16_f32(
         device const  char * src0,
         device const  char * src1,
         device       float * dst,
@@ -769,7 +782,7 @@ kernel void kernel_mul_mat_f16_f32(
 }
 
 // Assumes row size (ne00) is a multiple of 4
-kernel void kernel_mul_mat_f16_f32_l4(
+kernel void kernel_mul_mv_f16_f32_l4(
         device const  char * src0,
         device const  char * src1,
         device       float * dst,
@@ -1098,6 +1111,62 @@ kernel void kernel_cpy_f32_f32(
     }
 }
 
+kernel void kernel_concat(
+    device const char * src0,
+    device const char * src1,
+    device       char * dst,
+    constant   int64_t & ne00,
+    constant   int64_t & ne01,
+    constant   int64_t & ne02,
+    constant   int64_t & ne03,
+    constant  uint64_t & nb00,
+    constant  uint64_t & nb01,
+    constant  uint64_t & nb02,
+    constant  uint64_t & nb03,
+    constant   int64_t & ne10,
+    constant   int64_t & ne11,
+    constant   int64_t & ne12,
+    constant   int64_t & ne13,
+    constant  uint64_t & nb10,
+    constant  uint64_t & nb11,
+    constant  uint64_t & nb12,
+    constant  uint64_t & nb13,
+    constant   int64_t & ne0,
+    constant   int64_t & ne1,
+    constant   int64_t & ne2,
+    constant   int64_t & ne3,
+    constant  uint64_t & nb0,
+    constant  uint64_t & nb1,
+    constant  uint64_t & nb2,
+    constant  uint64_t & nb3,
+    uint3 tgpig[[threadgroup_position_in_grid]],
+    uint3 tpitg[[thread_position_in_threadgroup]],
+    uint3   ntg[[threads_per_threadgroup]]) {
+
+    const int64_t i03 = tgpig.z;
+    const int64_t i02 = tgpig.y;
+    const int64_t i01 = tgpig.x;
+
+    const int64_t i13 = i03 % ne13;
+    const int64_t i12 = i02 % ne12;
+    const int64_t i11 = i01 % ne11;
+
+    device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00;
+    device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
+    device       char * dst_ptr  = dst  + i03*nb3  + i02*nb2  + i01*nb1  + tpitg.x*nb0;
+
+    for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+        if (i02 < ne02) {
+            ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0];
+            src0_ptr += ntg.x*nb00;
+        } else {
+            ((device float *)dst_ptr)[0] = ((device float *)src1_ptr)[0];
+            src1_ptr += ntg.x*nb10;
+        }
+        dst_ptr += ntg.x*nb0;
+    }
+}
+
 //============================================ k-quants ======================================================
 
 #ifndef QK_K
@@ -1190,7 +1259,7 @@ static inline uchar4 get_scale_min_k4(int j, device const uint8_t * q) {
 
 //====================================== dot products =========================
 
-kernel void kernel_mul_mat_q2_K_f32(
+kernel void kernel_mul_mv_q2_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -1334,7 +1403,7 @@ kernel void kernel_mul_mat_q2_K_f32(
 }
 
 #if QK_K == 256
-kernel void kernel_mul_mat_q3_K_f32(
+kernel void kernel_mul_mv_q3_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -1486,7 +1555,7 @@ kernel void kernel_mul_mat_q3_K_f32(
     }
 }
 #else
-kernel void kernel_mul_mat_q3_K_f32(
+kernel void kernel_mul_mv_q3_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -1557,7 +1626,7 @@ kernel void kernel_mul_mat_q3_K_f32(
 #endif
 
 #if QK_K == 256
-kernel void kernel_mul_mat_q4_K_f32(
+kernel void kernel_mul_mv_q4_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -1663,7 +1732,7 @@ kernel void kernel_mul_mat_q4_K_f32(
     }
 }
 #else
-kernel void kernel_mul_mat_q4_K_f32(
+kernel void kernel_mul_mv_q4_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -1752,7 +1821,7 @@ kernel void kernel_mul_mat_q4_K_f32(
 }
 #endif
 
-kernel void kernel_mul_mat_q5_K_f32(
+kernel void kernel_mul_mv_q5_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -1925,7 +1994,7 @@ kernel void kernel_mul_mat_q5_K_f32(
 
 }
 
-kernel void kernel_mul_mat_q6_K_f32(
+kernel void kernel_mul_mv_q6_K_f32(
         device const  void * src0,
         device const float * src1,
         device       float * dst,
@@ -2263,7 +2332,7 @@ kernel void kernel_get_rows(
 }
 
 #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
-#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix A
+#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
 #define BLOCK_SIZE_K 32
 #define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
 #define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
@@ -2300,9 +2369,11 @@ kernel void kernel_mul_mm(device const  uchar * src0,
     const uint r0 = tgpig.y;
     const uint r1 = tgpig.x;
     const uint im = tgpig.z;
+
     // if this block is of 64x32 shape or smaller
     short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
     short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
+
     // a thread shouldn't load data outside of the matrix
     short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
     short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
@@ -2326,26 +2397,30 @@ kernel void kernel_mul_mm(device const  uchar * src0,
         + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
 
     for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
-        //load data and store to threadgroup memory
+        // load data and store to threadgroup memory
         half4x4 temp_a;
         dequantize_func(x, il, temp_a);
         threadgroup_barrier(mem_flags::mem_threadgroup);
+
         #pragma unroll(16)
         for (int i = 0; i < 16; i++) {
             *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
-            + 16 * (tiitg % THREAD_PER_ROW) + 8 * (i / 8)) \
-            + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
+            +                     (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \
+            +                     (tiitg / THREAD_PER_ROW) % 8  + (i & 7) * 8) = temp_a[i/4][i%4];
         }
-        *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) \
-                = *((device float2x4 *)y);
+
+        *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y);
+
         il = (il + 2 < nl) ? il + 2 : il % 2;
         x  = (il < 2) ? x + (2+nl-1)/nl : x;
         y += BLOCK_SIZE_K;
 
         threadgroup_barrier(mem_flags::mem_threadgroup);
-        //load matrices from threadgroup memory and conduct outer products
+
+        // load matrices from threadgroup memory and conduct outer products
         threadgroup half  * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
         threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
+
         #pragma unroll(4)
         for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
             #pragma unroll(4)
@@ -2360,6 +2435,7 @@ kernel void kernel_mul_mm(device const  uchar * src0,
 
             lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
             lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
+
             #pragma unroll(8)
             for (int i = 0; i < 8; i++){
                 simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
@@ -2368,25 +2444,26 @@ kernel void kernel_mul_mm(device const  uchar * src0,
     }
 
     if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
-        device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \
-                          + (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne1*ne0;
+        device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg &  1)) \
+                               + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
         for (int i = 0; i < 8; i++) {
             simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
         }
     } else {
         // block is smaller than 64x32, we should avoid writing data outside of the matrix
         threadgroup_barrier(mem_flags::mem_threadgroup);
-        threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
+        threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
                                       + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
         for (int i = 0; i < 8; i++) {
             simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
         }
 
         threadgroup_barrier(mem_flags::mem_threadgroup);
-        device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
-        if (sgitg==0) {
+
+        device float * C = dst + (BLOCK_SIZE_M * r0) + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
+        if (sgitg == 0) {
             for (int i = 0; i < n_rows; i++) {
-                for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) {
+                for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
                     *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
                 }
             }
index 7e4069d76b2595737edfeb6c3415a4d8ef8949fe..4a331f24a92ae340f7c910f219344d8d0467d7af 100644 (file)
@@ -202,14 +202,14 @@ inline void get_scale_min_k4(int j, const __global uint8_t *q, uint8_t *d, uint8
 
 __kernel void dequantize_block_q2_K(__global const struct block_q2_K *x, __global float *yy)
 {
-    const int i = get_group_id(0);
+    const int i = get_group_id(0) + get_global_offset(0);
     const int tid = get_local_id(0);
     const int n = tid / 32;
     const int l = tid - 32 * n;
     const int is = 8 * n + l / 16;
 
     const uint8_t q = x[i].qs[32 * n + l];
-    __global float *y = yy + i * QK_K + 128 * n;
+    __global float *y = yy + get_group_id(0) * QK_K + 128 * n;
 
     const float dall = vload_half(0, &x[i].d);
     const float dmin = vload_half(0, &x[i].dmin);
@@ -223,7 +223,7 @@ __kernel void dequantize_block_q2_K(__global const struct block_q2_K *x, __globa
 __kernel void dequantize_block_q3_K(__global const struct block_q3_K *x, __global float *yy)
 {
     int r = get_local_id(0) / 4;
-    int i = get_group_id(0);
+    int i = get_group_id(0) + get_global_offset(0);
     int tid = r / 2;
     int is0 = r % 2;
     int l0 = 16 * is0 + 4 * (get_local_id(0) % 4);
@@ -241,7 +241,7 @@ __kernel void dequantize_block_q3_K(__global const struct block_q3_K *x, __globa
     float d_all = vload_half(0, &x[i].d);
     float dl = d_all * (us - 32);
 
-    __global float *y = yy + i * QK_K + 128 * n + 32 * j;
+    __global float *y = yy + get_group_id(0) * QK_K + 128 * n + 32 * j;
     const __global uint8_t *q = x[i].qs + 32 * n;
     const __global uint8_t *hm = x[i].hmask;
 
@@ -251,14 +251,14 @@ __kernel void dequantize_block_q3_K(__global const struct block_q3_K *x, __globa
 
 __kernel void dequantize_block_q4_K(__global const struct block_q4_K *x, __global float *yy)
 {
-    const int i = get_group_id(0);
+    const int i = get_group_id(0) + get_global_offset(0);
     const int tid = get_local_id(0);
     const int il = tid / 8;
     const int ir = tid % 8;
     const int is = 2 * il;
     const int n = 4;
 
-    __global float *y = yy + i * QK_K + 64 * il + n * ir;
+    __global float *y = yy + get_group_id(0) * QK_K + 64 * il + n * ir;
 
     const float dall = vload_half(0, &x[i].d);
     const float dmin = vload_half(0, &x[i].dmin);
@@ -281,13 +281,13 @@ __kernel void dequantize_block_q4_K(__global const struct block_q4_K *x, __globa
 
 __kernel void dequantize_block_q5_K(__global const struct block_q5_K *x, __global float *yy)
 {
-    const int i = get_group_id(0);
+    const int i = get_group_id(0) + get_global_offset(0);
     const int tid = get_local_id(0);
     const int il = tid / 16;
     const int ir = tid % 16;
     const int is = 2 * il;
 
-    __global float *y = yy + i * QK_K + 64 * il + 2 * ir;
+    __global float *y = yy + get_group_id(0) * QK_K + 64 * il + 2 * ir;
 
     const float dall = vload_half(0, &x[i].d);
     const float dmin = vload_half(0, &x[i].dmin);
@@ -313,13 +313,13 @@ __kernel void dequantize_block_q5_K(__global const struct block_q5_K *x, __globa
 
 __kernel void dequantize_block_q6_K(__global const struct block_q6_K *x, __global float *yy)
 {
-    const int i = get_group_id(0);
+    const int i = get_group_id(0) + get_global_offset(0);
     const int tid = get_local_id(0);
     const int ip = tid / 32;
     const int il = tid - 32 * ip;
     const int is = 8 * ip + il / 16;
 
-    __global float *y = yy + i * QK_K + 128 * ip + il;
+    __global float *y = yy + get_group_id(0) * QK_K + 128 * ip + il;
 
     const float d = vload_half(0, &x[i].d);
 
@@ -730,7 +730,7 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __global float* y) {
     const uint qk = QUANT_K;
     const uint qr = QUANT_R;
 
-    const int ib = i/qk; // block index
+    const int ib = i/qk + get_global_offset(0); // block index
     const int iqs = (i%qk)/qr; // quant index
     const int iybs = i - i%qk; // y block start index
     const int y_offset = qr == 1 ? 1 : qk/2;
@@ -1349,30 +1349,42 @@ static cl_int ggml_cl_h2d_tensor_2d(cl_command_queue queue, cl_mem dst, size_t o
     const enum ggml_type type = src->type;
     const size_t ts = ggml_type_size(type);
     const size_t bs = ggml_blck_size(type);
+    const uint64_t row_size = ts*ne0/bs;
 
-    const void * x = (const void *) ((const char *) src->data + i2*nb2 + i3*nb3);
-    if (nb0 == ts && nb1 == ts*ne0/bs) {
-        err = clEnqueueWriteBuffer(queue, dst, CL_FALSE, offset, ne1*nb1, x, 0, NULL, ev);
-        return err;
+    const char * x = (const char *) src->data + i2*nb2 + i3*nb3;
+    if (nb0 == ts && nb1 == row_size) {
+        return clEnqueueWriteBuffer(queue, dst, CL_FALSE, offset, ne1*row_size, x, 0, NULL, ev);
     }
     if (nb0 == ts) {
         const size_t buffer_origin[3] = { offset, 0, 0 };
         const size_t host_origin[3] = { 0, 0, 0 };
-        const size_t region[3] = { ts*ne0/bs, ne1, 1 };
-        err = clEnqueueWriteBufferRect(queue, dst, CL_FALSE, buffer_origin, host_origin, region, ts*ne0/bs, 0, nb1, 0, x, 0, NULL, ev);
-        return err;
+        const size_t region[3] = { row_size, ne1, 1 };
+        return clEnqueueWriteBufferRect(queue, dst, CL_FALSE, buffer_origin, host_origin, region, row_size, 0, nb1, 0, x, 0, NULL, ev);
     }
+    std::vector<cl_event> events;
+    if (ev && ne1>1) events.reserve(ne1-1);
     for (uint64_t i1 = 0; i1 < ne1; i1++) {
         // pretend the row is a matrix with cols=1
-        const size_t buffer_origin[3] = { offset, i1, 0 };
+        const size_t buffer_origin[3] = { offset + i1*row_size, 0, 0 };
         const size_t host_origin[3] = { 0, 0, 0 };
-        const size_t region[3] = { ts/bs, ne0, 1 };
-        err = clEnqueueWriteBufferRect(queue, dst, CL_FALSE, buffer_origin, host_origin, region, 0, 0, nb0, 0, ((const char *)x) + i1*nb0, 0, NULL, ev);
+        const size_t region[3] = { ts, ne0/bs, 1 };
+        // if an event is requested, make the last write wait for all previous writes to complete
+        if (ev && i1) {
+            events.push_back(*ev);
+        }
+        cl_uint nevents = i1 == ne1-1 ? events.size() : 0U;
+        err = clEnqueueWriteBufferRect(queue, dst, CL_FALSE, buffer_origin, host_origin, region, ts, 0, nb0, 0, x + i1*nb1, nevents, nevents ? events.data() : nullptr, ev);
         if (err != CL_SUCCESS) {
-            break;
+            for (auto event : events) {
+                clReleaseEvent(event);
+            }
+            return err;
         }
     }
-    return err;
+    for (auto event : events) {
+        CL_CHECK(clReleaseEvent(event));
+    }
+    return CL_SUCCESS;
 }
 
 static void ggml_cl_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
@@ -1503,6 +1515,7 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
     cl_mem d_Y = ggml_cl_pool_malloc(sizeof(float) * y_ne, &y_size);
     cl_mem d_D = ggml_cl_pool_malloc(sizeof(float) * d_ne, &d_size);
 
+    size_t x_offset = 0;
     int64_t pi02 = -1;
     int64_t pi03 = -1;
 
@@ -1513,7 +1526,9 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
             int64_t i02 = i12 / r2;
 
             // copy data to device
-            if (src0->backend != GGML_BACKEND_GPU && (i02 != pi02 || i03 != pi03)) {
+            if (src0->backend == GGML_BACKEND_GPU) {
+                x_offset = (i03 * ne02 + i02) * x_ne;
+            } else if (i02 != pi02 || i03 != pi03) {
                 CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, NULL));
                 pi02 = i02;
                 pi03 = i03;
@@ -1528,7 +1543,7 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
                                                        clblast::Transpose::kYes, clblast::Transpose::kNo,
                                                        ne01, ne11, ne10,
                                                        alpha,
-                                                       d_X, 0, ne00,
+                                                       d_X, x_offset, ne00,
                                                        d_Y, 0, ne10,
                                                        beta,
                                                        d_D, 0, ne01,
@@ -1596,6 +1611,7 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
     bool src1_cont_rows = nb10 == sizeof(float);
     bool src1_cont_cols = (size_t)nb11 == ne11*sizeof(float);
 
+    size_t x_offset = 0;
     int64_t pi02 = -1;
     int64_t pi03 = -1;
 
@@ -1606,7 +1622,9 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
             int64_t i02 = i12 / r2;
 
             // copy src0 to device
-            if (src0->backend != GGML_BACKEND_GPU && (i02 != pi02 || i03 != pi03)) {
+            if (src0->backend == GGML_BACKEND_GPU) {
+                x_offset = (i03 * ne02 + i02) * x_ne;
+            } else if (i02 != pi02 || i03 != pi03) {
                 CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, NULL));
                 pi02 = i02;
                 pi03 = i03;
@@ -1646,7 +1664,7 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
                                                        clblast::Transpose::kYes, clblast::Transpose::kNo,
                                                        ne01, ne11, ne10,
                                                        alpha,
-                                                       d_X, 0, ne00,
+                                                       d_X, x_offset, ne00,
                                                        d_Y, 0, ne10,
                                                        beta,
                                                        d_D, 0, ne01,
@@ -1696,7 +1714,8 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
     const int x_ne = ne01 * ne00;
     const int y_ne = ne11 * ne10;
     const int d_ne = ne11 * ne01;
-    const size_t q_sz = ggml_type_size(type) * x_ne / ggml_blck_size(type);
+    const int x_bps = x_ne / ggml_blck_size(type); // blocks per 2D slice
+    const size_t q_sz = ggml_type_size(type) * x_bps;
 
     size_t x_size;
     size_t y_size;
@@ -1764,9 +1783,10 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
             } else { // general dequantization kernel + CLBlast matrix matrix multiplication
                 // convert src0 to fp32 on device
                 const size_t global = x_ne / global_denom;
+                const size_t offset = src0->backend == GGML_BACKEND_GPU ? (i03 * ne02 + i02) * x_bps : 0;
                 CL_CHECK(clSetKernelArg(*to_fp32_cl, 0, sizeof(cl_mem), &d_Q));
                 CL_CHECK(clSetKernelArg(*to_fp32_cl, 1, sizeof(cl_mem), &d_X));
-                CL_CHECK(clEnqueueNDRangeKernel(queue, *to_fp32_cl, 1, NULL, &global, local > 0 ? &local : NULL, events.size(), !events.empty() ? events.data() : NULL, NULL));
+                CL_CHECK(clEnqueueNDRangeKernel(queue, *to_fp32_cl, 1, offset > 0 ? &offset : NULL, &global, local > 0 ? &local : NULL, events.size(), !events.empty() ? events.data() : NULL, NULL));
 
                 // copy src1 to device
                 CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i13, i12, NULL));
@@ -1888,17 +1908,19 @@ void ggml_cl_transform_tensor(void * data, ggml_tensor * tensor) {
     const int64_t ne3 = tensor->ne[3];
 
     const ggml_type type = tensor->type;
-    const size_t q_sz = ggml_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_blck_size(type);
+    const size_t s_sz = ggml_type_size(type) * (size_t) (ne0 * ne1 / ggml_blck_size(type));
+    const size_t q_sz = s_sz * (size_t) (ne2 * ne3);
 
     size_t q_size;
     cl_mem dst = ggml_cl_pool_malloc(q_sz, &q_size);
 
     tensor->data = data;
     // copy tensor to device
+    size_t offset = 0;
     for (int64_t i3 = 0; i3 < ne3; i3++) {
         for (int64_t i2 = 0; i2 < ne2; i2++) {
-            int i = i3*ne2 + i2;
-            CL_CHECK(ggml_cl_h2d_tensor_2d(queue, dst, i*ne0*ne1, tensor, i3, i2, NULL));
+            CL_CHECK(ggml_cl_h2d_tensor_2d(queue, dst, offset, tensor, i3, i2, NULL));
+            offset += s_sz;
         }
     }
 
index e2508d3c4d98d76e5a5bb61e932801f721d21579..6d1776ca46741f2d436a0c667e4609dd04c28a38 100644 (file)
@@ -13059,7 +13059,7 @@ static void ggml_compute_forward_alibi_f32(
         return;
     }
 
-    const int n_past = ((int32_t *) dst->op_params)[0];
+    const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
     const int n_head = ((int32_t *) dst->op_params)[1];
     float max_bias;
     memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
@@ -13080,7 +13080,6 @@ static void ggml_compute_forward_alibi_f32(
     //const int nb3 = src0->nb[3];
 
     GGML_ASSERT(nb0 == sizeof(float));
-    GGML_ASSERT(ne1 + n_past == ne0);
     GGML_ASSERT(n_head == ne2);
 
     // add alibi to src0 (KQ_scaled)