]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
ggml : add ggml_set_rows (llama/14274)
authorRadoslav Gerganov <redacted>
Fri, 27 Jun 2025 13:41:40 +0000 (16:41 +0300)
committerGeorgi Gerganov <redacted>
Tue, 1 Jul 2025 08:52:14 +0000 (11:52 +0300)
* ggml : add ggml_set_rows

Add ggml_set_rows(a, b, c) which copies rows from 'b' into 'a' using
indices from 'c'.

ref: #8366

* use I64 for indices

* ggml : add repeat impl for i64

* ggml : add ggml_is_contiguous_rows

* ggml : ggml_set_rows support broadcast

* ggml : ggml_set_rows support quantized dst

ggml-ci

* ggml : support GGML_TYPE_F32 ".from_float" trait

* ggml : ggml_set_rows update comment + better index name

* tests : add ggml_set_rows

* metal : add ggml_set_rows implementation

ggml-ci

* ggml : simplify forward_dup_f32

* ggml : fix supports_op

* tests : add comment to set_rows

* ggml : leave the repeat_i64 for a separate PR

ggml-ci

* ggml : set_rows use std::min instead of MIN

* ggml : better error message for set_rows unsupported type

* metal : perform op->type check only once

* tests : more consistent implementation + more tests

ggml-ci

---------

Co-authored-by: Georgi Gerganov <redacted>
include/ggml-cpu.h
include/ggml.h
src/ggml-cpu/ggml-cpu.c
src/ggml-cpu/ggml-cpu.cpp
src/ggml-cpu/ops.cpp
src/ggml-cpu/ops.h
src/ggml-metal/ggml-metal-impl.h
src/ggml-metal/ggml-metal.m
src/ggml-metal/ggml-metal.metal
src/ggml.c
tests/test-backend-ops.cpp

index e3b79d09bb66f12a8202683bc97e4f9f558c0e9f..be40b100979dee55265aa6e2a977f2ac354c59a6 100644 (file)
@@ -134,6 +134,7 @@ extern "C" {
 
     GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);
 
+    GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *,       float *, int64_t);
     GGML_BACKEND_API void ggml_cpu_fp32_to_fp16(const float *, ggml_fp16_t *, int64_t);
     GGML_BACKEND_API void ggml_cpu_fp16_to_fp32(const ggml_fp16_t *, float *, int64_t);
     GGML_BACKEND_API void ggml_cpu_fp32_to_bf16(const float *, ggml_bf16_t *, int64_t);
index 4231a668eb0419c1723c2c200198ba4eabe762ea..042a13e055bd3e80c86d764d97f96255a86fc589 100644 (file)
@@ -470,6 +470,7 @@ extern "C" {
         GGML_OP_TRANSPOSE,
         GGML_OP_GET_ROWS,
         GGML_OP_GET_ROWS_BACK,
+        GGML_OP_SET_ROWS,
         GGML_OP_DIAG,
         GGML_OP_DIAG_MASK_INF,
         GGML_OP_DIAG_MASK_ZERO,
@@ -687,6 +688,9 @@ extern "C" {
     // true for tensor that is stored in memory as CxWxHxN and has been permuted to WxHxCxN
     GGML_API bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor);
 
+    // true if the elements in dimension 0 are contiguous, or there is just 1 block of elements
+    GGML_API bool ggml_is_contiguous_rows(const struct ggml_tensor * tensor);
+
     GGML_API bool ggml_are_same_shape (const struct ggml_tensor * t0, const struct ggml_tensor * t1);
     GGML_API bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tensor * t1);
 
@@ -1375,6 +1379,23 @@ extern "C" {
             struct ggml_tensor  * b,  // row indices
             struct ggml_tensor  * c); // data for ggml_get_rows, only used for its shape
 
+    // a TD  [n_embd, ne1,    ne2,    ne3]
+    // b TS  [n_embd, n_rows, ne02,   ne03] | ne02 == ne2, ne03 == ne3
+    // c I64 [n_rows, ne11,   ne12,   1]    | c[i] in [0, ne1)
+    //
+    // undefined behavior if destination rows overlap
+    //
+    // broadcast:
+    //   ne2 % ne11 == 0
+    //   ne3 % ne12 == 0
+    //
+    // return view(a)
+    GGML_API struct ggml_tensor * ggml_set_rows(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,  // destination
+            struct ggml_tensor  * b,  // source
+            struct ggml_tensor  * c); // row indices
+
     GGML_API struct ggml_tensor * ggml_diag(
         struct ggml_context     * ctx,
         struct ggml_tensor      * a);
index 7cae96f4b4885ebbcd4d394d914083c677355418..2042ee71f1f804c647e715f2ba770721f25e25d2 100644 (file)
@@ -195,6 +195,7 @@ typedef pthread_t ggml_thread_t;
 
 static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
     [GGML_TYPE_F32] = {
+        .from_float               = (ggml_from_float_t) ggml_cpu_fp32_to_fp32,
         .vec_dot                  = (ggml_vec_dot_t) ggml_vec_dot_f32,
         .vec_dot_type             = GGML_TYPE_F32,
         .nrows                    = 1,
@@ -1817,6 +1818,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_get_rows_back(params, tensor);
             } break;
+        case GGML_OP_SET_ROWS:
+            {
+                ggml_compute_forward_set_rows(params, tensor);
+            } break;
         case GGML_OP_DIAG:
             {
                 ggml_compute_forward_diag(params, tensor);
@@ -2170,6 +2175,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
                 n_tasks = n_threads;
             } break;
         case GGML_OP_GET_ROWS:
+        case GGML_OP_SET_ROWS:
             {
                 // FIXME: get_rows can use additional threads, but the cost of launching additional threads
                 // decreases performance with GPU offloading
@@ -3124,6 +3130,10 @@ enum ggml_status ggml_graph_compute_with_ctx(struct ggml_context * ctx, struct g
     return ggml_graph_compute(cgraph, &cplan);
 }
 
+void ggml_cpu_fp32_to_fp32(const float * x, float * y, int64_t n) {
+    memcpy(y, x, n * sizeof(float));
+}
+
 void ggml_cpu_fp32_to_fp16(const float * x, ggml_fp16_t * y, int64_t n) {
     int64_t i = 0;
 #if defined(__F16C__)
index a98866a2d80523a03c8e7d5c7bcd5d00438c8e20..c9daa4c39e83efcef2d2c33d855860c1d362c552 100644 (file)
@@ -416,6 +416,7 @@ static bool ggml_backend_cpu_device_supports_op(ggml_backend_dev_t dev, const st
 
     switch (op->op) {
         case GGML_OP_CPY:
+        case GGML_OP_SET_ROWS:
             return
                 op->type != GGML_TYPE_IQ3_XXS &&
                 op->type != GGML_TYPE_IQ3_S   &&
index 525fd178760f1144b969f481718f809b3a6bd5e5..7ead13f42a3a526f653d1fb4f5d22cd601d82985 100644 (file)
@@ -696,24 +696,8 @@ static void ggml_compute_forward_dup_f32(
     if (ggml_is_contiguous(dst)) {
         // TODO: simplify
         if (nb00 == sizeof(float)) {
-            if (dst->type == GGML_TYPE_F32) {
-                size_t id = 0;
-                const size_t rs = ne00 * nb00;
-                char * dst_ptr = (char *) dst->data;
-
-                for (int i03 = 0; i03 < ne03; i03++) {
-                    for (int i02 = 0; i02 < ne02; i02++) {
-                        id += rs * ir0;
-                        for (int i01 = ir0; i01 < ir1; i01++) {
-                            const char * src0_ptr = (char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03;
-                            memcpy(dst_ptr + id, src0_ptr, rs);
-                            id += rs;
-                        }
-                        id += rs * (ne01 - ir1);
-                    }
-                }
-            } else if (ggml_get_type_traits_cpu(dst->type)->from_float) {
-                ggml_from_float_t const quantize_row_q = ggml_get_type_traits_cpu(dst->type)->from_float;
+            if (ggml_get_type_traits_cpu(dst->type)->from_float) {
+                ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
 
                 size_t id = 0;
                 size_t rs = nb0 * (ne00 / ggml_blck_size(dst->type));
@@ -724,7 +708,7 @@ static void ggml_compute_forward_dup_f32(
                         id += rs * ir0;
                         for (int i01 = ir0; i01 < ir1; i01++) {
                             const float * src0_ptr = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
-                            quantize_row_q(src0_ptr, dst_ptr + id, ne00);
+                            from_float(src0_ptr, dst_ptr + id, ne00);
                             id += rs;
                         }
                         id += rs * (ne01 - ir1);
@@ -2300,6 +2284,12 @@ void ggml_compute_forward_repeat(
             {
                 ggml_compute_forward_repeat_f32(params, dst);
             } break;
+        // TODO: templateify the implemenation and support for I64
+        //       ref https://github.com/ggml-org/llama.cpp/pull/14274#discussion_r2169492225
+        //case GGML_TYPE_I64:
+        //    {
+        //        ggml_compute_forward_repeat_i64(params, dst);
+        //    } break;
         default:
             {
                 GGML_ABORT("fatal error");
@@ -4470,6 +4460,74 @@ void ggml_compute_forward_get_rows(
     //}
 }
 
+static void ggml_compute_forward_set_rows_f32(
+        const ggml_compute_params * params,
+              ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_TENSOR_BINARY_OP_LOCALS
+
+    const int64_t nc = ne00;
+    const int64_t nr = ne01;
+
+    assert(ne0  == nc);
+    assert(ne2  == ne02);
+    assert(ne3  == ne03);
+    assert(src0->type == GGML_TYPE_F32);
+    assert(ne02 % ne11 == 0);
+    assert(ne03 % ne12 == 0);
+
+    const int ith = params->ith;
+    const int nth = params->nth;
+
+    // rows per thread
+    const int64_t dr = (nr + nth - 1)/nth;
+
+    // row range for this thread
+    const int64_t ir0 = dr*ith;
+    const int64_t ir1 = std::min(ir0 + dr, nr);
+
+    ggml_from_float_t const from_float = ggml_get_type_traits_cpu(dst->type)->from_float;
+
+    for (int64_t i03 = 0; i03 < ne03; ++i03) {
+        for (int64_t i02 = 0; i02 < ne02; ++i02) {
+            for (int64_t i = ir0; i < ir1; ++i) {
+                const int64_t i12 = i03%ne12;
+                const int64_t i11 = i02%ne11;
+                const int64_t i10 = i;
+
+                const int64_t i1 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+
+                GGML_ASSERT(i1 >= 0 && i1 < ne1);
+
+                from_float(
+                        (const float *) ((char *) src0->data +  i*nb01 + i02*nb02 + i03*nb03),
+                                        ((char *)  dst->data + i1*nb1  + i02*nb2  + i03*nb3), nc);
+            }
+        }
+    }
+}
+
+void ggml_compute_forward_set_rows(
+        const ggml_compute_params * params,
+        ggml_tensor * dst) {
+
+    const ggml_tensor * src0 = dst->src[0];
+
+    switch (src0->type) {
+        case GGML_TYPE_F32:
+            {
+                ggml_compute_forward_set_rows_f32(params, dst);
+            } break;
+        default:
+            {
+                GGML_ABORT("src0->type = %d (%s) not supported", src0->type, ggml_type_name(src0->type));
+            }
+    }
+}
+
 // ggml_compute_forward_get_rows_back
 
 static void ggml_compute_forward_get_rows_back_f32_f16(
index 2d8544d7d3d436afd00eae1a149ad308c1d58690..3a395fdcd9f048d76684b14da9de6ac5d6d16ee1 100644 (file)
@@ -53,6 +53,7 @@ void ggml_compute_forward_permute(const struct ggml_compute_params * params, str
 void ggml_compute_forward_transpose(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_get_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_get_rows_back(const struct ggml_compute_params * params, struct ggml_tensor * dst);
+void ggml_compute_forward_set_rows(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_diag(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_diag_mask_inf(const struct ggml_compute_params * params, struct ggml_tensor * dst);
 void ggml_compute_forward_diag_mask_zero(const struct ggml_compute_params * params, struct ggml_tensor * dst);
index 17eab976f3ad114eaeb9ce108ecd635ce12fddc9..260440aedde00c05dce10b92eb158d1a600c4abc 100644 (file)
@@ -521,6 +521,22 @@ typedef struct {
     uint64_t nb2;
 } ggml_metal_kargs_get_rows;
 
+typedef struct {
+    int32_t  nk0;
+    int32_t  ne01;
+    uint64_t nb01;
+    uint64_t nb02;
+    uint64_t nb03;
+    int32_t  ne11;
+    int32_t  ne12;
+    uint64_t nb10;
+    uint64_t nb11;
+    uint64_t nb12;
+    uint64_t nb1;
+    uint64_t nb2;
+    uint64_t nb3;
+} ggml_metal_kargs_set_rows;
+
 typedef struct {
     int64_t  ne00;
     int64_t  ne01;
index d8d30cc0b41ca5b1b6fef3a623b0ed4cef8a6c73..349f0ff998ed13d7e2262e42cea2efdea19f5659 100644 (file)
@@ -202,6 +202,15 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,
     GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,
+    GGML_METAL_KERNEL_TYPE_SET_ROWS_F32,
+    GGML_METAL_KERNEL_TYPE_SET_ROWS_F16,
+    GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16,
+    GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0,
+    GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0,
+    GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1,
+    GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0,
+    GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,
+    GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,
     GGML_METAL_KERNEL_TYPE_RMS_NORM,
     GGML_METAL_KERNEL_TYPE_L2_NORM,
     GGML_METAL_KERNEL_TYPE_GROUP_NORM,
@@ -1169,6 +1178,15 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_NL,                 get_rows_iq4_nl,                 true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ4_XS,                 get_rows_iq4_xs,                 true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32,                    get_rows_i32,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F32,                    set_rows_f32,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_F16,                    set_rows_f16,                    true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16,                   set_rows_bf16,                   use_bfloat);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0,                   set_rows_q8_0,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0,                   set_rows_q4_0,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1,                   set_rows_q4_1,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0,                   set_rows_q5_0,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1,                   set_rows_q5_1,                   true);
+        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL,                 set_rows_iq4_nl,                 true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM,                        rms_norm,                        has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_L2_NORM,                         l2_norm,                         has_simdgroup_reduction);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM,                      group_norm,                      has_simdgroup_reduction);
@@ -1635,6 +1653,10 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
     const bool use_bfloat              = ctx_dev->use_bfloat;
 
     if (!use_bfloat) {
+        if (op->type == GGML_TYPE_BF16) {
+            return false;
+        }
+
         for (size_t i = 0, n = 3; i < n; ++i) {
             if (op->src[i] != NULL && op->src[i]->type == GGML_TYPE_BF16) {
                 return false;
@@ -1804,6 +1826,27 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
             {
                 return op->ne[3] == 1;
             }
+        case GGML_OP_SET_ROWS:
+            {
+                if (op->src[0]->type != GGML_TYPE_F32) {
+                    return false;
+                }
+
+                switch (op->type) {
+                    case GGML_TYPE_F32:
+                    case GGML_TYPE_F16:
+                    case GGML_TYPE_BF16:
+                    case GGML_TYPE_Q8_0:
+                    case GGML_TYPE_Q4_0:
+                    case GGML_TYPE_Q4_1:
+                    case GGML_TYPE_Q5_0:
+                    case GGML_TYPE_Q5_1:
+                    case GGML_TYPE_IQ4_NL:
+                        return true;
+                    default:
+                        return false;
+                };
+            }
         default:
             return false;
     }
@@ -3777,13 +3820,74 @@ static bool ggml_metal_encode_node(
                 };
 
                 [encoder setComputePipelineState:pipeline];
-                [encoder setBuffer:id_src0     offset:offs_src0 atIndex:0];
-                [encoder setBuffer:id_src1     offset:offs_src1 atIndex:1];
-                [encoder setBuffer:id_dst      offset:offs_dst  atIndex:2];
-                [encoder setBytes:&args length:sizeof(args) atIndex:3];
+                [encoder setBytes:&args    length:sizeof(args) atIndex:0];
+                [encoder setBuffer:id_src0 offset:offs_src0    atIndex:1];
+                [encoder setBuffer:id_src1 offset:offs_src1    atIndex:2];
+                [encoder setBuffer:id_dst  offset:offs_dst     atIndex:3];
 
                 [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
             } break;
+        case GGML_OP_SET_ROWS:
+            {
+                id<MTLComputePipelineState> pipeline = nil;
+
+                switch (dst->type) {
+                    case GGML_TYPE_F32:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F32   ].pipeline; break;
+                    case GGML_TYPE_F16:    pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_F16   ].pipeline; break;
+                    case GGML_TYPE_BF16:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_BF16  ].pipeline; break;
+                    case GGML_TYPE_Q8_0:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q8_0  ].pipeline; break;
+                    case GGML_TYPE_Q4_0:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_0  ].pipeline; break;
+                    case GGML_TYPE_Q4_1:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q4_1  ].pipeline; break;
+                    case GGML_TYPE_Q5_0:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_0  ].pipeline; break;
+                    case GGML_TYPE_Q5_1:   pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_Q5_1  ].pipeline; break;
+                    case GGML_TYPE_IQ4_NL: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SET_ROWS_IQ4_NL].pipeline; break;
+                    default: GGML_ABORT("not implemented");
+                }
+
+                const int32_t nk0 = ne0/ggml_blck_size(dst->type);
+
+                int nth = 32; // SIMD width
+
+                while (nth < nk0 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
+                    nth *= 2;
+                }
+
+                int nrptg = 1;
+                if (nth > nk0) {
+                    nrptg = (nth + nk0 - 1)/nk0;
+                    nth   = nk0;
+
+                    if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
+                        nrptg--;
+                    }
+                }
+
+                nth = MIN(nth, nk0);
+
+                ggml_metal_kargs_set_rows args = {
+                    /*.nk0  =*/ nk0,
+                    /*.ne01 =*/ ne01,
+                    /*.nb01 =*/ nb01,
+                    /*.nb02 =*/ nb02,
+                    /*.nb03 =*/ nb03,
+                    /*.ne11 =*/ ne11,
+                    /*.ne12 =*/ ne12,
+                    /*.nb10 =*/ nb10,
+                    /*.nb11 =*/ nb11,
+                    /*.nb12 =*/ nb12,
+                    /*.nb1  =*/ nb1,
+                    /*.nb2  =*/ nb2,
+                    /*.nb3  =*/ nb3,
+                };
+
+                [encoder setComputePipelineState:pipeline];
+                [encoder setBytes:&args    length:sizeof(args) atIndex:0];
+                [encoder setBuffer:id_src0 offset:offs_src0    atIndex:1];
+                [encoder setBuffer:id_src1 offset:offs_src1    atIndex:2];
+                [encoder setBuffer:id_dst  offset:offs_dst     atIndex:3];
+
+                [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
+            } break;
         case GGML_OP_RMS_NORM:
             {
                 GGML_ASSERT(ne00 % 4 == 0);
index 5f004a856bde63897ce34552d2eb75a9ad7ac698..984a0ab503e7d9a0aecab223ac5e5f193c8cdc84 100644 (file)
@@ -35,6 +35,17 @@ constexpr constant static float kvalues_iq4nl_f[16] = {
     -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
 };
 
+static inline int best_index_int8(int n, constant float * val, float x) {
+    if (x <= val[0]) return 0;
+    if (x >= val[n-1]) return n-1;
+    int ml = 0, mu = n-1;
+    while (mu-ml > 1) {
+        int mav = (ml+mu)/2;
+        if (x < val[mav]) mu = mav; else ml = mav;
+    }
+    return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
+}
+
 // NOTE: this is not dequantizing - we are simply fitting the template
 template <typename type4x4>
 void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
@@ -97,6 +108,173 @@ void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & r
     }
 }
 
+void quantize_q4_0(device const float * src, device block_q4_0 & dst) {
+    float amax = 0.0f; // absolute max
+    float max  = 0.0f;
+
+    for (int j = 0; j < QK4_0; j++) {
+        const float v = src[j];
+        if (amax < fabs(v)) {
+            amax = fabs(v);
+            max  = v;
+        }
+    }
+
+    const float d = max / -8;
+    const float id = d ? 1.0f/d : 0.0f;
+
+    dst.d = d;
+
+    for (int j = 0; j < QK4_0/2; ++j) {
+        const float x0 = src[0       + j]*id;
+        const float x1 = src[QK4_0/2 + j]*id;
+
+        const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
+        const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
+
+        dst.qs[j]  = xi0;
+        dst.qs[j] |= xi1 << 4;
+    }
+}
+
+void quantize_q4_1(device const float * src, device block_q4_1 & dst) {
+    float min = FLT_MAX;
+    float max = -FLT_MAX;
+
+    for (int j = 0; j < QK4_1; j++) {
+        const float v = src[j];
+        if (min > v) min = v;
+        if (max < v) max = v;
+    }
+
+    const float d = (max - min) / ((1 << 4) - 1);
+    const float id = d ? 1.0f/d : 0.0f;
+
+    dst.d = d;
+    dst.m = min;
+
+    for (int j = 0; j < QK4_1/2; ++j) {
+        const float x0 = (src[0       + j] - min)*id;
+        const float x1 = (src[QK4_1/2 + j] - min)*id;
+
+        const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
+        const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
+
+        dst.qs[j]  = xi0;
+        dst.qs[j] |= xi1 << 4;
+    }
+}
+
+void quantize_q5_0(device const float * src, device block_q5_0 & dst) {
+    float amax = 0.0f; // absolute max
+    float max  = 0.0f;
+
+    for (int j = 0; j < QK5_0; j++) {
+        const float v = src[j];
+        if (amax < fabs(v)) {
+            amax = fabs(v);
+            max  = v;
+        }
+    }
+
+    const float d = max / -16;
+    const float id = d ? 1.0f/d : 0.0f;
+
+    dst.d = d;
+
+    uint32_t qh = 0;
+    for (int j = 0; j < QK5_0/2; ++j) {
+        const float x0 = src[0       + j]*id;
+        const float x1 = src[QK5_0/2 + j]*id;
+
+        const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
+        const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
+
+        dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+        qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+        qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
+    }
+
+    thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
+
+    for (int j = 0; j < 4; ++j) {
+        dst.qh[j] = qh8[j];
+    }
+}
+
+void quantize_q5_1(device const float * src, device block_q5_1 & dst) {
+    float max = src[0];
+    float min = src[0];
+
+    for (int j = 1; j < QK5_1; j++) {
+        const float v = src[j];
+        min = v < min ? v : min;
+        max = v > max ? v : max;
+    }
+
+    const float d = (max - min) / 31;
+    const float id = d ? 1.0f/d : 0.0f;
+
+    dst.d = d;
+    dst.m = min;
+
+    uint32_t qh = 0;
+    for (int j = 0; j < QK5_1/2; ++j) {
+        const float x0 = (src[0       + j] - min)*id;
+        const float x1 = (src[QK5_1/2 + j] - min)*id;
+
+        const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
+        const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
+
+        dst.qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
+        qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
+        qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
+    }
+
+    thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
+
+    for (int j = 0; j < 4; ++j) {
+        dst.qh[j] = qh8[j];
+    }
+}
+
+void quantize_iq4_nl(device const float * src, device block_iq4_nl & dst) {
+    float amax = 0.0f; // absolute max
+    float max  = 0.0f;
+
+    for (int j = 0; j < QK4_NL; j++) {
+        const float v = src[j];
+        if (amax < fabs(v)) {
+            amax = fabs(v);
+            max  = v;
+        }
+    }
+
+    const float d = max / kvalues_iq4nl_f[0];
+    const float id = d ? 1.0f/d : 0.0f;
+
+    float sumqx = 0, sumq2 = 0;
+    for (int j = 0; j < QK4_NL/2; ++j) {
+        const float x0 = src[0        + j]*id;
+        const float x1 = src[QK4_NL/2 + j]*id;
+
+        const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
+        const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
+
+        dst.qs[j] = xi0 | (xi1 << 4);
+
+        const float v0 = kvalues_iq4nl_f[xi0];
+        const float v1 = kvalues_iq4nl_f[xi1];
+        const float w0 = src[0        + j]*src[0        + j];
+        const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
+        sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
+        sumq2 += w0*v0*v0 + w1*v1*v1;
+
+    }
+
+    dst.d = sumq2 > 0 ? sumqx/sumq2 : d;
+}
+
 template <typename type4x4>
 void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
     device const uint16_t * qs = ((device const uint16_t *)xb + 2);
@@ -279,6 +457,26 @@ void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & re
     }
 }
 
+void quantize_q8_0(device const float * src, device block_q8_0 & dst) {
+    float amax = 0.0f; // absolute max
+
+    for (int j = 0; j < QK8_0; j++) {
+        const float v = src[j];
+        amax = MAX(amax, fabs(v));
+    }
+
+    const float d = amax / ((1 << 7) - 1);
+    const float id = d ? 1.0f/d : 0.0f;
+
+    dst.d = d;
+
+    for (int j = 0; j < QK8_0; ++j) {
+        const float x0 = src[j]*id;
+
+        dst.qs[j] = round(x0);
+    }
+}
+
 template <typename type4x4>
 void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
     const float d = xb->d;
@@ -4410,6 +4608,7 @@ template [[host_name("kernel_cpy_bf16_f32")]]  kernel kernel_cpy_t kernel_cpy<bf
 template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy<bfloat, bfloat>;
 #endif
 
+// TODO: templetify these kernels
 kernel void kernel_cpy_f32_q8_0(
         constant ggml_metal_kargs_cpy & args,
         device const char * src0,
@@ -4433,23 +4632,7 @@ kernel void kernel_cpy_f32_q8_0(
     for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) {
         device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
 
-        float amax = 0.0f; // absolute max
-
-        for (int j = 0; j < QK8_0; j++) {
-            const float v = src[j];
-            amax = MAX(amax, fabs(v));
-        }
-
-        const float d = amax / ((1 << 7) - 1);
-        const float id = d ? 1.0f/d : 0.0f;
-
-        dst_data[i00/QK8_0].d = d;
-
-        for (int j = 0; j < QK8_0; ++j) {
-            const float x0 = src[j]*id;
-
-            dst_data[i00/QK8_0].qs[j] = round(x0);
-        }
+        quantize_q8_0(src, dst_data[i00/QK8_0]);
     }
 }
 
@@ -4476,32 +4659,7 @@ kernel void kernel_cpy_f32_q4_0(
     for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) {
         device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
 
-        float amax = 0.0f; // absolute max
-        float max  = 0.0f;
-
-        for (int j = 0; j < QK4_0; j++) {
-            const float v = src[j];
-            if (amax < fabs(v)) {
-                amax = fabs(v);
-                max  = v;
-            }
-        }
-
-        const float d = max / -8;
-        const float id = d ? 1.0f/d : 0.0f;
-
-        dst_data[i00/QK4_0].d = d;
-
-        for (int j = 0; j < QK4_0/2; ++j) {
-            const float x0 = src[0       + j]*id;
-            const float x1 = src[QK4_0/2 + j]*id;
-
-            const uint8_t xi0 = MIN(15, (int8_t)(x0 + 8.5f));
-            const uint8_t xi1 = MIN(15, (int8_t)(x1 + 8.5f));
-
-            dst_data[i00/QK4_0].qs[j]  = xi0;
-            dst_data[i00/QK4_0].qs[j] |= xi1 << 4;
-        }
+        quantize_q4_0(src, dst_data[i00/QK4_0]);
     }
 }
 
@@ -4528,31 +4686,7 @@ kernel void kernel_cpy_f32_q4_1(
     for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) {
         device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
 
-        float min = FLT_MAX;
-        float max = -FLT_MAX;
-
-        for (int j = 0; j < QK4_1; j++) {
-            const float v = src[j];
-            if (min > v) min = v;
-            if (max < v) max = v;
-        }
-
-        const float d = (max - min) / ((1 << 4) - 1);
-        const float id = d ? 1.0f/d : 0.0f;
-
-        dst_data[i00/QK4_1].d = d;
-        dst_data[i00/QK4_1].m = min;
-
-        for (int j = 0; j < QK4_1/2; ++j) {
-            const float x0 = (src[0       + j] - min)*id;
-            const float x1 = (src[QK4_1/2 + j] - min)*id;
-
-            const uint8_t xi0 = MIN(15, (int8_t)(x0 + 0.5f));
-            const uint8_t xi1 = MIN(15, (int8_t)(x1 + 0.5f));
-
-            dst_data[i00/QK4_1].qs[j]  = xi0;
-            dst_data[i00/QK4_1].qs[j] |= xi1 << 4;
-        }
+        quantize_q4_1(src, dst_data[i00/QK4_1]);
     }
 }
 
@@ -4579,38 +4713,7 @@ kernel void kernel_cpy_f32_q5_0(
     for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) {
         device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
 
-        float amax = 0.0f; // absolute max
-        float max  = 0.0f;
-
-        for (int j = 0; j < QK5_0; j++) {
-            const float v = src[j];
-            if (amax < fabs(v)) {
-                amax = fabs(v);
-                max  = v;
-            }
-        }
-
-        const float d = max / -16;
-        const float id = d ? 1.0f/d : 0.0f;
-
-        dst_data[i00/QK5_0].d = d;
-
-        uint32_t qh = 0;
-        for (int j = 0; j < QK5_0/2; ++j) {
-            const float x0 = src[0       + j]*id;
-            const float x1 = src[QK5_0/2 + j]*id;
-
-            const uint8_t xi0 = MIN(31, (int8_t)(x0 + 16.5f));
-            const uint8_t xi1 = MIN(31, (int8_t)(x1 + 16.5f));
-
-            dst_data[i00/QK5_0].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
-            qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
-            qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
-        }
-        thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
-        for (int j = 0; j < 4; ++j) {
-            dst_data[i00/QK5_0].qh[j] = qh8[j];
-        }
+        quantize_q5_0(src, dst_data[i00/QK5_0]);
     }
 }
 
@@ -4637,49 +4740,8 @@ kernel void kernel_cpy_f32_q5_1(
     for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) {
         device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
 
-        float max = src[0];
-        float min = src[0];
-
-        for (int j = 1; j < QK5_1; j++) {
-            const float v = src[j];
-            min = v < min ? v : min;
-            max = v > max ? v : max;
-        }
-
-        const float d = (max - min) / 31;
-        const float id = d ? 1.0f/d : 0.0f;
-
-        dst_data[i00/QK5_1].d = d;
-        dst_data[i00/QK5_1].m = min;
-
-        uint32_t qh = 0;
-        for (int j = 0; j < QK5_1/2; ++j) {
-            const float x0 = (src[0       + j] - min)*id;
-            const float x1 = (src[QK5_1/2 + j] - min)*id;
-
-            const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
-            const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
-
-            dst_data[i00/QK5_1].qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
-            qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
-            qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
-        }
-        thread const uint8_t * qh8 = (thread const uint8_t *)&qh;
-        for (int j = 0; j < 4; ++j) {
-            dst_data[i00/QK5_1].qh[j] = qh8[j];
-        }
-    }
-}
-
-static inline int best_index_int8(int n, constant float * val, float x) {
-    if (x <= val[0]) return 0;
-    if (x >= val[n-1]) return n-1;
-    int ml = 0, mu = n-1;
-    while (mu-ml > 1) {
-        int mav = (ml+mu)/2;
-        if (x < val[mav]) mu = mav; else ml = mav;
+        quantize_q5_1(src, dst_data[i00/QK5_1]);
     }
-    return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
 }
 
 kernel void kernel_cpy_f32_iq4_nl(
@@ -4705,40 +4767,7 @@ kernel void kernel_cpy_f32_iq4_nl(
     for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) {
         device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
 
-        float amax = 0.0f; // absolute max
-        float max  = 0.0f;
-
-        for (int j = 0; j < QK4_NL; j++) {
-            const float v = src[j];
-            if (amax < fabs(v)) {
-                amax = fabs(v);
-                max  = v;
-            }
-        }
-
-        const float d = max / kvalues_iq4nl_f[0];
-        const float id = d ? 1.0f/d : 0.0f;
-
-        float sumqx = 0, sumq2 = 0;
-        for (int j = 0; j < QK4_NL/2; ++j) {
-            const float x0 = src[0        + j]*id;
-            const float x1 = src[QK4_NL/2 + j]*id;
-
-            const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl_f, x0);
-            const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl_f, x1);
-
-            dst_data[i00/QK4_NL].qs[j] = xi0 | (xi1 << 4);
-
-            const float v0 = kvalues_iq4nl_f[xi0];
-            const float v1 = kvalues_iq4nl_f[xi1];
-            const float w0 = src[0        + j]*src[0        + j];
-            const float w1 = src[QK4_NL/2 + j]*src[QK4_NL/2 + j];
-            sumqx += w0*v0*src[j] + w1*v1*src[QK4_NL/2 + j];
-            sumq2 += w0*v0*v0 + w1*v1*v1;
-
-        }
-
-        dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d;
+        quantize_iq4_nl(src, dst_data[i00/QK4_NL]);
     }
 }
 
@@ -6419,10 +6448,10 @@ kernel void kernel_mul_mv_iq4_xs_f32(
 
 template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
 kernel void kernel_get_rows_q(
+        constant ggml_metal_kargs_get_rows & args,
         device const  void * src0,
         device const  void * src1,
         device       float * dst,
-        constant ggml_metal_kargs_get_rows & args,
         uint3                tgpig[[threadgroup_position_in_grid]],
         uint                 tiitg[[thread_index_in_threadgroup]],
         uint3                tptg [[threads_per_threadgroup]]) {
@@ -6442,10 +6471,10 @@ kernel void kernel_get_rows_q(
 
 template<typename T>
 kernel void kernel_get_rows_f(
+        constant ggml_metal_kargs_get_rows & args,
         device const  void * src0,
         device const  void * src1,
         device       float * dst,
-        constant ggml_metal_kargs_get_rows & args,
         uint3                tgpig[[threadgroup_position_in_grid]],
         uint                 tiitg[[thread_index_in_threadgroup]],
         uint3                tptg [[threads_per_threadgroup]]) {
@@ -6463,10 +6492,10 @@ kernel void kernel_get_rows_f(
 }
 
 kernel void kernel_get_rows_i32(
+        constant ggml_metal_kargs_get_rows & args,
         device const  void * src0,
         device const  void * src1,
         device     int32_t * dst,
-        constant ggml_metal_kargs_get_rows & args,
         uint3                tgpig[[threadgroup_position_in_grid]],
         uint                 tiitg[[thread_index_in_threadgroup]],
         uint3                tptg [[threads_per_threadgroup]]) {
@@ -6483,6 +6512,67 @@ kernel void kernel_get_rows_i32(
     }
 }
 
+template<typename block_q, void (*quantize_func)(device const float *, device block_q &)>
+kernel void kernel_set_rows_q32(
+        constant ggml_metal_kargs_set_rows & args,
+        device const  void * src0,
+        device const  void * src1,
+        device       float * dst,
+        uint3                tgpig[[threadgroup_position_in_grid]],
+        uint                 tiitg[[thread_index_in_threadgroup]],
+        uint3                tptg [[threads_per_threadgroup]]) {
+    const int32_t i03 = tgpig.z;
+    const int32_t i02 = tgpig.y;
+
+    const int32_t i12 = i03%args.ne12;
+    const int32_t i11 = i02%args.ne11;
+
+    const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
+    if (i01 >= args.ne01) {
+        return;
+    }
+
+    const int32_t i10 = i01;
+    const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
+
+          device block_q * dst_row = (      device block_q *) ((      device char *) dst  +  i1*args.nb1  + i02*args.nb2  + i03*args.nb3);
+    const device float   * src_row = (const device float   *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
+
+    for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
+        quantize_func(src_row + 32*ind, dst_row[ind]);
+    }
+}
+
+template<typename T>
+kernel void kernel_set_rows_f(
+        constant ggml_metal_kargs_set_rows & args,
+        device const  void * src0,
+        device const  void * src1,
+        device       float * dst,
+        uint3                tgpig[[threadgroup_position_in_grid]],
+        uint                 tiitg[[thread_index_in_threadgroup]],
+        uint3                tptg [[threads_per_threadgroup]]) {
+    const int32_t i03 = tgpig.z;
+    const int32_t i02 = tgpig.y;
+
+    const int32_t i12 = i03%args.ne12;
+    const int32_t i11 = i02%args.ne11;
+
+    const int32_t i01 = tgpig.x*tptg.y + tiitg/tptg.x;
+    if (i01 >= args.ne01) {
+        return;
+    }
+
+    const int32_t i10 = i01;
+    const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
+
+    device T     * dst_row = (      device T     *) ((      device char *) dst  +  i1*args.nb1  + i02*args.nb2  + i03*args.nb3);
+    const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
+
+    for (int ind = tiitg%tptg.x; ind < args.nk0; ind += tptg.x) {
+        dst_row[ind] = (T) src_row[ind];
+    }
+}
 
 #define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
 #define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix B
@@ -6906,6 +6996,27 @@ template [[host_name("kernel_get_rows_iq1_m")]]   kernel get_rows_q_t kernel_get
 template [[host_name("kernel_get_rows_iq4_nl")]]  kernel get_rows_q_t kernel_get_rows_q<block_iq4_nl,  2,     dequantize_iq4_nl>;
 template [[host_name("kernel_get_rows_iq4_xs")]]  kernel get_rows_q_t kernel_get_rows_q<block_iq4_xs,  QK_NL, dequantize_iq4_xs>;
 
+//
+// set rows
+//
+
+typedef decltype(kernel_set_rows_f<float>) set_rows_f_t;
+
+template [[host_name("kernel_set_rows_f32")]]  kernel set_rows_f_t kernel_set_rows_f<float>;
+template [[host_name("kernel_set_rows_f16")]]  kernel set_rows_f_t kernel_set_rows_f<half>;
+#if defined(GGML_METAL_USE_BF16)
+template [[host_name("kernel_set_rows_bf16")]] kernel set_rows_f_t kernel_set_rows_f<bfloat>;
+#endif
+
+typedef decltype(kernel_set_rows_q32<block_q8_0, quantize_q8_0>) set_rows_q32_t;
+
+template [[host_name("kernel_set_rows_q8_0")]]   kernel set_rows_q32_t kernel_set_rows_q32<block_q8_0,   quantize_q8_0>;
+template [[host_name("kernel_set_rows_q4_0")]]   kernel set_rows_q32_t kernel_set_rows_q32<block_q4_0,   quantize_q4_0>;
+template [[host_name("kernel_set_rows_q4_1")]]   kernel set_rows_q32_t kernel_set_rows_q32<block_q4_1,   quantize_q4_1>;
+template [[host_name("kernel_set_rows_q5_0")]]   kernel set_rows_q32_t kernel_set_rows_q32<block_q5_0,   quantize_q5_0>;
+template [[host_name("kernel_set_rows_q5_1")]]   kernel set_rows_q32_t kernel_set_rows_q32<block_q5_1,   quantize_q5_1>;
+template [[host_name("kernel_set_rows_iq4_nl")]] kernel set_rows_q32_t kernel_set_rows_q32<block_iq4_nl, quantize_iq4_nl>;
+
 //
 // matrix-matrix multiplication
 //
index c963b338a18a64c6683dfc20ed9a2b9877dbcfef..be8a9c4488b14648640b7e8abd02529ea705ad94 100644 (file)
@@ -933,6 +933,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "TRANSPOSE",
     "GET_ROWS",
     "GET_ROWS_BACK",
+    "SET_ROWS",
     "DIAG",
     "DIAG_MASK_INF",
     "DIAG_MASK_ZERO",
@@ -983,7 +984,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "OPT_STEP_ADAMW",
 };
 
-static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
+static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -1029,6 +1030,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "transpose(x)",
     "get_rows(x)",
     "get_rows_back(x)",
+    "set_rows(x)",
     "diag(x)",
     "diag_mask_inf(x)",
     "diag_mask_zero(x)",
@@ -1079,7 +1081,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "adamw(x)",
 };
 
-static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
+static_assert(GGML_OP_COUNT == 84, "GGML_OP_COUNT != 84");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -1348,6 +1350,12 @@ bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor) {
         tensor->nb[2] == ggml_type_size(tensor->type);
 }
 
+bool ggml_is_contiguous_rows(const struct ggml_tensor * tensor) {
+    return
+        tensor->ne[0] == ggml_blck_size(tensor->type) ||
+        tensor->nb[0] == ggml_type_size(tensor->type);
+}
+
 static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
     static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
 
@@ -3384,6 +3392,35 @@ struct ggml_tensor * ggml_get_rows_back(
     return result;
 }
 
+// ggml_set_rows
+
+struct ggml_tensor * ggml_set_rows(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * b,
+        struct ggml_tensor  * c) {
+    GGML_ASSERT(a->ne[0] == b->ne[0]);
+    GGML_ASSERT(a->ne[2] == b->ne[2]);
+    GGML_ASSERT(a->ne[3] == b->ne[3]);
+    GGML_ASSERT(b->ne[1] == c->ne[0]);
+    GGML_ASSERT(b->ne[2] % c->ne[1] == 0);
+    GGML_ASSERT(b->ne[3] % c->ne[2] == 0);
+    GGML_ASSERT(c->ne[3] == 1);
+    GGML_ASSERT(b->type == GGML_TYPE_F32);
+    GGML_ASSERT(c->type == GGML_TYPE_I64);
+
+    GGML_ASSERT(ggml_is_contiguous_rows(a));
+    GGML_ASSERT(ggml_is_contiguous_rows(b));
+
+    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
+
+    result->op     = GGML_OP_SET_ROWS;
+    result->src[0] = b;
+    result->src[1] = c;
+
+    return result;
+}
+
 // ggml_diag
 
 struct ggml_tensor * ggml_diag(
index c3118d416b12dada3229f346f114d6c0d2e84f88..2e68c35813713b8c99d7781e38811beda4af76c7 100644 (file)
@@ -1213,6 +1213,76 @@ struct test_get_rows_back : public test_case {
     }
 };
 
+// GGML_OP_SET_ROWS
+struct test_set_rows : public test_case {
+    const ggml_type type;
+    const std::array<int64_t, 4> ne;
+    const std::array<int, 2> nr23; // broadcast only dims 2 and 3
+    const int r; // rows to set
+    const bool v; // view (non-contiguous src1)
+
+    std::string vars() override {
+        return VARS_TO_STR5(type, ne, nr23, r, v);
+    }
+
+    test_set_rows(ggml_type type,
+            std::array<int64_t, 4> ne,
+            std::array<int, 2> nr23,
+            int r, bool v = false)
+        : type(type), ne(ne), nr23(nr23), r(r), v(v) {}
+
+    ggml_tensor * build_graph(ggml_context * ctx) override {
+        ggml_tensor * dst = ggml_new_tensor_4d(ctx, type,          ne[0], ne[1], ne[2]*nr23[0], ne[3]*nr23[1]);
+        ggml_set_name(dst, "dst");
+
+        ggml_tensor * src = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], r,     ne[2]*nr23[0], ne[3]*nr23[1]);
+        ggml_set_name(src, "src");
+
+        ggml_tensor * row_idxs = ggml_new_tensor_3d(ctx, GGML_TYPE_I64, r, ne[2], ne[3]);
+        ggml_set_name(row_idxs, "row_idxs");
+
+        if (v) {
+            src      = ggml_view_4d(ctx, src, ne[0], r/2, ne[2]*nr23[0], ne[3]*nr23[1], src->nb[1], src->nb[2], src->nb[3], 0);
+            row_idxs = ggml_view_3d(ctx, row_idxs, r/2, ne[2], ne[3], row_idxs->nb[1], row_idxs->nb[2], 0);
+            ggml_set_name(row_idxs, "view_of_rows");
+        }
+
+        ggml_tensor * out = ggml_set_rows(ctx, dst, src, row_idxs);
+        ggml_set_name(out, "out");
+
+        return out;
+    }
+
+    void initialize_tensors(ggml_context * ctx) override {
+        std::random_device rd;
+        std::default_random_engine rng(rd());
+        for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
+            if (t->type == GGML_TYPE_I64) {
+                if (ggml_is_view_op(t->op)) {
+                    continue;
+                }
+
+                for (int i2 = 0; i2 < t->ne[2]; i2++) {
+                    for (int i1 = 0; i1 < t->ne[1]; i1++) {
+                        // generate a shuffled subset of row indices
+                        std::vector<int64_t> data(ne[1]);
+                        for (int i = 0; i < ne[1]; i++) {
+                            data[i] = i;
+                        }
+                        std::shuffle(data.begin(), data.end(), rng);
+                        data.resize(t->ne[0]);
+
+                        const size_t offs = i1*t->nb[1] + i2*t->nb[2];
+                        ggml_backend_tensor_set(t, data.data(), offs, t->ne[0]*sizeof(int64_t));
+                    }
+                }
+            } else {
+                init_tensor_uniform(t);
+            }
+        }
+    }
+};
+
 // GGML_OP_ARGMAX
 struct test_argmax : public test_case {
     const ggml_type type;
@@ -3984,6 +4054,23 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         test_cases.emplace_back(new test_get_rows_back(GGML_TYPE_I32, 256, 5, 4, 1, v));
     }
 
+    test_cases.emplace_back(new test_set_rows(GGML_TYPE_F32, { 1, 8, 1, 3 }, { 1, 1 }, 2, false));
+    for (ggml_type type : all_types) {
+        for (int b : {1, 7}) {
+            for (bool v : {false, true}) {
+                test_cases.emplace_back(new test_set_rows(type, { 256, 5,  b, 3 }, { 1, 1, }, 1, v));
+                test_cases.emplace_back(new test_set_rows(type, { 256, 11, 1, b }, { 2, 3, }, 7, v));
+
+                test_cases.emplace_back(new test_set_rows(type, { 3*ggml_blck_size(type), 3, b, 1 }, { 2, 3, }, 2, v));
+
+                if (ggml_blck_size(type) == 1) {
+                    test_cases.emplace_back(new test_set_rows(type, { 31, 3, b, 1 }, { 2, 3, }, 2, v));
+                    test_cases.emplace_back(new test_set_rows(type, { 33, 5, 1, b }, { 2, 3, }, 1, v));
+                }
+            }
+        }
+    }
+
     for (ggml_type type_input : {GGML_TYPE_F32}) {
         for (ggml_op_pool pool_type : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) {
             for (int k0 : {1, 3}) {