//}
}
+template<typename idx_t>
static void ggml_compute_forward_set_rows_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const int64_t i11 = i02%ne11;
const int64_t i10 = i;
- const int64_t i1 = *(int64_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+ const int64_t i1 = *(idx_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
GGML_ASSERT(i1 >= 0 && i1 < ne1);
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
switch (src0->type) {
case GGML_TYPE_F32:
{
- ggml_compute_forward_set_rows_f32(params, dst);
+ if (src1->type == GGML_TYPE_I64) {
+ ggml_compute_forward_set_rows_f32<int64_t>(params, dst);
+ } else if (src1->type == GGML_TYPE_I32) {
+ ggml_compute_forward_set_rows_f32<int32_t>(params, dst);
+ } else {
+ GGML_ABORT("src1->type = %d (%s) not supported", src1->type, ggml_type_name(src1->type));
+ }
} break;
default:
{
op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q5_0 ||
op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_IQ4_NL) &&
op->src[0]->type == GGML_TYPE_F32 &&
- op->src[1]->type == GGML_TYPE_I64;
+ (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32);
} break;
case GGML_OP_CPY:
{
typedef void (*set_rows_kernel_t)(const char * src, char * dst);
// Generic quantized set_rows kernel template
-template<typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
+template<typename idx_t, typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
static __global__ void k_set_rows_quant(
- const float * __restrict__ src0, const int64_t * __restrict__ src1, block_type * __restrict__ dst,
+ const float * __restrict__ src0, const idx_t * __restrict__ src1, block_type * __restrict__ dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
const int64_t s01, const int64_t s02, const int64_t s03,
}
// Template dispatch function for quantized set_rows
-template<typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
+template<typename idx_t, typename block_type, int qk, void (*quantize_func)(const float*, block_type*)>
static void set_rows_cuda_quant(
- const float * src0_d, const int64_t * src1_d, block_type * dst_d,
+ const float * src0_d, const idx_t * src1_d, block_type * dst_d,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
const size_t nb01, const size_t nb02, const size_t nb03,
const int64_t s01 = nb01/sizeof(float);
const int64_t s02 = nb02/sizeof(float);
const int64_t s03 = nb03/sizeof(float);
- const int64_t s10 = nb10/sizeof(int64_t);
- const int64_t s11 = nb11/sizeof(int64_t);
- const int64_t s12 = nb12/sizeof(int64_t);
+ const int64_t s10 = nb10/sizeof(idx_t);
+ const int64_t s11 = nb11/sizeof(idx_t);
+ const int64_t s12 = nb12/sizeof(idx_t);
const int64_t s1 = nb1;
const int64_t s2 = nb2;
const int64_t s3 = nb3;
if (ne_total > 0) {
- k_set_rows_quant<block_type, qk, quantize_func><<<grid_size, block_size, 0, stream>>>(
+ k_set_rows_quant<idx_t, block_type, qk, quantize_func><<<grid_size, block_size, 0, stream>>>(
src0_d, src1_d, dst_d,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
}
}
-template<typename src_t, typename dst_t>
+template<typename src_t, typename idx_t, typename dst_t>
static __global__ void k_set_rows(
- const src_t * __restrict__ src0, const int64_t * __restrict__ src1, dst_t * __restrict__ dst,
+ const src_t * __restrict__ src0, const idx_t * __restrict__ src1, dst_t * __restrict__ dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
const int64_t s01, const int64_t s02, const int64_t s03,
GGML_UNUSED(ne13);
}
-template<typename src_t, typename dst_t>
+template<typename src_t, typename idx_t, typename dst_t>
static void set_rows_cuda(
- const src_t * src0_d, const int64_t * src1_d, dst_t * dst_d,
+ const src_t * src0_d, const idx_t * src1_d, dst_t * dst_d,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13,
const size_t nb01, const size_t nb02, const size_t nb03,
const int64_t s01 = nb01/sizeof(src_t);
const int64_t s02 = nb02/sizeof(src_t);
const int64_t s03 = nb03/sizeof(src_t);
- const int64_t s10 = nb10/sizeof(int64_t);
- const int64_t s11 = nb11/sizeof(int64_t);
- const int64_t s12 = nb12/sizeof(int64_t);
+ const int64_t s10 = nb10/sizeof(idx_t);
+ const int64_t s11 = nb11/sizeof(idx_t);
+ const int64_t s12 = nb12/sizeof(idx_t);
const int64_t s1 = nb1/sizeof(dst_t);
const int64_t s2 = nb2/sizeof(dst_t);
const int64_t s3 = nb3/sizeof(dst_t);
}
}
-
-void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
- const ggml_tensor * src0 = dst->src[0];
- const ggml_tensor * src1 = dst->src[1];
-
- GGML_ASSERT(src0->type == GGML_TYPE_F32);
- GGML_ASSERT(src1->type == GGML_TYPE_I64);
+template<typename src_t, typename idx_t>
+static void set_rows_cuda(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ const src_t * src0_d = (const src_t *)src0->data;
+ const idx_t * src1_d = (const idx_t *)src1->data;
GGML_TENSOR_BINARY_OP_LOCALS
- const float * src0_d = (const float *)src0->data;
- const int64_t * src1_d = (const int64_t *)src1->data;
-
cudaStream_t stream = ctx.stream();
-
if (dst->type == GGML_TYPE_F32) {
set_rows_cuda(
src0_d, src1_d, (float*)dst->data,
stream
);
} else if (dst->type == GGML_TYPE_Q4_0) {
- set_rows_cuda_quant<block_q4_0, QK4_0, quantize_f32_q4_0_block>(
+ set_rows_cuda_quant<idx_t, block_q4_0, QK4_0, quantize_f32_q4_0_block>(
src0_d, src1_d, (block_q4_0*)dst->data,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
stream
);
} else if (dst->type == GGML_TYPE_Q4_1) {
- set_rows_cuda_quant<block_q4_1, QK4_1, quantize_f32_q4_1_block>(
+ set_rows_cuda_quant<idx_t, block_q4_1, QK4_1, quantize_f32_q4_1_block>(
src0_d, src1_d, (block_q4_1*)dst->data,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
stream
);
} else if (dst->type == GGML_TYPE_Q5_0) {
- set_rows_cuda_quant<block_q5_0, QK5_0, quantize_f32_q5_0_block>(
+ set_rows_cuda_quant<idx_t, block_q5_0, QK5_0, quantize_f32_q5_0_block>(
src0_d, src1_d, (block_q5_0*)dst->data,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
stream
);
} else if (dst->type == GGML_TYPE_Q5_1) {
- set_rows_cuda_quant<block_q5_1, QK5_1, quantize_f32_q5_1_block>(
+ set_rows_cuda_quant<idx_t, block_q5_1, QK5_1, quantize_f32_q5_1_block>(
src0_d, src1_d, (block_q5_1*)dst->data,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
stream
);
} else if (dst->type == GGML_TYPE_Q8_0) {
- set_rows_cuda_quant<block_q8_0, QK8_0, quantize_f32_q8_0_block>(
+ set_rows_cuda_quant<idx_t, block_q8_0, QK8_0, quantize_f32_q8_0_block>(
src0_d, src1_d, (block_q8_0*)dst->data,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
stream
);
} else if (dst->type == GGML_TYPE_IQ4_NL) {
- set_rows_cuda_quant<block_iq4_nl, QK4_NL, quantize_f32_iq4_nl_block>(
+ set_rows_cuda_quant<idx_t, block_iq4_nl, QK4_NL, quantize_f32_iq4_nl_block>(
src0_d, src1_d, (block_iq4_nl*)dst->data,
ne00, ne01, ne02, ne03,
ne10, ne11, ne12, ne13,
GGML_ABORT("unsupported type %s", ggml_type_name(dst->type));
}
}
+
+
+void ggml_cuda_op_set_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32);
+
+ if (src1->type == GGML_TYPE_I64) {
+ set_rows_cuda<float, int64_t>(ctx, src0, src1, dst);
+ } else {
+ set_rows_cuda<float, int32_t>(ctx, src0, src1, dst);
+ }
+}
return res;
}
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows(ggml_metal_library_t lib, ggml_type tdst) {
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows(ggml_metal_library_t lib, ggml_type tidx, ggml_type tdst) {
char base[256];
char name[256];
- snprintf(base, 256, "kernel_set_rows_%s", ggml_type_name(tdst));
+ snprintf(base, 256, "kernel_set_rows_%s_%s", ggml_type_name(tdst), ggml_type_name(tidx));
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc);
-ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tdst);
+ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
- ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->type);
+ ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
const int32_t nk0 = ne0/ggml_blck_size(op->type);
}
}
-template<typename block_q, void (*quantize_func)(device const float *, device block_q &)>
+template<typename TI, typename block_q, void (*quantize_func)(device const float *, device block_q &)>
kernel void kernel_set_rows_q32(
constant ggml_metal_kargs_set_rows & args,
device const void * src0,
}
const int32_t i10 = i01;
- const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
+ const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
device block_q * dst_row = ( device block_q *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
}
}
-template<typename T>
+template<typename T, typename TI>
kernel void kernel_set_rows_f(
constant ggml_metal_kargs_set_rows & args,
device const void * src0,
}
const int32_t i10 = i01;
- const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
+ const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
// set rows
//
-typedef decltype(kernel_set_rows_f<float>) set_rows_f_t;
+typedef decltype(kernel_set_rows_f<float, int64_t>) set_rows_f_t;
-template [[host_name("kernel_set_rows_f32")]] kernel set_rows_f_t kernel_set_rows_f<float>;
-template [[host_name("kernel_set_rows_f16")]] kernel set_rows_f_t kernel_set_rows_f<half>;
+template [[host_name("kernel_set_rows_f32_i64")]] kernel set_rows_f_t kernel_set_rows_f<float, int64_t>;
+template [[host_name("kernel_set_rows_f32_i32")]] kernel set_rows_f_t kernel_set_rows_f<float, int32_t>;
+template [[host_name("kernel_set_rows_f16_i64")]] kernel set_rows_f_t kernel_set_rows_f<half, int64_t>;
+template [[host_name("kernel_set_rows_f16_i32")]] kernel set_rows_f_t kernel_set_rows_f<half, int32_t>;
#if defined(GGML_METAL_HAS_BF16)
-template [[host_name("kernel_set_rows_bf16")]] kernel set_rows_f_t kernel_set_rows_f<bfloat>;
+template [[host_name("kernel_set_rows_bf16_i64")]] kernel set_rows_f_t kernel_set_rows_f<bfloat, int64_t>;
+template [[host_name("kernel_set_rows_bf16_i32")]] kernel set_rows_f_t kernel_set_rows_f<bfloat, int32_t>;
#endif
-typedef decltype(kernel_set_rows_q32<block_q8_0, quantize_q8_0>) set_rows_q32_t;
-
-template [[host_name("kernel_set_rows_q8_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q8_0, quantize_q8_0>;
-template [[host_name("kernel_set_rows_q4_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q4_0, quantize_q4_0>;
-template [[host_name("kernel_set_rows_q4_1")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q4_1, quantize_q4_1>;
-template [[host_name("kernel_set_rows_q5_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q5_0, quantize_q5_0>;
-template [[host_name("kernel_set_rows_q5_1")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q5_1, quantize_q5_1>;
-template [[host_name("kernel_set_rows_iq4_nl")]] kernel set_rows_q32_t kernel_set_rows_q32<block_iq4_nl, quantize_iq4_nl>;
+typedef decltype(kernel_set_rows_q32<int64_t, block_q8_0, quantize_q8_0>) set_rows_q32_t;
+
+template [[host_name("kernel_set_rows_q8_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q8_0, quantize_q8_0>;
+template [[host_name("kernel_set_rows_q8_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q8_0, quantize_q8_0>;
+template [[host_name("kernel_set_rows_q4_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q4_0, quantize_q4_0>;
+template [[host_name("kernel_set_rows_q4_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q4_0, quantize_q4_0>;
+template [[host_name("kernel_set_rows_q4_1_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q4_1, quantize_q4_1>;
+template [[host_name("kernel_set_rows_q4_1_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q4_1, quantize_q4_1>;
+template [[host_name("kernel_set_rows_q5_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q5_0, quantize_q5_0>;
+template [[host_name("kernel_set_rows_q5_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q5_0, quantize_q5_0>;
+template [[host_name("kernel_set_rows_q5_1_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q5_1, quantize_q5_1>;
+template [[host_name("kernel_set_rows_q5_1_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q5_1, quantize_q5_1>;
+template [[host_name("kernel_set_rows_iq4_nl_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_iq4_nl, quantize_iq4_nl>;
+template [[host_name("kernel_set_rows_iq4_nl_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_iq4_nl, quantize_iq4_nl>;
//
// matrix-matrix multiplication
std::map<std::pair<int, int>, int> kernels_flash_attn_bm;
std::map<std::pair<int, int>, int> kernels_flash_attn_bn;
cl_kernel kernel_get_rows_f32, kernel_get_rows_f16, kernel_get_rows_q4_0;
- cl_kernel kernel_set_rows_f32, kernel_set_rows_f16;
+ cl_kernel kernel_set_rows_f32_i64, kernel_set_rows_f32_i32, kernel_set_rows_f16_i64, kernel_set_rows_f16_i32;
cl_kernel kernel_rope_norm_f32, kernel_rope_norm_f16, kernel_rope_neox_f32, kernel_rope_neox_f16;
cl_kernel kernel_rope_multi_f32, kernel_rope_multi_f16, kernel_rope_vision_f32, kernel_rope_vision_f16;
cl_kernel kernel_cpy_f16_f16, kernel_cpy_f16_f32, kernel_cpy_f32_f16, kernel_cpy_f32_f32;
backend_ctx->program_set_rows =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
- CL_CHECK((backend_ctx->kernel_set_rows_f32 = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_f32", &err), err));
- CL_CHECK((backend_ctx->kernel_set_rows_f16 = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_f16", &err), err));
+ CL_CHECK((backend_ctx->kernel_set_rows_f32_i64 = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_f32_i64", &err), err));
+ CL_CHECK((backend_ctx->kernel_set_rows_f32_i32 = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_f32_i32", &err), err));
+ CL_CHECK((backend_ctx->kernel_set_rows_f16_i64 = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_f16_i64", &err), err));
+ CL_CHECK((backend_ctx->kernel_set_rows_f16_i32 = clCreateKernel(backend_ctx->program_set_rows, "kernel_set_rows_f16_i32", &err), err));
GGML_LOG_CONT(".");
}
switch (op->type) {
case GGML_TYPE_F16:
case GGML_TYPE_F32:
- return true;
+ return (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32);
default:
return false;
}
GGML_ASSERT(src1->extra);
GGML_ASSERT(dst);
GGML_ASSERT(dst->extra);
+ GGML_ASSERT(src1->type == GGML_TYPE_I64 || src1->type == GGML_TYPE_I32);
// ne0 = ne00
// ne2 = ne02
switch (dst->type) {
case GGML_TYPE_F32:
- kernel = backend_ctx->kernel_set_rows_f32;
+ if (src1->type == GGML_TYPE_I64) {
+ kernel = backend_ctx->kernel_set_rows_f32_i64;
+ } else {
+ kernel = backend_ctx->kernel_set_rows_f32_i32;
+ }
break;
case GGML_TYPE_F16:
- kernel = backend_ctx->kernel_set_rows_f16;
+ if (src1->type == GGML_TYPE_I64) {
+ kernel = backend_ctx->kernel_set_rows_f16_i64;
+ } else {
+ kernel = backend_ctx->kernel_set_rows_f16_i32;
+ }
break;
default:
GGML_ABORT("not implemented");
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
-kernel void kernel_set_rows_f32(
+kernel void kernel_set_rows_f32_i64(
global char * src0,
ulong offset0,
global char * src1,
}
}
-kernel void kernel_set_rows_f16(
+kernel void kernel_set_rows_f16_i64(
global char * src0,
ulong offset0,
global char * src1,
dst_row[ind] = src_row[ind];
}
}
+
+kernel void kernel_set_rows_f32_i32(
+ global char * src0,
+ ulong offset0,
+ global char * src1,
+ ulong offset1,
+ global char * dst,
+ ulong offsetd,
+ int ne01,
+ ulong nb01,
+ ulong nb02,
+ ulong nb03,
+ int ne11,
+ int ne12,
+ ulong nb10,
+ ulong nb11,
+ ulong nb12,
+ int nblk0,
+ ulong nb1,
+ ulong nb2,
+ ulong nb3
+) {
+ src0 = src0 + offset0;
+ src1 = src1 + offset1;
+ dst = dst + offsetd;
+
+ int i03 = get_group_id(2);
+ int i02 = get_group_id(1);
+ int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
+
+ if (i01 >= ne01) {
+ return;
+ }
+
+ int i12 = i03%ne12;
+ int i11 = i02%ne11;
+
+ int i10 = i01;
+ int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
+
+ global float * dst_row = (global float *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
+ global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
+
+ for (int ind = get_local_id(0); ind < nblk0; ind += get_local_size(0)) {
+ dst_row[ind] = (float)src_row[ind];
+ }
+}
+
+kernel void kernel_set_rows_f16_i32(
+ global char * src0,
+ ulong offset0,
+ global char * src1,
+ ulong offset1,
+ global char * dst,
+ ulong offsetd,
+ int ne01,
+ ulong nb01,
+ ulong nb02,
+ ulong nb03,
+ int ne11,
+ int ne12,
+ ulong nb10,
+ ulong nb11,
+ ulong nb12,
+ int nblk0,
+ ulong nb1,
+ ulong nb2,
+ ulong nb3
+) {
+ src0 = src0 + offset0;
+ src1 = src1 + offset1;
+ dst = dst + offsetd;
+
+ int i03 = get_group_id(2);
+ int i02 = get_group_id(1);
+ int i01 = get_group_id(0)*get_local_size(1) + get_local_id(1);
+
+ if (i01 >= ne01) {
+ return;
+ }
+
+ int i12 = i03%ne12;
+ int i11 = i02%ne11;
+
+ int i10 = i01;
+ int i1 = ((global int *)(src1 + i10*nb10 + i11*nb11 + i12*nb12))[0];
+
+ global half * dst_row = (global half *) (dst + i1*nb1 + i02*nb2 + i03*nb3);
+ global float * src_row = (global float *) (src0 + i01*nb01 + i02*nb02 + i03*nb03);
+
+ for (int ind = get_local_id(0); ind < nblk0; ind += get_local_size(0)) {
+ dst_row[ind] = src_row[ind];
+ }
+}
return ((op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16 || op->type == GGML_TYPE_BF16 ||
op->type == GGML_TYPE_Q8_0 || op->type == GGML_TYPE_Q5_1 || op->type == GGML_TYPE_Q5_0 ||
op->type == GGML_TYPE_Q4_1 || op->type == GGML_TYPE_Q4_0 || op->type == GGML_TYPE_IQ4_NL) &&
- (op->src[1]->type == GGML_TYPE_I64));
+ (op->src[1]->type == GGML_TYPE_I64 || op->src[1]->type == GGML_TYPE_I32));
}
break;
case GGML_OP_CPY:
*reinterpret_cast<TOut*>(dst) = dst_val;
}
-template <typename blockType, int qk, cpy_kernel_t cpyblck>
+template <typename TIdx, typename blockType, int qk, cpy_kernel_t cpyblck>
static void set_rows_sycl_q(const char * __restrict__ src0_d,
- const int64_t * __restrict__ src1_d,
+ const TIdx * __restrict__ src1_d,
blockType * __restrict__ dst_d,
// tensor dimensions src0 and src1
const int64_t ne00,
const size_t src_offset = calculate_offset<3>({ nb01, nb02, nb03 }, { i01, i02, i03 });
const char * src_block = src0_d + src_offset + i00 * sizeof(float);
const size_t src1_offset = calculate_offset<3>({ nb10, nb11, nb12 }, { i10, i11, i12 });
- const int64_t dst_row = src1_d[src1_offset / sizeof(int64_t)];
+ const int64_t dst_row = src1_d[src1_offset / sizeof(TIdx)];
const size_t dst_offset =
calculate_offset<3>({ nb1, nb2, nb3 }, { dst_row, i02, i03 }) + (i00 / qk) * sizeof(blockType);
char * dst_block = reinterpret_cast<char *>(reinterpret_cast<char *>(dst_d) + dst_offset);
GGML_UNUSED(nb13);
}
-template<typename TIn, typename TOut>
+template<typename TIn, typename TIdx, typename TOut>
static void k_set_rows(
- const char * __restrict__ src0, const int64_t * __restrict__ src1, char * __restrict__ dst,
+ const char * __restrict__ src0, const TIdx * __restrict__ src1, char * __restrict__ dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02,
const int64_t ne11, const int64_t ne12,
const size_t nb01, const size_t nb02, const size_t nb03,
const int64_t i11 = i02 % ne11;
const int64_t i10 = i01;
- const int64_t dst_row = *(const int64_t *)((const char *)src1 + calculate_offset<3>({nb10, nb11, nb12}, {i10, i11, i12}));
+ const int64_t dst_row = *(const TIdx *)((const char *)src1 + calculate_offset<3>({nb10, nb11, nb12}, {i10, i11, i12}));
const char * src0_row = src0 + calculate_offset<3>({nb01, nb02, nb03}, {i01, i02, i03});
const char * src_elem = src0_row + i00 * src_type_size;
convert<TIn, TOut>(src_elem, dst_elem);
}
-template<typename TIn, typename TOut>
+template<typename TIn, typename TIdx, typename TOut>
static void set_rows_sycl(
- const char * src0_d, const int64_t * src1_d, char * dst_d,
+ const char * src0_d, const TIdx * src1_d, char * dst_d,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t ne11, const int64_t ne12, const size_t nb01, const size_t nb02, const size_t nb03,
const size_t nb10, const size_t nb11, const size_t nb12,
stream->parallel_for(
sycl::nd_range<1>(grid_size * block_size, block_size),
[=](sycl::nd_item<1> item_ct1) {
- k_set_rows<TIn, TOut>(
+ k_set_rows<TIn, TIdx, TOut>(
src0_d, src1_d, dst_d,
ne00, ne01, ne02,
ne11, ne12,
);
}
-void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
- scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
- const ggml_tensor * src0 = dst->src[0];
- const ggml_tensor * src1 = dst->src[1];
-
- GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
- GGML_ASSERT(dst->src[1]->type == GGML_TYPE_I64);
+template<typename TIn, typename TIdx>
+static void set_rows_sycl(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ const char * src0_d = (const char *)src0->data;
+ const TIdx * src1_d = (const TIdx *)src1->data;
GGML_TENSOR_BINARY_OP_LOCALS
- const int64_t * src1_dd = static_cast<const int64_t *>(src1->data);
-
dpct::queue_ptr stream = ctx.stream();
switch (dst->type) {
case GGML_TYPE_F32:
- set_rows_sycl<float, float>(
- (const char *)src0->data, src1_dd, (char *)dst->data,
+ set_rows_sycl<TIn, TIdx, float>(
+ src0_d, src1_d, (char *)dst->data,
ne00, ne01, ne02, ne03,
ne11, ne12,
nb01, nb02, nb03,
nb10, nb11, nb12,
nb1, nb2, nb3,
- sizeof(float), sizeof(float),
+ sizeof(TIn), sizeof(float),
stream
);
break;
case GGML_TYPE_F16:
dpct::has_capability_or_fail(stream->get_device(), { sycl::aspect::fp16 });
- set_rows_sycl<float, sycl::half>(
- (const char *)src0->data, src1_dd, (char *)dst->data,
+ set_rows_sycl<TIn, TIdx, sycl::half>(
+ src0_d, src1_d, (char *)dst->data,
ne00, ne01, ne02, ne03,
ne11, ne12,
nb01, nb02, nb03,
nb10, nb11, nb12,
nb1, nb2, nb3,
- sizeof(float), sizeof(sycl::half),
+ sizeof(TIn), sizeof(sycl::half),
stream
);
break;
case GGML_TYPE_BF16:
- set_rows_sycl<float, sycl::ext::oneapi::bfloat16>(
- (const char *)src0->data, src1_dd, (char *)dst->data,
+ set_rows_sycl<TIn, TIdx, sycl::ext::oneapi::bfloat16>(
+ src0_d, src1_d, (char *)dst->data,
ne00, ne01, ne02, ne03,
ne11, ne12,
nb01, nb02, nb03,
nb10, nb11, nb12,
nb1, nb2, nb3,
- sizeof(float), sizeof(sycl::ext::oneapi::bfloat16),
+ sizeof(TIn), sizeof(sycl::ext::oneapi::bfloat16),
stream
);
break;
case GGML_TYPE_Q8_0:
- set_rows_sycl_q<block_q8_0, QK8_0, cpy_blck_f32_q8_0>((const char *)src0->data, src1_dd, (block_q8_0 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
+ set_rows_sycl_q<TIdx, block_q8_0, QK8_0, cpy_blck_f32_q8_0>(src0_d, src1_d, (block_q8_0 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_Q5_1:
- set_rows_sycl_q<block_q5_1, QK5_1, cpy_blck_f32_q5_1>((const char *)src0->data, src1_dd, (block_q5_1 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
+ set_rows_sycl_q<TIdx, block_q5_1, QK5_1, cpy_blck_f32_q5_1>(src0_d, src1_d, (block_q5_1 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_Q5_0:
- set_rows_sycl_q<block_q5_0, QK5_0, cpy_blck_f32_q5_0>((const char *)src0->data, src1_dd, (block_q5_0 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
+ set_rows_sycl_q<TIdx, block_q5_0, QK5_0, cpy_blck_f32_q5_0>(src0_d, src1_d, (block_q5_0 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_Q4_1:
- set_rows_sycl_q<block_q4_1, QK4_1, cpy_blck_f32_q4_1>((const char *)src0->data, src1_dd, (block_q4_1 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
+ set_rows_sycl_q<TIdx, block_q4_1, QK4_1, cpy_blck_f32_q4_1>(src0_d, src1_d, (block_q4_1 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_Q4_0:
- set_rows_sycl_q<block_q4_0, QK4_0, cpy_blck_f32_q4_0>((const char *)src0->data, src1_dd, (block_q4_0 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
+ set_rows_sycl_q<TIdx, block_q4_0, QK4_0, cpy_blck_f32_q4_0>(src0_d, src1_d, (block_q4_0 *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
break;
case GGML_TYPE_IQ4_NL:
- set_rows_sycl_q<block_iq4_nl, QK4_NL, cpy_blck_f32_iq4_nl>((const char *)src0->data, src1_dd, (block_iq4_nl *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
+ set_rows_sycl_q<TIdx, block_iq4_nl, QK4_NL, cpy_blck_f32_iq4_nl>(src0_d, src1_d, (block_iq4_nl *)dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb1, nb2, nb3, stream);
break;
default:
break;
}
}
+
+void ggml_sycl_op_set_rows(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
+ scope_op_debug_print scope_dbg_print(__func__, dst, /*num_src=*/2);
+ const ggml_tensor * src0 = dst->src[0];
+ const ggml_tensor * src1 = dst->src[1];
+
+ GGML_ASSERT(dst->src[0]->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->src[1]->type == GGML_TYPE_I64 || dst->src[1]->type == GGML_TYPE_I32);
+
+ if (src1->type == GGML_TYPE_I64) {
+ set_rows_sycl<float, int64_t>(ctx, src0, src1, dst);
+ } else {
+ set_rows_sycl<float, int32_t>(ctx, src0, src1, dst);
+ }
+}
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32;
vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
- vk_pipeline pipeline_set_rows[GGML_TYPE_COUNT];
+ vk_pipeline pipeline_set_rows_i32[GGML_TYPE_COUNT];
+ vk_pipeline pipeline_set_rows_i64[GGML_TYPE_COUNT];
vk_pipeline pipeline_norm_f32;
vk_pipeline pipeline_group_norm_f32;
vk_pipeline pipeline_rms_norm_f32;
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
}
+#define SET_ROWS(itype, rte) \
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F32], "set_rows_f32" #itype, set_rows_f32 ## itype ## rte ## _len, set_rows_f32 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F16], "set_rows_f16" #itype, set_rows_f16 ## itype ## rte ## _len, set_rows_f16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_BF16], "set_rows_bf16" #itype, set_rows_bf16 ## itype ## rte ## _len, set_rows_bf16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_0], "set_rows_q4_0" #itype, set_rows_q4_0 ## itype ## rte ## _len, set_rows_q4_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_1], "set_rows_q4_1" #itype, set_rows_q4_1 ## itype ## rte ## _len, set_rows_q4_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_0], "set_rows_q5_0" #itype, set_rows_q5_0 ## itype ## rte ## _len, set_rows_q5_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_1], "set_rows_q5_1" #itype, set_rows_q5_1 ## itype ## rte ## _len, set_rows_q5_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q8_0], "set_rows_q8_0" #itype, set_rows_q8_0 ## itype ## rte ## _len, set_rows_q8_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \
+ ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_IQ4_NL], "set_rows_iq4_nl" #itype, set_rows_iq4_nl ## itype ## rte ## _len, set_rows_iq4_nl ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
+
if (device->float_controls_rte_fp16) {
- ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F32], "set_rows_f32", set_rows_f32_rte_len, set_rows_f32_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F16], "set_rows_f16", set_rows_f16_rte_len, set_rows_f16_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_BF16], "set_rows_bf16", set_rows_bf16_rte_len, set_rows_bf16_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_0], "set_rows_q4_0", set_rows_q4_0_rte_len, set_rows_q4_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_1], "set_rows_q4_1", set_rows_q4_1_rte_len, set_rows_q4_1_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_0], "set_rows_q5_0", set_rows_q5_0_rte_len, set_rows_q5_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_1], "set_rows_q5_1", set_rows_q5_1_rte_len, set_rows_q5_1_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q8_0], "set_rows_q8_0", set_rows_q8_0_rte_len, set_rows_q8_0_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_IQ4_NL], "set_rows_iq4_nl", set_rows_iq4_nl_rte_len, set_rows_iq4_nl_rte_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
+ SET_ROWS(_i32, _rte)
+ SET_ROWS(_i64, _rte)
} else {
- ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F32], "set_rows_f32", set_rows_f32_len, set_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_F16], "set_rows_f16", set_rows_f16_len, set_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_BF16], "set_rows_bf16", set_rows_bf16_len, set_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_0], "set_rows_q4_0", set_rows_q4_0_len, set_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q4_1], "set_rows_q4_1", set_rows_q4_1_len, set_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_0], "set_rows_q5_0", set_rows_q5_0_len, set_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q5_1], "set_rows_q5_1", set_rows_q5_1_len, set_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_Q8_0], "set_rows_q8_0", set_rows_q8_0_len, set_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
- ggml_vk_create_pipeline(device, device->pipeline_set_rows[GGML_TYPE_IQ4_NL], "set_rows_iq4_nl", set_rows_iq4_nl_len, set_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true);
+ SET_ROWS(_i32, )
+ SET_ROWS(_i64, )
}
+#undef SET_ROWS
+
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_1], "cpy_q4_1_f32", cpy_q4_1_f32_len, cpy_q4_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1);
case GGML_OP_DUP:
return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type);
case GGML_OP_SET_ROWS:
- return ctx->device->pipeline_set_rows[dst->type];
+ if (src1->type == GGML_TYPE_I64) {
+ return ctx->device->pipeline_set_rows_i64[dst->type];
+ } else {
+ return ctx->device->pipeline_set_rows_i32[dst->type];
+ }
case GGML_OP_SILU_BACK:
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_silu_back_f32;
#if defined(SET_ROWS)
#include "generic_binary_head.comp"
-layout (binding = 1) readonly buffer C {uvec2 data_i[];};
+layout (binding = 1) readonly buffer C {B_TYPE data_i[];};
layout (binding = 2) writeonly buffer Q {A_TYPE data_q[];};
+
+#if B_SIZE == 64
+#define DATA_I_SWIZZLE .x
+#else
+#define DATA_I_SWIZZLE
+#endif
+
#else
#include "generic_unary_head.comp"
layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];};
uint i11 = fastmod(i02, p.ne11);
uint i10 = i01;
- uint i1 = data_i[src1_idx(i10, i11, i12, 0) + get_boffset()].x;
+ uint i1 = data_i[src1_idx(i10, i11, i12, 0) + get_boffset()] DATA_I_SWIZZLE;
uint src0_idx = src0_idx(i00, i01, i02, i03) + get_aoffset();
uint dst_idx = dst_idx(i00 / QUANT_K, i1, i02, i03) + get_doffset();
}
for (std::string t : {"f32", "f16", "bf16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
- string_to_spv("set_rows_" + t, "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
- string_to_spv("set_rows_" + t + "_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
+ string_to_spv("set_rows_" + t + "_i32", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
+ string_to_spv("set_rows_" + t + "_i32_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
+ string_to_spv("set_rows_" + t + "_i64", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
+ string_to_spv("set_rows_" + t + "_i64_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}});
}
auto get_type_str = [](bool f16) {
break;
case GGML_OP_CPY:
case GGML_OP_SET_ROWS:
- supports_op = (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32);
+ supports_op = (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_I64);
break;
case GGML_OP_GET_ROWS:
if (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 ||
GGML_ASSERT(b->ne[3] % c->ne[2] == 0);
GGML_ASSERT(c->ne[3] == 1);
GGML_ASSERT(b->type == GGML_TYPE_F32);
- GGML_ASSERT(c->type == GGML_TYPE_I64);
+ GGML_ASSERT(c->type == GGML_TYPE_I64 || c->type == GGML_TYPE_I32);
GGML_ASSERT(ggml_is_contiguous_rows(a));
GGML_ASSERT(ggml_is_contiguous_rows(b));