]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
feat: add `GGML_UNARY_OP_ARGMAX` Metal kernel (ggml/1019)
authorPAB <redacted>
Mon, 2 Dec 2024 18:27:24 +0000 (19:27 +0100)
committerGeorgi Gerganov <redacted>
Tue, 3 Dec 2024 18:04:49 +0000 (20:04 +0200)
* implemented argmax kernel

* tpig -> tgpig

* change to strides

* contiguous assertions

* kernel working and tested

* argmax simd parallel implementation

* added 2 new tests for argmax in test-backend-ops

* cosmit

* added 3 tests cases for perf eval

* add test_argmax in make_test_cases_perf

* Update test-backend-ops.cpp

Co-authored-by: Diego Devesa <redacted>
---------

Co-authored-by: Diego Devesa <redacted>
ggml/src/ggml-metal/ggml-metal.m
ggml/src/ggml-metal/ggml-metal.metal
tests/test-backend-ops.cpp

index d374b65a4f080c28f30b5b64185cdbbf6bfe18db..093ae9000ab37d2f69767b3a8beb083bb865d6a7 100644 (file)
@@ -392,6 +392,7 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_SUM_ROWS,
     GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,
     GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,
+    GGML_METAL_KERNEL_TYPE_ARGMAX,
 
     GGML_METAL_KERNEL_TYPE_COUNT
 };
@@ -956,6 +957,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SIN,                           sin,                            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_COS,                           cos,                            true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS,                      sum_rows,                       true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGMAX,                        argmax,                         true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_AVG_F32,               pool_2d_avg_f32,                true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_POOL_2D_MAX_F32,               pool_2d_max_f32,                true);
     }
@@ -1086,6 +1088,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
             return has_simdgroup_reduction;
         case GGML_OP_RMS_NORM:
             return has_simdgroup_reduction && (op->ne[0] % 4 == 0);
+        case GGML_OP_ARGMAX:
         case GGML_OP_NORM:
         case GGML_OP_ROPE:
             return true;
@@ -3845,6 +3848,31 @@ static void ggml_metal_encode_node(
 
                 [encoder dispatchThreadgroups:MTLSizeMake(n_tg, 1, 1) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
             } break;
+            case GGML_OP_ARGMAX:
+            {
+                GGML_ASSERT(src0->type == GGML_TYPE_F32);
+                GGML_ASSERT(ggml_is_contiguous_1(src0));
+                GGML_ASSERT(nb00 == ggml_type_size(src0->type));
+
+                const int64_t nrows = ggml_nrows(src0);
+
+                int nth = 32; // SIMD width
+                while (nth < ne00 && nth*ne01*ne02*ne03 < 256) {
+                    nth *= 2;
+                }
+
+                id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARGMAX].pipeline;
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
+                [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
+                [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
+                [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:3];
+                [encoder setThreadgroupMemoryLength:32*sizeof(float)   atIndex:0];
+                [encoder setThreadgroupMemoryLength:32*sizeof(int32_t) atIndex:1];
+
+                [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
        default:
             {
                 GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(dst->op));
index 8cb9a3414974cd793911424667c5d6da70a05c4f..5caa0846a8b537dfce529da4b1a9e87a8674e036 100644 (file)
@@ -1366,6 +1366,63 @@ kernel void kernel_ssm_scan_f32(
     }
 }
 
+kernel void kernel_argmax(
+        device   const void * x,
+        device      int32_t * dst,
+        constant    int64_t & ncols,
+        constant   uint64_t & nb01,
+        threadgroup   float * shared_maxval [[threadgroup(0)]],
+        threadgroup int32_t * shared_argmax [[threadgroup(1)]],
+        uint  tgpig[[threadgroup_position_in_grid]],
+        uint  tpitg[[thread_position_in_threadgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint    ntg[[threads_per_threadgroup]]) {
+    device const float * x_row = (device const float *) ((device const char *) x + tgpig * nb01);
+
+    float   lmax = -INFINITY;
+    int32_t larg = -1;
+
+    for (int i00 = tpitg; i00 < ncols; i00 += ntg) {
+        if (x_row[i00] > lmax) {
+            lmax = x_row[i00];
+            larg = i00;
+        }
+    }
+
+    // find the argmax value in the block
+    float max_val = simd_max(lmax);
+    int32_t arg_val = simd_max(select(-1, larg, lmax == max_val));
+
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            shared_maxval[tiisg] = -INFINITY;
+            shared_argmax[tiisg] = -1;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        if (tiisg == 0) {
+            shared_maxval[sgitg] = max_val;
+            shared_argmax[sgitg] = arg_val;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        max_val = shared_maxval[tiisg];
+        arg_val = shared_argmax[tiisg];
+
+        float max_val_reduced   = simd_max(max_val);
+        int32_t arg_val_reduced = simd_max(select(-1, arg_val, max_val == max_val_reduced));
+
+        dst[tgpig] = arg_val_reduced;
+
+        return;
+    }
+
+    dst[tgpig] = arg_val;
+}
+
 kernel void kernel_norm(
         constant ggml_metal_kargs_norm & args,
         device const char * src0,
index c786da4c35d7c920b21ac24830d1c4ff786b9af6..87c92dadd9bccfbb92bd95f800404c382185a4f3 100644 (file)
@@ -3460,13 +3460,14 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
     test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1));
     test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1));
 
-    test_cases.emplace_back(new test_argmax());
-    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32, 1, 1, 1}));
-    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {100, 10, 1, 1}));
+    test_cases.emplace_back(new test_count_equal());
+
+    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {32,    1, 1, 1}));
+    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {100,  10, 1, 1}));
     test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 10, 1, 1}));
+    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {1024, 12, 1, 1}));
     test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {2000, 10, 1, 1}));
-
-    test_cases.emplace_back(new test_count_equal());
+    test_cases.emplace_back(new test_argmax(GGML_TYPE_F32, {5438,  3, 1, 1}));
 
     for (int ne3 : {1, 3}) { // CUDA backward pass only supports ne3 == 1
         test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 1, 1}));