GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
- if (tensor->view_src != NULL && tensor->view_offs == 0) {
+ if (tensor->view_src != NULL) {
assert(tensor->view_src->buffer->buft == buffer->buft);
- tensor->backend = tensor->view_src->backend;
- tensor->extra = tensor->view_src->extra;
return;
}
}
}
-#if 0
-template<typename ... Srcs>
-static __global__ void k_compute_batched_ptrs_id(
- const void ** ptrs_src, void ** ptrs_dst,
- int ne12, int ne13,
- int ne23,
- int nb02, int nb03,
- int nb12, int nb13,
- int nb2, int nb3,
- int r2, int r3,
- ggml_type src0_type, half * src0_as_f16, int64_t src0_ne,
- const half * src1_f16, half * dst_f16,
- const int32_t * ids, const int id,
- Srcs... src0s) {
-
- int i = ids[id];
-
- half * src0_f16;
- const void * srcs_ar[] = { (const half *) src0s... };
- if (src0_type == GGML_TYPE_F16) {
- src0_f16 = (half *) srcs_ar[i];
- } else {
- src0_f16 = src0_as_f16;
- if (threadIdx.x == 0 && threadIdx.y == 0) {
- const to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(src0_type);
- to_fp16(srcs_ar[i], src0_f16, src0_ne, cudaStreamFireAndForget);
- }
- }
-
- int i13 = blockIdx.x * blockDim.x + threadIdx.x;
- int i12 = blockIdx.y * blockDim.y + threadIdx.y;
-
- if (i13 >= ne13 || i12 >= ne12) {
- return;
- }
-
- int i03 = i13 / r3;
- int i02 = i12 / r2;
-
- ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_f16 + i02*nb02 + i03*nb03;
- ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_f16 + i12*nb12/2 + i13*nb13/2;
- ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f16 + i12* nb2/2 + i13* nb3/2;
-}
-
-static void ggml_cuda_mul_mat_id_cublas(ggml_tensor * dst) {
- const struct ggml_tensor * ids = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
- const struct ggml_tensor * src00 = dst->src[2];
-
- const int id = dst->op_params[0];
-
- GGML_ASSERT(!ggml_is_transposed(src00));
- GGML_ASSERT(!ggml_is_transposed(src1));
-
- GGML_ASSERT(src00->backend != GGML_BACKEND_TYPE_GPU_SPLIT);
- GGML_ASSERT(src1->type == GGML_TYPE_F32);
-
- const int64_t ne00 = src00->ne[0]; GGML_UNUSED(ne00);
- const int64_t ne01 = src00->ne[1];
- const int64_t ne02 = src00->ne[2];
- const int64_t ne03 = src00->ne[3];
-
- //const int64_t nb01 = src00->nb[1];
- const int64_t nb02 = src00->nb[2]; GGML_UNUSED(nb02);
- const int64_t nb03 = src00->nb[3]; GGML_UNUSED(nb03);
-
- const int64_t ne10 = src1->ne[0];
- const int64_t ne11 = src1->ne[1];
- const int64_t ne12 = src1->ne[2];
- const int64_t ne13 = src1->ne[3];
-
- //const int64_t nb11 = src1->nb[1];
- const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12);
- const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13);
-
- const int64_t ne1 = ggml_nelements(src1);
- const int64_t ne = ggml_nelements(dst);
-
- ggml_cuda_set_device(g_main_device);
- cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
-
- CUBLAS_CHECK(cublasSetStream(g_cublas_handles[g_main_device], main_stream));
-
- //ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
- //void * src0_ddq = src0_extra->data_device[g_main_device];
- //half * src0_as_f16 = (half *) src0_ddq;
-
- ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
- float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
-
- ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
- float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
-
- // convert src1 to fp16
- const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
- GGML_ASSERT(to_fp16_cuda != nullptr);
-
- size_t src1_as = 0;
- half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
- to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
-
- size_t dst_as = 0;
- half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
-
- GGML_ASSERT(ne12 % ne02 == 0);
- GGML_ASSERT(ne13 % ne03 == 0);
-
- // broadcast factors
- const int64_t r2 = ne12/ne02;
- const int64_t r3 = ne13/ne03;
-
- const half alpha_f16 = 1.0f;
- const half beta_f16 = 0.0f;
-
- // use cublasGemmBatchedEx
- const int ne23 = ne12*ne13;
-
- const void ** ptrs_src = nullptr;
- void ** ptrs_dst = nullptr;
-
- size_t ptrs_src_s = 0;
- size_t ptrs_dst_s = 0;
-
- ptrs_src = (const void **) ggml_cuda_pool_malloc(2*ne23*sizeof(void *), &ptrs_src_s);
- ptrs_dst = ( void **) ggml_cuda_pool_malloc(1*ne23*sizeof(void *), &ptrs_dst_s);
-
- int64_t src0_ne = ggml_nelements(src00);
- half * src0_as_f16 = nullptr;
- size_t src0_as = 0;
- if (src00->type != GGML_TYPE_F16) {
- src0_as_f16 = (half *) ggml_cuda_pool_malloc(src0_ne * sizeof(half), &src0_as);
- }
-
- static_assert(GGML_MAX_SRC == 6, "GGML_MAX_SRC == 6");
- dim3 block_dims(ne13, ne12);
- k_compute_batched_ptrs_id<<<1, block_dims, 0, main_stream>>>(
- ptrs_src, ptrs_dst,
- ne12, ne13,
- ne23,
- ne00*ne01*sizeof(half), ne00*ne01*ne02*sizeof(half),
- nb12, nb13,
- dst->nb[2], dst->nb[3],
- r2, r3,
- src00->type, src0_as_f16, src0_ne,
- src1_as_f16, dst_f16,
- (const int *)((ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device], id,
- dst->src[2] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[2]->extra)->data_device[g_main_device] : nullptr,
- dst->src[3] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[3]->extra)->data_device[g_main_device] : nullptr,
- dst->src[4] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[4]->extra)->data_device[g_main_device] : nullptr,
- dst->src[5] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[5]->extra)->data_device[g_main_device] : nullptr
- );
- CUDA_CHECK(cudaGetLastError());
-
- CUBLAS_CHECK(
- cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N,
- ne01, ne11, ne10,
- &alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, ne00,
- (const void **) (ptrs_src + 1*ne23), CUDA_R_16F, ne10,
- &beta_f16, ( void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01,
- ne23,
- CUBLAS_COMPUTE_16F,
- CUBLAS_GEMM_DEFAULT_TENSOR_OP));
-
- if (src0_as != 0) {
- ggml_cuda_pool_free(src0_as_f16, src0_as);
- }
- if (ptrs_src_s != 0) {
- ggml_cuda_pool_free(ptrs_src, ptrs_src_s);
- }
- if (ptrs_dst_s != 0) {
- ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s);
- }
-
- const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
- to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
-
- ggml_cuda_pool_free(src1_as_f16, src1_as);
- ggml_cuda_pool_free(dst_f16, dst_as);
-}
-#endif
-
static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-#if 0
- ggml_cuda_mul_mat_id_cublas(dst);
- // TODO: mmq/mmv support
-#endif
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
+ const ggml_tensor * ids = dst->src[2];
+
+ GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0->buffer) && "mul_mat_id does not support split buffers");
cudaStream_t stream = ctx.stream();
const size_t nb11 = src1->nb[1];
const size_t nb1 = dst->nb[1];
- const struct ggml_tensor * ids = src0;
const int32_t id = ((int32_t *) dst->op_params)[0];
- const int32_t n_as = ((int32_t *) dst->op_params)[1];
+ const int32_t n_as = src0->ne[2];
std::vector<char> ids_host(ggml_nbytes(ids));
const char * ids_dev = (const char *) ids->data;
CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, stream));
CUDA_CHECK(cudaStreamSynchronize(stream));
+ ggml_tensor src0_row = *src0;
ggml_tensor src1_row = *src1;
ggml_tensor dst_row = *dst;
+ char * src0_original = (char *) src0->data;
char * src1_original = (char *) src1->data;
char * dst_original = (char *) dst->data;
+ src0_row.ne[2] = 1;
+ src0_row.ne[3] = 1;
+ src0_row.nb[3] = src0->nb[2];
+
if (src1->ne[1] == 1) {
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
GGML_ASSERT(row_id >= 0 && row_id < n_as);
- const struct ggml_tensor * src0_row = dst->src[row_id + 2];
-
+ src0_row.data = src0_original + row_id*src0->nb[2];
src1_row.data = src1_original + i01*src1->nb[1];
dst_row.data = dst_original + i01*dst->nb[1];
- ggml_cuda_mul_mat(ctx, src0_row, &src1_row, &dst_row);
+ ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
}
} else {
ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
dst_row.data = dst_contiguous.get();
for (int32_t row_id = 0; row_id < n_as; ++row_id) {
- const struct ggml_tensor * src0_row = dst->src[row_id + 2];
-
int64_t num_src1_rows = 0;
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
continue;
}
+ src0_row.data = src0_original + row_id*src0->nb[2];
+
src1_row.ne[1] = num_src1_rows;
dst_row.ne[1] = num_src1_rows;
dst_row.nb[2] = num_src1_rows*nb1;
dst_row.nb[3] = num_src1_rows*nb1;
- ggml_cuda_mul_mat(ctx, src0_row, &src1_row, &dst_row);
+ ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
num_src1_rows = 0;
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
cudaError_t err = cudaGetLastError();
if (err != cudaSuccess) {
fprintf(stderr, "%s: %s failed\n", __func__, ggml_op_desc(dst));
- GGML_ASSERT(false);
+ CUDA_CHECK(err);
}
return true;
{
//GGML_ASSERT(ne00 == ne10);
//GGML_ASSERT(ne03 == ne13);
-
- GGML_ASSERT(src0t == GGML_TYPE_I32);
-
- const int n_as = ((int32_t *) dst->op_params)[1];
-
- // TODO: make this more general
- GGML_ASSERT(n_as <= 8);
+ const int n_as = src0->ne[2];
// max size of the src1ids array in the kernel shared buffer
GGML_ASSERT(ne11 <= 4096);
- const int64_t ne20 = src2 ? src2->ne[0] : 0;
- const int64_t ne21 = src2 ? src2->ne[1] : 0;
- const int64_t ne22 = src2 ? src2->ne[2] : 0;
- const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
+ // src2 = ids
+ const int64_t ne20 = src2->ne[0]; GGML_UNUSED(ne20);
+ const int64_t ne21 = src2->ne[1];
+ const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22);
+ const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23);
+
+ const uint64_t nb20 = src2->nb[0]; GGML_UNUSED(nb20);
+ const uint64_t nb21 = src2->nb[1];
+ const uint64_t nb22 = src2->nb[2]; GGML_UNUSED(nb22);
+ const uint64_t nb23 = src2->nb[3]; GGML_UNUSED(nb23);
- const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
- const uint64_t nb21 = src2 ? src2->nb[1] : 0;
- const uint64_t nb22 = src2 ? src2->nb[2] : 0;
- const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
+ const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
- const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
+ GGML_ASSERT(src2t == GGML_TYPE_I32);
- GGML_ASSERT(!ggml_is_transposed(src2));
+ GGML_ASSERT(!ggml_is_transposed(src0));
GGML_ASSERT(!ggml_is_transposed(src1));
GGML_ASSERT(src1t == GGML_TYPE_F32);
- const uint r2 = ne12/ne22;
- const uint r3 = ne13/ne23;
-
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
// to the matrix-vector kernel
int ne11_mm_min = n_as;
const int idx = ((int32_t *) dst->op_params)[0];
// batch size
- GGML_ASSERT(ne01 == ne11);
+ GGML_ASSERT(ne21 == ne11); // ?
+ GGML_ASSERT(ne12 == 1 && ne13 == 1); // no broadcasting
+ const uint r2 = 1;
+ const uint r3 = 1;
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
// indirect matrix multiplication
// !!!
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
- ne20 % 32 == 0 && ne20 >= 64 &&
+ ne00 % 32 == 0 && ne00 >= 64 &&
ne11 > ne11_mm_min) {
// some Metal matrix data types require aligned pointers
id<MTLComputePipelineState> pipeline = nil;
- switch (src2->type) {
+ switch (src0->type) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
- [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
- [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
- [encoder setBytes:&idx length:sizeof(idx) atIndex:18];
- // TODO: how to make this an array? read Metal docs
- for (int j = 0; j < 8; ++j) {
- // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
- struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
-
- size_t offs_src_cur = 0;
- id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur);
-
- [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
- }
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:9];
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:10];
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11];
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12];
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:13];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:14];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:15];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:16];
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:19];
[encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0];
- [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
} else {
int nth0 = 32;
int nth1 = 1;
id<MTLComputePipelineState> pipeline = nil;
// use custom matrix x vector kernel
- switch (src2t) {
+ switch (src0t) {
case GGML_TYPE_F32:
{
GGML_ASSERT(src1t == GGML_TYPE_F32);
}
};
- if (ggml_is_quantized(src2t)) {
- GGML_ASSERT(ne20 >= nth0*nth1);
+ if (ggml_is_quantized(src0t)) {
+ GGML_ASSERT(ne00 >= nth0*nth1);
}
const int64_t _ne1 = 1; // kernels needs a reference in constant memory
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
- [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
- [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
- [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
- [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
- [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
- [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
- [encoder setBytes:&idx length:sizeof(idx) atIndex:22];
- // TODO: how to make this an array? read Metal docs
- for (int j = 0; j < 8; ++j) {
- // NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
- struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
-
- size_t offs_src_cur = 0;
- id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur);
-
- [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
- }
+ [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:6];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:7];
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:8];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:12];
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:19];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:20];
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:21];
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:22];
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:23];
- if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 || src2t == GGML_TYPE_Q5_0 ||
- src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 || src2t == GGML_TYPE_Q2_K ||
- src2t == GGML_TYPE_IQ1_S || src2t == GGML_TYPE_IQ1_M || src2t == GGML_TYPE_IQ2_S) {
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
+ src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
+ src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
- else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
- const int mem_size = src2t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
+ else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
+ const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
- else if (src2t == GGML_TYPE_IQ3_XXS || src2t == GGML_TYPE_IQ3_S) {
- const int mem_size = src2t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
+ else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
+ const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
- else if (src2t == GGML_TYPE_IQ4_NL || src2t == GGML_TYPE_IQ4_XS) {
+ else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
const int mem_size = 32*sizeof(float);
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
- else if (src2t == GGML_TYPE_Q4_K) {
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ else if (src0t == GGML_TYPE_Q4_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
- else if (src2t == GGML_TYPE_Q3_K) {
+ else if (src0t == GGML_TYPE_Q3_K) {
#ifdef GGML_QKK_64
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
#else
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
#endif
}
- else if (src2t == GGML_TYPE_Q5_K) {
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ else if (src0t == GGML_TYPE_Q5_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
- else if (src2t == GGML_TYPE_Q6_K) {
- [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ else if (src0t == GGML_TYPE_Q6_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {
const int64_t ny = (_ne1 + nrows - 1)/nrows;
- [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
}
} break;
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
+ // bitonic sort requires the number of elements to be power of 2
+ int64_t ne00_padded = 1;
+ while (ne00_padded < ne00) {
+ ne00_padded *= 2;
+ }
+
+ // Metal kernels require the buffer size to be multiple of 16 bytes
+ // https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
+ const int mem_size = GGML_PAD(ne00_padded*sizeof(int32_t), 16);
+
id<MTLComputePipelineState> pipeline = nil;
switch (order) {
};
[encoder setComputePipelineState:pipeline];
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+ [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3];
+ [encoder setThreadgroupMemoryLength:mem_size atIndex:0];
- [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
+ [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)];
} break;
case GGML_OP_LEAKY_RELU:
{
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
enum ggml_sort_order {
- GGML_SORT_ASC,
- GGML_SORT_DESC,
+ GGML_SORT_ORDER_ASC,
+ GGML_SORT_ORDER_DESC,
};
// general-purpose kernel for addition, multiplication and division of two tensors
// bitonic sort implementation following the CUDA kernels as reference
typedef void (argsort_t)(
- device const float * x,
- device int32_t * dst,
- constant int64_t & ncols,
+ device const float * x,
+ device int32_t * dst,
+ constant int64_t & ncols,
+ constant int64_t & ncols_pad,
+ threadgroup int32_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]]);
device const float * x,
device int32_t * dst,
constant int64_t & ncols,
+ constant int64_t & ncols_pad,
+ threadgroup int32_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]]) {
// bitonic sort
int col = tpitg[0];
int row = tgpig[1];
- if (col >= ncols) return;
+ if (col >= ncols_pad) return;
- device const float * x_row = x + row * ncols;
- device int32_t * dst_row = dst + row * ncols;
+ device const float * x_row = x + row * ncols;
+ threadgroup int32_t * dst_row = shared_values;
// initialize indices
- if (col < ncols) {
- dst_row[col] = col;
- }
+ dst_row[col] = col;
+
threadgroup_barrier(mem_flags::mem_threadgroup);
- for (int k = 2; k <= ncols; k *= 2) {
+ for (int k = 2; k <= ncols_pad; k *= 2) {
for (int j = k / 2; j > 0; j /= 2) {
int ixj = col ^ j;
if (ixj > col) {
if ((col & k) == 0) {
- if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
+ if (dst_row[col] >= ncols ||
+ (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
+ x_row[dst_row[col]] > x_row[dst_row[ixj]] :
+ x_row[dst_row[col]] < x_row[dst_row[ixj]]))
+ ) {
SWAP(dst_row[col], dst_row[ixj]);
}
} else {
- if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
+ if (dst_row[ixj] >= ncols ||
+ (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
+ x_row[dst_row[col]] < x_row[dst_row[ixj]] :
+ x_row[dst_row[col]] > x_row[dst_row[ixj]]))
+ ) {
SWAP(dst_row[col], dst_row[ixj]);
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
}
+
+ // copy the result to dst without the padding
+ if (col < ncols) {
+ dst[row * ncols + col] = dst_row[col];
+ }
}
-template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
-template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
+template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
+template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
kernel void kernel_leaky_relu_f32(
device const float * src0,
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
kernel void kernel_mul_mm_id(
- device const uchar * ids,
+ device const uchar * src0s,
device const uchar * src1,
device float * dst,
+ device const uchar * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne02,
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const uchar * src00,
- device const uchar * src01,
- device const uchar * src02,
- device const uchar * src03,
- device const uchar * src04,
- device const uchar * src05,
- device const uchar * src06,
- device const uchar * src07,
threadgroup uchar * shared_memory [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
// expert id
const int32_t id = tgpig.z/(ne12*ne13);
+ device const uchar * src0 = src0s + id*nb02;
tgpig.z = tgpig.z%(ne12*ne13);
}
kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
- src0s[id],
+ src0,
src1,
src1ids,
dst,
//
typedef void (mat_mm_id_t)(
- device const uchar * ids,
+ device const uchar * src0s,
device const uchar * src1,
device float * dst,
+ device const uchar * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne02,
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const uchar * src00,
- device const uchar * src01,
- device const uchar * src02,
- device const uchar * src03,
- device const uchar * src04,
- device const uchar * src05,
- device const uchar * src06,
- device const uchar * src07,
threadgroup uchar *,
uint3, uint, uint);
[[host_name("kernel_mul_mv_id_f32_f32")]]
kernel void kernel_mul_mv_id_f32_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
kernel_mul_mv_f32_f32_impl(
- src0[id],
+ src0,
src1 + bid*nb11,
dst + bid*ne0,
ne00,
[[host_name("kernel_mul_mv_id_f16_f32")]]
kernel void kernel_mul_mv_id_f16_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
kernel_mul_mv_f16_f32_impl(
- src0[id],
+ src0,
src1 + bid*nb11,
dst + bid*ne0,
ne00,
[[host_name("kernel_mul_mv_id_q8_0_f32")]]
kernel void kernel_mul_mv_id_q8_0_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
kernel_mul_mv_q8_0_f32_impl(
- src0[id],
+ src0,
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
[[host_name("kernel_mul_mv_id_q4_0_f32")]]
kernel void kernel_mul_mv_id_q4_0_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
- src0[id],
+ src0,
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
[[host_name("kernel_mul_mv_id_q4_1_f32")]]
kernel void kernel_mul_mv_id_q4_1_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
- src0[id],
+ src0,
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
[[host_name("kernel_mul_mv_id_q5_0_f32")]]
kernel void kernel_mul_mv_id_q5_0_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
- src0[id],
+ src0,
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
[[host_name("kernel_mul_mv_id_q5_1_f32")]]
kernel void kernel_mul_mv_id_q5_1_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
- src0[id],
+ src0,
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
[[host_name("kernel_mul_mv_id_q2_K_f32")]]
kernel void kernel_mul_mv_id_q2_K_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
kernel_mul_mv_q2_K_f32_impl(
- src0[id],
+ src0,
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
[[host_name("kernel_mul_mv_id_q3_K_f32")]]
kernel void kernel_mul_mv_id_q3_K_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
kernel_mul_mv_q3_K_f32_impl(
- src0[id],
+ src0,
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
[[host_name("kernel_mul_mv_id_q4_K_f32")]]
kernel void kernel_mul_mv_id_q4_K_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
kernel_mul_mv_q4_K_f32_impl(
- src0[id],
+ src0,
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
[[host_name("kernel_mul_mv_id_q5_K_f32")]]
kernel void kernel_mul_mv_id_q5_K_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
kernel_mul_mv_q5_K_f32_impl(
- src0[id],
+ src0,
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
[[host_name("kernel_mul_mv_id_q6_K_f32")]]
kernel void kernel_mul_mv_id_q6_K_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
kernel_mul_mv_q6_K_f32_impl(
- src0[id],
+ src0,
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
[[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
kernel void kernel_mul_mv_id_iq2_xxs_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
kernel_mul_mv_iq2_xxs_f32_impl(
- src0[id],
+ src0,
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
[[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
kernel void kernel_mul_mv_id_iq2_xs_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
kernel_mul_mv_iq2_xs_f32_impl(
- src0[id],
+ src0,
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
[[host_name("kernel_mul_mv_id_iq3_xxs_f32")]]
kernel void kernel_mul_mv_id_iq3_xxs_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
kernel_mul_mv_iq3_xxs_f32_impl(
- src0[id],
+ src0,
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
[[host_name("kernel_mul_mv_id_iq3_s_f32")]]
kernel void kernel_mul_mv_id_iq3_s_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
kernel_mul_mv_iq3_s_f32_impl(
- src0[id],
+ src0,
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
[[host_name("kernel_mul_mv_id_iq2_s_f32")]]
kernel void kernel_mul_mv_id_iq2_s_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
kernel_mul_mv_iq2_s_f32_impl(
- src0[id],
+ src0,
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
[[host_name("kernel_mul_mv_id_iq1_s_f32")]]
kernel void kernel_mul_mv_id_iq1_s_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
kernel_mul_mv_iq1_s_f32_impl(
- src0[id],
+ src0,
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
[[host_name("kernel_mul_mv_id_iq1_m_f32")]]
kernel void kernel_mul_mv_id_iq1_m_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
kernel_mul_mv_iq1_m_f32_impl(
- src0[id],
+ src0,
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
[[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
kernel void kernel_mul_mv_id_iq4_nl_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
threadgroup float * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
kernel_mul_mv_iq4_nl_f32_impl(
- src0[id],
+ src0,
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,
[[host_name("kernel_mul_mv_id_iq4_xs_f32")]]
kernel void kernel_mul_mv_id_iq4_xs_f32(
- device const char * ids,
+ device const char * src0s,
device const char * src1,
device float * dst,
+ device const char * ids,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant uint & r2,
constant uint & r3,
constant int & idx,
- device const char * src00,
- device const char * src01,
- device const char * src02,
- device const char * src03,
- device const char * src04,
- device const char * src05,
- device const char * src06,
- device const char * src07,
threadgroup float * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
-
const int64_t bid = tgpig.z/(ne12*ne13);
tgpig.z = tgpig.z%(ne12*ne13);
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+ device const char * src0 = src0s + id*nb02;
#if QK_K == 64
kernel_mul_mv_iq4_nl_f32_impl(
#else
kernel_mul_mv_iq4_xs_f32_impl(
#endif
- src0[id],
+ src0,
(device const float *) (src1 + bid*nb11),
dst + bid*ne0,
ne00,