]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml: add `GGML_SET` Metal kernel + i32 CPU kernel (ggml/1037)
authorPAB <redacted>
Wed, 4 Dec 2024 08:19:30 +0000 (09:19 +0100)
committerGeorgi Gerganov <redacted>
Thu, 5 Dec 2024 11:27:33 +0000 (13:27 +0200)
* implemented cpu kernel

* add i32 test cases in test-backend-ops

* typedef `ggml_metal_kargs_set`

* implemented `kernel_set`

* memcpy

ggml/src/ggml-cpu/ggml-cpu.c
ggml/src/ggml-metal/ggml-metal-impl.h
ggml/src/ggml-metal/ggml-metal.m
ggml/src/ggml-metal/ggml-metal.metal
tests/test-backend-ops.cpp

index f82e877f055067ef956f0ef33a32094c4f562bea..f12b62e3bd8b988ad9b3624fb3d4cbfef3b29e69 100644 (file)
@@ -1374,7 +1374,10 @@ struct ggml_compute_state {
 
 inline static void ggml_vec_set_i8(const int n, int8_t * x, const int8_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
 inline static void ggml_vec_set_i16(const int n, int16_t * x, const int16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
-inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
+
+inline static void ggml_vec_set_i32(const int n, int32_t * x, const int32_t   v) { for (int i = 0; i < n; ++i) x[i] = v;    }
+inline static void ggml_vec_cpy_i32(const int n, int32_t * y, const int32_t * x) { for (int i = 0; i < n; ++i) y[i] = x[i]; }
+
 inline static void ggml_vec_set_f16(const int n, ggml_fp16_t * x, const int32_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
 inline static void ggml_vec_set_bf16(const int n, ggml_bf16_t * x, const ggml_bf16_t v) { for (int i = 0; i < n; ++i) x[i] = v; }
 inline static void ggml_vec_add_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i]  = x[i] + y[i]; }
@@ -8248,6 +8251,77 @@ static void ggml_compute_forward_set_f32(
     }
 }
 
+static void ggml_compute_forward_set_i32(
+        const struct ggml_compute_params * params,
+        struct ggml_tensor * dst) {
+
+    const struct ggml_tensor * src0 = dst->src[0];
+    const struct ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+    GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
+
+    // view src0 and dst with these strides and data offset inbytes during set
+    // nb0 is implicitly element_size because src0 and dst are contiguous
+    size_t nb1     = ((int32_t *) dst->op_params)[0];
+    size_t nb2     = ((int32_t *) dst->op_params)[1];
+    size_t nb3     = ((int32_t *) dst->op_params)[2];
+    size_t offset  = ((int32_t *) dst->op_params)[3];
+    bool   inplace = (bool) ((int32_t *) dst->op_params)[4];
+
+    if (!inplace) {
+        if (params->ith == 0) {
+            // memcpy needs to be synchronized across threads to avoid race conditions.
+            // => do it in INIT phase
+            memcpy(
+                ((char *)  dst->data),
+                ((char *) src0->data),
+                ggml_nbytes(dst));
+        }
+        ggml_barrier(params->threadpool);
+    }
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    const int nr = ggml_nrows(src1);
+    const int nc = src1->ne[0];
+
+    GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne)
+    GGML_TENSOR_LOCALS(size_t,  nb1, src1, nb)
+
+    // src0 and dst as viewed during set
+    const size_t nb0 = ggml_element_size(src0);
+
+    const int im0 = (ne10 == 0 ? 0 : ne10-1);
+    const int im1 = (ne11 == 0 ? 0 : ne11-1);
+    const int im2 = (ne12 == 0 ? 0 : ne12-1);
+    const int im3 = (ne13 == 0 ? 0 : ne13-1);
+
+    GGML_ASSERT(offset + im0*nb0  + im1*nb1  + im2*nb2  + im3*nb3  <= ggml_nbytes(dst));
+
+    GGML_ASSERT(nb10 == sizeof(int32_t));
+
+    // rows per thread
+    const int dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int ir0 = dr*ith;
+    const int ir1 = MIN(ir0 + dr, nr);
+
+    for (int ir = ir0; ir < ir1; ++ir) {
+        // src0 and dst are viewed with shape of src1 and offset
+        // => same indices
+        const int i3 = ir/(ne12*ne11);
+        const int i2 = (ir - i3*ne12*ne11)/ne11;
+        const int i1 = (ir - i3*ne12*ne11 - i2*ne11);
+
+        ggml_vec_cpy_i32(nc,
+                (int32_t *) ((char *)  dst->data + i3*nb3  + i2*nb2  + i1*nb1  + offset),
+                (int32_t *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
+    }
+}
+
 static void ggml_compute_forward_set(
         const struct ggml_compute_params * params,
         struct ggml_tensor * dst) {
@@ -8259,6 +8333,10 @@ static void ggml_compute_forward_set(
             {
                 ggml_compute_forward_set_f32(params, dst);
             } break;
+        case GGML_TYPE_I32:
+            {
+                ggml_compute_forward_set_i32(params, dst);
+            } break;
         case GGML_TYPE_F16:
         case GGML_TYPE_BF16:
         case GGML_TYPE_Q4_0:
index e195867192712b36b2c20b16a97453a0ad542339..e3dc25f1686fbf352c2b72a4e84aadaae1a2d987 100644 (file)
@@ -102,6 +102,21 @@ typedef struct {
     uint64_t nb3;
 } ggml_metal_kargs_cpy;
 
+typedef struct {
+    int64_t  ne10;
+    int64_t  ne11;
+    int64_t  ne12;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb13;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+    uint64_t offs;
+    bool     inplace;
+} ggml_metal_kargs_set;
+
 typedef struct {
     int32_t  ne00;
     int32_t  ne01;
index 71f5525f38c36e6b44403d91499db1ce74551775..f80fda7a4ca3880741a561fcb86fb7fa703a3301 100644 (file)
@@ -372,6 +372,8 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,
     GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,
+    GGML_METAL_KERNEL_TYPE_SET_I32,
+    GGML_METAL_KERNEL_TYPE_SET_F32,
     GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
     GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
     GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,
@@ -940,6 +942,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H256,  flash_attn_ext_vec_q5_0_h256,   has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H256,  flash_attn_ext_vec_q5_1_h256,   has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H256,  flash_attn_ext_vec_q8_0_h256,   has_simdgroup_reduction);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_F32,                       set_f32,                        true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_I32,                       set_i32,                        true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32,                   cpy_f32_f32,                    true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16,                   cpy_f32_f16,                    true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_BF16,                  cpy_f32_bf16,                   use_bfloat);
@@ -1159,6 +1163,16 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
                         return false;
                 };
             }
+        case GGML_OP_SET:
+            {
+                switch (op->src[0]->type) {
+                    case GGML_TYPE_F32:
+                    case GGML_TYPE_I32:
+                        return true;
+                    default:
+                        return false;
+                };
+            }
         case GGML_OP_DIAG_MASK_INF:
         case GGML_OP_GET_ROWS:
             {
@@ -3824,6 +3838,68 @@ static void ggml_metal_encode_node(
 
                 [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
             } break;
+        case GGML_OP_SET:
+            {
+                GGML_ASSERT(ggml_are_same_shape(src0, dst));
+                GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
+
+                // src0 and dst as viewed during set
+                const size_t dst_nb0 = ggml_element_size(src0);
+
+                const size_t dst_nb1 = ((int32_t *) dst->op_params)[0];
+                const size_t dst_nb2 = ((int32_t *) dst->op_params)[1];
+                const size_t dst_nb3 = ((int32_t *) dst->op_params)[2];
+                const size_t offset  = ((int32_t *) dst->op_params)[3];
+                const bool   inplace = (bool) ((int32_t *) dst->op_params)[4];
+
+                if (!inplace) {
+                    memcpy(((char *)  dst->data), ((char *) src0->data), ggml_nbytes(dst));
+                }
+
+                const int im0 = (ne10 == 0 ? 0 : ne10-1);
+                const int im1 = (ne11 == 0 ? 0 : ne11-1);
+                const int im2 = (ne12 == 0 ? 0 : ne12-1);
+                const int im3 = (ne13 == 0 ? 0 : ne13-1);
+
+                GGML_ASSERT(offset + im0*dst_nb0  + im1*dst_nb1  + im2*dst_nb2  + im3*dst_nb3  <= ggml_nbytes(dst));
+
+                id<MTLComputePipelineState> pipeline = nil;
+
+                switch (src0t) {
+                    case GGML_TYPE_F32:
+                        GGML_ASSERT(nb10 == sizeof(float));
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_F32].pipeline; break;
+                    case GGML_TYPE_I32:
+                        GGML_ASSERT(nb10 == sizeof(int32_t));
+                        pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_I32].pipeline; break;
+                    default: GGML_ABORT("fatal error");
+                }
+
+                ggml_metal_kargs_set args = {
+                    /*.ne10    =*/ ne10,
+                    /*.ne11    =*/ ne11,
+                    /*.ne12    =*/ ne12,
+                    /*.nb10    =*/ nb10,
+                    /*.nb11    =*/ nb11,
+                    /*.nb12    =*/ nb12,
+                    /*.nb13    =*/ nb13,
+                    /*.nb1     =*/ dst_nb1,
+                    /*.nb2     =*/ dst_nb2,
+                    /*.nb3     =*/ dst_nb3,
+                    /*.offs    =*/ offset,
+                    /*.inplace =*/ inplace,
+                };
+
+                const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne10);
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBytes:&args    length:sizeof(args) atIndex:0];
+                [encoder setBuffer:id_src0 offset:offs_src0    atIndex:1];
+                [encoder setBuffer:id_src1 offset:offs_src1    atIndex:2];
+                [encoder setBuffer:id_dst  offset:offs_dst     atIndex:3];
+
+                [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+            } break;
         case GGML_OP_POOL_2D:
             {
                 GGML_ASSERT(ggml_is_contiguous(src0));
index ee3e353ab0d798d00e59f7941fc15859f690d6a5..8ba43904d0c1b2760f962110125b39457318c6cb 100644 (file)
@@ -3927,6 +3927,38 @@ template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_
 
 #undef FA_TYPES
 
+template<typename T>
+kernel void kernel_set(
+    constant ggml_metal_kargs_set & args,
+    device  const char * src0,
+    device  const char * src1,
+    device        char * dst,
+    uint3   tgpig[[threadgroup_position_in_grid]],
+    ushort3 tpitg[[thread_position_in_threadgroup]],
+    ushort3   ntg[[threads_per_threadgroup]]) {
+    const int i13 = tgpig[2];
+    const int i12 = tgpig[1];
+    const int i11 = tgpig[0];
+
+    const int64_t n = i13*args.ne12*args.ne11*args.ne10 + i12*args.ne11*args.ne10 + i11*args.ne10;
+
+    const int64_t i3 = n / (args.ne12*args.ne11*args.ne10);
+    const int64_t i2 = (n - i3*args.ne12*args.ne11*args.ne10) / (args.ne11*args.ne10);
+    const int64_t i1 = (n - i3*args.ne12*args.ne11*args.ne10 - i2*args.ne11*args.ne10) / args.ne10;
+
+    device T * dst_data = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + args.offs);
+
+    for (int64_t i10 = tpitg.x; i10 < args.ne10; i10 += ntg.x) {
+        device const T * src = (device T *) (src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10);
+        dst_data[i10] = (T) src[0];
+    }
+}
+
+typedef decltype(kernel_set<float>) kernel_set_t;
+
+template [[host_name("kernel_set_f32")]] kernel kernel_set_t kernel_set<float>;
+template [[host_name("kernel_set_i32")]] kernel kernel_set_t kernel_set<int32_t>;
+
 template<typename T0, typename T1>
 kernel void kernel_cpy(
         constant ggml_metal_kargs_cpy & args,
index 042eeaaaba5c77834d31845e6b8b6d5abb0b4b36..9dd41260a1f19f508c177518c5fda3eaefaac32a 100644 (file)
@@ -3521,6 +3521,10 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim));
     }
 
+    for (int dim = 1; dim < GGML_MAX_DIMS; ++dim) {
+        test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim));
+    }
+
     for (ggml_type type_src : {GGML_TYPE_F16, GGML_TYPE_F32}) {
         for (ggml_type type_dst : all_types) {
            test_cases.emplace_back(new test_cpy(type_src, type_dst, {256, 4, 4, 4}));