#include <algorithm>
+#include <assert.h>
+#include <atomic>
+#include <cinttypes>
#include <cstddef>
#include <cstdint>
-#include <cinttypes>
#include <float.h>
#include <limits>
#include <stdint.h>
#include <stdio.h>
-#include <atomic>
-#include <assert.h>
+#include <vector>
+
#if defined(GGML_USE_HIPBLAS)
#include <hip/hip_runtime.h>
}
template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
-static __global__ void k_get_rows(const void * x, const int32_t * y, dst_t * dst, const int ncols) {
- const int col = (blockIdx.x*blockDim.x + threadIdx.x)*2;
- const int row = blockDim.y*blockIdx.y + threadIdx.y;
-
- if (col >= ncols) {
+static __global__ void k_get_rows(
+ const void * src0, const int32_t * src1, dst_t * dst,
+ int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
+ /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
+ /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
+ /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
+ size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
+
+ const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2;
+ const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
+ const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
+ const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
+
+ if (i00 >= ne00) {
return;
}
- const int r = y[row];
+ const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
- // copy x[r*ncols + col] to dst[row*ncols + col]
- const int xi = r*ncols + col;
- const int di = row*ncols + col;
+ dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+ const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
- const int ib = xi/qk; // block index
- const int iqs = (xi%qk)/qr; // quant index
- const int iybs = di - di%qk; // y block start index
+ const int ib = i00/qk; // block index
+ const int iqs = (i00%qk)/qr; // quant index
+ const int iybs = i00 - i00%qk; // dst block start index
const int y_offset = qr == 1 ? 1 : qk/2;
// dequantize
dfloat2 v;
- dequantize_kernel(x, ib, iqs, v);
+ dequantize_kernel(src0_row, ib, iqs, v);
- dst[iybs + iqs + 0] = v.x;
- dst[iybs + iqs + y_offset] = v.y;
+ dst_row[iybs + iqs + 0] = v.x;
+ dst_row[iybs + iqs + y_offset] = v.y;
+}
+
+template<typename src0_t, typename dst_t>
+static __global__ void k_get_rows_float(
+ const src0_t * src0, const int32_t * src1, dst_t * dst,
+ int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
+ /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
+ /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
+ /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
+ size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
+
+ const int i00 = blockIdx.x*blockDim.x + threadIdx.x;
+ const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
+ const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
+ const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
+
+ if (i00 >= ne00) {
+ return;
+ }
+
+ const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
+
+ dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+ const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
+
+ dst_row[i00] = src0_row[i00];
}
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
}
template<int qk, int qr, dequantize_kernel_t dq>
-static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
+static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
- const int block_num_x = (ncols + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
- const dim3 block_nums(block_num_x, nrows, 1);
- k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols);
+ const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
+ const dim3 block_nums(block_num_x, ne10, ne11*ne12);
+
+ // strides in elements
+ //const size_t s0 = nb0 / ggml_element_size(dst);
+ const size_t s1 = nb1 / ggml_element_size(dst);
+ const size_t s2 = nb2 / ggml_element_size(dst);
+ const size_t s3 = nb3 / ggml_element_size(dst);
+
+ const size_t s10 = nb10 / ggml_element_size(src1);
+ const size_t s11 = nb11 / ggml_element_size(src1);
+ const size_t s12 = nb12 / ggml_element_size(src1);
+ //const size_t s13 = nb13 / ggml_element_size(src1);
+
+ GGML_ASSERT(ne00 % 2 == 0);
+
+ k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
+ src0_dd, src1_dd, dst_dd,
+ ne00, /*ne01, ne02, ne03,*/
+ /*ne10, ne11,*/ ne12, /*ne13,*/
+ /* s0,*/ s1, s2, s3,
+ /* nb00,*/ nb01, nb02, nb03,
+ s10, s11, s12/*, s13*/);
+
+ (void) dst;
+}
+
+template<typename src0_t>
+static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
+ const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
+ const dim3 block_nums(block_num_x, ne10, ne11*ne12);
+
+ // strides in elements
+ //const size_t s0 = nb0 / ggml_element_size(dst);
+ const size_t s1 = nb1 / ggml_element_size(dst);
+ const size_t s2 = nb2 / ggml_element_size(dst);
+ const size_t s3 = nb3 / ggml_element_size(dst);
+
+ const size_t s10 = nb10 / ggml_element_size(src1);
+ const size_t s11 = nb11 / ggml_element_size(src1);
+ const size_t s12 = nb12 / ggml_element_size(src1);
+ //const size_t s13 = nb13 / ggml_element_size(src1);
+
+ k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
+ src0_dd, src1_dd, dst_dd,
+ ne00, /*ne01, ne02, ne03,*/
+ /*ne10, ne11,*/ ne12, /*ne13,*/
+ /* s0,*/ s1, s2, s3,
+ /* nb00,*/ nb01, nb02, nb03,
+ s10, s11, s12/*, s13*/);
+
+ (void) dst;
}
template<float (*bin_op)(const float, const float)>
GGML_TENSOR_BINARY_OP_LOCALS
-
int nr0 = ne10/ne0;
int nr1 = ne11/ne1;
int nr2 = ne12/ne2;
int64_t ne12 = cne1[2];
int64_t ne13 = cne1[3];
- //size_t nb0 = cnb0[0];
+ size_t nb0 = cnb0[0];
size_t nb1 = cnb0[1];
size_t nb2 = cnb0[2];
size_t nb3 = cnb0[3];
- //size_t nb10 = cnb1[0];
+ size_t nb10 = cnb1[0];
size_t nb11 = cnb1[1];
size_t nb12 = cnb1[2];
size_t nb13 = cnb1[3];
- //size_t s0 = nb0 / sizeof(src1_t);
- size_t s1 = nb1 / sizeof(src1_t);
- size_t s2 = nb2 / sizeof(src1_t);
- size_t s3 = nb3 / sizeof(src1_t);
+ size_t s0 = nb0 / sizeof(dst_t);
+ size_t s1 = nb1 / sizeof(dst_t);
+ size_t s2 = nb2 / sizeof(dst_t);
+ size_t s3 = nb3 / sizeof(dst_t);
- //size_t s10 = nb10 / sizeof(src1_t);
+ size_t s10 = nb10 / sizeof(src1_t);
size_t s11 = nb11 / sizeof(src1_t);
size_t s12 = nb12 / sizeof(src1_t);
size_t s13 = nb13 / sizeof(src1_t);
+ GGML_ASSERT(s0 == 1);
+ GGML_ASSERT(s10 == 1);
const int block_size = 128;
GGML_ASSERT(src1->type == GGML_TYPE_I32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
- GGML_ASSERT(ggml_is_contiguous(src0));
- GGML_ASSERT(ggml_is_contiguous(src1));
- GGML_ASSERT(ggml_is_contiguous(dst));
- const int ncols = src0->ne[0];
- const int nrows = ggml_nelements(src1);
+ GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
+ GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
+ GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
const int32_t * src1_i32 = (const int32_t *) src1_d;
switch (src0->type) {
case GGML_TYPE_F16:
- get_rows_cuda<1, 1, convert_f16>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+ get_rows_cuda_float(src0, src1, dst, (const half *)src0_d, src1_i32, dst_d, stream);
break;
case GGML_TYPE_F32:
- get_rows_cuda<1, 1, convert_f32>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+ get_rows_cuda_float(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
break;
case GGML_TYPE_Q4_0:
- get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+ get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
break;
case GGML_TYPE_Q4_1:
- get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+ get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
break;
case GGML_TYPE_Q5_0:
- get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+ get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
break;
case GGML_TYPE_Q5_1:
- get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+ get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
break;
case GGML_TYPE_Q8_0:
- get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+ get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
break;
default:
// TODO: k-quants
(void) src0_dd;
}
+
inline void ggml_cuda_op_sum_rows(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
}
#endif
-static void ggml_cuda_mul_mat_id(const ggml_tensor * _src0, const ggml_tensor * _src1, ggml_tensor * dst) {
+static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
#if 0
-//#ifdef CUDA_USE_TENSOR_CORES
-// const bool use_tensor_cores = true;
-//#else
-// const bool use_tensor_cores = false;
-//#endif
-
ggml_cuda_mul_mat_id_cublas(dst);
-
// TODO: mmq/mmv support
-#else
- const struct ggml_tensor * ids = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
- const int id = dst->op_params[0];
+#endif
- int32_t * ids_dev = (int32_t *)((ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
+ GGML_ASSERT(dst->backend == GGML_BACKEND_GPU);
- int32_t a_id;
- CUDA_CHECK(cudaMemcpyAsync(&a_id, ids_dev + id, sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
- CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
+ const struct ggml_tensor * ids = src0;
+ const int32_t id = ((int32_t *) dst->op_params)[0];
+ const int32_t n_as = ((int32_t *) dst->op_params)[1];
- GGML_ASSERT(a_id >= 0 && a_id < ids->ne[0]);
- const struct ggml_tensor * src0 = dst->src[a_id + 2];
+ std::vector<char> ids_host(ggml_nbytes(ids));
+
+ if (ids->backend == GGML_BACKEND_GPU) {
+ const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
+ CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
+ CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
+ } else {
+ memcpy(ids_host.data(), ids->data, ggml_nbytes(ids));
+ }
+
+ const ggml_tensor_extra_gpu * src1_extra = (const ggml_tensor_extra_gpu *) src1->extra;
+ const ggml_tensor_extra_gpu * dst_extra = (const ggml_tensor_extra_gpu *) dst->extra;
+
+ ggml_tensor_extra_gpu src1_row_extra;
+ ggml_tensor_extra_gpu dst_row_extra;
+
+ ggml_tensor src1_row = *src1;
+ ggml_tensor dst_row = *dst;
+
+ src1_row.ne[1] = 1;
+ dst_row.ne[1] = 1;
+
+ src1_row.nb[2] = src1_row.nb[1];
+ dst_row.nb[2] = dst_row.nb[1];
+
+ src1_row.nb[3] = src1_row.nb[1];
+ dst_row.nb[3] = dst_row.nb[1];
+
+ src1_row.extra = &src1_row_extra;
+ dst_row.extra = &dst_row_extra;
- ggml_cuda_mul_mat(src0, src1, dst);
-#endif
- (void) _src0;
- (void) _src1;
+ for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
+ //int32_t row_id;
+ //CUDA_CHECK(cudaMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
+ //CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
+
+ const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
+
+ GGML_ASSERT(row_id >= 0 && row_id < n_as);
+
+ const struct ggml_tensor * src0_row = dst->src[row_id + 2];
+
+ src1_row_extra.data_device[g_main_device] = (char *) src1_extra->data_device[g_main_device] + i01*src1->nb[1];
+ src1_row.data = (char *) src1->data + i01*src1->nb[1];
+
+ dst_row_extra.data_device[g_main_device] = (char *) dst_extra->data_device[g_main_device] + i01*dst->nb[1];
+ dst_row.data = (char *) dst->data + i01*dst->nb[1];
+
+ ggml_cuda_mul_mat(src0_row, &src1_row, &dst_row);
+ }
}
static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
}
return true;
} break;
+ case GGML_OP_GET_ROWS:
+ {
+ switch (op->src[0]->type) {
+ case GGML_TYPE_F16:
+ case GGML_TYPE_F32:
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q8_0:
+ return true;
+ default:
+ return false;
+ }
+ } break;
+ case GGML_OP_CPY:
+ {
+ ggml_type src0_type = op->src[0]->type;
+ ggml_type src1_type = op->src[1]->type;
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
+ return true;
+ }
+ return false;
+ } break;
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_TRANSPOSE:
case GGML_OP_NORM:
case GGML_OP_REPEAT:
- case GGML_OP_GET_ROWS:
case GGML_OP_DUP:
case GGML_OP_ADD:
case GGML_OP_MUL:
case GGML_OP_SCALE:
case GGML_OP_SQR:
case GGML_OP_CLAMP:
- case GGML_OP_CPY:
case GGML_OP_CONT:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
UNUSED(params);
}
-extern "C" int ggml_backend_cuda_reg_devices() {
+extern "C" int ggml_backend_cuda_reg_devices();
+
+int ggml_backend_cuda_reg_devices() {
int device_count = ggml_cuda_get_device_count();
//int device_count = 1; // DEBUG: some tools require delaying CUDA initialization
for (int i = 0; i < device_count; i++) {
GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
+ //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
+ //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row);
+ //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
+ GGML_METAL_DECL_KERNEL(cpy_f16_f32);
GGML_METAL_DECL_KERNEL(concat);
GGML_METAL_DECL_KERNEL(sqr);
GGML_METAL_DECL_KERNEL(sum_rows);
GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
+ //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
+ //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_1row);
+ //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_l4);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q4_0_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q4_1_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q5_0_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q5_1_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q8_0_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q2_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q3_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
+ GGML_METAL_ADD_KERNEL(cpy_f16_f32);
GGML_METAL_ADD_KERNEL(concat);
GGML_METAL_ADD_KERNEL(sqr);
GGML_METAL_ADD_KERNEL(sum_rows);
GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
+ //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
+ //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row);
+ //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
+ GGML_METAL_DEL_KERNEL(cpy_f16_f32);
GGML_METAL_DEL_KERNEL(concat);
GGML_METAL_DEL_KERNEL(sqr);
GGML_METAL_DEL_KERNEL(sum_rows);
case GGML_OP_PAD:
case GGML_OP_ARGSORT:
case GGML_OP_LEAKY_RELU:
- case GGML_OP_DUP:
- case GGML_OP_CPY:
- case GGML_OP_CONT:
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
return true;
+ case GGML_OP_CPY:
+ case GGML_OP_DUP:
+ case GGML_OP_CONT:
+ {
+ switch (op->src[0]->type) {
+ case GGML_TYPE_F32:
+ switch (op->type) {
+ case GGML_TYPE_F16:
+ case GGML_TYPE_F32:
+ case GGML_TYPE_Q8_0:
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ return true;
+ default:
+ return false;
+ }
+ case GGML_TYPE_F16:
+ switch (op->type) {
+ case GGML_TYPE_F16:
+ case GGML_TYPE_F32:
+ return true;
+ default:
+ return false;
+ }
+ default:
+ return false;
+ };
+ }
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_GET_ROWS:
{
return op->ne[3] == 1;
- } break;
+ }
default:
return false;
}
case GGML_OP_MUL:
case GGML_OP_DIV:
{
- GGML_ASSERT(ggml_is_contiguous(src0));
- GGML_ASSERT(ggml_is_contiguous(src1));
-
const size_t offs = 0;
bool bcast_row = false;
int64_t nb = ne00;
- if (ggml_nelements(src1) == ne10 && ne00 % 4 == 0) {
+ id<MTLComputePipelineState> pipeline = nil;
+
+ if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
// src1 is a row
GGML_ASSERT(ne11 == 1);
nb = ne00 / 4;
switch (dst->op) {
- case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add_row]; break;
- case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul_row]; break;
- case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div_row]; break;
+ case GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break;
+ case GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break;
+ case GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break;
default: GGML_ASSERT(false);
}
bcast_row = true;
} else {
switch (dst->op) {
- case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add]; break;
- case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul]; break;
- case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div]; break;
+ case GGML_OP_ADD: pipeline = ctx->pipeline_add; break;
+ case GGML_OP_MUL: pipeline = ctx->pipeline_mul; break;
+ case GGML_OP_DIV: pipeline = ctx->pipeline_div; break;
default: GGML_ASSERT(false);
}
}
+
+ [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} else {
- const int nth = MIN(1024, ne0);
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
}
else if (src0t == GGML_TYPE_Q6_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {
- int64_t ny = (ne11 + nrows - 1)/nrows;
+ const int64_t ny = (ne11 + nrows - 1)/nrows;
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
}
GGML_ASSERT(src0t == GGML_TYPE_I32);
- const int n_as = ne00;
+ const int n_as = ((int32_t *) dst->op_params)[1];
// TODO: make this more general
GGML_ASSERT(n_as <= 8);
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
// to the matrix-vector kernel
- int ne11_mm_min = 0;
+ int ne11_mm_min = 1;
const int idx = ((int32_t *) dst->op_params)[0];
+ // batch size
+ GGML_ASSERT(ne01 == ne11);
+
+ const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
+
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
- ne11 > ne11_mm_min) {
+ // !!!
+ // TODO: for now, always use mat-vec kernels until we figure out how to improve the
+ // indirect matrix multiplication
+ // !!!
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) {
switch (src2->type) {
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:3];
- [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:4];
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:5];
- [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:6];
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
- [encoder setBytes:&idx length:sizeof(idx) atIndex:15];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:14];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:18];
// TODO: how to make this an array? read Metal docs
for (int j = 0; j < n_as; ++j) {
struct ggml_tensor * src_cur = dst->src[2 + j];
size_t offs_src_cur = 0;
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
- [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:16 + j];
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
}
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+
+ // TODO: processing one row at a time (ne11 -> 1) is not efficient
+ [encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+ } else {
+ int nth0 = 32;
+ int nth1 = 1;
+ int nrows = 1;
+ //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
+
+ // use custom matrix x vector kernel
+ switch (src2t) {
+ case GGML_TYPE_F32:
+ {
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32];
+ } break;
+ case GGML_TYPE_F16:
+ {
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
+ nth0 = 32;
+ nth1 = 1;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32];
+ } break;
+ case GGML_TYPE_Q4_0:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32];
+ } break;
+ case GGML_TYPE_Q4_1:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32];
+ } break;
+ case GGML_TYPE_Q5_0:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32];
+ } break;
+ case GGML_TYPE_Q5_1:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32];
+ } break;
+ case GGML_TYPE_Q8_0:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32];
+ } break;
+ case GGML_TYPE_Q2_K:
+ {
+ nth0 = 2;
+ nth1 = 32;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32];
+ } break;
+ case GGML_TYPE_Q3_K:
+ {
+ nth0 = 2;
+ nth1 = 32;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32];
+ } break;
+ case GGML_TYPE_Q4_K:
+ {
+ nth0 = 4; //1;
+ nth1 = 8; //32;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32];
+ } break;
+ case GGML_TYPE_Q5_K:
+ {
+ nth0 = 2;
+ nth1 = 32;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32];
+ } break;
+ case GGML_TYPE_Q6_K:
+ {
+ nth0 = 2;
+ nth1 = 32;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
+ } break;
+ default:
+ {
+ GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
+ GGML_ASSERT(false && "not implemented");
+ }
+ };
+
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
+ [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
+ [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:22];
+ // TODO: how to make this an array? read Metal docs
+ for (int j = 0; j < n_as; ++j) {
+ struct ggml_tensor * src_cur = dst->src[2 + j];
+
+ size_t offs_src_cur = 0;
+ id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
+
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
+ }
+
+ if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
+ src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
+ src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src2t == GGML_TYPE_Q4_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src2t == GGML_TYPE_Q3_K) {
+#ifdef GGML_QKK_64
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+#else
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+#endif
+ }
+ else if (src2t == GGML_TYPE_Q5_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src2t == GGML_TYPE_Q6_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ } else {
+ const int64_t ny = (_ne1 + nrows - 1)/nrows;
+ [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
}
} break;
case GGML_OP_GET_ROWS:
default: GGML_ASSERT(false && "not implemented");
}
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
-
- const int64_t n = ggml_nelements(src1);
-
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
+ [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
+ [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
+ [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
} break;
case GGML_OP_RMS_NORM:
{
{
switch (dstt) {
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
- case GGML_TYPE_F32: GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break;
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f32]; break;
default: GGML_ASSERT(false && "not implemented");
};
} break;
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
- device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
- device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
- device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+ device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+ device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
+ device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
// parallel max
float lmax = -INFINITY;
pdst[i00] = exp_psrc0;
}
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ // This barrier fixes a failing test
+ // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
+ threadgroup_barrier(mem_flags::mem_none);
+
float sum = simd_sum(lsum);
+
if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = 0.0f;
}
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
- threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ // This barrier fixes a failing test
+ // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
+ threadgroup_barrier(mem_flags::mem_none);
+
float sum = simd_sum(lsum);
+
if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = 0.0f;
// giard against the number of rows not being divisible by
// N_DST, so this is another explicit assumption of the implementation.
template<typename block_q_type, int nr, int nsg, int nw>
-void mul_vec_q_n_f32(
+void mul_vec_q_n_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+ mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
}
kernel void kernel_mul_mv_q4_1_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+ mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
}
kernel void kernel_mul_mv_q5_0_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+ mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
}
kernel void kernel_mul_mv_q5_1_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+ mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
}
#define NB_Q8_0 8
-kernel void kernel_mul_mv_q8_0_f32(
+void kernel_mul_mv_q8_0_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne01[[buffer(4)]],
- constant int64_t & ne02[[buffer(5)]],
- constant int64_t & ne10[[buffer(9)]],
- constant int64_t & ne12[[buffer(11)]],
- constant int64_t & ne0 [[buffer(15)]],
- constant int64_t & ne1 [[buffer(16)]],
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
const int nr = N_DST;
const int nsg = N_SIMDGROUP;
const int nw = N_SIMDWIDTH;
}
}
+[[host_name("kernel_mul_mv_q8_0_f32")]]
+kernel void kernel_mul_mv_q8_0_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2 [[buffer(17)]],
+ constant uint & r3 [[buffer(18)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+}
+
#define N_F32_F32 4
-kernel void kernel_mul_mv_f32_f32(
+void kernel_mul_mv_f32_f32_impl(
device const char * src0,
device const char * src1,
device float * dst,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
}
}
+[[host_name("kernel_mul_mv_f32_f32")]]
+kernel void kernel_mul_mv_f32_f32(
+ device const char * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2 [[buffer(17)]],
+ constant uint & r3 [[buffer(18)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]]) {
+ kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
+}
+
#define N_F16_F16 4
kernel void kernel_mul_mv_f16_f16(
}
}
-kernel void kernel_mul_mv_f16_f32_1row(
+void kernel_mul_mv_f16_f32_1row_impl(
device const char * src0,
device const char * src1,
device float * dst,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
+}
+[[host_name("kernel_mul_mv_f16_f32_1row")]]
+kernel void kernel_mul_mv_f16_f32_1row(
+ device const char * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2 [[buffer(17)]],
+ constant uint & r3 [[buffer(18)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]]) {
+ kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
}
#define N_F16_F32 4
-kernel void kernel_mul_mv_f16_f32(
+void kernel_mul_mv_f16_f32_impl(
device const char * src0,
device const char * src1,
device float * dst,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
}
}
+[[host_name("kernel_mul_mv_f16_f32")]]
+kernel void kernel_mul_mv_f16_f32(
+ device const char * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2 [[buffer(17)]],
+ constant uint & r3 [[buffer(18)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]]) {
+ kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
+}
+
// Assumes row size (ne00) is a multiple of 4
kernel void kernel_mul_mv_f16_f32_l4(
device const char * src0,
}
kernel void kernel_cpy_f16_f16(
- device const half * src0,
- device half * dst,
+ device const half * src0,
+ device half * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
}
}
+kernel void kernel_cpy_f16_f32(
+ device const half * src0,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig[2];
+ const int64_t i02 = tgpig[1];
+ const int64_t i01 = tgpig[0];
+
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+ const int64_t i3 = n / (ne2*ne1*ne0);
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+ device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
+ device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+ dst_data[i00] = src[0];
+ }
+}
+
kernel void kernel_cpy_f32_f16(
device const float * src0,
device half * dst,
//====================================== dot products =========================
-kernel void kernel_mul_mv_q2_K_f32(
+void kernel_mul_mv_q2_K_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne01[[buffer(4)]],
- constant int64_t & ne02[[buffer(5)]],
- constant int64_t & ne10[[buffer(9)]],
- constant int64_t & ne12[[buffer(11)]],
- constant int64_t & ne0 [[buffer(15)]],
- constant int64_t & ne1 [[buffer(16)]],
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
}
}
-#if QK_K == 256
-kernel void kernel_mul_mv_q3_K_f32(
+[[host_name("kernel_mul_mv_q2_K_f32")]]
+kernel void kernel_mul_mv_q2_K_f32(
device const void * src0,
device const float * src1,
device float * dst,
constant uint & r2 [[buffer(17)]],
constant uint & r3 [[buffer(18)]],
uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
+#if QK_K == 256
+void kernel_mul_mv_q3_K_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
const int nb = ne00/QK_K;
}
}
#else
-kernel void kernel_mul_mv_q3_K_f32(
+void kernel_mul_mv_q3_K_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne01[[buffer(4)]],
- constant int64_t & ne02[[buffer(5)]],
- constant int64_t & ne10[[buffer(9)]],
- constant int64_t & ne12[[buffer(11)]],
- constant int64_t & ne0 [[buffer(15)]],
- constant int64_t & ne1 [[buffer(16)]],
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
}
#endif
+[[host_name("kernel_mul_mv_q3_K_f32")]]
+kernel void kernel_mul_mv_q3_K_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01[[buffer(4)]],
+ constant int64_t & ne02[[buffer(5)]],
+ constant int64_t & ne10[[buffer(9)]],
+ constant int64_t & ne12[[buffer(11)]],
+ constant int64_t & ne0 [[buffer(15)]],
+ constant int64_t & ne1 [[buffer(16)]],
+ constant uint & r2 [[buffer(17)]],
+ constant uint & r3 [[buffer(18)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
#if QK_K == 256
-kernel void kernel_mul_mv_q4_K_f32(
+void kernel_mul_mv_q4_K_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne01 [[buffer(4)]],
- constant int64_t & ne02 [[buffer(5)]],
- constant int64_t & ne10 [[buffer(9)]],
- constant int64_t & ne12 [[buffer(11)]],
- constant int64_t & ne0 [[buffer(15)]],
- constant int64_t & ne1 [[buffer(16)]],
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
}
}
#else
-kernel void kernel_mul_mv_q4_K_f32(
+void kernel_mul_mv_q4_K_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne01[[buffer(4)]],
- constant int64_t & ne02[[buffer(5)]],
- constant int64_t & ne10[[buffer(9)]],
- constant int64_t & ne12[[buffer(11)]],
- constant int64_t & ne0 [[buffer(15)]],
- constant int64_t & ne1 [[buffer(16)]],
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
}
#endif
-kernel void kernel_mul_mv_q5_K_f32(
+[[host_name("kernel_mul_mv_q4_K_f32")]]
+kernel void kernel_mul_mv_q4_K_f32(
device const void * src0,
device const float * src1,
device float * dst,
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
+void kernel_mul_mv_q5_K_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
const int nb = ne00/QK_K;
const int64_t r0 = tgpig.x;
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
}
}
-
}
-kernel void kernel_mul_mv_q6_K_f32(
+[[host_name("kernel_mul_mv_q5_K_f32")]]
+kernel void kernel_mul_mv_q5_K_f32(
device const void * src0,
device const float * src1,
device float * dst,
constant uint & r2 [[buffer(17)]],
constant uint & r3 [[buffer(18)]],
uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
+void kernel_mul_mv_q6_K_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
const uint8_t kmask1 = 0x03;
const uint8_t kmask2 = 0x0C;
}
}
+[[host_name("kernel_mul_mv_q6_K_f32")]]
+kernel void kernel_mul_mv_q6_K_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01[[buffer(4)]],
+ constant int64_t & ne02[[buffer(5)]],
+ constant int64_t & ne10[[buffer(9)]],
+ constant int64_t & ne12[[buffer(11)]],
+ constant int64_t & ne0 [[buffer(15)]],
+ constant int64_t & ne1 [[buffer(16)]],
+ constant uint & r2 [[buffer(17)]],
+ constant uint & r3 [[buffer(18)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
//============================= templates and their specializations =============================
// NOTE: this is not dequantizing - we are simply fitting the template
template <typename type4x4>
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
- const half d = xb->d;
- const half min = xb->dmin;
+ const float d = xb->d;
+ const float min = xb->dmin;
device const uint8_t * q = (device const uint8_t *)xb->qs;
- half dl, ml;
+ float dl, ml;
uint8_t sc = xb->scales[il];
#if QK_K == 256
q = q + (il/4) * 32 + 16 * (il&1);
il = il & 3;
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
- const half d = il < 2 ? xb->d : xb->d / 16.h;
- const half min = xb->dmin;
- const half dl = d * sc[0];
- const half ml = min * sc[1];
+ const float d = il < 2 ? xb->d : xb->d / 16.h;
+ const float min = xb->dmin;
+ const float dl = d * sc[0];
+ const float ml = min * sc[1];
#else
q = q + 16 * (il&1);
device const uint8_t * s = xb->scales;
uint8_t ul = 1 << (il/2);
il = il & 3;
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
- const half d = il < 2 ? xb->d : xb->d / 16.h;
- const half min = xb->dmin;
- const half dl = d * sc[0];
- const half ml = min * sc[1];
+ const float d = il < 2 ? xb->d : xb->d / 16.h;
+ const float min = xb->dmin;
+ const float dl = d * sc[0];
+ const float ml = min * sc[1];
- const ushort mask = il<2 ? 0x0F : 0xF0;
- const half qh_val = il<2 ? 16.h : 256.h;
+ const ushort mask = il<2 ? 0x0F : 0xF0;
+ const float qh_val = il<2 ? 16.f : 256.f;
for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
}
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
kernel void kernel_get_rows(
device const void * src0,
- device const int * src1,
+ device const char * src1,
device float * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
constant uint64_t & nb1,
- uint tgpig[[threadgroup_position_in_grid]],
+ constant uint64_t & nb2,
+ uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
- uint tptg[[threads_per_threadgroup]]) {
- const int i = tgpig;
- const int r = ((device int32_t *) src1)[i];
+ uint3 tptg [[threads_per_threadgroup]]) {
+ //const int64_t i = tgpig;
+ //const int64_t r = ((device int32_t *) src1)[i];
+
+ const int64_t i10 = tgpig.x;
+ const int64_t i11 = tgpig.y;
+
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+ const int64_t i02 = i11;
- for (int ind = tiitg; ind < ne00/16; ind += tptg) {
+ for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
float4x4 temp;
dequantize_func(
- ((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
- *(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp;
+ ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
+ *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
+ }
+}
+
+kernel void kernel_get_rows_f32(
+ device const void * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint3 tptg [[threads_per_threadgroup]]) {
+ const int64_t i10 = tgpig.x;
+ const int64_t i11 = tgpig.y;
+
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+ const int64_t i02 = i11;
+
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
+ ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
+ ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
+ }
+}
+
+kernel void kernel_get_rows_f16(
+ device const void * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint3 tptg [[threads_per_threadgroup]]) {
+ const int64_t i10 = tgpig.x;
+ const int64_t i11 = tgpig.y;
+
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+ const int64_t i02 = i11;
+
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
+ ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
+ ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
}
}
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
kernel void kernel_mul_mm_id(
- device const int32_t * ids,
+ device const uchar * ids,
device const uchar * src1,
- device float * dst,
+ device uchar * dst,
+ constant int64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne02,
constant int64_t & nb01,
constant int64_t & nb02,
constant int64_t & ne12,
+ constant int64_t & ne13,
constant int64_t & nb10,
constant int64_t & nb11,
constant int64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
+ constant int64_t & nb1,
constant uint & r2,
constant uint & r3,
constant int & idx,
uint sgitg[[simdgroup_index_in_threadgroup]]) {
device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
kernel_mul_mm_impl<block_q, nl, dequantize_func>(
- src0[ids[idx]],
- src1,
- dst,
+ src0[id],
+ src1 + bid*nb11,
+ (device float *) (dst + bid*nb1),
ne00,
ne02,
nb01,
#define QK_NL 4
#endif
+//
+// get rows
+//
+
typedef void (get_rows_t)(
device const void * src0,
- device const int * src1,
+ device const char * src1,
device float * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
constant uint64_t & nb1,
- uint, uint, uint);
+ constant uint64_t & nb2,
+ uint3, uint, uint3);
-template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
-template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
+//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
+//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
+//
+// matrix-matrix multiplication
+//
+
typedef void (mat_mm_t)(
device const uchar * src0,
device const uchar * src1,
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
+//
+// indirect matrix-matrix multiplication
+//
+
typedef void (mat_mm_id_t)(
- device const int32_t * ids,
+ device const uchar * ids,
device const uchar * src1,
- device float * dst,
+ device uchar * dst,
+ constant int64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne02,
constant int64_t & nb01,
constant int64_t & nb02,
constant int64_t & ne12,
+ constant int64_t & ne13,
constant int64_t & nb10,
constant int64_t & nb11,
constant int64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
+ constant int64_t & nb1,
constant uint & r2,
constant uint & r3,
constant int & idx,
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
+
+//
+// matrix-vector multiplication
+//
+
+[[host_name("kernel_mul_mv_id_f32_f32")]]
+kernel void kernel_mul_mv_id_f32_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_f32_f32_impl(
+ src0[id],
+ src1 + bid*nb11,
+ (device float *) (dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ nb00,
+ nb01,
+ nb02,
+ ne10,
+ ne11,
+ ne12,
+ nb10,
+ nb11,
+ nb12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg);
+}
+
+[[host_name("kernel_mul_mv_id_f16_f32")]]
+kernel void kernel_mul_mv_id_f16_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_f16_f32_impl(
+ src0[id],
+ src1 + bid*nb11,
+ (device float *) (dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ nb00,
+ nb01,
+ nb02,
+ ne10,
+ ne11,
+ ne12,
+ nb10,
+ nb11,
+ nb12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg);
+}
+
+[[host_name("kernel_mul_mv_id_q8_0_f32")]]
+kernel void kernel_mul_mv_id_q8_0_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_q8_0_f32_impl(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q4_0_f32")]]
+kernel void kernel_mul_mv_id_q4_0_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q4_1_f32")]]
+kernel void kernel_mul_mv_id_q4_1_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q5_0_f32")]]
+kernel void kernel_mul_mv_id_q5_0_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q5_1_f32")]]
+kernel void kernel_mul_mv_id_q5_1_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q2_K_f32")]]
+kernel void kernel_mul_mv_id_q2_K_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_q2_K_f32_impl(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q3_K_f32")]]
+kernel void kernel_mul_mv_id_q3_K_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_q3_K_f32_impl(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q4_K_f32")]]
+kernel void kernel_mul_mv_id_q4_K_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_q4_K_f32_impl(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q5_K_f32")]]
+kernel void kernel_mul_mv_id_q5_K_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_q5_K_f32_impl(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q6_K_f32")]]
+kernel void kernel_mul_mv_id_q6_K_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_q6_K_f32_impl(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
size_t size = ggml_nelements(tensor);
std::vector<float> data(size);
- std::random_device rd;
-
#if 0
std::default_random_engine generator(rd());
std::uniform_real_distribution<float> distribution(min, max);
}
#endif
auto init_thread = [&](size_t start, size_t end) {
+ std::random_device rd;
std::default_random_engine generator(rd());
std::uniform_real_distribution<float> distribution(min, max);
t.join();
}
- if (tensor->type == GGML_TYPE_F32) {
+ if (tensor->type == GGML_TYPE_F32 || tensor->type == GGML_TYPE_I32) {
ggml_backend_tensor_set(tensor, data.data(), 0, size * sizeof(float));
} else if (ggml_is_quantized(tensor->type) || tensor->type == GGML_TYPE_F16) {
GGML_ASSERT(size % ggml_blck_size(tensor->type) == 0);
std::vector<uint8_t> buf(ggml_nbytes(t));
ggml_backend_tensor_get(t, buf.data(), 0, ggml_nbytes(t));
+ ggml_type_traits_t tt = ggml_internal_get_type_traits(t->type);
+ size_t bs = ggml_blck_size(t->type);
+
// access elements by index to avoid gaps in views
for (int64_t i3 = 0; i3 < t->ne[3]; i3++) {
for (int64_t i2 = 0; i2 < t->ne[2]; i2++) {
for (int64_t i1 = 0; i1 < t->ne[1]; i1++) {
- for (int64_t i0 = 0; i0 < t->ne[0]; i0++) {
- size_t i = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0*t->nb[0];
- float v;
+ for (int64_t i0 = 0; i0 < t->ne[0]; i0 += bs) {
+ size_t i = i3*t->nb[3] + i2*t->nb[2] + i1*t->nb[1] + i0/bs*t->nb[0];
if (t->type == GGML_TYPE_F16) {
- v = (float) ggml_fp16_to_fp32(*(ggml_fp16_t*)&buf[i]);
+ tv.push_back(ggml_fp16_to_fp32(*(ggml_fp16_t*)&buf[i]));
} else if (t->type == GGML_TYPE_F32) {
- v = *(float *) &buf[i];
+ tv.push_back(*(float *) &buf[i]);
} else if (t->type == GGML_TYPE_I32) {
- v = *(int32_t *) &buf[i];
+ tv.push_back((float)*(int32_t *) &buf[i]);
+ } else if (ggml_is_quantized(t->type)) {
+ std::vector<float> vq(ggml_blck_size(t->type));
+ tt.to_float(&buf[i], vq.data(), ggml_blck_size(t->type));
+ tv.insert(tv.end(), vq.begin(), vq.end());
} else {
GGML_ASSERT(false);
}
- tv.push_back(v);
}
}
}
struct test_case {
virtual ~test_case() {}
+ virtual std::string op_desc(ggml_tensor * t) {
+ return ggml_op_desc(t);
+ }
+
virtual std::string vars() {
return "";
}
virtual ggml_tensor * build_graph(ggml_context * ctx) = 0;
virtual double max_nmse_err() {
- return 1e-6;
+ return 1e-7;
}
virtual void initialize_tensors(ggml_context * ctx) {
ggml_tensor * out = build_graph(ctx);
- if (op_name != nullptr && strcmp(ggml_op_desc(out), op_name) != 0) {
- //printf(" %s: skipping\n", ggml_op_desc(out));
+ if (op_name != nullptr && op_desc(out) != op_name) {
+ //printf(" %s: skipping\n", op_desc(out).c_str());
ggml_free(ctx);
return true;
}
- printf(" %s(%s): ", ggml_op_desc(out), vars().c_str());
+ printf(" %s(%s): ", op_desc(out).c_str(), vars().c_str());
fflush(stdout);
// check if backends support op
for (size_t i = 0; i < f1.size(); i++) {
// check for nans
if (std::isnan(f1[i]) || std::isnan(f2[i])) {
- printf("NaN at index %zu ", i);
+ printf("[%s] NaN at index %zu (%f %f) ", ggml_op_desc(t1), i, f1[i], f2[i]);
ud->ok = false;
return true;
}
if (isinf_or_max(f1[i]) || isinf_or_max(f2[i])) {
if (isinf_or_max(f1[i]) && isinf_or_max(f2[i])) {
if (std::signbit(f1[i]) != std::signbit(f2[i])) {
- printf("inf sign mismatch: %f %f ", f1[i], f2[i]);
+ printf("[%s] inf sign mismatch: %f %f ", ggml_op_desc(t1), f1[i], f2[i]);
ud->ok = false;
return true;
}
} else {
- printf("inf mismatch: %f %f ", f1[i], f2[i]);
+ printf("[%s] inf mismatch: %f %f ", ggml_op_desc(t1), f1[i], f2[i]);
ud->ok = false;
return true;
}
double err = nmse(f1.data(), f2.data(), f1.size());
if (err > ud->max_err) {
- printf("NMSE = %f ", err);
+ printf("[%s] NMSE = %f ", ggml_op_desc(t1), err);
+ //for (int i = 0; i < f1.size(); i++) {
+ // printf("%5d %9.6f %9.6f, diff = %9.6f\n", i, f1[i], f2[i], f1[i] - f2[i]);
+ //}
+ //printf("\n");
+ //exit(1);
ud->ok = false;
}
return true;
+
+ GGML_UNUSED(index);
};
ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud);
ggml_tensor * out = build_graph(ctx);
- if (op_name != nullptr && strcmp(ggml_op_desc(out), op_name) != 0) {
- //printf(" %s: skipping\n", ggml_op_desc(out));
+ if (op_name != nullptr && op_desc(out) != op_name) {
+ //printf(" %s: skipping\n", op_desc(out).c_str());
ggml_free(ctx);
return true;
}
- int len = printf(" %s(%s): ", ggml_op_desc(out), vars().c_str());
+ int len = printf(" %s(%s): ", op_desc(out).c_str(), vars().c_str());
fflush(stdout);
// check if backends support op
return size;
};
for (int i = 0; i < gf->n_nodes; i++) {
- if (ggml_is_view_op(gf->nodes[i]->op) || gf->nodes[i] == out)
+ if (ggml_is_view_op(gf->nodes[i]->op) || gf->nodes[i] == out) {
continue;
+ }
mem += tensor_op_size(gf->nodes[i]);
}
const int n; // cols
const int m; // rows
const int r; // rows to get
+ const int b; // batch size
+ const bool v; // view (non-contiguous src1)
std::string vars() override {
- return VARS_TO_STR4(type, n, m, r);
+ return VARS_TO_STR6(type, n, m, r, b, v);
}
- test_get_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3)
- : type(type), n(n), m(m), r(r) {}
+ test_get_rows(ggml_type type = GGML_TYPE_F32, int n = 10, int m = 5, int r = 3, int b = 1, bool v = false)
+ : type(type), n(n), m(m), r(r), b(b), v(v) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
- ggml_tensor * in = ggml_new_tensor_2d(ctx, type, n, m);
- ggml_tensor * rows = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, r);
+ ggml_tensor * in = ggml_new_tensor_3d(ctx, type, n, m, b);
+ ggml_tensor * rows = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, r, b);
+ if (v) {
+ rows = ggml_view_2d(ctx, rows, r/2, b, rows->nb[1], 0);
+ }
ggml_tensor * out = ggml_get_rows(ctx, in, rows);
return out;
}
void initialize_tensors(ggml_context * ctx) override {
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->type == GGML_TYPE_I32) {
+ if (ggml_is_view_op(t->op)) { continue; }
// rows
- std::vector<int> data(r);
- for (int i = 0; i < r; i++) {
+ std::vector<int> data(r*b);
+ for (int i = 0; i < r*b; i++) {
data[i] = rand() % m;
}
- ggml_backend_tensor_set(t, data.data(), 0, r * sizeof(int));
+ ggml_backend_tensor_set(t, data.data(), 0, r * b * sizeof(int));
} else {
init_tensor_uniform(t);
}
const int64_t m;
const int64_t n;
const int64_t k;
- const std::array<int64_t, 2> bs; // dims 3 and 4
- const std::array<int64_t, 2> nr; // repeat in dims 3 and 4
+ const bool v; // view (non-contiguous ids)
std::string vars() override {
- return VARS_TO_STR9(type_a, type_b, n_mats, id, m, n, k, bs, nr);
+ return VARS_TO_STR8(type_a, type_b, n_mats, id, m, n, k, v);
}
double max_nmse_err() override {
}
size_t op_size(ggml_tensor * t) override {
- size_t a = ggml_nbytes(t->src[2]) * n * nr[0] * nr[1];
+ size_t a = ggml_nbytes(t->src[2]) * n;
size_t b = ggml_nbytes(t->src[1]) * m;
size_t c = ggml_nbytes(t);
return a + b + c;
test_mul_mat_id(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32,
int n_mats = 2, int id = 0,
- int64_t m = 32, int64_t n = 32, int64_t k = 32,
- std::array<int64_t, 2> bs = {10, 10},
- std::array<int64_t, 2> nr = {2, 2})
+ int64_t m = 32, int64_t n = 32, int64_t k = 32, bool v = false)
: type_a(type_a), type_b(type_b), n_mats(n_mats), id(id),
- m(m), n(n), k(k), bs(bs), nr(nr) {}
+ m(m), n(n), k(k), v(v) {}
ggml_tensor * build_graph(ggml_context * ctx) override {
// C^T = A * B^T: (k, m) * (k, n) => (m, n)
std::vector<ggml_tensor *> mats;
for (int i = 0; i < n_mats; i++) {
- ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]);
+ ggml_tensor * a = ggml_new_tensor_2d(ctx, type_a, k, m);
mats.push_back(a);
}
- ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_mats);
- ggml_tensor * b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]);
- ggml_tensor * out = ggml_mul_mat_id(ctx, mats.data(), ids, id, b);
+ ggml_tensor * ids = ggml_new_tensor_2d(ctx, GGML_TYPE_I32, n_mats, n);
+ if (v) {
+ ids = ggml_view_2d(ctx, ids, n_mats/2, ids->ne[1], ids->nb[1], 0);
+ }
+ ggml_tensor * b = ggml_new_tensor_2d(ctx, type_b, k, n);
+ ggml_tensor * out = ggml_mul_mat_id(ctx, mats.data(), n_mats, ids, v ? id/2 : id, b);
return out;
}
void initialize_tensors(ggml_context * ctx) override {
+ std::random_device rd;
+ std::default_random_engine rng(rd());
for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
if (t->type == GGML_TYPE_I32) {
+ if (ggml_is_view_op(t->op)) { continue; }
// ids
- std::vector<int> data(n_mats);
- for (int i = 0; i < n_mats; i++) {
- data[i] = i;
+ for (int64_t r = 0; r < ggml_nrows(t); r++) {
+ std::vector<int32_t> data(t->ne[0]);
+ for (int i = 0; i < t->ne[0]; i++) {
+ data[i] = i % n_mats;
+ }
+ std::shuffle(data.begin(), data.end(), rng);
+ ggml_backend_tensor_set(t, data.data(), r * t->nb[1], t->ne[0] * sizeof(int32_t));
}
- std::shuffle(data.begin(), data.end(), std::default_random_engine(std::random_device()()));
- ggml_backend_tensor_set(t, data.data(), 0, n_mats * sizeof(int));
} else {
init_tensor_uniform(t);
}
}
};
+// Mixtral MOE
+struct test_moe : public test_case {
+ const int n_experts;
+ const int n_experts_per_tok;
+ const int n_tokens;
+ const int n_embd;
+ const int n_ff;
+
+ std::string op_desc(ggml_tensor * t) override {
+ return "MOE";
+
+ GGML_UNUSED(t);
+ }
+
+ std::string vars() override {
+ return VARS_TO_STR5(n_experts, n_experts_per_tok, n_tokens, n_embd, n_ff);
+ }
+
+ test_moe(int n_experts = 8, int n_experts_per_tok = 2, int n_tokens = 1, int n_embd = 4096, int n_ff = 14336)
+ : n_experts(n_experts), n_experts_per_tok(n_experts_per_tok), n_tokens(n_tokens), n_embd(n_embd), n_ff(n_ff) {
+ }
+
+ ggml_tensor * build_graph(ggml_context * ctx) override {
+ ggml_tensor * ffn_gate_inp = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_experts);
+
+ std::vector<ggml_tensor *> ffn_up_exp(n_experts);
+ std::vector<ggml_tensor *> ffn_gate_exp(n_experts);
+ std::vector<ggml_tensor *> ffn_down_exp(n_experts);
+
+ for (int i = 0; i < n_experts; ++i) {
+ ffn_up_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
+ ffn_gate_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_ff);
+ ffn_down_exp[i] = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_ff, n_embd);
+ }
+
+ ggml_tensor * cur = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_tokens);
+
+ ggml_tensor * logits = ggml_mul_mat(ctx, ffn_gate_inp, cur);
+ ggml_tensor * probs = ggml_soft_max_ext(ctx, logits, nullptr, 1.0f/sqrtf(n_embd));
+
+ // select experts
+ ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_experts_per_tok);
+
+ ggml_tensor * weights = ggml_get_rows(ctx,
+ ggml_reshape_3d(ctx, probs, 1, n_experts, n_tokens), selected_experts);
+
+ weights = ggml_reshape_2d(ctx, weights, n_experts_per_tok, n_tokens);
+
+ ggml_tensor * weights_sum = ggml_sum_rows(ctx, weights);
+
+ weights = ggml_div(ctx, weights, weights_sum);
+
+ // compute expert outputs
+ ggml_tensor * moe_out = nullptr;
+
+ for (int i = 0; i < n_experts_per_tok; ++i) {
+ ggml_tensor * cur_expert;
+
+ ggml_tensor * cur_up = ggml_mul_mat_id(ctx, ffn_up_exp.data(), n_experts, selected_experts, i, cur);
+
+ ggml_tensor * cur_gate = ggml_mul_mat_id(ctx, ffn_gate_exp.data(), n_experts, selected_experts, i, cur);
+
+ cur_gate = ggml_silu(ctx, cur_gate);
+
+ cur_expert = ggml_mul(ctx, cur_up, cur_gate);
+
+ cur_expert = ggml_mul_mat_id(ctx, ffn_down_exp.data(), n_experts, selected_experts, i, cur_expert);
+
+ cur_expert = ggml_mul(ctx, cur_expert,
+ ggml_view_2d(ctx, weights, 1, n_tokens, weights->nb[1], i*weights->nb[0]));
+
+ if (i == 0) {
+ moe_out = cur_expert;
+ } else {
+ moe_out = ggml_add(ctx, moe_out, cur_expert);
+ }
+ }
+
+ cur = moe_out;
+
+ return cur;
+ }
+};
+
static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) {
std::vector<std::unique_ptr<test_case>> test_cases;
+ const ggml_type all_types[] = {
+ GGML_TYPE_F32, GGML_TYPE_F16,
+ GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
+ GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
+ GGML_TYPE_Q8_0,
+ GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
+ GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
+ GGML_TYPE_Q6_K
+ };
+
// unary ops
for (int op = 0; op < GGML_UNARY_OP_COUNT; op++) {
test_cases.emplace_back(new test_unary((ggml_unary_op) op));
}
- for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
- test_cases.emplace_back(new test_get_rows(type, 10, 5, 3));
- test_cases.emplace_back(new test_get_rows(type, 16, 5, 3));
+ test_cases.emplace_back(new test_get_rows(GGML_TYPE_F32, 1, 8, 2, 1, false));
+ for (ggml_type type : all_types) {
+ for (int b : {1, 7}) {
+ for (bool v : {false, true}) {
+ test_cases.emplace_back(new test_get_rows(type, 256, 5, 4, b, v));
+ }
+ }
}
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1}));
test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 2}));
test_cases.emplace_back(new test_dup());
- test_cases.emplace_back(new test_cpy());
+
+ for (ggml_type type : all_types) {
+ test_cases.emplace_back(new test_cpy(GGML_TYPE_F32, type, {256, 10, 10, 1}));
+ }
+
test_cases.emplace_back(new test_cont());
auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr) {
};
add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 8, 1}, {1, 1, 1, 1});
+ add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1, 1}, {32, 1, 1, 1});
add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 320, 320}, {1, 1, 1, 1});
add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 1, 1}, {1, 1, 1, 1});
add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 1}, {1, 1, 1, 1});
add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 640, 1}, {32, 32, 1, 1});
add_test_bin_bcast(GGML_TYPE_F32, {5120, 1, 1, 1}, {1, 256, 1, 1});
add_test_bin_bcast(GGML_TYPE_F32, {640, 1, 1, 1}, {1, 1, 1, 1});
- add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {1, 1, 1, 1});
- add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {2, 1, 1, 1});
+ //add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {1, 1, 1, 1});
+ //add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {2, 1, 1, 1});
test_cases.emplace_back(new test_scale());
test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps));
}
- const ggml_type all_types[] = {
- GGML_TYPE_F32, GGML_TYPE_F16,
- GGML_TYPE_Q4_0, GGML_TYPE_Q4_1,
- GGML_TYPE_Q5_0, GGML_TYPE_Q5_1,
- GGML_TYPE_Q8_0,
- GGML_TYPE_Q2_K, GGML_TYPE_Q3_K,
- GGML_TYPE_Q4_K, GGML_TYPE_Q5_K,
- GGML_TYPE_Q6_K
- };
-
for (ggml_type type_a : all_types) {
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
// FIXME: CPU crashes on f16xf16
for (ggml_type type_a : all_types) {
for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) {
- for (int n_mats : {1, 2, 4}) {
+ for (int n_mats : {2, 4, 8}) {
for (int id = 0; id < n_mats; id++) {
- test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 16, 256, {1, 1}, {1, 1}));
+ for (bool v : {false, true}) {
+ test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 16, 256, v));
+ }
}
}
}
test_cases.emplace_back(new test_concat());
for (ggml_sort_order order : {GGML_SORT_ASC, GGML_SORT_DESC}) {
+ test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {8, 1, 1, 1}, order));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order));
}
test_cases.emplace_back(new test_pad());
test_cases.emplace_back(new test_leaky_relu());
+#if !defined(__SANITIZE_THREAD__)
+ // FIXME: these tests use too much memory with thread sanitizer
+ test_cases.emplace_back(new test_moe(8, 2, 1, 4096, 8*1024));
+ //test_cases.emplace_back(new test_moe(8, 2, 8, 4096, 14336));
+#endif
+
// run tests
if (mode == MODE_TEST) {
ggml_backend_t backend_cpu = ggml_backend_cpu_init();
ggml_backend_free(backend_cpu);
return n_ok == test_cases.size();
- } else if (mode == MODE_PERF) {
+ }
+
+ if (mode == MODE_PERF) {
for (auto & test : test_cases) {
test->eval_perf(backend, op_name);
}
return true;
- } else {
- GGML_ASSERT(false);
}
+
+ GGML_ASSERT(false);
+ return false;
}
static void usage(char ** argv) {
}
printf("%zu/%zu backends passed\n", n_ok, ggml_backend_reg_get_count());
+
if (n_ok != ggml_backend_reg_get_count()) {
printf("\033[1;31mFAIL\033[0m\n");
return 1;
- } else {
- printf("\033[1;32mOK\033[0m\n");
- return 0;
}
+
+ printf("\033[1;32mOK\033[0m\n");
+ return 0;
}