]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : implement set_rows with i32 index (#16159)
authorSigbjørn Skjæret <redacted>
Mon, 22 Sep 2025 17:13:00 +0000 (19:13 +0200)
committerGitHub <redacted>
Mon, 22 Sep 2025 17:13:00 +0000 (19:13 +0200)
* implement set_rows with i32 index

* template fix

* test quantized path

warnings--

* Apply suggestions from code review

Co-authored-by: Georgi Gerganov <redacted>
* forgotten name change

* deduplicate cuda/sycl and test-fix

* indent++

* vulkan: support set_rows with i32 index type (#16162)

* disable i32 index for webgpu for now

---------

Co-authored-by: Georgi Gerganov <redacted>
Co-authored-by: Jeff Bolz <redacted>
17 files changed:
ggml/src/ggml-cpu/ops.cpp
ggml/src/ggml-cuda/ggml-cuda.cu
ggml/src/ggml-cuda/set-rows.cu
ggml/src/ggml-metal/ggml-metal-device.cpp
ggml/src/ggml-metal/ggml-metal-device.h
ggml/src/ggml-metal/ggml-metal-ops.cpp
ggml/src/ggml-metal/ggml-metal.metal
ggml/src/ggml-opencl/ggml-opencl.cpp
ggml/src/ggml-opencl/kernels/set_rows.cl
ggml/src/ggml-sycl/ggml-sycl.cpp
ggml/src/ggml-sycl/set_rows.cpp
ggml/src/ggml-vulkan/ggml-vulkan.cpp
ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp
ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
ggml/src/ggml-webgpu/ggml-webgpu.cpp
ggml/src/ggml.c
tests/test-backend-ops.cpp

index 763ab099e31a656b2735fa41ad68434b956a98e6..14f7dcf4f41ad244f013fd18220461767c1f6228 100644 (file)
@@ -4739,6 +4739,7 @@ void ggml_compute_forward_get_rows(
     //}
 }
 
+template<typename idx_t>
 static void ggml_compute_forward_set_rows_f32(
         const ggml_compute_params * params,
               ggml_tensor * dst) {
@@ -4777,7 +4778,7 @@ static void ggml_compute_forward_set_rows_f32(
                 const int64_t i11 = i02%ne11;
                 const int64_t i10 = i;
 
-                const int64_t i1 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+                const int64_t i1 = *(idx_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
 
                 GGML_ASSERT(i1 >= 0 && i1 < ne1);
 
@@ -4794,11 +4795,18 @@ void ggml_compute_forward_set_rows(
         ggml_tensor * dst) {
 
     const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
 
     switch (src0->type) {
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_set_rows_f32(params, dst);
+                if (src1->type == GGML_TYPE_I64) {
+                    ggml_compute_forward_set_rows_f32<int64_t>(params, dst);
+                } else if (src1->type == GGML_TYPE_I32) {
+                    ggml_compute_forward_set_rows_f32<int32_t>(params, dst);
+                } else {
+                    GGML_ABORT("src1->type = %d (%s) not supported", src1->type, ggml_type_name(src1->type));
+                }
             } break;
         default:
             {
index f3ba20fe3f747d641b203b2a7fd3b502498ad7d2..4d85c5dc083d1d056e50594458813d8f1bfad999 100644 (file)
@@ -3427,7 +3427,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
                        op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q5_0 ||
                        op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_IQ4_NL) &&
                        op->src[0]->type == GGML_TYPE_F32 &&
-                       op->src[1]->type == GGML_TYPE_I64;
+                       (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32);
             } break;
         case GGML_OP_CPY:
             {
index b4115a43c2a3296103af86baa41a36eca8f0b8f5..1525a159527af36581c78933978972f6a3ccf0d2 100644 (file)
@@ -4,9 +4,9 @@
 typedef void (*set_rows_kernel_t)(const char * src, char * dst);
 
 // Generic quantized set_rows kernel template
-template<typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
+template<typename idx_t, typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
 static __global__ void k_set_rows_quant(
-        const float * __restrict__ src0, const int64_t * __restrict__ src1, block_type * __restrict__ dst,
+        const float * __restrict__ src0, const idx_t * __restrict__ src1, block_type * __restrict__ dst,
         const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
         const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
         const int64_t s01, const int64_t s02, const int64_t s03,
@@ -45,9 +45,9 @@ static __global__ void k_set_rows_quant(
 }
 
 // Template dispatch function for quantized set_rows
-template<typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
+template<typename idx_t, typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
 static void set_rows_cuda_quant(
-        const float * src0_d, const int64_t * src1_d, block_type * dst_d,
+        const float * src0_d, const idx_t * src1_d, block_type * dst_d,
         const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
         const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
         const size_t nb01, const size_t nb02, const size_t nb03,
@@ -64,15 +64,15 @@ static void set_rows_cuda_quant(
     const int64_t s01 = nb01/sizeof(float);
     const int64_t s02 = nb02/sizeof(float);
     const int64_t s03 = nb03/sizeof(float);
-    const int64_t s10 = nb10/sizeof(int64_t);
-    const int64_t s11 = nb11/sizeof(int64_t);
-    const int64_t s12 = nb12/sizeof(int64_t);
+    const int64_t s10 = nb10/sizeof(idx_t);
+    const int64_t s11 = nb11/sizeof(idx_t);
+    const int64_t s12 = nb12/sizeof(idx_t);
     const int64_t s1  = nb1;
     const int64_t s2  = nb2;
     const int64_t s3  = nb3;
 
     if (ne_total > 0) {
-        k_set_rows_quant<block_type, qk, quantize_func><<<grid_size, block_size, 0, stream>>>(
+        k_set_rows_quant<idx_t, block_type, qk, quantize_func><<<grid_size, block_size, 0, stream>>>(
             src0_d, src1_d, dst_d,
             ne00, ne01, ne02, ne03,
             ne10, ne11, ne12, ne13,
@@ -82,9 +82,9 @@ static void set_rows_cuda_quant(
     }
 }
 
-template<typename src_t, typename dst_t>
+template<typename src_t, typename idx_t, typename dst_t>
 static __global__ void k_set_rows(
-        const src_t * __restrict__ src0, const int64_t * __restrict__ src1, dst_t * __restrict__ dst,
+        const src_t * __restrict__ src0, const idx_t * __restrict__ src1, dst_t * __restrict__ dst,
         const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
         const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
         const int64_t s01, const int64_t s02, const int64_t s03,
@@ -118,9 +118,9 @@ static __global__ void k_set_rows(
     GGML_UNUSED(ne13);
 }
 
-template<typename src_t, typename dst_t>
+template<typename src_t, typename idx_t, typename dst_t>
 static void set_rows_cuda(
-        const src_t * src0_d, const int64_t * src1_d, dst_t * dst_d,
+        const src_t * src0_d, const idx_t * src1_d, dst_t * dst_d,
         const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
         const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
         const size_t nb01, const size_t nb02, const size_t nb03,
@@ -137,9 +137,9 @@ static void set_rows_cuda(
     const int64_t s01 = nb01/sizeof(src_t);
     const int64_t s02 = nb02/sizeof(src_t);
     const int64_t s03 = nb03/sizeof(src_t);
-    const int64_t s10 = nb10/sizeof(int64_t);
-    const int64_t s11 = nb11/sizeof(int64_t);
-    const int64_t s12 = nb12/sizeof(int64_t);
+    const int64_t s10 = nb10/sizeof(idx_t);
+    const int64_t s11 = nb11/sizeof(idx_t);
+    const int64_t s12 = nb12/sizeof(idx_t);
     const int64_t s1  = nb1/sizeof(dst_t);
     const int64_t s2  = nb2/sizeof(dst_t);
     const int64_t s3  = nb3/sizeof(dst_t);
@@ -155,23 +155,16 @@ static void set_rows_cuda(
     }
 }
 
-
-void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * src0 = dst->src[0];
-    const ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT(src1->type == GGML_TYPE_I64);
+template<typename src_t, typename idx_t>
+static void set_rows_cuda(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    const src_t * src0_d = (const src_t *)src0->data;
+    const idx_t * src1_d = (const idx_t *)src1->data;
 
     GGML_TENSOR_BINARY_OP_LOCALS
 
-    const float * src0_d   = (const float *)src0->data;
-    const int64_t * src1_d = (const int64_t *)src1->data;
-
     cudaStream_t stream = ctx.stream();
 
 
-
     if (dst->type == GGML_TYPE_F32) {
         set_rows_cuda(
             src0_d, src1_d, (float*)dst->data,
@@ -203,7 +196,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
             stream
         );
     } else if (dst->type == GGML_TYPE_Q4_0) {
-        set_rows_cuda_quant<block_q4_0, QK4_0, quantize_f32_q4_0_block>(
+        set_rows_cuda_quant<idx_t, block_q4_0, QK4_0, quantize_f32_q4_0_block>(
             src0_d, src1_d, (block_q4_0*)dst->data,
             ne00, ne01, ne02, ne03,
             ne10, ne11, ne12, ne13,
@@ -213,7 +206,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
             stream
         );
     } else if (dst->type == GGML_TYPE_Q4_1) {
-        set_rows_cuda_quant<block_q4_1, QK4_1, quantize_f32_q4_1_block>(
+        set_rows_cuda_quant<idx_t, block_q4_1, QK4_1, quantize_f32_q4_1_block>(
             src0_d, src1_d, (block_q4_1*)dst->data,
             ne00, ne01, ne02, ne03,
             ne10, ne11, ne12, ne13,
@@ -223,7 +216,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
             stream
         );
     } else if (dst->type == GGML_TYPE_Q5_0) {
-        set_rows_cuda_quant<block_q5_0, QK5_0, quantize_f32_q5_0_block>(
+        set_rows_cuda_quant<idx_t, block_q5_0, QK5_0, quantize_f32_q5_0_block>(
             src0_d, src1_d, (block_q5_0*)dst->data,
             ne00, ne01, ne02, ne03,
             ne10, ne11, ne12, ne13,
@@ -233,7 +226,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
             stream
         );
     } else if (dst->type == GGML_TYPE_Q5_1) {
-        set_rows_cuda_quant<block_q5_1, QK5_1, quantize_f32_q5_1_block>(
+        set_rows_cuda_quant<idx_t, block_q5_1, QK5_1, quantize_f32_q5_1_block>(
             src0_d, src1_d, (block_q5_1*)dst->data,
             ne00, ne01, ne02, ne03,
             ne10, ne11, ne12, ne13,
@@ -243,7 +236,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
             stream
         );
     } else if (dst->type == GGML_TYPE_Q8_0) {
-        set_rows_cuda_quant<block_q8_0, QK8_0, quantize_f32_q8_0_block>(
+        set_rows_cuda_quant<idx_t, block_q8_0, QK8_0, quantize_f32_q8_0_block>(
             src0_d, src1_d, (block_q8_0*)dst->data,
             ne00, ne01, ne02, ne03,
             ne10, ne11, ne12, ne13,
@@ -253,7 +246,7 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
             stream
         );
     } else if (dst->type == GGML_TYPE_IQ4_NL) {
-        set_rows_cuda_quant<block_iq4_nl, QK4_NL, quantize_f32_iq4_nl_block>(
+        set_rows_cuda_quant<idx_t, block_iq4_nl, QK4_NL, quantize_f32_iq4_nl_block>(
             src0_d, src1_d, (block_iq4_nl*)dst->data,
             ne00, ne01, ne02, ne03,
             ne10, ne11, ne12, ne13,
@@ -266,3 +259,18 @@ void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
         GGML_ABORT("unsupported type %s", ggml_type_name(dst->type));
     }
 }
+
+
+void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT(src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32);
+
+    if (src1->type == GGML_TYPE_I64) {
+        set_rows_cuda<float, int64_t>(ctx, src0, src1, dst);
+    } else {
+        set_rows_cuda<float, int32_t>(ctx, src0, src1, dst);
+    }
+}
index fe015afc54aa9fe0a791db63eb53f1069af04ac9..9f91662cbd876d447b02eae2e968d6178748e6b5 100644 (file)
@@ -142,11 +142,11 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_get_rows(ggml_metal_librar
     return res;
 }
 
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows(ggml_metal_library_t lib, ggml_type tdst) {
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows(ggml_metal_library_t lib, ggml_type tidx, ggml_type tdst) {
     char base[256];
     char name[256];
 
-    snprintf(base, 256, "kernel_set_rows_%s", ggml_type_name(tdst));
+    snprintf(base, 256, "kernel_set_rows_%s_%s", ggml_type_name(tdst), ggml_type_name(tidx));
     snprintf(name, 256, "%s", base);
 
     ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
index 044d6953f6779d37e1c1a7291629125f23d54a6c..da67bfab758e894e58e8bab6fd7f92930b5f84e6 100644 (file)
@@ -105,7 +105,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_base              (ggml_me
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cpy               (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d           (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_get_rows          (ggml_metal_library_t lib, enum ggml_type tsrc);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows          (ggml_metal_library_t lib, enum ggml_type tdst);
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows          (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat            (ggml_metal_library_t lib, enum ggml_type tsrc);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary             (ggml_metal_library_t lib, const struct ggml_tensor * op);
 ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu               (ggml_metal_library_t lib, const struct ggml_tensor * op);
index 04665b3d6dbb6f55a244db3d19cdd3d0c9a040ec..3b163d9a38e75afbc9ab324650181264894401e5 100644 (file)
@@ -892,7 +892,7 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
     GGML_TENSOR_LOCALS( int32_t, ne,  op,         ne);
     GGML_TENSOR_LOCALS(uint32_t, nb,  op,         nb);
 
-    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->type);
+    ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
 
     const int32_t nk0 = ne0/ggml_blck_size(op->type);
 
index c7d97ba70b4c9a644127c5361c590ec887f2052f..2ba4cb50b9ec224121caa50410f599bc69f6ea18 100644 (file)
@@ -7743,7 +7743,7 @@ kernel void kernel_get_rows_i32(
     }
 }
 
-template<typename block_q, void (*quantize_func)(device const float *, device block_q &)>
+template<typename TI, typename block_q, void (*quantize_func)(device const float *, device block_q &)>
 kernel void kernel_set_rows_q32(
         constant ggml_metal_kargs_set_rows & args,
         device const  void * src0,
@@ -7764,7 +7764,7 @@ kernel void kernel_set_rows_q32(
     }
 
     const int32_t i10 = i01;
-    const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
+    const TI      i1  = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
 
           device block_q * dst_row = (      device block_q *) ((      device char *) dst  +  i1*args.nb1  + i02*args.nb2  + i03*args.nb3);
     const device float   * src_row = (const device float   *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
@@ -7774,7 +7774,7 @@ kernel void kernel_set_rows_q32(
     }
 }
 
-template<typename T>
+template<typename T, typename TI>
 kernel void kernel_set_rows_f(
         constant ggml_metal_kargs_set_rows & args,
         device const  void * src0,
@@ -7795,7 +7795,7 @@ kernel void kernel_set_rows_f(
     }
 
     const int32_t i10 = i01;
-    const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
+    const TI      i1  = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
 
           device T     * dst_row = (      device T     *) ((      device char *) dst  +  i1*args.nb1  + i02*args.nb2  + i03*args.nb3);
     const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
@@ -8218,22 +8218,31 @@ template [[host_name("kernel_get_rows_iq4_xs")]]  kernel get_rows_q_t kernel_get
 // set rows
 //
 
-typedef decltype(kernel_set_rows_f<float>) set_rows_f_t;
+typedef decltype(kernel_set_rows_f<float, int64_t>) set_rows_f_t;
 
-template [[host_name("kernel_set_rows_f32")]]  kernel set_rows_f_t kernel_set_rows_f<float>;
-template [[host_name("kernel_set_rows_f16")]]  kernel set_rows_f_t kernel_set_rows_f<half>;
+template [[host_name("kernel_set_rows_f32_i64")]]  kernel set_rows_f_t kernel_set_rows_f<float, int64_t>;
+template [[host_name("kernel_set_rows_f32_i32")]]  kernel set_rows_f_t kernel_set_rows_f<float, int32_t>;
+template [[host_name("kernel_set_rows_f16_i64")]]  kernel set_rows_f_t kernel_set_rows_f<half, int64_t>;
+template [[host_name("kernel_set_rows_f16_i32")]]  kernel set_rows_f_t kernel_set_rows_f<half, int32_t>;
 #if defined(GGML_METAL_HAS_BF16)
-template [[host_name("kernel_set_rows_bf16")]] kernel set_rows_f_t kernel_set_rows_f<bfloat>;
+template [[host_name("kernel_set_rows_bf16_i64")]] kernel set_rows_f_t kernel_set_rows_f<bfloat, int64_t>;
+template [[host_name("kernel_set_rows_bf16_i32")]] kernel set_rows_f_t kernel_set_rows_f<bfloat, int32_t>;
 #endif
 
-typedef decltype(kernel_set_rows_q32<block_q8_0, quantize_q8_0>) set_rows_q32_t;
-
-template [[host_name("kernel_set_rows_q8_0")]]   kernel set_rows_q32_t kernel_set_rows_q32<block_q8_0,   quantize_q8_0>;
-template [[host_name("kernel_set_rows_q4_0")]]   kernel set_rows_q32_t kernel_set_rows_q32<block_q4_0,   quantize_q4_0>;
-template [[host_name("kernel_set_rows_q4_1")]]   kernel set_rows_q32_t kernel_set_rows_q32<block_q4_1,   quantize_q4_1>;
-template [[host_name("kernel_set_rows_q5_0")]]   kernel set_rows_q32_t kernel_set_rows_q32<block_q5_0,   quantize_q5_0>;
-template [[host_name("kernel_set_rows_q5_1")]]   kernel set_rows_q32_t kernel_set_rows_q32<block_q5_1,   quantize_q5_1>;
-template [[host_name("kernel_set_rows_iq4_nl")]] kernel set_rows_q32_t kernel_set_rows_q32<block_iq4_nl, quantize_iq4_nl>;
+typedef decltype(kernel_set_rows_q32<int64_t, block_q8_0, quantize_q8_0>) set_rows_q32_t;
+
+template [[host_name("kernel_set_rows_q8_0_i64")]]   kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q8_0,   quantize_q8_0>;
+template [[host_name("kernel_set_rows_q8_0_i32")]]   kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q8_0,   quantize_q8_0>;
+template [[host_name("kernel_set_rows_q4_0_i64")]]   kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q4_0,   quantize_q4_0>;
+template [[host_name("kernel_set_rows_q4_0_i32")]]   kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q4_0,   quantize_q4_0>;
+template [[host_name("kernel_set_rows_q4_1_i64")]]   kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q4_1,   quantize_q4_1>;
+template [[host_name("kernel_set_rows_q4_1_i32")]]   kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q4_1,   quantize_q4_1>;
+template [[host_name("kernel_set_rows_q5_0_i64")]]   kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q5_0,   quantize_q5_0>;
+template [[host_name("kernel_set_rows_q5_0_i32")]]   kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q5_0,   quantize_q5_0>;
+template [[host_name("kernel_set_rows_q5_1_i64")]]   kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q5_1,   quantize_q5_1>;
+template [[host_name("kernel_set_rows_q5_1_i32")]]   kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q5_1,   quantize_q5_1>;
+template [[host_name("kernel_set_rows_iq4_nl_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_iq4_nl, quantize_iq4_nl>;
+template [[host_name("kernel_set_rows_iq4_nl_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_iq4_nl, quantize_iq4_nl>;
 
 //
 // matrix-matrix multiplication
index 259b42e55920c08a70263b3fcc0eca10b86a4467..0cf3b92464c6e077d8adb26c8c01acc03ea9276b 100644 (file)
@@ -439,7 +439,7 @@ struct ggml_backend_opencl_context {
     std::map<std::pair<int, int>, int>       kernels_flash_attn_bm;
     std::map<std::pair<int, int>, int>       kernels_flash_attn_bn;
     cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0;
-    cl_kernel kernel_set_rows_f32, kernel_set_rows_f16;
+    cl_kernel kernel_set_rows_f32_i64, kernel_set_rows_f32_i32, kernel_set_rows_f16_i64, kernel_set_rows_f16_i32;
     cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16;
     cl_kernel kernel_rope_multi_f32, kernel_rope_multi_f16, kernel_rope_vision_f32, kernel_rope_vision_f16;
     cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32;
@@ -1710,8 +1710,10 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
         backend_ctx->program_set_rows =
             build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
 
-        CL_CHECK((backend_ctx->kernel_set_rows_f32  = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_f32", &err), err));
-        CL_CHECK((backend_ctx->kernel_set_rows_f16  = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_f16", &err), err));
+        CL_CHECK((backend_ctx->kernel_set_rows_f32_i64 = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_f32_i64", &err), err));
+        CL_CHECK((backend_ctx->kernel_set_rows_f32_i32 = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_f32_i32", &err), err));
+        CL_CHECK((backend_ctx->kernel_set_rows_f16_i64 = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_f16_i64", &err), err));
+        CL_CHECK((backend_ctx->kernel_set_rows_f16_i32 = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_f16_i32", &err), err));
         GGML_LOG_CONT(".");
     }
 
@@ -2803,7 +2805,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
                 switch (op->type) {
                     case GGML_TYPE_F16:
                     case GGML_TYPE_F32:
-                        return true;
+                        return (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32);
                     default:
                         return false;
                 }
@@ -4284,6 +4286,7 @@ static void ggml_cl_set_rows(ggml_backend_t backend, const ggml_tensor * src0, c
     GGML_ASSERT(src1->extra);
     GGML_ASSERT(dst);
     GGML_ASSERT(dst->extra);
+    GGML_ASSERT(src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32);
 
     // ne0 = ne00
     // ne2 = ne02
@@ -4326,10 +4329,18 @@ static void ggml_cl_set_rows(ggml_backend_t backend, const ggml_tensor * src0, c
 
     switch (dst->type) {
         case GGML_TYPE_F32:
-            kernel = backend_ctx->kernel_set_rows_f32;
+            if (src1->type == GGML_TYPE_I64) {
+                kernel = backend_ctx->kernel_set_rows_f32_i64;
+            } else {
+                kernel = backend_ctx->kernel_set_rows_f32_i32;
+            }
             break;
         case GGML_TYPE_F16:
-            kernel = backend_ctx->kernel_set_rows_f16;
+            if (src1->type == GGML_TYPE_I64) {
+                kernel = backend_ctx->kernel_set_rows_f16_i64;
+            } else {
+                kernel = backend_ctx->kernel_set_rows_f16_i32;
+            }
             break;
         default:
             GGML_ABORT("not implemented");
index a94b4361b4d3311fbff5555dcd4a629d8a5bb9b2..dcdc1d1b6fdc89cce37223acaf5de3117ccb9d70 100644 (file)
@@ -1,6 +1,6 @@
 #pragma OPENCL EXTENSION cl_khr_fp16 : enable
 
-kernel void kernel_set_rows_f32(
+kernel void kernel_set_rows_f32_i64(
         global char * src0,
         ulong         offset0,
         global char * src1,
@@ -47,7 +47,7 @@ kernel void kernel_set_rows_f32(
     }
 }
 
-kernel void kernel_set_rows_f16(
+kernel void kernel_set_rows_f16_i64(
         global char * src0,
         ulong         offset0,
         global char * src1,
@@ -93,3 +93,97 @@ kernel void kernel_set_rows_f16(
         dst_row[ind] = src_row[ind];
     }
 }
+
+kernel void kernel_set_rows_f32_i32(
+        global char * src0,
+        ulong         offset0,
+        global char * src1,
+        ulong         offset1,
+        global char * dst,
+        ulong         offsetd,
+        int           ne01,
+        ulong         nb01,
+        ulong         nb02,
+        ulong         nb03,
+        int           ne11,
+        int           ne12,
+        ulong         nb10,
+        ulong         nb11,
+        ulong         nb12,
+        int           nblk0,
+        ulong         nb1,
+        ulong         nb2,
+        ulong         nb3
+) {
+    src0 = src0 + offset0;
+    src1 = src1 + offset1;
+    dst  = dst  + offsetd;
+
+    int i03 = get_group_id(2);
+    int i02 = get_group_id(1);
+    int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
+
+    if (i01 >= ne01) {
+        return;
+    }
+
+    int i12 = i03%ne12;
+    int i11 = i02%ne11;
+
+    int i10 = i01;
+    int i1  = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
+
+    global float * dst_row = (global float *) (dst  +  i1*nb1  + i02*nb2  + i03*nb3);
+    global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
+
+    for (int ind = get_local_id(0); ind < nblk0; ind += get_local_size(0)) {
+        dst_row[ind] = (float)src_row[ind];
+    }
+}
+
+kernel void kernel_set_rows_f16_i32(
+        global char * src0,
+        ulong         offset0,
+        global char * src1,
+        ulong         offset1,
+        global char * dst,
+        ulong         offsetd,
+        int           ne01,
+        ulong         nb01,
+        ulong         nb02,
+        ulong         nb03,
+        int           ne11,
+        int           ne12,
+        ulong         nb10,
+        ulong         nb11,
+        ulong         nb12,
+        int           nblk0,
+        ulong         nb1,
+        ulong         nb2,
+        ulong         nb3
+) {
+    src0 = src0 + offset0;
+    src1 = src1 + offset1;
+    dst  = dst  + offsetd;
+
+    int i03 = get_group_id(2);
+    int i02 = get_group_id(1);
+    int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
+
+    if (i01 >= ne01) {
+        return;
+    }
+
+    int i12 = i03%ne12;
+    int i11 = i02%ne11;
+
+    int i10 = i01;
+    int i1  = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
+
+    global half  * dst_row = (global half  *) (dst  +  i1*nb1  + i02*nb2  + i03*nb3);
+    global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
+
+    for (int ind = get_local_id(0); ind < nblk0; ind += get_local_size(0)) {
+        dst_row[ind] = src_row[ind];
+    }
+}
index 78853eb67671c3136bd1599115bfa5c2c4774779..4ac919ea2d75757f412c50a0385afd3d6479270d 100644 (file)
@@ -4271,7 +4271,7 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
                 return ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||
                          op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q5_0 ||
                          op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_IQ4_NL) &&
-                        (op->src[1]->type == GGML_TYPE_I64));
+                        (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32));
             }
             break;
         case GGML_OP_CPY:
index fbe15ffdd77e7f66a0d0c97b75696c845d350846..a641c100913123e9bb973fb60b681358e4aa5fda 100644 (file)
@@ -16,9 +16,9 @@ convert (const char* src, char* dst) {
    *reinterpret_cast<TOut*>(dst) = dst_val;
 }
 
-template <typename blockType, int qk, cpy_kernel_t cpyblck>
+template <typename TIdx, typename blockType, int qk, cpy_kernel_t cpyblck>
 static void set_rows_sycl_q(const char * __restrict__ src0_d,
-                            const int64_t * __restrict__ src1_d,
+                            const TIdx * __restrict__ src1_d,
                             blockType * __restrict__ dst_d,
                             // tensor dimensions src0 and src1
                             const int64_t ne00,
@@ -66,7 +66,7 @@ static void set_rows_sycl_q(const char * __restrict__ src0_d,
         const size_t  src_offset  = calculate_offset<3>({ nb01, nb02, nb03 }, { i01, i02, i03 });
         const char *  src_block   = src0_d + src_offset + i00 * sizeof(float);
         const size_t  src1_offset = calculate_offset<3>({ nb10, nb11, nb12 }, { i10, i11, i12 });
-        const int64_t dst_row     = src1_d[src1_offset / sizeof(int64_t)];
+        const int64_t dst_row     = src1_d[src1_offset / sizeof(TIdx)];
         const size_t  dst_offset =
             calculate_offset<3>({ nb1, nb2, nb3 }, { dst_row, i02, i03 }) + (i00 / qk) * sizeof(blockType);
         char * dst_block = reinterpret_cast<char *>(reinterpret_cast<char *>(dst_d) + dst_offset);
@@ -78,9 +78,9 @@ static void set_rows_sycl_q(const char * __restrict__ src0_d,
     GGML_UNUSED(nb13);
 }
 
-template<typename TIn, typename TOut>
+template<typename TIn, typename TIdx, typename TOut>
 static void k_set_rows(
-        const char * __restrict__ src0, const int64_t * __restrict__ src1, char * __restrict__ dst,
+        const char * __restrict__ src0, const TIdx * __restrict__ src1, char * __restrict__ dst,
         const int64_t ne00, const int64_t ne01, const int64_t ne02,
         const int64_t ne11, const int64_t ne12,
         const size_t nb01, const size_t nb02, const size_t nb03,
@@ -104,7 +104,7 @@ static void k_set_rows(
     const int64_t i11 = i02 % ne11;
     const int64_t i10 = i01;
 
-    const int64_t dst_row = *(const int64_t *)((const char *)src1 + calculate_offset<3>({nb10, nb11, nb12}, {i10, i11, i12}));
+    const int64_t dst_row = *(const TIdx *)((const char *)src1 + calculate_offset<3>({nb10, nb11, nb12}, {i10, i11, i12}));
 
     const char * src0_row = src0 + calculate_offset<3>({nb01, nb02, nb03}, {i01, i02, i03});
     const char * src_elem = src0_row + i00 * src_type_size;
@@ -114,9 +114,9 @@ static void k_set_rows(
     convert<TIn, TOut>(src_elem, dst_elem);
 }
 
-template<typename TIn, typename TOut>
+template<typename TIn, typename TIdx, typename TOut>
 static void set_rows_sycl(
-        const char * src0_d, const int64_t * src1_d, char * dst_d,
+        const char * src0_d, const TIdx * src1_d, char * dst_d,
         const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
         const int64_t ne11, const int64_t ne12, const size_t nb01, const size_t nb02, const size_t nb03,
         const size_t nb10, const size_t nb11, const size_t nb12,
@@ -132,7 +132,7 @@ static void set_rows_sycl(
     stream->parallel_for(
         sycl::nd_range<1>(grid_size * block_size, block_size),
         [=](sycl::nd_item<1> item_ct1) {
-            k_set_rows<TIn, TOut>(
+            k_set_rows<TIn, TIdx, TOut>(
                 src0_d, src1_d, dst_d,
                 ne00, ne01, ne02,
                 ne11, ne12,
@@ -147,74 +147,69 @@ static void set_rows_sycl(
     );
 }
 
-void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
-    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
-    const ggml_tensor * src0 = dst->src[0];
-    const ggml_tensor * src1 = dst->src[1];
-
-    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
-    GGML_ASSERT(dst->src[1]->type == GGML_TYPE_I64);
+template<typename TIn, typename TIdx>
+static void set_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    const char * src0_d = (const char *)src0->data;
+    const TIdx * src1_d = (const TIdx *)src1->data;
 
     GGML_TENSOR_BINARY_OP_LOCALS
 
-    const int64_t * src1_dd = static_cast<const int64_t *>(src1->data);
-
     dpct::queue_ptr stream = ctx.stream();
     switch (dst->type) {
         case GGML_TYPE_F32:
-            set_rows_sycl<float, float>(
-                (const char *)src0->data, src1_dd, (char *)dst->data,
+            set_rows_sycl<TIn, TIdx, float>(
+                src0_d, src1_d, (char *)dst->data,
                 ne00, ne01, ne02, ne03,
                 ne11, ne12,
                 nb01, nb02, nb03,
                 nb10, nb11, nb12,
                 nb1, nb2, nb3,
-                sizeof(float), sizeof(float),
+                sizeof(TIn), sizeof(float),
                 stream
             );
             break;
         case GGML_TYPE_F16:
             dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
-            set_rows_sycl<float, sycl::half>(
-                (const char *)src0->data, src1_dd, (char *)dst->data,
+            set_rows_sycl<TIn, TIdx, sycl::half>(
+                src0_d, src1_d, (char *)dst->data,
                 ne00, ne01, ne02, ne03,
                 ne11, ne12,
                 nb01, nb02, nb03,
                 nb10, nb11, nb12,
                 nb1, nb2, nb3,
-                sizeof(float), sizeof(sycl::half),
+                sizeof(TIn), sizeof(sycl::half),
                 stream
             );
             break;
         case GGML_TYPE_BF16:
-            set_rows_sycl<float, sycl::ext::oneapi::bfloat16>(
-                (const char *)src0->data, src1_dd, (char *)dst->data,
+            set_rows_sycl<TIn, TIdx, sycl::ext::oneapi::bfloat16>(
+                src0_d, src1_d, (char *)dst->data,
                 ne00, ne01, ne02, ne03,
                 ne11, ne12,
                 nb01, nb02, nb03,
                 nb10, nb11, nb12,
                 nb1, nb2, nb3,
-                sizeof(float), sizeof(sycl::ext::oneapi::bfloat16),
+                sizeof(TIn), sizeof(sycl::ext::oneapi::bfloat16),
                 stream
             );
             break;
         case GGML_TYPE_Q8_0:
-            set_rows_sycl_q<block_q8_0, QK8_0, cpy_blck_f32_q8_0>((const char *)src0->data, src1_dd, (block_q8_0 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
+            set_rows_sycl_q<TIdx, block_q8_0, QK8_0, cpy_blck_f32_q8_0>(src0_d, src1_d, (block_q8_0 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
             break;
         case GGML_TYPE_Q5_1:
-            set_rows_sycl_q<block_q5_1, QK5_1, cpy_blck_f32_q5_1>((const char *)src0->data, src1_dd, (block_q5_1 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
+            set_rows_sycl_q<TIdx, block_q5_1, QK5_1, cpy_blck_f32_q5_1>(src0_d, src1_d, (block_q5_1 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
             break;
         case GGML_TYPE_Q5_0:
-            set_rows_sycl_q<block_q5_0, QK5_0, cpy_blck_f32_q5_0>((const char *)src0->data, src1_dd, (block_q5_0 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
+            set_rows_sycl_q<TIdx, block_q5_0, QK5_0, cpy_blck_f32_q5_0>(src0_d, src1_d, (block_q5_0 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
             break;
         case GGML_TYPE_Q4_1:
-            set_rows_sycl_q<block_q4_1, QK4_1, cpy_blck_f32_q4_1>((const char *)src0->data, src1_dd, (block_q4_1 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
+            set_rows_sycl_q<TIdx, block_q4_1, QK4_1, cpy_blck_f32_q4_1>(src0_d, src1_d, (block_q4_1 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
             break;
         case GGML_TYPE_Q4_0:
-            set_rows_sycl_q<block_q4_0, QK4_0, cpy_blck_f32_q4_0>((const char *)src0->data, src1_dd, (block_q4_0 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
+            set_rows_sycl_q<TIdx, block_q4_0, QK4_0, cpy_blck_f32_q4_0>(src0_d, src1_d, (block_q4_0 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
             break;
         case GGML_TYPE_IQ4_NL:
-            set_rows_sycl_q<block_iq4_nl, QK4_NL, cpy_blck_f32_iq4_nl>((const char *)src0->data, src1_dd, (block_iq4_nl *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
+            set_rows_sycl_q<TIdx, block_iq4_nl, QK4_NL, cpy_blck_f32_iq4_nl>(src0_d, src1_d, (block_iq4_nl *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
             break;
 
         default:
@@ -222,3 +217,18 @@ void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
             break;
     }
 }
+
+void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+    scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
+    const ggml_tensor * src0 = dst->src[0];
+    const ggml_tensor * src1 = dst->src[1];
+
+    GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
+    GGML_ASSERT(dst->src[1]->type == GGML_TYPE_I64 || dst->src[1]->type == GGML_TYPE_I32);
+
+    if (src1->type == GGML_TYPE_I64) {
+        set_rows_sycl<float, int64_t>(ctx, src0, src1, dst);
+    } else {
+        set_rows_sycl<float, int32_t>(ctx, src0, src1, dst);
+    }
+}
index 0feaf4cb5d7c23383ff74f583a969aff5454bbab..ebbb412e55f4bfbe8d4784ec47310113ac27e168 100644 (file)
@@ -520,7 +520,8 @@ struct vk_device_struct {
     vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32;
     vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
     vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
-    vk_pipeline pipeline_set_rows[GGML_TYPE_COUNT];
+    vk_pipeline pipeline_set_rows_i32[GGML_TYPE_COUNT];
+    vk_pipeline pipeline_set_rows_i64[GGML_TYPE_COUNT];
     vk_pipeline pipeline_norm_f32;
     vk_pipeline pipeline_group_norm_f32;
     vk_pipeline pipeline_rms_norm_f32;
@@ -3348,27 +3349,26 @@ static void ggml_vk_load_shaders(vk_device& device) {
         ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
     }
 
+#define SET_ROWS(itype, rte) \
+        ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F32],  "set_rows_f32" #itype,  set_rows_f32 ## itype ## rte ## _len,  set_rows_f32 ## itype ## rte ## _data,  "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
+        ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F16],  "set_rows_f16" #itype,  set_rows_f16 ## itype ## rte ## _len,  set_rows_f16 ## itype ## rte ## _data,  "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
+        ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_BF16], "set_rows_bf16" #itype, set_rows_bf16 ## itype ## rte ## _len, set_rows_bf16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
+        ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_0], "set_rows_q4_0" #itype, set_rows_q4_0 ## itype ## rte ## _len, set_rows_q4_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
+        ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_1], "set_rows_q4_1" #itype, set_rows_q4_1 ## itype ## rte ## _len, set_rows_q4_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
+        ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_0], "set_rows_q5_0" #itype, set_rows_q5_0 ## itype ## rte ## _len, set_rows_q5_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
+        ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_1], "set_rows_q5_1" #itype, set_rows_q5_1 ## itype ## rte ## _len, set_rows_q5_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
+        ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q8_0], "set_rows_q8_0" #itype, set_rows_q8_0 ## itype ## rte ## _len, set_rows_q8_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
+        ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_IQ4_NL], "set_rows_iq4_nl" #itype, set_rows_iq4_nl ## itype ## rte ## _len, set_rows_iq4_nl ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
+
     if (device->float_controls_rte_fp16) {
-        ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F32],  "set_rows_f32",  set_rows_f32_rte_len,  set_rows_f32_rte_data,  "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
-        ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F16],  "set_rows_f16",  set_rows_f16_rte_len,  set_rows_f16_rte_data,  "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
-        ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_BF16], "set_rows_bf16", set_rows_bf16_rte_len, set_rows_bf16_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
-        ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_0], "set_rows_q4_0", set_rows_q4_0_rte_len, set_rows_q4_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
-        ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_1], "set_rows_q4_1", set_rows_q4_1_rte_len, set_rows_q4_1_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
-        ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_0], "set_rows_q5_0", set_rows_q5_0_rte_len, set_rows_q5_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
-        ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_1], "set_rows_q5_1", set_rows_q5_1_rte_len, set_rows_q5_1_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
-        ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q8_0], "set_rows_q8_0", set_rows_q8_0_rte_len, set_rows_q8_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
-        ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_IQ4_NL], "set_rows_iq4_nl", set_rows_iq4_nl_rte_len, set_rows_iq4_nl_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
+        SET_ROWS(_i32, _rte)
+        SET_ROWS(_i64, _rte)
     } else {
-        ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F32],  "set_rows_f32",  set_rows_f32_len,  set_rows_f32_data,  "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
-        ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F16],  "set_rows_f16",  set_rows_f16_len,  set_rows_f16_data,  "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
-        ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_BF16], "set_rows_bf16", set_rows_bf16_len, set_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
-        ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_0], "set_rows_q4_0", set_rows_q4_0_len, set_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
-        ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_1], "set_rows_q4_1", set_rows_q4_1_len, set_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
-        ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_0], "set_rows_q5_0", set_rows_q5_0_len, set_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
-        ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_1], "set_rows_q5_1", set_rows_q5_1_len, set_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
-        ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q8_0], "set_rows_q8_0", set_rows_q8_0_len, set_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
-        ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_IQ4_NL], "set_rows_iq4_nl", set_rows_iq4_nl_len, set_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
+        SET_ROWS(_i32, )
+        SET_ROWS(_i64, )
     }
+#undef SET_ROWS
+
 
     ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
     ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_1], "cpy_q4_1_f32", cpy_q4_1_f32_len, cpy_q4_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
@@ -7772,7 +7772,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
     case GGML_OP_DUP:
         return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type);
     case GGML_OP_SET_ROWS:
-        return ctx->device->pipeline_set_rows[dst->type];
+        if (src1->type == GGML_TYPE_I64) {
+            return ctx->device->pipeline_set_rows_i64[dst->type];
+        } else {
+            return ctx->device->pipeline_set_rows_i32[dst->type];
+        }
     case GGML_OP_SILU_BACK:
         if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
             return ctx->device->pipeline_silu_back_f32;
index 27d6b7464f62c0fc5fd12ee536792124f9c0da67..bc2e1f2df3e13498648b75741967ac2ebff1425c 100644 (file)
@@ -15,8 +15,15 @@ layout (binding = 0) readonly buffer S {float data_s[];};
 
 #if defined(SET_ROWS)
 #include "generic_binary_head.comp"
-layout (binding = 1) readonly buffer C {uvec2 data_i[];};
+layout (binding = 1) readonly buffer C {B_TYPE data_i[];};
 layout (binding = 2) writeonly buffer Q {A_TYPE data_q[];};
+
+#if B_SIZE == 64
+#define DATA_I_SWIZZLE .x
+#else
+#define DATA_I_SWIZZLE
+#endif
+
 #else
 #include "generic_unary_head.comp"
 layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];};
@@ -259,7 +266,7 @@ void main() {
     uint i11 = fastmod(i02, p.ne11);
     uint i10 = i01;
 
-    uint i1 = data_i[src1_idx(i10, i11, i12, 0) + get_boffset()].x;
+    uint i1 = data_i[src1_idx(i10, i11, i12, 0) + get_boffset()] DATA_I_SWIZZLE;
 
     uint src0_idx = src0_idx(i00, i01, i02, i03) + get_aoffset();
     uint dst_idx = dst_idx(i00 / QUANT_K, i1, i02, i03) + get_doffset();
index 2531610e47b8980b2939561a488a8093d18483c6..79701544ff6d324e3b3c41029f5966fbe1a01cac 100644 (file)
@@ -635,8 +635,10 @@ void process_shaders() {
     }
 
     for (std::string t : {"f32", "f16", "bf16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
-        string_to_spv("set_rows_" + t, "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
-        string_to_spv("set_rows_" + t + "_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
+        string_to_spv("set_rows_" + t + "_i32",     "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
+        string_to_spv("set_rows_" + t + "_i32_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
+        string_to_spv("set_rows_" + t + "_i64",     "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
+        string_to_spv("set_rows_" + t + "_i64_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
     }
 
     auto get_type_str = [](bool f16) {
index a92ddc582a371a954e8081e505540802a9366f50..cee4b08366d79c387781a94b2e1bd5d7e87f8834 100644 (file)
@@ -1310,7 +1310,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
             break;
         case GGML_OP_CPY:
         case GGML_OP_SET_ROWS:
-            supports_op = (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32);
+            supports_op = (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_I64);
             break;
         case GGML_OP_GET_ROWS:
             if (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 ||
index 3584827dca7fcb14f002014147f9c9a75a8a81b8..fe36bab8362b2be68e956ed7bf3e8712092ef030 100644 (file)
@@ -3677,7 +3677,7 @@ struct ggml_tensor * ggml_set_rows(
     GGML_ASSERT(b->ne[3] % c->ne[2] == 0);
     GGML_ASSERT(c->ne[3] == 1);
     GGML_ASSERT(b->type == GGML_TYPE_F32);
-    GGML_ASSERT(c->type == GGML_TYPE_I64);
+    GGML_ASSERT(c->type == GGML_TYPE_I64 || c->type == GGML_TYPE_I32);
 
     GGML_ASSERT(ggml_is_contiguous_rows(a));
     GGML_ASSERT(ggml_is_contiguous_rows(b));
index f11eecd8e71a5f86ec72141d83435725692555ca..592631f3ed21ad480b4f2fe5c7194e98fb29304a 100644 (file)
@@ -2064,20 +2064,22 @@ struct test_get_rows_back : public test_case {
 // GGML_OP_SET_ROWS
 struct test_set_rows : public test_case {
     const ggml_type type;
+    const ggml_type type_idx;
     const std::array<int64_t, 4> ne;
     const std::array<int, 2> nr23; // broadcast only dims 2 and 3
     const int r; // rows to set
     const bool v; // view (non-contiguous src1)
 
     std::string vars() override {
-        return VARS_TO_STR5(type, ne, nr23, r, v);
+        return VARS_TO_STR6(type, type_idx, ne, nr23, r, v);
     }
 
     test_set_rows(ggml_type type,
+            ggml_type type_idx,
             std::array<int64_t, 4> ne,
             std::array<int, 2> nr23,
             int r, bool v = false)
-        : type(type), ne(ne), nr23(nr23), r(r), v(v) {}
+        : type(type), type_idx(type_idx), ne(ne), nr23(nr23), r(r), v(v) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * dst = ggml_new_tensor_4d(ctx, type,          ne[0], ne[1], ne[2]*nr23[0], ne[3]*nr23[1]);
@@ -2086,7 +2088,7 @@ struct test_set_rows : public test_case {
         ggml_tensor * src = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, ne[0], r,     ne[2]*nr23[0], ne[3]*nr23[1]);
         ggml_set_name(src, "src");
 
-        ggml_tensor * row_idxs = ggml_new_tensor_3d(ctx, GGML_TYPE_I64, r, ne[2], ne[3]);
+        ggml_tensor * row_idxs = ggml_new_tensor_3d(ctx, type_idx, r, ne[2], ne[3]);
         ggml_set_name(row_idxs, "row_idxs");
 
         if (v) {
@@ -2105,7 +2107,7 @@ struct test_set_rows : public test_case {
         std::random_device rd;
         std::default_random_engine rng(rd());
         for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
-            if (t->type == GGML_TYPE_I64) {
+            if (t->type == GGML_TYPE_I64 || t->type == GGML_TYPE_I32) {
                 if (ggml_is_view_op(t->op)) {
                     continue;
                 }
@@ -2121,7 +2123,16 @@ struct test_set_rows : public test_case {
                         data.resize(t->ne[0]);
 
                         const size_t offs = i1*t->nb[1] + i2*t->nb[2];
-                        ggml_backend_tensor_set(t, data.data(), offs, t->ne[0]*sizeof(int64_t));
+                        if (t->type == GGML_TYPE_I32) {
+                            // TODO: Make a template or something
+                            std::vector<int32_t> data_i32(t->ne[0]);
+                            for (int i = 0; i < t->ne[0]; i++) {
+                                data_i32[i] = static_cast<int32_t>(data[i]);
+                            }
+                            ggml_backend_tensor_set(t, data_i32.data(), offs, t->ne[0]*sizeof(int32_t));
+                        } else {
+                            ggml_backend_tensor_set(t, data.data(), offs, t->ne[0]*sizeof(int64_t));
+                        }
                     }
                 }
             } else {
@@ -5662,18 +5673,20 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
         test_cases.emplace_back(new test_get_rows_back(GGML_TYPE_I32, 256, 5, 4, 1, v));
     }
 
-    test_cases.emplace_back(new test_set_rows(GGML_TYPE_F32, { 1, 8, 1, 3 }, { 1, 1 }, 2, false));
+    test_cases.emplace_back(new test_set_rows(GGML_TYPE_F32, GGML_TYPE_I64, { 1, 8, 1, 3 }, { 1, 1 }, 2, false));
+    test_cases.emplace_back(new test_set_rows(GGML_TYPE_F32, GGML_TYPE_I32, { 1, 8, 1, 3 }, { 1, 1 }, 2, false));
+    test_cases.emplace_back(new test_set_rows(GGML_TYPE_Q8_0, GGML_TYPE_I32, { 256, 5, 1, 3 }, { 1, 1, }, 1, false));
     for (ggml_type type : all_types) {
         for (int b : {1, 7}) {
             for (bool v : {false, true}) {
-                test_cases.emplace_back(new test_set_rows(type, { 256, 5,  b, 3 }, { 1, 1, }, 1, v));
-                test_cases.emplace_back(new test_set_rows(type, { 256, 11, 1, b }, { 2, 3, }, 7, v));
+                test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 256, 5,  b, 3 }, { 1, 1, }, 1, v));
+                test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 256, 11, 1, b }, { 2, 3, }, 7, v));
 
-                test_cases.emplace_back(new test_set_rows(type, { 3*ggml_blck_size(type), 3, b, 1 }, { 2, 3, }, 2, v));
+                test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 3*ggml_blck_size(type), 3, b, 1 }, { 2, 3, }, 2, v));
 
                 if (ggml_blck_size(type) == 1) {
-                    test_cases.emplace_back(new test_set_rows(type, { 31, 3, b, 1 }, { 2, 3, }, 2, v));
-                    test_cases.emplace_back(new test_set_rows(type, { 33, 5, 1, b }, { 2, 3, }, 1, v));
+                    test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 31, 3, b, 1 }, { 2, 3, }, 2, v));
+                    test_cases.emplace_back(new test_set_rows(type, GGML_TYPE_I64, { 33, 5, 1, b }, { 2, 3, }, 1, v));
                 }
             }
         }