if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT) {
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
- ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool());
+ ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
if (src0->type != GGML_TYPE_F16) {
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src0->type);
GGML_ASSERT(to_fp16_cuda != nullptr);
}
const half * src0_ptr = src0->type == GGML_TYPE_F16 ? (const half *) src0_dd_i : src0_as_f16.get();
- ggml_cuda_pool_alloc<half> src1_as_f16(ctx.pool());
+ ggml_cuda_pool_alloc<half> src1_as_f16(ctx.pool(id));
if (src1->type != GGML_TYPE_F16) {
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
GGML_ASSERT(to_fp16_cuda != nullptr);
to_fp16_cuda(src1_ddf_i, src1_as_f16.get(), ne, stream);
}
const half * src1_ptr = src1->type == GGML_TYPE_F16 ? (const half *) src1_ddf_i : src1_as_f16.get();
- ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(), row_diff*src1_ncols);
+ ggml_cuda_pool_alloc<half> dst_f16(ctx.pool(id), row_diff*src1_ncols);
const half alpha_f16 = 1.0f;
const half beta_f16 = 0.0f;
}
}
+struct mmid_row_mapping {
+ int32_t i1;
+ int32_t i2;
+};
+
+static __global__ void k_copy_src1_to_contiguous(const char * __restrict__ src1_original, char * __restrict__ src1_contiguous,
+ int * __restrict__ cur_src1_row, mmid_row_mapping * __restrict__ row_mapping,
+ const char * __restrict ids, int64_t i02, size_t ids_nb1, size_t ids_nb0,
+ int64_t ne11, int64_t ne10,
+ size_t nb11, size_t nb12) {
+ int32_t iid1 = blockIdx.x;
+ int32_t id = blockIdx.y;
+
+ const int32_t row_id_i = *(const int32_t *) (ids + iid1*ids_nb1 + id*ids_nb0);
+
+ if (row_id_i != i02) {
+ return;
+ }
+
+ const int64_t i11 = id % ne11;
+ const int64_t i12 = iid1;
+
+ __shared__ int src1_row;
+ if (threadIdx.x == 0) {
+ src1_row = atomicAdd(cur_src1_row, 1);
+ row_mapping[src1_row] = {id, iid1};
+ }
+ __syncthreads();
+
+ const float * src1_row_original = (const float *)(src1_original + i11*nb11 + i12*nb12);
+ float * src1_row_contiguous = (float *)(src1_contiguous + src1_row*nb11);
+
+ for (int i = threadIdx.x; i < ne10; i += blockDim.x) {
+ src1_row_contiguous[i] = src1_row_original[i];
+ }
+}
+
+static __global__ void k_copy_dst_from_contiguous(char * __restrict__ dst_original, const char * __restrict__ dst_contiguous,
+ const mmid_row_mapping * __restrict__ row_mapping,
+ int64_t ne0,
+ size_t nb1, size_t nb2) {
+ int32_t i = blockIdx.x;
+
+ const int32_t i1 = row_mapping[i].i1;
+ const int32_t i2 = row_mapping[i].i2;
+
+ const float * dst_row_contiguous = (const float *)(dst_contiguous + i*nb1);
+ float * dst_row_original = (float *)(dst_original + i1*nb1 + i2*nb2);
+
+ for (int j = threadIdx.x; j < ne0; j += blockDim.x) {
+ dst_row_original[j] = dst_row_contiguous[j];
+ }
+}
+
static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * ids = dst->src[2];
+ GGML_TENSOR_BINARY_OP_LOCALS
+
GGML_ASSERT(!ggml_backend_buffer_is_cuda_split(src0->buffer) && "mul_mat_id does not support split buffers");
cudaStream_t stream = ctx.stream();
- const size_t nb11 = src1->nb[1];
- const size_t nb1 = dst->nb[1];
-
- const int32_t id = ((int32_t *) dst->op_params)[0];
- const int32_t n_as = src0->ne[2];
+ const int64_t n_as = ne02;
+ const int64_t n_ids = ids->ne[0];
std::vector<char> ids_host(ggml_nbytes(ids));
const char * ids_dev = (const char *) ids->data;
ggml_tensor src0_row = *src0;
ggml_tensor src1_row = *src1;
- ggml_tensor dst_row = *dst;
+ ggml_tensor dst_row = *dst;
char * src0_original = (char *) src0->data;
char * src1_original = (char *) src1->data;
src0_row.ne[2] = 1;
src0_row.ne[3] = 1;
- src0_row.nb[3] = src0->nb[2];
+ src0_row.nb[3] = nb02;
- if (src1->ne[1] == 1) {
- for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
- const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
+ src1_row.ne[1] = 1;
+ src1_row.ne[2] = 1;
+ src1_row.ne[3] = 1;
+ src1_row.nb[2] = nb11;
+ src1_row.nb[3] = nb11;
- GGML_ASSERT(row_id >= 0 && row_id < n_as);
+ dst_row.ne[1] = 1;
+ dst_row.ne[2] = 1;
+ dst_row.ne[3] = 1;
+ dst_row.nb[2] = nb1;
+ dst_row.nb[3] = nb1;
- src0_row.data = src0_original + row_id*src0->nb[2];
- src1_row.data = src1_original + i01*src1->nb[1];
- dst_row.data = dst_original + i01*dst->nb[1];
+ if (ne12 == 1) {
+ for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
+ for (int64_t id = 0; id < n_ids; id++) {
+ const int32_t i02 = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
- ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
+ GGML_ASSERT(i02 >= 0 && i02 < n_as);
+
+ const int64_t i11 = id % ne11;
+ const int64_t i12 = iid1;
+
+ const int64_t i1 = id;
+ const int64_t i2 = i12;
+
+ src0_row.data = src0_original + i02*nb02;
+ src1_row.data = src1_original + i11*nb11 + i12*nb12;
+ dst_row.data = dst_original + i1*nb1 + i2*nb2;
+
+ ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
+ }
}
} else {
ggml_cuda_pool_alloc<char> src1_contiguous(ctx.pool(), sizeof(float)*ggml_nelements(src1));
src1_row.data = src1_contiguous.get();
dst_row.data = dst_contiguous.get();
- for (int32_t row_id = 0; row_id < n_as; ++row_id) {
+ for (int64_t i02 = 0; i02 < n_as; i02++) {
int64_t num_src1_rows = 0;
- for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
- const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
- if (row_id_i != row_id) {
- continue;
- }
+ for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
+ for (int64_t id = 0; id < n_ids; id++) {
+ const int32_t row_id_i = *(const int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
- GGML_ASSERT(row_id >= 0 && row_id < n_as);
+ GGML_ASSERT(row_id_i >= 0 && row_id_i < n_as);
- CUDA_CHECK(cudaMemcpyAsync(src1_contiguous.get() + num_src1_rows*nb11, src1_original + i01*nb11,
- nb11, cudaMemcpyDeviceToDevice, stream));
- num_src1_rows++;
+ if (row_id_i != i02) {
+ continue;
+ }
+
+ num_src1_rows++;
+ }
}
if (num_src1_rows == 0) {
continue;
}
- src0_row.data = src0_original + row_id*src0->nb[2];
+ ggml_cuda_pool_alloc<int> dev_cur_src1_row(ctx.pool(), 1);
+ ggml_cuda_pool_alloc<mmid_row_mapping> dev_row_mapping(ctx.pool(), num_src1_rows);
+ CUDA_CHECK(cudaMemsetAsync(dev_cur_src1_row.get(), 0, sizeof(int), stream));
- src1_row.ne[1] = num_src1_rows;
- dst_row.ne[1] = num_src1_rows;
+ {
+ dim3 block_dims(std::min((unsigned int)ne10, 768u));
+ dim3 grid_dims(ids->ne[1], n_ids);
+ k_copy_src1_to_contiguous<<<grid_dims, block_dims, 0, stream>>>(
+ src1_original, src1_contiguous.get(),
+ dev_cur_src1_row.get(), dev_row_mapping.get(),
+ ids_dev, i02, ids->nb[1], ids->nb[0],
+ ne11, ne10,
+ nb11, nb12);
+ CUDA_CHECK(cudaGetLastError());
+ }
+
+ src0_row.data = src0_original + i02*nb02;
+ GGML_ASSERT(nb11 == sizeof(float)*ne10);
+ GGML_ASSERT(nb1 == sizeof(float)*ne0);
+
+ src1_row.ne[1] = num_src1_rows;
src1_row.nb[1] = nb11;
src1_row.nb[2] = num_src1_rows*nb11;
src1_row.nb[3] = num_src1_rows*nb11;
+ dst_row.ne[1] = num_src1_rows;
dst_row.nb[1] = nb1;
dst_row.nb[2] = num_src1_rows*nb1;
dst_row.nb[3] = num_src1_rows*nb1;
ggml_cuda_mul_mat(ctx, &src0_row, &src1_row, &dst_row);
- num_src1_rows = 0;
- for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
- const int32_t row_id_i = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
-
- if (row_id_i != row_id) {
- continue;
- }
-
- GGML_ASSERT(row_id >= 0 && row_id < n_as);
-
- CUDA_CHECK(cudaMemcpyAsync(dst_original + i01*nb1, dst_contiguous.get() + num_src1_rows*nb1,
- nb1, cudaMemcpyDeviceToDevice, stream));
- num_src1_rows++;
+ {
+ dim3 block_dims(std::min((unsigned int)ne0, 768u));
+ dim3 grid_dims(num_src1_rows);
+ k_copy_dst_from_contiguous<<<grid_dims, block_dims, 0, stream>>>(
+ dst_original, dst_contiguous.get(),
+ dev_row_mapping.get(),
+ ne0,
+ nb1, nb2);
+ CUDA_CHECK(cudaGetLastError());
}
}
}
GGML_CALL static bool ggml_backend_cuda_offload_op(ggml_backend_t backend, const ggml_tensor * op) {
const int min_batch_size = 32;
- return op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS;
+ return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) ||
+ (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID);
GGML_UNUSED(backend);
}
} break;
case GGML_OP_MUL_MAT_ID:
{
- //GGML_ASSERT(ne00 == ne10);
- //GGML_ASSERT(ne03 == ne13);
const int n_as = src0->ne[2];
- // max size of the src1ids array in the kernel shared buffer
- GGML_ASSERT(ne11 <= 4096);
-
// src2 = ids
- const int64_t ne20 = src2->ne[0]; GGML_UNUSED(ne20);
+ const int64_t ne20 = src2->ne[0];
const int64_t ne21 = src2->ne[1];
const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22);
const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23);
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
// to the matrix-vector kernel
- int ne11_mm_min = n_as;
-
- const int idx = ((int32_t *) dst->op_params)[0];
+ // ne20 = n_used_experts
+ // ne21 = n_rows
+ const int dst_rows = ne20*ne21;
+ const int dst_rows_min = n_as;
- // batch size
- GGML_ASSERT(ne21 == ne11); // ?
- GGML_ASSERT(ne12 == 1 && ne13 == 1); // no broadcasting
- const uint r2 = 1;
- const uint r3 = 1;
+ // max size of the rowids array in the kernel shared buffer
+ GGML_ASSERT(dst_rows <= 2048);
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
// !!!
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
ne00 % 32 == 0 && ne00 >= 64 &&
- ne11 > ne11_mm_min) {
+ dst_rows > dst_rows_min) {
// some Metal matrix data types require aligned pointers
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:9];
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:10];
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11];
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12];
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:13];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:14];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:15];
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:16];
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
- [encoder setBytes:&idx length:sizeof(idx) atIndex:19];
-
- [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0];
-
- [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
+ [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11];
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
+
+ [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0];
+
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, n_as) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
} else {
int nth0 = 32;
int nth1 = 1;
GGML_ASSERT(ne00 >= nth0*nth1);
}
- const int64_t _ne1 = 1; // kernels needs a reference in constant memory
-
[encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4];
- [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5];
- [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:6];
- [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:7];
- [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:8];
- [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
- [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
- [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
- [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:12];
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
- [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
- [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:19];
- [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:20];
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:21];
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:22];
- [encoder setBytes:&idx length:sizeof(idx) atIndex:23];
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
+ [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9];
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11];
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12];
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13];
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14];
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15];
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16];
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17];
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18];
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22];
+
+ const int64_t _ne1 = 1;
+ const int tgz = dst_rows;
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
const int mem_size = 32*sizeof(float);
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_Q4_K) {
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_Q3_K) {
#ifdef GGML_QKK_64
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
#else
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
#endif
}
else if (src0t == GGML_TYPE_Q5_K) {
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_Q6_K) {
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {
- const int64_t ny = (_ne1 + nrows - 1)/nrows;
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ const int64_t ny = (_ne1 + nrows - 1)/nrows; // = _ne1
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, tgz) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
}
} break;
device const void * src0,
device const float * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne10,
- constant int64_t & ne12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values,
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
uint3 tgpig, uint tiisg, uint sgitg) {
const int nb = ne00/QK4_0;
device const void * src0,
device const float * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne10,
- constant int64_t & ne12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
const int nr = N_DST;
const int nsg = N_SIMDGROUP;
const int nw = N_SIMDWIDTH;
device const char * src0,
device const char * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]]) {
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ uint64_t nb00,
+ uint64_t nb01,
+ uint64_t nb02,
+ int64_t ne10,
+ int64_t ne11,
+ int64_t ne12,
+ uint64_t nb10,
+ uint64_t nb11,
+ uint64_t nb12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ uint3 tgpig,
+ uint tiisg) {
const int64_t r0 = tgpig.x;
const int64_t rb = tgpig.y*N_F32_F32;
device const char * src0,
device const char * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]]) {
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ uint64_t nb00,
+ uint64_t nb01,
+ uint64_t nb02,
+ int64_t ne10,
+ int64_t ne11,
+ int64_t ne12,
+ uint64_t nb10,
+ uint64_t nb11,
+ uint64_t nb12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ uint3 tgpig,
+ uint tiisg) {
const int64_t r0 = tgpig.x;
const int64_t rb = tgpig.y*N_F16_F32;
device const void * src0,
device const float * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne10,
- constant int64_t & ne12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
const int nb = ne00/QK_K;
const int r0 = tgpig.x;
device const void * src0,
device const float * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne10,
- constant int64_t & ne12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
const int nb = ne00/QK_K;
device const void * src0,
device const float * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne10,
- constant int64_t & ne12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
const uint16_t kmask1 = 0x3f3f;
const uint16_t kmask2 = 0x0f0f;
device const void * src0,
device const float * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne10,
- constant int64_t & ne12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
const int nb = ne00/QK_K;
device const void * src0,
device const float * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne10,
- constant int64_t & ne12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
const uint8_t kmask1 = 0x03;
const uint8_t kmask2 = 0x0C;
device const void * src0,
device const float * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne10,
- constant int64_t & ne12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
const int nb = ne00/QK_K;
const int r0 = tgpig.x;
device const void * src0,
device const float * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne10,
- constant int64_t & ne12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
const int nb = ne00/QK_K;
const int r0 = tgpig.x;
device const void * src0,
device const float * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne10,
- constant int64_t & ne12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
const int nb = ne00/QK_K;
const int r0 = tgpig.x;
device const void * src0,
device const float * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne10,
- constant int64_t & ne12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
const int nb = ne00/QK_K;
const int r0 = tgpig.x;
device const void * src0,
device const float * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne10,
- constant int64_t & ne12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
const int nb = ne00/QK_K;
const int r0 = tgpig.x;
device const void * src0,
device const float * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne10,
- constant int64_t & ne12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_value,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
const int nb = ne00/QK_K;
const int r0 = tgpig.x;
device const void * src0,
device const float * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne10,
- constant int64_t & ne12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_value,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
const int nb = ne00/QK_K;
const int r0 = tgpig.x;
device const void * src0,
device const float * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne10,
- constant int64_t & ne12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values_i8 [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values_i8,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
const int nb = ne00/QK4_NL;
device const void * src0,
device const float * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne10,
- constant int64_t & ne12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values_i8 [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values_i8,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg) {
+
threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
const int nb = ne00/QK_K;
const int r0 = tgpig.x;
}
}
-// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids
+// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
void kernel_mul_mm_id_impl(
device const uchar * src0,
device const uchar * src1,
- threadgroup short * src1ids,
+ threadgroup ushort2 * rowids,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne02,
constant uint64_t & nb01,
constant uint64_t & nb02,
+ constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
int64_t ne1,
- constant uint & r2,
- constant uint & r3,
+ int64_t ne0ne1,
threadgroup uchar * shared_memory,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
const uint r0 = tgpig.y;
const uint r1 = tgpig.x;
- const uint im = tgpig.z;
if (r1 * BLOCK_SIZE_N >= ne1) return;
for (int i = 0; i < 8; i++){
c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
}
-
short il = (tiitg % THREAD_PER_ROW);
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
-
- uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02);
ushort offset1 = il/nl;
- device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
+ threadgroup const auto & id = rowids[r1 * BLOCK_SIZE_N + thread_col];
+
+ device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01) + offset1;
device const float * y = (device const float *)(src1
- + nb12 * im
- + nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col]
+ + nb12 * id[1]
+ + nb11 * (id[0] % ne11)
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
for (int i = 0; i < 4; i++) {
- simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
+ simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
}
simdgroup_barrier(mem_flags::mem_none);
for (int i = 0; i < 2; i++) {
- simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
+ simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
}
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
threadgroup_barrier(mem_flags::mem_threadgroup);
- device float * C = dst + (BLOCK_SIZE_M * r0) + im*ne1*ne0;
+ device float * C = dst + (BLOCK_SIZE_M * r0);
if (sgitg == 0) {
- for (int i = 0; i < n_rows; i++) {
- for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
- *(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
+ for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
+ threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j];
+ int joff = jid[0] * ne0 + jid[1] * ne0ne1;
+ for (int i = 0; i < n_rows; i++) {
+ *(C + i + joff) = *(temp_str + i + j * BLOCK_SIZE_M);
}
}
}
device const uchar * src1,
device float * dst,
device const uchar * ids,
+ constant int64_t & nei0,
+ constant int64_t & nei1,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne02,
constant uint64_t & nb01,
constant uint64_t & nb02,
+ constant int64_t & ne11,
constant int64_t & ne12,
constant int64_t & ne13,
constant uint64_t & nb10,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
threadgroup uchar * shared_memory [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- // expert id
- const int32_t id = tgpig.z/(ne12*ne13);
- device const uchar * src0 = src0s + id*nb02;
+ const int32_t i02 = tgpig.z;
+ tgpig.z = 0;
- tgpig.z = tgpig.z%(ne12*ne13);
+ device const uchar * src0 = src0s + i02*nb02;
- // row indices of src1 for expert id
- threadgroup short * src1ids = (threadgroup short *)(shared_memory + 8192);
+ // row indices
+ threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
+ // TODO: parallelize this loop
int64_t _ne1 = 0;
- for (int64_t i1 = 0; i1 < ne1; i1++) {
- if (((device int32_t *) (ids + i1*nbi1))[idx] == id) {
- src1ids[_ne1++] = i1;
+ for (ushort ii1 = 0; ii1 < nei1; ii1++) {
+ for (ushort ii0 = 0; ii0 < nei0; ii0++) {
+ int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
+ if (id == i02) {
+ //if (tiitg == 0) {
+ rowids[_ne1] = ushort2(ii0, ii1);
+ //}
+ _ne1++;
+ }
}
}
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
src0,
src1,
- src1ids,
+ rowids,
dst,
ne00,
ne02,
nb01,
nb02,
+ ne11,
ne12,
nb10,
nb11,
nb12,
ne0,
_ne1,
- r2,
- r3,
+ ne0*ne1,
shared_memory,
tgpig,
tiitg,
// matrix-matrix multiplication
//
-typedef void (mat_mm_t)(
- device const uchar * src0,
- device const uchar * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne02,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup uchar *,
- uint3, uint, uint);
+typedef decltype(kernel_mul_mm<float4x4, 1, dequantize_f32>) mat_mm_t;
template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm<float4x4, 1, dequantize_f32>;
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
// indirect matrix-matrix multiplication
//
-typedef void (mat_mm_id_t)(
- device const uchar * src0s,
- device const uchar * src1,
- device float * dst,
- device const uchar * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne02,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- threadgroup uchar *,
- uint3, uint, uint);
+typedef decltype(kernel_mul_mm_id<float4x4, 1, dequantize_f32>) mat_mm_id_t;
template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<float4x4, 1, dequantize_f32>;
template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<half4x4, 1, dequantize_f16>;
device const char * src0,
device const char * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]]);
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ uint64_t nb00,
+ uint64_t nb01,
+ uint64_t nb02,
+ int64_t ne10,
+ int64_t ne11,
+ int64_t ne12,
+ uint64_t nb10,
+ uint64_t nb11,
+ uint64_t nb12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ uint3 tgpig,
+ uint tiisg);
typedef void (kernel_mul_mv2_impl_t)(
device const void * src0,
device const float * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne10,
- constant int64_t & ne12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]);
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ int64_t ne10,
+ int64_t ne12,
+ int64_t ne0,
+ int64_t ne1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiisg,
+ uint sgitg);
template<kernel_mul_mv_impl_t impl_fn>
void mmv_fn(
device const char * src0,
device const char * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ uint64_t nb00,
+ uint64_t nb01,
+ uint64_t nb02,
+ int64_t ne10,
+ int64_t ne11,
+ int64_t ne12,
+ int64_t ne13,
+ uint64_t nb10,
+ uint64_t nb11,
+ uint64_t nb12,
+ int64_t ne0,
+ int64_t ne1,
+ uint64_t nb1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiitg,
+ uint tiisg,
+ uint sgitg) {
impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,ne10,ne11,ne12,nb10,nb11,nb12,ne0,ne1,r2,r3,tgpig,tiisg);
}
device const char * src0,
device const char * src1,
device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ int64_t ne00,
+ int64_t ne01,
+ int64_t ne02,
+ uint64_t nb00,
+ uint64_t nb01,
+ uint64_t nb02,
+ int64_t ne10,
+ int64_t ne11,
+ int64_t ne12,
+ int64_t ne13,
+ uint64_t nb10,
+ uint64_t nb11,
+ uint64_t nb12,
+ int64_t ne0,
+ int64_t ne1,
+ uint64_t nb1,
+ uint r2,
+ uint r3,
+ threadgroup int8_t * shared_values,
+ uint3 tgpig,
+ uint tiitg,
+ uint tiisg,
+ uint sgitg) {
impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
}
-typedef void (mul_mv_impl_fn_t)(
- device const char * src0,
- device const char * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]);
+typedef decltype(mmv_fn<kernel_mul_mv_f32_f32_impl>) mul_mv_impl_fn_t;
template<mul_mv_impl_fn_t impl_fn>
kernel void kernel_mul_mv_id(
device const char * src1,
device float * dst,
device const char * ids,
+ constant int64_t & nei0,
+ constant int64_t & nei1,
constant uint64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
threadgroup int8_t * shared_values [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int64_t bid = tgpig.z/(ne12*ne13);
+ const int iid1 = tgpig.z/nei0;
+ const int idx = tgpig.z%nei0;
+
+ tgpig.z = 0;
- tgpig.z = tgpig.z%(ne12*ne13);
+ const int32_t i02 = ((device const int32_t *) (ids + iid1*nbi1))[idx];
- const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
- device const char * src0 = src0s + id*nb02;
+ const int64_t i11 = idx % ne11;
+ const int64_t i12 = iid1;
+
+ const int64_t i1 = idx;
+ const int64_t i2 = i12;
+
+ device const char * src0_cur = src0s + i02*nb02;
+ device const char * src1_cur = src1 + i11*nb11 + i12*nb12;
+ device float * dst_cur = dst + i1*ne0 + i2*ne1*ne0;
impl_fn(
- src0,
- src1 + bid*nb11,
- dst + bid*ne0,
- ne00,
- ne01,
- ne02,
- nb00,
- nb01,
- nb02,
- ne10,
- ne11,
- ne12,
- ne13,
- nb10,
- nb11,
- nb12,
- ne0,
- ne1,
- nb1,
- r2,
- r3,
+ /* src0 */ src0_cur,
+ /* src1 */ src1_cur,
+ /* dst */ dst_cur,
+ /* ne00 */ ne00,
+ /* ne01 */ ne01,
+ /* ne02 */ 1,//ne02,
+ /* nb00 */ nb00,
+ /* nb01 */ nb01,
+ /* nb02 */ nb02,
+ /* ne10 */ ne10,
+ /* ne11 */ 1,//ne11,
+ /* ne12 */ 1,//ne12,
+ /* ne13 */ 1,//ne13,
+ /* nb10 */ nb10,
+ /* nb11 */ nb11,
+ /* nb12 */ nb12,
+ /* ne0 */ ne0,
+ /* ne1 */ 1,//ne1,
+ /* nb1 */ nb1,
+ /* r2 */ 1,
+ /* r3 */ 1,
shared_values,
tgpig,
tiitg,
sgitg);
}
-typedef void (kernel_mul_mv_id_t)(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- constant uint & r2,
- constant uint & r3,
- constant int & idx,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]);
+typedef decltype(kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>) kernel_mul_mv_id_t;
template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f32_f32_impl>>;
template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id<mmv_fn<kernel_mul_mv_f16_f32_impl>>;
// ggml_mul_mat_id
-// NOTE: id will be removed in the future and instead all the experts listed in ids will be computed
-// this will allow computing all the used experts in a single matrix multiplication
+/*
+ c = ggml_mul_mat_id(ctx, as, b, ids);
+
+ as -> [cols, rows, n_expert]
+ ids -> [n_experts_used, n_tokens] (i32)
+ b -> [cols, n_expert_used, n_tokens]
+ c -> [cols, n_expert_used, n_tokens]
+
+ in b, n_experts_used can be broadcasted to match the n_expert_used of ids
+
+ c ~= as[:,:,i] @ b[:,i%r,t], i = ids[e,t] for all e,t in ids
+*/
struct ggml_tensor * ggml_mul_mat_id(
struct ggml_context * ctx,
struct ggml_tensor * as,
- struct ggml_tensor * ids,
- int id,
- struct ggml_tensor * b) {
-
+ struct ggml_tensor * b,
+ struct ggml_tensor * ids) {
+ GGML_ASSERT(!ggml_is_transposed(as));
GGML_ASSERT(ids->type == GGML_TYPE_I32);
+
+ GGML_ASSERT(as->ne[3] == 1); // as is 3d (one matrix per expert)
+ GGML_ASSERT(b->ne[3] == 1); // b is 3d
GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1); // ids is 2d
- GGML_ASSERT(ids->ne[1] == b->ne[1]); // must have an expert per b row
- GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]);
- GGML_ASSERT(id >= 0 && id < ids->ne[0]); // valid id
+ GGML_ASSERT(ids->ne[1] == b->ne[2]); // must have an expert list per b row
GGML_ASSERT(as->ne[0] == b->ne[0]); // can_mul_mat
+ GGML_ASSERT(ids->ne[0] % b->ne[1] == 0); // can broadcast
bool is_node = false;
is_node = true;
}
- const int64_t ne[4] = { as->ne[1], b->ne[1], b->ne[2], b->ne[3] };
+ const int64_t ne[4] = { as->ne[1], ids->ne[0], b->ne[2], 1 };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
- ggml_set_op_params_i32(result, 0, id);
-
result->op = GGML_OP_MUL_MAT_ID;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = as;
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
- GGML_ASSERT(ne0 == ne01);
- GGML_ASSERT(ne1 == ne11);
- GGML_ASSERT(ne2 == ne12);
- GGML_ASSERT(ne3 == ne13);
-
// we don't support permuted src0 or src1
GGML_ASSERT(nb00 == ggml_type_size(type));
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);
- // broadcast is not supported with mmid
- assert(ne12 == 1);
- assert(ne13 == 1);
-
// row groups
- const int id = ggml_get_op_params_i32(dst, 0);
- const int n_as = src0->ne[2];
+ const int n_ids = ids->ne[0]; // n_expert_used
+ const int n_as = ne02; // n_expert
char * wdata_src1_end = (src1->type == vec_dot_type) ?
(char *) params->wdata :
(char *) params->wdata + GGML_PAD(ggml_row_size(vec_dot_type, ggml_nelements(src1)), sizeof(int64_t));
- int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
- int64_t * matrix_rows = matrix_row_counts + n_as; // [n_as][ne11]
+ struct mmid_row_mapping {
+ int32_t i1;
+ int32_t i2;
+ };
- #define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne11 + (i1)]
+ int64_t * matrix_row_counts = (int64_t *) (wdata_src1_end); // [n_as]
+ struct mmid_row_mapping * matrix_rows = (struct mmid_row_mapping *)(matrix_row_counts + n_as); // [n_as][ne11]
if (params->type == GGML_TASK_TYPE_INIT) {
if (ith != 0) {
// initialize matrix_row_counts
memset(matrix_row_counts, 0, n_as*sizeof(int64_t));
+#define MMID_MATRIX_ROW(row_id, i1) matrix_rows[(row_id)*ne12 + (i1)]
+
// group rows by src0 matrix
- for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
- const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]);
+ for (int64_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) {
+ for (int id = 0; id < n_ids; ++id) {
+ const int32_t i02 = *(const int32_t *) ((const char *) ids->data + iid1*ids->nb[1] + id*ids->nb[0]);
+
+ assert(i02 >= 0 && i02 < n_as);
- GGML_ASSERT(row_id >= 0 && row_id < n_as);
- MMID_MATRIX_ROW(row_id, matrix_row_counts[row_id]) = i01;
- matrix_row_counts[row_id] += 1;
+ MMID_MATRIX_ROW(i02, matrix_row_counts[i02]) = (struct mmid_row_mapping) {id, iid1};
+ matrix_row_counts[i02] += 1;
+ }
}
return;
continue;
}
- size_t src0_offset = cur_a*src0->nb[2];
+ const char * src0_cur = (const char *) src0->data + cur_a*nb02;
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
- const int64_t nr0 = ne01; // src0 rows
- const int64_t nr1 = cne1*ne12*ne13; // src1 rows
-
- //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
+ const int64_t nr0 = ne01; // src0 rows
+ const int64_t nr1 = cne1; // src1 rows
// distribute the thread work across the inner or outer loop based on which one is larger
const int64_t ir110 = dr1*ith1;
const int64_t ir111 = MIN(ir110 + dr1, nr1);
- //printf("ir010 = %6lld, ir011 = %6lld, ir110 = %6lld, ir111 = %6lld\n", ir010, ir011, ir110, ir111);
-
// threads with no work simply yield (not sure if it helps)
- if (ir010 >= ir011 || ir110 >= ir111) {
- sched_yield();
- continue;
- }
+ //if (ir010 >= ir011 || ir110 >= ir111) {
+ // sched_yield();
+ // continue;
+ //}
// block-tiling attempt
const int64_t blck_0 = 16;
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
- const int64_t i13 = (ir1/(ne12*cne1)); // Note: currently, src1 is always a matrix
- const int64_t i12 = (ir1 - i13*ne12*cne1)/cne1;
- const int64_t _i11 = (ir1 - i13*ne12*cne1 - i12*cne1);
- const int64_t i11 = MMID_MATRIX_ROW(cur_a, _i11);
+ const int64_t _i12 = ir1; // logical row index for this expert
- // broadcast src0 into src1
- //const int64_t i03 = i13/r3;
- //const int64_t i02 = i12/r2;
+ struct mmid_row_mapping row_mapping = MMID_MATRIX_ROW(cur_a, _i12);
+ const int id = row_mapping.i1; // selected expert index
- const int64_t i1 = i11;
- const int64_t i2 = i12;
- const int64_t i3 = i13;
+ const int64_t i11 = id % ne11;
+ const int64_t i12 = row_mapping.i2; // row index in src1
- const char * src0_row = (const char *) src0->data + src0_offset;
+ const int64_t i1 = id; // selected expert index
+ const int64_t i2 = i12; // row
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
// TODO: this is a bit of a hack, we should probably have a better way to handle this
const char * src1_col = (const char *) wdata +
(src1_cont || src1->type != vec_dot_type
- ? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
- : (i11*nb11 + i12*nb12 + i13*nb13));
+ ? (i11 + i12*ne11)*row_size
+ : (i11*nb11 + i12*nb12));
- float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
+ float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2));
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
// vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
//}
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
- vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_row + ir0*nb01, 0, src1_col, 0, 1);
+ vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_cur + ir0*nb01, 0, src1_col, 0, 1);
}
+
memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
}
}
}
}
- #undef MMID_MATRIX_ROW
+#undef MMID_MATRIX_ROW
}
// ggml_compute_forward_out_prod
const int n_as = src0->ne[2];
cur += GGML_PAD(cur, sizeof(int64_t)); // align
cur += n_as * sizeof(int64_t); // matrix_row_counts
- cur += n_as * src1->ne[1] * sizeof(int64_t); // matrix_rows
+ cur += n_as * src1->ne[2] * sizeof(int64_t); // matrix_rows
} break;
case GGML_OP_OUT_PROD:
{
ok = ok && cur != NULL;
- ggml_set_name(cur, ctx->infos[i].name.data);
-
if (!ok) {
break;
}
+ ggml_set_name(cur, ctx->infos[i].name.data);
+
// point the data member to the appropriate location in the binary blob using the tensor infos
if (!params.no_alloc) {
//cur->data = (char *) data->data + ctx->infos[i].offset - ctx->offset; // offset from start of file