// ggml-backend v2 API
//
-// Seperate tensor and graph allocator objects
+// Separate tensor and graph allocator objects
// This is necessary for multi-backend allocation because the graph allocator needs to use multiple tensor allocators
// The original API is kept as a wrapper around the new API
#include <algorithm>
+#include <assert.h>
+#include <atomic>
+#include <cinttypes>
#include <cstddef>
#include <cstdint>
-#include <cinttypes>
#include <float.h>
#include <limits>
#include <stdint.h>
#include <stdio.h>
-#include <atomic>
-#include <assert.h>
+#include <vector>
+
#if defined(GGML_USE_HIPBLAS)
#include <hip/hip_runtime.h>
#define CUDA_GELU_BLOCK_SIZE 256
#define CUDA_SILU_BLOCK_SIZE 256
+#define CUDA_TANH_BLOCK_SIZE 256
#define CUDA_RELU_BLOCK_SIZE 256
#define CUDA_SQR_BLOCK_SIZE 256
#define CUDA_CPY_BLOCK_SIZE 32
#define CUDA_QUANTIZE_BLOCK_SIZE 256
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
#define CUDA_GET_ROWS_BLOCK_SIZE 256
+#define CUDA_UPSCALE_BLOCK_SIZE 256
+#define CUDA_CONCAT_BLOCK_SIZE 256
+#define CUDA_PAD_BLOCK_SIZE 256
+#define CUDA_ACC_BLOCK_SIZE 256
+#define CUDA_IM2COL_BLOCK_SIZE 256
// dmmv = dequantize_mul_mat_vec
#ifndef GGML_CUDA_DMMV_X
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
}
+static __global__ void acc_f32(const float * x, const float * y, float * dst, const int ne,
+ const int ne10, const int ne11, const int ne12,
+ const int nb1, const int nb2, int offset) {
+ const int i = blockDim.x * blockIdx.x + threadIdx.x;
+ if (i >= ne) {
+ return;
+ }
+ int src1_idx = i - offset;
+ int oz = src1_idx / nb2;
+ int oy = (src1_idx - (oz * nb2)) / nb1;
+ int ox = src1_idx % nb1;
+ if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) {
+ dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11];
+ } else {
+ dst[i] = x[i];
+ }
+}
+
static __global__ void gelu_f32(const float * x, float * dst, const int k) {
const float GELU_COEF_A = 0.044715f;
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
dst[i] = x[i] / (1.0f + expf(-x[i]));
}
+static __global__ void gelu_quick_f32(const float *x, float *dst, int k) {
+ const float GELU_QUICK_COEF = -1.702f;
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+ if (i >= k) {
+ return;
+ }
+ dst[i] = x[i] * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x[i])));
+}
+
+static __global__ void tanh_f32(const float *x, float *dst, int k) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+ if (i >= k) {
+ return;
+ }
+ dst[i] = tanhf(x[i]);
+}
+
static __global__ void relu_f32(const float * x, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
dst[i] = fmaxf(x[i], 0);
}
+static __global__ void leaky_relu_f32(const float *x, float *dst, const int k, const float negative_slope) {
+ const int i = blockDim.x*blockIdx.x + threadIdx.x;
+ if (i >= k) {
+ return;
+ }
+ dst[i] = fmaxf(x[i], 0) + fminf(x[i], 0.0f) * negative_slope;
+}
+
static __global__ void sqr_f32(const float * x, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
}
}
+static __global__ void concat_f32(const float *x,const float *y, float *dst, const int ne0, const int ne02) {
+ int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+ if (nidx >= ne0) {
+ return;
+ }
+ // operation
+ int offset_dst =
+ nidx +
+ blockIdx.y * ne0 +
+ blockIdx.z * ne0 * gridDim.y;
+ if (blockIdx.z < ne02) { // src0
+ int offset_src =
+ nidx +
+ blockIdx.y * ne0 +
+ blockIdx.z * ne0 * gridDim.y;
+ dst[offset_dst] = x[offset_src];
+ } else {
+ int offset_src =
+ nidx +
+ blockIdx.y * ne0 +
+ (blockIdx.z - ne02) * ne0 * gridDim.y;
+ dst[offset_dst] = y[offset_src];
+ }
+}
+
+static __global__ void upscale_f32(const float *x, float *dst, const int ne00, const int nb02, const int scale_factor) {
+ int ne0 = ne00 * scale_factor;
+ int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+ if (nidx >= ne0) {
+ return;
+ }
+ // operation
+ int i00 = nidx / scale_factor;
+ int i01 = blockIdx.y / scale_factor;
+ int offset_src =
+ i00 +
+ i01 * ne00 +
+ blockIdx.z * nb02;
+ int offset_dst =
+ nidx +
+ blockIdx.y * ne0 +
+ blockIdx.z * ne0 * gridDim.y;
+ dst[offset_dst] = x[offset_src];
+}
+
+static __global__ void pad_f32(const float *x, float *dst, const int ne0, const int ne00, const int ne01, const int ne02) {
+ int nidx = threadIdx.x + blockIdx.x * blockDim.x;
+ if (nidx >= ne0) {
+ return;
+ }
+
+ // operation
+ int offset_dst =
+ nidx +
+ blockIdx.y * ne0 +
+ blockIdx.z * ne0 * gridDim.y;
+ if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02) {
+ int offset_src =
+ nidx +
+ blockIdx.y * ne00 +
+ blockIdx.z * ne00 * ne01;
+ dst[offset_dst] = x[offset_src];
+ } else {
+ dst[offset_dst] = 0.0f;
+ }
+}
+
+template <int block_size>
+static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
+ int start = blockIdx.x * group_size;
+ int end = start + group_size;
+
+ start += threadIdx.x;
+
+ if (end >= ne_elements) {
+ end = ne_elements;
+ }
+
+ float tmp = 0.0f; // partial sum for thread in warp
+
+ for (int j = start; j < end; j += block_size) {
+ tmp += x[j];
+ }
+
+ tmp = warp_reduce_sum(tmp);
+ if (block_size > WARP_SIZE) {
+ __shared__ float s_sum[32];
+ int warp_id = threadIdx.x / WARP_SIZE;
+ int lane_id = threadIdx.x % WARP_SIZE;
+ if (lane_id == 0) {
+ s_sum[warp_id] = tmp;
+ }
+ __syncthreads();
+ tmp = s_sum[lane_id];
+ tmp = warp_reduce_sum(tmp);
+ }
+
+ float mean = tmp / group_size;
+ tmp = 0.0f;
+
+ for (int j = start; j < end; j += block_size) {
+ float xi = x[j] - mean;
+ dst[j] = xi;
+ tmp += xi * xi;
+ }
+
+ tmp = warp_reduce_sum(tmp);
+ if (block_size > WARP_SIZE) {
+ __shared__ float s_sum[32];
+ int warp_id = threadIdx.x / WARP_SIZE;
+ int lane_id = threadIdx.x % WARP_SIZE;
+ if (lane_id == 0) {
+ s_sum[warp_id] = tmp;
+ }
+ __syncthreads();
+ tmp = s_sum[lane_id];
+ tmp = warp_reduce_sum(tmp);
+ }
+
+ float variance = tmp / group_size;
+ float scale = rsqrtf(variance + eps);
+ for (int j = start; j < end; j += block_size) {
+ dst[j] *= scale;
+ }
+}
+
template <int block_size>
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
}
template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
-static __global__ void k_get_rows(const void * x, const int32_t * y, dst_t * dst, const int ncols) {
- const int col = (blockIdx.x*blockDim.x + threadIdx.x)*2;
- const int row = blockDim.y*blockIdx.y + threadIdx.y;
-
- if (col >= ncols) {
+static __global__ void k_get_rows(
+ const void * src0, const int32_t * src1, dst_t * dst,
+ int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
+ /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
+ /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
+ /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
+ size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
+
+ const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2;
+ const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
+ const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
+ const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
+
+ if (i00 >= ne00) {
return;
}
- const int r = y[row];
+ const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
- // copy x[r*ncols + col] to dst[row*ncols + col]
- const int xi = r*ncols + col;
- const int di = row*ncols + col;
+ dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+ const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
- const int ib = xi/qk; // block index
- const int iqs = (xi%qk)/qr; // quant index
- const int iybs = di - di%qk; // y block start index
+ const int ib = i00/qk; // block index
+ const int iqs = (i00%qk)/qr; // quant index
+ const int iybs = i00 - i00%qk; // dst block start index
const int y_offset = qr == 1 ? 1 : qk/2;
// dequantize
dfloat2 v;
- dequantize_kernel(x, ib, iqs, v);
+ dequantize_kernel(src0_row, ib, iqs, v);
+
+ dst_row[iybs + iqs + 0] = v.x;
+ dst_row[iybs + iqs + y_offset] = v.y;
+}
+
+template<typename src0_t, typename dst_t>
+static __global__ void k_get_rows_float(
+ const src0_t * src0, const int32_t * src1, dst_t * dst,
+ int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
+ /*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
+ /*size_t s0,*/ size_t s1, size_t s2, size_t s3,
+ /*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
+ size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
+
+ const int i00 = blockIdx.x*blockDim.x + threadIdx.x;
+ const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
+ const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
+ const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
+
+ if (i00 >= ne00) {
+ return;
+ }
+
+ const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
+
+ dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
+ const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
- dst[iybs + iqs + 0] = v.x;
- dst[iybs + iqs + y_offset] = v.y;
+ dst_row[i00] = src0_row[i00];
}
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static __global__ void im2col_f32_f16(
const float * x, half * dst,
- int ofs0, int ofs1, int IW, int IH, int CHW,
+ int offset_delta, int IW, int IH, int OW, int KW, int KH, int pelements, int CHW,
int s0, int s1, int p0, int p1, int d0, int d1) {
- const int iiw = blockIdx.z * s0 + threadIdx.z * d0 - p0;
- const int iih = blockIdx.y * s1 + threadIdx.y * d1 - p1;
+ const int i = threadIdx.x + blockIdx.x * blockDim.x;
+ if (i >= pelements) {
+ return;
+ }
+
+ const int ksize = OW * (KH > 1 ? KW : 1);
+ const int kx = i / ksize;
+ const int kd = kx * ksize;
+ const int ky = (i - kd) / OW;
+ const int ix = i % OW;
+
+ const int iiw = ix * s0 + kx * d0 - p0;
+ const int iih = blockIdx.y * s1 + ky * d1 - p1;
const int offset_dst =
- (threadIdx.x * gridDim.y * gridDim.z + blockIdx.y * gridDim.z + blockIdx.z) * CHW +
- (blockIdx.x * (blockDim.y * blockDim.z) + threadIdx.y * blockDim.z + threadIdx.z);
+ (blockIdx.y * OW + ix) * CHW +
+ (blockIdx.z * (KW * KH) + ky * KW + kx);
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst[offset_dst] = __float2half(0.0f);
} else {
- const int offset_src = threadIdx.x * ofs0 + blockIdx.x * ofs1;
+ const int offset_src = blockIdx.z * offset_delta;
dst[offset_dst] = __float2half(x[offset_src + iih * IW + iiw]);
}
}
template<int qk, int qr, dequantize_kernel_t dq>
-static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
+static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
- const int block_num_x = (ncols + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
- const dim3 block_nums(block_num_x, nrows, 1);
- k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols);
+ const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
+ const dim3 block_nums(block_num_x, ne10, ne11*ne12);
+
+ // strides in elements
+ //const size_t s0 = nb0 / ggml_element_size(dst);
+ const size_t s1 = nb1 / ggml_element_size(dst);
+ const size_t s2 = nb2 / ggml_element_size(dst);
+ const size_t s3 = nb3 / ggml_element_size(dst);
+
+ const size_t s10 = nb10 / ggml_element_size(src1);
+ const size_t s11 = nb11 / ggml_element_size(src1);
+ const size_t s12 = nb12 / ggml_element_size(src1);
+ //const size_t s13 = nb13 / ggml_element_size(src1);
+
+ GGML_ASSERT(ne00 % 2 == 0);
+
+ k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
+ src0_dd, src1_dd, dst_dd,
+ ne00, /*ne01, ne02, ne03,*/
+ /*ne10, ne11,*/ ne12, /*ne13,*/
+ /* s0,*/ s1, s2, s3,
+ /* nb00,*/ nb01, nb02, nb03,
+ s10, s11, s12/*, s13*/);
+
+ (void) dst;
+}
+
+template<typename src0_t>
+static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
+
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
+ const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
+ const dim3 block_nums(block_num_x, ne10, ne11*ne12);
+
+ // strides in elements
+ //const size_t s0 = nb0 / ggml_element_size(dst);
+ const size_t s1 = nb1 / ggml_element_size(dst);
+ const size_t s2 = nb2 / ggml_element_size(dst);
+ const size_t s3 = nb3 / ggml_element_size(dst);
+
+ const size_t s10 = nb10 / ggml_element_size(src1);
+ const size_t s11 = nb11 / ggml_element_size(src1);
+ const size_t s12 = nb12 / ggml_element_size(src1);
+ //const size_t s13 = nb13 / ggml_element_size(src1);
+
+ k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
+ src0_dd, src1_dd, dst_dd,
+ ne00, /*ne01, ne02, ne03,*/
+ /*ne10, ne11,*/ ne12, /*ne13,*/
+ /* s0,*/ s1, s2, s3,
+ /* nb00,*/ nb01, nb02, nb03,
+ s10, s11, s12/*, s13*/);
+
+ (void) dst;
}
template<float (*bin_op)(const float, const float)>
GGML_TENSOR_BINARY_OP_LOCALS
-
int nr0 = ne10/ne0;
int nr1 = ne11/ne1;
int nr2 = ne12/ne2;
int64_t ne12 = cne1[2];
int64_t ne13 = cne1[3];
- //size_t nb0 = cnb0[0];
+ size_t nb0 = cnb0[0];
size_t nb1 = cnb0[1];
size_t nb2 = cnb0[2];
size_t nb3 = cnb0[3];
- //size_t nb10 = cnb1[0];
+ size_t nb10 = cnb1[0];
size_t nb11 = cnb1[1];
size_t nb12 = cnb1[2];
size_t nb13 = cnb1[3];
- //size_t s0 = nb0 / sizeof(src1_t);
- size_t s1 = nb1 / sizeof(src1_t);
- size_t s2 = nb2 / sizeof(src1_t);
- size_t s3 = nb3 / sizeof(src1_t);
+ size_t s0 = nb0 / sizeof(dst_t);
+ size_t s1 = nb1 / sizeof(dst_t);
+ size_t s2 = nb2 / sizeof(dst_t);
+ size_t s3 = nb3 / sizeof(dst_t);
- //size_t s10 = nb10 / sizeof(src1_t);
+ size_t s10 = nb10 / sizeof(src1_t);
size_t s11 = nb11 / sizeof(src1_t);
size_t s12 = nb12 / sizeof(src1_t);
size_t s13 = nb13 / sizeof(src1_t);
+ GGML_ASSERT(s0 == 1);
+ GGML_ASSERT(s10 == 1);
const int block_size = 128;
}
};
+static void acc_f32_cuda(const float * x, const float * y, float * dst, const int n_elements,
+ const int ne10, const int ne11, const int ne12,
+ const int nb1, const int nb2, const int offset, cudaStream_t stream) {
+ int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE;
+ acc_f32<<<num_blocks, CUDA_ACC_BLOCK_SIZE, 0, stream>>>(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset);
+}
+
static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
+static void gelu_quick_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
+ gelu_quick_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
+static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE;
+ tanh_f32<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
+}
+
static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
+static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
+ const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
+ leaky_relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
+}
+
static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE;
sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
}
+static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const int group_size, const int ne_elements, cudaStream_t stream) {
+ static const float eps = 1e-6f;
+ if (group_size < 1024) {
+ const dim3 block_dims(WARP_SIZE, 1, 1);
+ group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
+ } else {
+ const dim3 block_dims(1024, 1, 1);
+ group_norm_f32<1024><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
+ }
+}
+
+static void concat_f32_cuda(const float * x, const float * y, float * dst, const int ne0, int ne1, int ne2, int ne02, cudaStream_t stream) {
+ int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE;
+ dim3 gridDim(num_blocks, ne1, ne2);
+ concat_f32<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
+}
+
+static void upscale_f32_cuda(const float * x, float * dst, const int ne00, const int ne01, const int ne02, const int scale_factor, cudaStream_t stream) {
+ int ne0 = (ne00 * scale_factor);
+ int num_blocks = (ne0 + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
+ dim3 gridDim(num_blocks, (ne01 * scale_factor), ne02);
+ upscale_f32<<<gridDim, CUDA_UPSCALE_BLOCK_SIZE, 0, stream>>>(x, dst, ne00, ne00 * ne01, scale_factor);
+}
+
+static void pad_f32_cuda(const float * x, float * dst,
+ const int ne00, const int ne01, const int ne02,
+ const int ne0, const int ne1, const int ne2, cudaStream_t stream) {
+ int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
+ dim3 gridDim(num_blocks, ne1, ne2);
+ pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02);
+}
+
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
if (ncols < 1024) {
soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
}
-static void im2col_f32_f16_cuda(const float * x, half * dst,
- int OH, int IW, int IH, int OW, int IC,
- int KH, int KW, int N, int ofs0, int ofs1,
- int s0, int s1, int p0, int p1, int d0, int d1, cudaStream_t stream) {
- dim3 block_nums(IC, OH, OW);
- dim3 block_dims(N, KH, KW);
- im2col_f32_f16<<<block_nums, block_dims, 0, stream>>>(x, dst, ofs0, ofs1, IW, IH, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
+static void im2col_f32_f16_cuda(const float* x, half* dst,
+ int IW, int IH, int OW, int OH, int KW, int KH, int IC,
+ int offset_delta,
+ int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
+ const int parallel_elements = OW * KW * KH;
+ const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
+ dim3 block_nums(num_blocks, OH, IC);
+ im2col_f32_f16<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, offset_delta, IW, IH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
}
// buffer pool for cuda
GGML_ASSERT(src1->type == GGML_TYPE_I32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
- GGML_ASSERT(ggml_is_contiguous(src0));
- GGML_ASSERT(ggml_is_contiguous(src1));
- GGML_ASSERT(ggml_is_contiguous(dst));
- const int ncols = src0->ne[0];
- const int nrows = ggml_nelements(src1);
+ GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
+ GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
+ GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
const int32_t * src1_i32 = (const int32_t *) src1_d;
switch (src0->type) {
case GGML_TYPE_F16:
- get_rows_cuda<1, 1, convert_f16>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+ get_rows_cuda_float(src0, src1, dst, (const half *)src0_d, src1_i32, dst_d, stream);
break;
case GGML_TYPE_F32:
- get_rows_cuda<1, 1, convert_f32>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+ get_rows_cuda_float(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
break;
case GGML_TYPE_Q4_0:
- get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+ get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
break;
case GGML_TYPE_Q4_1:
- get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+ get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
break;
case GGML_TYPE_Q5_0:
- get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+ get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
break;
case GGML_TYPE_Q5_1:
- get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+ get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
break;
case GGML_TYPE_Q8_0:
- get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0_d, src1_i32, dst_d, nrows, ncols, stream);
+ get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
break;
default:
// TODO: k-quants
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream);
}
+inline void ggml_cuda_op_acc(
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
+
+ int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
+ int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
+ // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
+ int offset = dst->op_params[3] / 4; // offset in bytes
+
+ acc_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, main_stream);
+
+ (void) dst;
+}
+
inline void ggml_cuda_op_mul(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
(void) src1_dd;
}
+inline void ggml_cuda_op_gelu_quick(
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ gelu_quick_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+ (void) src1;
+ (void) dst;
+ (void) src1_dd;
+}
+
+inline void ggml_cuda_op_tanh(
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ tanh_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), main_stream);
+
+ (void) src1;
+ (void) dst;
+ (void) src1_dd;
+}
+
inline void ggml_cuda_op_relu(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
(void) src1_dd;
}
+inline void ggml_cuda_op_leaky_relu(
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ float negative_slope;
+ memcpy(&negative_slope, dst->op_params, sizeof(float));
+
+ leaky_relu_f32_cuda(src0_dd, dst_dd, ggml_nelements(src0), negative_slope, main_stream);
+
+ (void) src1;
+ (void) dst;
+ (void) src1_dd;
+}
+
inline void ggml_cuda_op_sqr(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
(void) src1_dd;
}
+
+inline void ggml_cuda_op_group_norm(
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+ int num_groups = dst->op_params[0];
+ int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
+ group_norm_f32_cuda(src0_dd, dst_dd, num_groups, group_size, src0->ne[0] * src0->ne[1] * src0->ne[2], main_stream);
+
+ (void) src1;
+ (void) dst;
+ (void) src1_dd;
+}
+
+inline void ggml_cuda_op_concat(
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(src1->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+
+ for (int i3 = 0; i3 < dst->ne[3]; i3++) {
+ concat_f32_cuda(src0_dd + i3 * (src0->nb[3] / 4), src1_dd + i3 * (src1->nb[3] / 4), dst_dd + i3 * (dst->nb[3] / 4), dst->ne[0], dst->ne[1], dst->ne[2], src0->ne[2], main_stream);
+ }
+
+ (void) src1;
+ (void) dst;
+}
+
+inline void ggml_cuda_op_upscale(
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
+
+ const int scale_factor = dst->op_params[0];
+
+ upscale_f32_cuda(src0_dd, dst_dd, src0->ne[0], src0->ne[1], src0->ne[2], scale_factor, main_stream);
+
+ (void) src1;
+ (void) dst;
+}
+
+inline void ggml_cuda_op_pad(
+ const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+ const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+ GGML_ASSERT(dst->type == GGML_TYPE_F32);
+ GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
+
+ pad_f32_cuda(src0_dd, dst_dd,
+ src0->ne[0], src0->ne[1], src0->ne[2],
+ dst->ne[0], dst->ne[1], dst->ne[2], main_stream);
+
+ (void) src1;
+ (void) dst;
+}
+
inline void ggml_cuda_op_rms_norm(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
- const int64_t N = src1->ne[is_2D ? 3 : 2];
const int64_t IC = src1->ne[is_2D ? 2 : 1];
const int64_t IH = is_2D ? src1->ne[1] : 1;
const int64_t IW = src1->ne[0];
const int64_t OH = is_2D ? dst->ne[2] : 1;
const int64_t OW = dst->ne[1];
- const size_t ofs0 = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32
- const size_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
+ const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
- im2col_f32_f16_cuda(src1_dd, (half*) dst_dd,
- OH, IW, IH, OW, IC, KH, KW, N,
- ofs0, ofs1, s0, s1, p0, p1, d0, d1, main_stream);
+ im2col_f32_f16_cuda(src1_dd, (half*) dst_dd, IW, IH, OW, OH, KW, KH, IC, delta_offset, s0, s1, p0, p1, d0, d1, main_stream);
(void) src0;
(void) src0_dd;
}
+
inline void ggml_cuda_op_sum_rows(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_add);
}
+static void ggml_cuda_acc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_acc);
+}
+
static void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_mul);
}
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_silu);
}
+static void ggml_cuda_gelu_quick(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu_quick);
+}
+
+static void ggml_cuda_tanh(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_tanh);
+}
+
static void ggml_cuda_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_relu);
}
+static void ggml_cuda_leaky_relu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_leaky_relu);
+}
+
static void ggml_cuda_sqr(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_sqr);
}
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_norm);
}
+static void ggml_cuda_group_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_group_norm);
+}
+
+static void ggml_cuda_concat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_concat);
+}
+
+static void ggml_cuda_upscale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_upscale);
+}
+
+static void ggml_cuda_pad(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+ ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_pad);
+}
+
static void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_rms_norm);
}
}
#endif
-static void ggml_cuda_mul_mat_id(const ggml_tensor * _src0, const ggml_tensor * _src1, ggml_tensor * dst) {
+static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
#if 0
-//#ifdef CUDA_USE_TENSOR_CORES
-// const bool use_tensor_cores = true;
-//#else
-// const bool use_tensor_cores = false;
-//#endif
-
ggml_cuda_mul_mat_id_cublas(dst);
-
// TODO: mmq/mmv support
-#else
- const struct ggml_tensor * ids = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
- const int id = dst->op_params[0];
+#endif
- int32_t * ids_dev = (int32_t *)((ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
+ GGML_ASSERT(dst->backend == GGML_BACKEND_GPU);
- int32_t a_id;
- CUDA_CHECK(cudaMemcpyAsync(&a_id, ids_dev + id, sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
- CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
+ const struct ggml_tensor * ids = src0;
+ const int32_t id = ((int32_t *) dst->op_params)[0];
+ const int32_t n_as = ((int32_t *) dst->op_params)[1];
- GGML_ASSERT(a_id >= 0 && a_id < ids->ne[0]);
- const struct ggml_tensor * src0 = dst->src[a_id + 2];
+ std::vector<char> ids_host(ggml_nbytes(ids));
- ggml_cuda_mul_mat(src0, src1, dst);
-#endif
+ if (ids->backend == GGML_BACKEND_GPU) {
+ const char * ids_dev = (const char *)((const ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device];
+ CUDA_CHECK(cudaMemcpyAsync(ids_host.data(), ids_dev, ggml_nbytes(ids), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
+ CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
+ } else {
+ memcpy(ids_host.data(), ids->data, ggml_nbytes(ids));
+ }
+
+ const ggml_tensor_extra_gpu * src1_extra = (const ggml_tensor_extra_gpu *) src1->extra;
+ const ggml_tensor_extra_gpu * dst_extra = (const ggml_tensor_extra_gpu *) dst->extra;
+
+ ggml_tensor_extra_gpu src1_row_extra;
+ ggml_tensor_extra_gpu dst_row_extra;
+
+ ggml_tensor src1_row = *src1;
+ ggml_tensor dst_row = *dst;
+
+ src1_row.ne[1] = 1;
+ dst_row.ne[1] = 1;
+
+ src1_row.nb[2] = src1_row.nb[1];
+ dst_row.nb[2] = dst_row.nb[1];
+
+ src1_row.nb[3] = src1_row.nb[1];
+ dst_row.nb[3] = dst_row.nb[1];
+
+ src1_row.extra = &src1_row_extra;
+ dst_row.extra = &dst_row_extra;
+
+
+ for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
+ //int32_t row_id;
+ //CUDA_CHECK(cudaMemcpyAsync(&row_id, ids_dev + i01*ids->nb[1] + id*ids->nb[0], sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0]));
+ //CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0]));
+
+ const int32_t row_id = *(const int32_t *) (ids_host.data() + i01*ids->nb[1] + id*ids->nb[0]);
+
+ GGML_ASSERT(row_id >= 0 && row_id < n_as);
- (void) _src0;
- (void) _src1;
+ const struct ggml_tensor * src0_row = dst->src[row_id + 2];
+
+ src1_row_extra.data_device[g_main_device] = (char *) src1_extra->data_device[g_main_device] + i01*src1->nb[1];
+ src1_row.data = (char *) src1->data + i01*src1->nb[1];
+
+ dst_row_extra.data_device[g_main_device] = (char *) dst_extra->data_device[g_main_device] + i01*dst->nb[1];
+ dst_row.data = (char *) dst->data + i01*dst->nb[1];
+
+ ggml_cuda_mul_mat(src0_row, &src1_row, &dst_row);
+ }
}
static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
case GGML_OP_ADD:
func = ggml_cuda_add;
break;
+ case GGML_OP_ACC:
+ func = ggml_cuda_acc;
+ break;
case GGML_OP_MUL:
func = ggml_cuda_mul;
break;
case GGML_UNARY_OP_SILU:
func = ggml_cuda_silu;
break;
+ case GGML_UNARY_OP_GELU_QUICK:
+ func = ggml_cuda_gelu_quick;
+ break;
+ case GGML_UNARY_OP_TANH:
+ func = ggml_cuda_tanh;
+ break;
case GGML_UNARY_OP_RELU:
func = ggml_cuda_relu;
break;
case GGML_OP_NORM:
func = ggml_cuda_norm;
break;
+ case GGML_OP_GROUP_NORM:
+ func = ggml_cuda_group_norm;
+ break;
+ case GGML_OP_CONCAT:
+ func = ggml_cuda_concat;
+ break;
+ case GGML_OP_UPSCALE:
+ func = ggml_cuda_upscale;
+ break;
+ case GGML_OP_PAD:
+ func = ggml_cuda_pad;
+ break;
+ case GGML_OP_LEAKY_RELU:
+ func = ggml_cuda_leaky_relu;
+ break;
case GGML_OP_RMS_NORM:
func = ggml_cuda_rms_norm;
break;
func = ggml_cuda_sqr;
break;
case GGML_OP_CLAMP:
- if (!any_on_device) {
- return false;
- }
func = ggml_cuda_clamp;
break;
case GGML_OP_CPY:
case GGML_OP_CONT:
func = ggml_cuda_dup;
break;
+ case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
case GGML_UNARY_OP_GELU:
case GGML_UNARY_OP_SILU:
case GGML_UNARY_OP_RELU:
+ case GGML_UNARY_OP_GELU_QUICK:
+ case GGML_UNARY_OP_TANH:
return true;
default:
return false;
}
return true;
} break;
+ case GGML_OP_GET_ROWS:
+ {
+ switch (op->src[0]->type) {
+ case GGML_TYPE_F16:
+ case GGML_TYPE_F32:
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ case GGML_TYPE_Q5_0:
+ case GGML_TYPE_Q5_1:
+ case GGML_TYPE_Q8_0:
+ return true;
+ default:
+ return false;
+ }
+ } break;
+ case GGML_OP_CPY:
+ {
+ ggml_type src0_type = op->src[0]->type;
+ ggml_type src1_type = op->src[1]->type;
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q8_0) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_0) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_Q4_1) {
+ return true;
+ }
+ if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) {
+ return true;
+ }
+ return false;
+ } break;
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_TRANSPOSE:
case GGML_OP_NORM:
case GGML_OP_REPEAT:
- case GGML_OP_GET_ROWS:
case GGML_OP_DUP:
case GGML_OP_ADD:
case GGML_OP_MUL:
case GGML_OP_SCALE:
case GGML_OP_SQR:
case GGML_OP_CLAMP:
- case GGML_OP_CPY:
case GGML_OP_CONT:
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX:
case GGML_OP_IM2COL:
case GGML_OP_SUM_ROWS:
case GGML_OP_ARGSORT:
+ case GGML_OP_ACC:
+ case GGML_OP_CONCAT:
+ case GGML_OP_GROUP_NORM:
+ case GGML_OP_UPSCALE:
+ case GGML_OP_PAD:
+ case GGML_OP_LEAKY_RELU:
return true;
default:
return false;
UNUSED(params);
}
-extern "C" int ggml_backend_cuda_reg_devices() {
+extern "C" int ggml_backend_cuda_reg_devices();
+
+int ggml_backend_cuda_reg_devices() {
int device_count = ggml_cuda_get_device_count();
//int device_count = 1; // DEBUG: some tools require delaying CUDA initialization
for (int i = 0; i < device_count; i++) {
GGML_METAL_DECL_KERNEL(div_row);
GGML_METAL_DECL_KERNEL(scale);
GGML_METAL_DECL_KERNEL(scale_4);
- GGML_METAL_DECL_KERNEL(silu);
+ GGML_METAL_DECL_KERNEL(tanh);
GGML_METAL_DECL_KERNEL(relu);
GGML_METAL_DECL_KERNEL(gelu);
+ GGML_METAL_DECL_KERNEL(gelu_quick);
+ GGML_METAL_DECL_KERNEL(silu);
GGML_METAL_DECL_KERNEL(soft_max);
GGML_METAL_DECL_KERNEL(soft_max_4);
GGML_METAL_DECL_KERNEL(diag_mask_inf);
GGML_METAL_DECL_KERNEL(get_rows_q5_K);
GGML_METAL_DECL_KERNEL(get_rows_q6_K);
GGML_METAL_DECL_KERNEL(rms_norm);
+ GGML_METAL_DECL_KERNEL(group_norm);
GGML_METAL_DECL_KERNEL(norm);
GGML_METAL_DECL_KERNEL(mul_mv_f32_f32);
GGML_METAL_DECL_KERNEL(mul_mv_f16_f16);
GGML_METAL_DECL_KERNEL(mul_mv_q4_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mv_q6_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_f32_f32);
+ //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f16);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32);
+ //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_1row);
+ //GGML_METAL_DECL_KERNEL(mul_mv_id_f16_f32_l4);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q4_0_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q4_1_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q5_0_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q5_1_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q8_0_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q2_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q3_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q4_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q5_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mv_id_q6_K_f32);
GGML_METAL_DECL_KERNEL(mul_mm_f32_f32);
GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
GGML_METAL_DECL_KERNEL(rope_f16);
GGML_METAL_DECL_KERNEL(alibi_f32);
GGML_METAL_DECL_KERNEL(im2col_f16);
+ GGML_METAL_DECL_KERNEL(upscale_f32);
+ GGML_METAL_DECL_KERNEL(pad_f32);
GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc);
GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc);
+ GGML_METAL_DECL_KERNEL(leaky_relu_f32);
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
GGML_METAL_DECL_KERNEL(cpy_f32_q8_0);
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_0);
//GGML_METAL_DECL_KERNEL(cpy_f32_q5_1);
GGML_METAL_DECL_KERNEL(cpy_f16_f16);
+ GGML_METAL_DECL_KERNEL(cpy_f16_f32);
GGML_METAL_DECL_KERNEL(concat);
GGML_METAL_DECL_KERNEL(sqr);
GGML_METAL_DECL_KERNEL(sum_rows);
GGML_METAL_ADD_KERNEL(div_row);
GGML_METAL_ADD_KERNEL(scale);
GGML_METAL_ADD_KERNEL(scale_4);
- GGML_METAL_ADD_KERNEL(silu);
+ GGML_METAL_ADD_KERNEL(tanh);
GGML_METAL_ADD_KERNEL(relu);
GGML_METAL_ADD_KERNEL(gelu);
+ GGML_METAL_ADD_KERNEL(gelu_quick);
+ GGML_METAL_ADD_KERNEL(silu);
GGML_METAL_ADD_KERNEL(soft_max);
GGML_METAL_ADD_KERNEL(soft_max_4);
GGML_METAL_ADD_KERNEL(diag_mask_inf);
GGML_METAL_ADD_KERNEL(get_rows_q5_K);
GGML_METAL_ADD_KERNEL(get_rows_q6_K);
GGML_METAL_ADD_KERNEL(rms_norm);
+ GGML_METAL_ADD_KERNEL(group_norm);
GGML_METAL_ADD_KERNEL(norm);
GGML_METAL_ADD_KERNEL(mul_mv_f32_f32);
GGML_METAL_ADD_KERNEL(mul_mv_f16_f16);
GGML_METAL_ADD_KERNEL(mul_mv_q4_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mv_q6_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_f32_f32);
+ //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f16);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32);
+ //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_1row);
+ //GGML_METAL_ADD_KERNEL(mul_mv_id_f16_f32_l4);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q4_0_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q4_1_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q5_0_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q5_1_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q8_0_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q2_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q3_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q4_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q5_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mv_id_q6_K_f32);
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
GGML_METAL_ADD_KERNEL(rope_f16);
GGML_METAL_ADD_KERNEL(alibi_f32);
GGML_METAL_ADD_KERNEL(im2col_f16);
+ GGML_METAL_ADD_KERNEL(upscale_f32);
+ GGML_METAL_ADD_KERNEL(pad_f32);
GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc);
GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc);
+ GGML_METAL_ADD_KERNEL(leaky_relu_f32);
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
GGML_METAL_ADD_KERNEL(cpy_f32_q8_0);
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_0);
//GGML_METAL_ADD_KERNEL(cpy_f32_q5_1);
GGML_METAL_ADD_KERNEL(cpy_f16_f16);
+ GGML_METAL_ADD_KERNEL(cpy_f16_f32);
GGML_METAL_ADD_KERNEL(concat);
GGML_METAL_ADD_KERNEL(sqr);
GGML_METAL_ADD_KERNEL(sum_rows);
GGML_METAL_DEL_KERNEL(div_row);
GGML_METAL_DEL_KERNEL(scale);
GGML_METAL_DEL_KERNEL(scale_4);
- GGML_METAL_DEL_KERNEL(silu);
+ GGML_METAL_DEL_KERNEL(tanh);
GGML_METAL_DEL_KERNEL(relu);
GGML_METAL_DEL_KERNEL(gelu);
+ GGML_METAL_DEL_KERNEL(gelu_quick);
+ GGML_METAL_DEL_KERNEL(silu);
GGML_METAL_DEL_KERNEL(soft_max);
GGML_METAL_DEL_KERNEL(soft_max_4);
GGML_METAL_DEL_KERNEL(diag_mask_inf);
GGML_METAL_DEL_KERNEL(get_rows_q5_K);
GGML_METAL_DEL_KERNEL(get_rows_q6_K);
GGML_METAL_DEL_KERNEL(rms_norm);
+ GGML_METAL_DEL_KERNEL(group_norm);
GGML_METAL_DEL_KERNEL(norm);
GGML_METAL_DEL_KERNEL(mul_mv_f32_f32);
GGML_METAL_DEL_KERNEL(mul_mv_f16_f16);
GGML_METAL_DEL_KERNEL(mul_mv_q4_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_q5_K_f32);
GGML_METAL_DEL_KERNEL(mul_mv_q6_K_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_f32_f32);
+ //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f16);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32);
+ //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_1row);
+ //GGML_METAL_DEL_KERNEL(mul_mv_id_f16_f32_l4);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q4_0_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q4_1_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q5_0_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q5_1_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q8_0_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q2_K_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q3_K_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q4_K_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q5_K_f32);
+ GGML_METAL_DEL_KERNEL(mul_mv_id_q6_K_f32);
if ([ctx->device supportsFamily:MTLGPUFamilyApple7]) {
GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
GGML_METAL_DEL_KERNEL(rope_f16);
GGML_METAL_DEL_KERNEL(alibi_f32);
GGML_METAL_DEL_KERNEL(im2col_f16);
+ GGML_METAL_DEL_KERNEL(upscale_f32);
+ GGML_METAL_DEL_KERNEL(pad_f32);
GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc);
GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc);
+ GGML_METAL_DEL_KERNEL(leaky_relu_f32);
GGML_METAL_DEL_KERNEL(cpy_f32_f16);
GGML_METAL_DEL_KERNEL(cpy_f32_f32);
GGML_METAL_DEL_KERNEL(cpy_f32_q8_0);
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_0);
//GGML_METAL_DEL_KERNEL(cpy_f32_q5_1);
GGML_METAL_DEL_KERNEL(cpy_f16_f16);
+ GGML_METAL_DEL_KERNEL(cpy_f16_f32);
GGML_METAL_DEL_KERNEL(concat);
GGML_METAL_DEL_KERNEL(sqr);
GGML_METAL_DEL_KERNEL(sum_rows);
switch (op->op) {
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
- case GGML_UNARY_OP_SILU:
+ case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_RELU:
case GGML_UNARY_OP_GELU:
+ case GGML_UNARY_OP_GELU_QUICK:
+ case GGML_UNARY_OP_SILU:
return true;
default:
return false;
case GGML_OP_PERMUTE:
case GGML_OP_CONCAT:
case GGML_OP_ADD:
+ case GGML_OP_ACC:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_SCALE:
case GGML_OP_SUM_ROWS:
case GGML_OP_SOFT_MAX:
case GGML_OP_RMS_NORM:
+ case GGML_OP_GROUP_NORM:
case GGML_OP_NORM:
case GGML_OP_ALIBI:
case GGML_OP_ROPE:
case GGML_OP_IM2COL:
+ case GGML_OP_UPSCALE:
+ case GGML_OP_PAD:
case GGML_OP_ARGSORT:
- case GGML_OP_DUP:
- case GGML_OP_CPY:
- case GGML_OP_CONT:
+ case GGML_OP_LEAKY_RELU:
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
return true;
+ case GGML_OP_CPY:
+ case GGML_OP_DUP:
+ case GGML_OP_CONT:
+ {
+ switch (op->src[0]->type) {
+ case GGML_TYPE_F32:
+ switch (op->type) {
+ case GGML_TYPE_F16:
+ case GGML_TYPE_F32:
+ case GGML_TYPE_Q8_0:
+ case GGML_TYPE_Q4_0:
+ case GGML_TYPE_Q4_1:
+ return true;
+ default:
+ return false;
+ }
+ case GGML_TYPE_F16:
+ switch (op->type) {
+ case GGML_TYPE_F16:
+ case GGML_TYPE_F32:
+ return true;
+ default:
+ return false;
+ }
+ default:
+ return false;
+ };
+ }
case GGML_OP_DIAG_MASK_INF:
case GGML_OP_GET_ROWS:
{
- return op->ne[0] % 4 == 0;
+ return op->ne[3] == 1;
}
default:
return false;
} break;
}
- GGML_ASSERT(ggml_metal_supports_op(dst));
+ if (!ggml_metal_supports_op(dst)) {
+ GGML_METAL_LOG_ERROR("%s: error: unsupported op '%s'\n", __func__, ggml_op_desc(dst));
+ GGML_ASSERT(!"unsupported op");
+ }
const int64_t ne00 = src0 ? src0->ne[0] : 0;
const int64_t ne01 = src0 ? src0->ne[1] : 0;
case GGML_OP_MUL:
case GGML_OP_DIV:
{
- GGML_ASSERT(ggml_is_contiguous(src0));
- GGML_ASSERT(ggml_is_contiguous(src1));
+ const size_t offs = 0;
bool bcast_row = false;
int64_t nb = ne00;
- if (ggml_nelements(src1) == ne10 && ne00 % 4 == 0) {
+ id<MTLComputePipelineState> pipeline = nil;
+
+ if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) {
+ GGML_ASSERT(ggml_is_contiguous(src0));
+
// src1 is a row
GGML_ASSERT(ne11 == 1);
nb = ne00 / 4;
switch (dst->op) {
- case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add_row]; break;
- case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul_row]; break;
- case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div_row]; break;
+ case GGML_OP_ADD: pipeline = ctx->pipeline_add_row; break;
+ case GGML_OP_MUL: pipeline = ctx->pipeline_mul_row; break;
+ case GGML_OP_DIV: pipeline = ctx->pipeline_div_row; break;
default: GGML_ASSERT(false);
}
bcast_row = true;
} else {
switch (dst->op) {
- case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add]; break;
- case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul]; break;
- case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div]; break;
+ case GGML_OP_ADD: pipeline = ctx->pipeline_add; break;
+ case GGML_OP_MUL: pipeline = ctx->pipeline_mul; break;
+ case GGML_OP_DIV: pipeline = ctx->pipeline_div; break;
default: GGML_ASSERT(false);
}
}
+
+ [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
- [encoder setBytes:&nb length:sizeof(nb) atIndex:27];
+ [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
+ [encoder setBytes:&nb length:sizeof(nb) atIndex:28];
if (bcast_row) {
const int64_t n = ggml_nelements(dst)/4;
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} else {
- const int nth = MIN(1024, ne0);
+ const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
}
} break;
+ case GGML_OP_ACC:
+ {
+ GGML_ASSERT(src0t == GGML_TYPE_F32);
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
+ GGML_ASSERT(dstt == GGML_TYPE_F32);
+
+ GGML_ASSERT(ggml_is_contiguous(src0));
+ GGML_ASSERT(ggml_is_contiguous(src1));
+
+ const size_t pnb1 = ((int32_t *) dst->op_params)[0];
+ const size_t pnb2 = ((int32_t *) dst->op_params)[1];
+ const size_t pnb3 = ((int32_t *) dst->op_params)[2];
+ const size_t offs = ((int32_t *) dst->op_params)[3];
+
+ const bool inplace = (bool) ((int32_t *) dst->op_params)[4];
+
+ if (!inplace) {
+ // run a separete kernel to cpy src->dst
+ // not sure how to avoid this
+ // TODO: make a simpler cpy_bytes kernel
+
+ const int nth = MIN(1024, ne00);
+
+ [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ }
+
+ [encoder setComputePipelineState:ctx->pipeline_add];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
+ [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8];
+ [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9];
+ [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10];
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
+ [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
+ [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
+ [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24];
+ [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25];
+ [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26];
+ [encoder setBytes:&offs length:sizeof(offs) atIndex:27];
+
+ const int nth = MIN(1024, ne0);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne11, ne12, ne13) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
case GGML_OP_SCALE:
{
GGML_ASSERT(ggml_is_contiguous(src0));
} break;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(gf->nodes[i])) {
- case GGML_UNARY_OP_SILU:
+ case GGML_UNARY_OP_TANH:
{
- [encoder setComputePipelineState:ctx->pipeline_silu];
+ [encoder setComputePipelineState:ctx->pipeline_tanh];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst);
- GGML_ASSERT(n % 4 == 0);
- [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_UNARY_OP_RELU:
{
const int64_t n = ggml_nelements(dst);
GGML_ASSERT(n % 4 == 0);
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_UNARY_OP_GELU_QUICK:
+ {
+ [encoder setComputePipelineState:ctx->pipeline_gelu_quick];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+
+ const int64_t n = ggml_nelements(dst);
+ GGML_ASSERT(n % 4 == 0);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
+ case GGML_UNARY_OP_SILU:
+ {
+ [encoder setComputePipelineState:ctx->pipeline_silu];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+
+ const int64_t n = ggml_nelements(dst);
+ GGML_ASSERT(n % 4 == 0);
+
[encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
default:
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
if (id_src1) {
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ } else {
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
}
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
else if (src0t == GGML_TYPE_Q6_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {
- int64_t ny = (ne11 + nrows - 1)/nrows;
+ const int64_t ny = (ne11 + nrows - 1)/nrows;
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
}
GGML_ASSERT(src0t == GGML_TYPE_I32);
- const int n_as = ne00;
+ const int n_as = ((int32_t *) dst->op_params)[1];
// TODO: make this more general
GGML_ASSERT(n_as <= 8);
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
// to the matrix-vector kernel
- int ne11_mm_min = 0;
+ int ne11_mm_min = 1;
const int idx = ((int32_t *) dst->op_params)[0];
+ // batch size
+ GGML_ASSERT(ne01 == ne11);
+
+ const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory
+
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
- if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
- ne11 > ne11_mm_min) {
+ // !!!
+ // TODO: for now, always use mat-vec kernels until we figure out how to improve the
+ // indirect matrix multiplication
+ // !!!
+ if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) {
switch (src2->type) {
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break;
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break;
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
- [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:3];
- [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:4];
- [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:5];
- [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:6];
- [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
- [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8];
- [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9];
- [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10];
- [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11];
- [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12];
- [encoder setBytes:&r2 length:sizeof(r2) atIndex:13];
- [encoder setBytes:&r3 length:sizeof(r3) atIndex:14];
- [encoder setBytes:&idx length:sizeof(idx) atIndex:15];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:14];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:18];
// TODO: how to make this an array? read Metal docs
for (int j = 0; j < n_as; ++j) {
struct ggml_tensor * src_cur = dst->src[2 + j];
size_t offs_src_cur = 0;
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
- [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:16 + j];
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
}
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
- [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+
+ // TODO: processing one row at a time (ne11 -> 1) is not efficient
+ [encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
+ } else {
+ int nth0 = 32;
+ int nth1 = 1;
+ int nrows = 1;
+ //printf("vector: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
+
+ // use custom matrix x vector kernel
+ switch (src2t) {
+ case GGML_TYPE_F32:
+ {
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f32_f32];
+ } break;
+ case GGML_TYPE_F16:
+ {
+ GGML_ASSERT(src1t == GGML_TYPE_F32);
+ nth0 = 32;
+ nth1 = 1;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_f16_f32];
+ } break;
+ case GGML_TYPE_Q4_0:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_0_f32];
+ } break;
+ case GGML_TYPE_Q4_1:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_1_f32];
+ } break;
+ case GGML_TYPE_Q5_0:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_0_f32];
+ } break;
+ case GGML_TYPE_Q5_1:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_1_f32];
+ } break;
+ case GGML_TYPE_Q8_0:
+ {
+ nth0 = 8;
+ nth1 = 8;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q8_0_f32];
+ } break;
+ case GGML_TYPE_Q2_K:
+ {
+ nth0 = 2;
+ nth1 = 32;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q2_K_f32];
+ } break;
+ case GGML_TYPE_Q3_K:
+ {
+ nth0 = 2;
+ nth1 = 32;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q3_K_f32];
+ } break;
+ case GGML_TYPE_Q4_K:
+ {
+ nth0 = 4; //1;
+ nth1 = 8; //32;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q4_K_f32];
+ } break;
+ case GGML_TYPE_Q5_K:
+ {
+ nth0 = 2;
+ nth1 = 32;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q5_K_f32];
+ } break;
+ case GGML_TYPE_Q6_K:
+ {
+ nth0 = 2;
+ nth1 = 32;
+ [encoder setComputePipelineState:ctx->pipeline_mul_mv_id_q6_K_f32];
+ } break;
+ default:
+ {
+ GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t);
+ GGML_ASSERT(false && "not implemented");
+ }
+ };
+
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
+ [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
+ [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
+ [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
+ [encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
+ [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
+ [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
+ [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
+ [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
+ [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
+ [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
+ [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
+ [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
+ [encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
+ [encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
+ [encoder setBytes:&idx length:sizeof(idx) atIndex:22];
+ // TODO: how to make this an array? read Metal docs
+ for (int j = 0; j < n_as; ++j) {
+ struct ggml_tensor * src_cur = dst->src[2 + j];
+
+ size_t offs_src_cur = 0;
+ id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur);
+
+ [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
+ }
+
+ if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
+ src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
+ src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src2t == GGML_TYPE_Q4_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src2t == GGML_TYPE_Q3_K) {
+#ifdef GGML_QKK_64
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+#else
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+#endif
+ }
+ else if (src2t == GGML_TYPE_Q5_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
+ else if (src2t == GGML_TYPE_Q6_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ } else {
+ const int64_t ny = (_ne1 + nrows - 1)/nrows;
+ [encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
}
} break;
case GGML_OP_GET_ROWS:
default: GGML_ASSERT(false && "not implemented");
}
- [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
- [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
- [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:4];
- [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:5];
-
- const int64_t n = ggml_nelements(src1);
-
- [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:5];
+ [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:6];
+ [encoder setBytes:&nb10 length:sizeof( int64_t) atIndex:7];
+ [encoder setBytes:&nb11 length:sizeof( int64_t) atIndex:8];
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:9];
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:10];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne10, ne11, 1) threadsPerThreadgroup:MTLSizeMake(32, 1, 1)];
} break;
case GGML_OP_RMS_NORM:
{
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
+ case GGML_OP_GROUP_NORM:
+ {
+ GGML_ASSERT(ne00 % 4 == 0);
+
+ //float eps;
+ //memcpy(&eps, dst->op_params, sizeof(float));
+
+ const float eps = 1e-6f; // TODO: temporarily hardcoded
+
+ const int32_t n_groups = ((int32_t *) dst->op_params)[0];
+
+ int nth = 32; // SIMD width
+
+ //while (nth < ne00/4 && nth < 1024) {
+ // nth *= 2;
+ //}
+
+ [encoder setComputePipelineState:ctx->pipeline_group_norm];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:5];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:6];
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:7];
+ [encoder setBytes:&n_groups length:sizeof( int32_t) atIndex:8];
+ [encoder setBytes:&eps length:sizeof( float) atIndex:9];
+ [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n_groups, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
case GGML_OP_NORM:
{
float eps;
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
} break;
+ case GGML_OP_UPSCALE:
+ {
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+ const int sf = dst->op_params[0];
+
+ [encoder setComputePipelineState:ctx->pipeline_upscale_f32];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
+ [encoder setBytes:&sf length:sizeof(sf) atIndex:18];
+
+ const int nth = MIN(1024, ne0);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_PAD:
+ {
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+ [encoder setComputePipelineState:ctx->pipeline_pad_f32];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
+ [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+ [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5];
+ [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
+ [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11];
+ [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12];
+ [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13];
+ [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14];
+ [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
+ [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16];
+ [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17];
+
+ const int nth = MIN(1024, ne0);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
case GGML_OP_ARGSORT:
{
GGML_ASSERT(src0->type == GGML_TYPE_F32);
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
} break;
+ case GGML_OP_LEAKY_RELU:
+ {
+ GGML_ASSERT(src0->type == GGML_TYPE_F32);
+
+ float slope;
+ memcpy(&slope, dst->op_params, sizeof(float));
+
+ [encoder setComputePipelineState:ctx->pipeline_leaky_relu_f32];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&slope length:sizeof(slope) atIndex:2];
+
+ const int64_t n = ggml_nelements(dst);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ } break;
case GGML_OP_DUP:
case GGML_OP_CPY:
case GGML_OP_CONT:
{
switch (dstt) {
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
- case GGML_TYPE_F32: GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break;
+ case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f32]; break;
default: GGML_ASSERT(false && "not implemented");
};
} break;
constant int64_t & nb1,
constant int64_t & nb2,
constant int64_t & nb3,
+ constant int64_t & offs,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i12 = i02 % ne12;
const int64_t i11 = i01 % ne11;
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
const int i10 = i0 % ne10;
device const float4 * src0,
device const float4 * src1,
device float4 * dst,
- constant int64_t & nb [[buffer(27)]],
+ constant int64_t & nb [[buffer(28)]],
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] + src1[tpig % nb];
}
device const float4 * src0,
device const float4 * src1,
device float4 * dst,
- constant int64_t & nb [[buffer(27)]],
+ constant int64_t & nb [[buffer(28)]],
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * src1[tpig % nb];
}
device const float4 * src0,
device const float4 * src1,
device float4 * dst,
- constant int64_t & nb [[buffer(27)]],
+ constant int64_t & nb [[buffer(28)]],
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] / src1[tpig % nb];
}
dst[tpig] = src0[tpig] * scale;
}
-kernel void kernel_silu(
- device const float4 * src0,
- device float4 * dst,
+kernel void kernel_relu(
+ device const float * src0,
+ device float * dst,
uint tpig[[thread_position_in_grid]]) {
- device const float4 & x = src0[tpig];
- dst[tpig] = x / (1.0f + exp(-x));
+ dst[tpig] = max(0.0f, src0[tpig]);
}
-kernel void kernel_relu(
+kernel void kernel_tanh(
device const float * src0,
device float * dst,
uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = max(0.0f, src0[tpig]);
+ device const float & x = src0[tpig];
+ dst[tpig] = precise::tanh(x);
+}
+
+constant float GELU_COEF_A = 0.044715f;
+constant float GELU_QUICK_COEF = -1.702f;
+constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
+
+kernel void kernel_gelu(
+ device const float4 * src0,
+ device float4 * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float4 & x = src0[tpig];
+
+ // BEWARE !!!
+ // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
+ // This was observed with Falcon 7B and 40B models
+ //
+ dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+}
+
+kernel void kernel_gelu_quick(
+ device const float4 * src0,
+ device float4 * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float4 & x = src0[tpig];
+
+ dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
+}
+
+kernel void kernel_silu(
+ device const float4 * src0,
+ device float4 * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float4 & x = src0[tpig];
+ dst[tpig] = x / (1.0f + exp(-x));
}
kernel void kernel_sqr(
dst_row[0] = row_sum;
}
-constant float GELU_COEF_A = 0.044715f;
-constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
-
-kernel void kernel_gelu(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- device const float4 & x = src0[tpig];
-
- // BEWARE !!!
- // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
- // This was observed with Falcon 7B and 40B models
- //
- dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
-}
-
kernel void kernel_soft_max(
device const float * src0,
device const float * src1,
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
- device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
- device const float * pmask = src1 ? src1 + i01*ne00 : nullptr;
- device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+ device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+ device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
+ device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
// parallel max
float lmax = -INFINITY;
pdst[i00] = exp_psrc0;
}
+ // This barrier fixes a failing test
+ // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
+ threadgroup_barrier(mem_flags::mem_none);
+
float sum = simd_sum(lsum);
+
if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = 0.0f;
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
- device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
- device const float4 * pmask = src1 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
- device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+ device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+ device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
+ device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
// parallel max
float4 lmax4 = -INFINITY;
}
const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
+
+ // This barrier fixes a failing test
+ // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
+ threadgroup_barrier(mem_flags::mem_none);
+
float sum = simd_sum(lsum);
+
if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = 0.0f;
}
}
+kernel void kernel_group_norm(
+ device const float * src0,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int32_t & n_groups,
+ constant float & eps,
+ threadgroup float * buf [[threadgroup(0)]],
+ uint tgpig[[threadgroup_position_in_grid]],
+ uint tpitg[[thread_position_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint ntg[[threads_per_threadgroup]]) {
+ const int64_t ne = ne00*ne01*ne02;
+ const int64_t gs = ne00*ne01*((ne02 + n_groups - 1) / n_groups);
+
+ int start = tgpig * gs;
+ int end = start + gs;
+
+ start += tpitg;
+
+ if (end >= ne) {
+ end = ne;
+ }
+
+ float tmp = 0.0f; // partial sum for thread in warp
+
+ for (int j = start; j < end; j += ntg) {
+ tmp += src0[j];
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ tmp = simd_sum(tmp);
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = 0.0f;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ buf[sgitg] = tmp;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ tmp = buf[tiisg];
+ tmp = simd_sum(tmp);
+ }
+
+ const float mean = tmp / gs;
+ tmp = 0.0f;
+
+ for (int j = start; j < end; j += ntg) {
+ float xi = src0[j] - mean;
+ dst[j] = xi;
+ tmp += xi * xi;
+ }
+
+ tmp = simd_sum(tmp);
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = 0.0f;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ buf[sgitg] = tmp;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ tmp = buf[tiisg];
+ tmp = simd_sum(tmp);
+ }
+
+ const float variance = tmp / gs;
+ const float scale = 1.0f/sqrt(variance + eps);
+ for (int j = start; j < end; j += ntg) {
+ dst[j] *= scale;
+ }
+}
+
// function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
// il indicates where the q4 quants begin (0 or QK4_0/4)
// we assume that the yl's have been multiplied with the appropriate scale factor
// giard against the number of rows not being divisible by
// N_DST, so this is another explicit assumption of the implementation.
template<typename block_q_type, int nr, int nsg, int nw>
-void mul_vec_q_n_f32(
+void mul_vec_q_n_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+ mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
}
kernel void kernel_mul_mv_q4_1_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+ mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
}
kernel void kernel_mul_mv_q5_0_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+ mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
}
kernel void kernel_mul_mv_q5_1_f32(
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+ mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
}
#define NB_Q8_0 8
-kernel void kernel_mul_mv_q8_0_f32(
+void kernel_mul_mv_q8_0_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne01[[buffer(4)]],
- constant int64_t & ne02[[buffer(5)]],
- constant int64_t & ne10[[buffer(9)]],
- constant int64_t & ne12[[buffer(11)]],
- constant int64_t & ne0 [[buffer(15)]],
- constant int64_t & ne1 [[buffer(16)]],
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
const int nr = N_DST;
const int nsg = N_SIMDGROUP;
const int nw = N_SIMDWIDTH;
}
}
+[[host_name("kernel_mul_mv_q8_0_f32")]]
+kernel void kernel_mul_mv_q8_0_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2 [[buffer(17)]],
+ constant uint & r3 [[buffer(18)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg);
+}
+
#define N_F32_F32 4
-kernel void kernel_mul_mv_f32_f32(
+void kernel_mul_mv_f32_f32_impl(
device const char * src0,
device const char * src1,
device float * dst,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
}
}
+[[host_name("kernel_mul_mv_f32_f32")]]
+kernel void kernel_mul_mv_f32_f32(
+ device const char * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2 [[buffer(17)]],
+ constant uint & r3 [[buffer(18)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]]) {
+ kernel_mul_mv_f32_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
+}
+
#define N_F16_F16 4
kernel void kernel_mul_mv_f16_f16(
}
}
-kernel void kernel_mul_mv_f16_f32_1row(
+void kernel_mul_mv_f16_f32_1row_impl(
device const char * src0,
device const char * src1,
device float * dst,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
+}
+[[host_name("kernel_mul_mv_f16_f32_1row")]]
+kernel void kernel_mul_mv_f16_f32_1row(
+ device const char * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2 [[buffer(17)]],
+ constant uint & r3 [[buffer(18)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]]) {
+ kernel_mul_mv_f16_f32_1row_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
}
#define N_F16_F32 4
-kernel void kernel_mul_mv_f16_f32(
+void kernel_mul_mv_f16_f32_impl(
device const char * src0,
device const char * src1,
device float * dst,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
}
}
+[[host_name("kernel_mul_mv_f16_f32")]]
+kernel void kernel_mul_mv_f16_f32(
+ device const char * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2 [[buffer(17)]],
+ constant uint & r3 [[buffer(18)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]]) {
+ kernel_mul_mv_f16_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb00, nb01, nb02, ne10, ne11, ne12, nb10, nb11, nb12, ne0, ne1, r2, r3, tgpig, tiisg);
+}
+
// Assumes row size (ne00) is a multiple of 4
kernel void kernel_mul_mv_f16_f32_l4(
device const char * src0,
}
}
-// bitonic sort implementation following the CUDA kernels as reference
-typedef void (argsort_t)(
- device const float * x,
- device int32_t * dst,
- constant int64_t & ncols,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]]);
-
-template<ggml_sort_order order>
-kernel void kernel_argsort_f32_i32(
- device const float * x,
- device int32_t * dst,
- constant int64_t & ncols,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]]) {
- // bitonic sort
- int col = tpitg[0];
- int row = tgpig[1];
-
+kernel void kernel_upscale_f32(
+ device const char * src0,
+ device char * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ constant int32_t & sf,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+
+ const int64_t i3 = tgpig.z;
+ const int64_t i2 = tgpig.y;
+ const int64_t i1 = tgpig.x;
+
+ const int64_t i03 = i3;
+ const int64_t i02 = i2;
+ const int64_t i01 = i1/sf;
+
+ device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
+ device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
+
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+ dst_ptr[i0] = src0_ptr[i0/sf];
+ }
+}
+
+kernel void kernel_pad_f32(
+ device const char * src0,
+ device char * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+
+ const int64_t i3 = tgpig.z;
+ const int64_t i2 = tgpig.y;
+ const int64_t i1 = tgpig.x;
+
+ const int64_t i03 = i3;
+ const int64_t i02 = i2;
+ const int64_t i01 = i1;
+
+ device const float * src0_ptr = (device const float *) (src0 + i03*nb03 + i02*nb02 + i01*nb01);
+ device float * dst_ptr = (device float *) (dst + i3*nb3 + i2*nb2 + i1*nb1);
+
+ if (i1 < ne01 && i2 < ne02 && i3 < ne03) {
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+ if (i0 < ne00) {
+ dst_ptr[i0] = src0_ptr[i0];
+ } else {
+ dst_ptr[i0] = 0.0f;
+ }
+ }
+
+ return;
+ }
+
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+ dst_ptr[i0] = 0.0f;
+ }
+}
+
+// bitonic sort implementation following the CUDA kernels as reference
+typedef void (argsort_t)(
+ device const float * x,
+ device int32_t * dst,
+ constant int64_t & ncols,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]]);
+
+template<ggml_sort_order order>
+kernel void kernel_argsort_f32_i32(
+ device const float * x,
+ device int32_t * dst,
+ constant int64_t & ncols,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]]) {
+ // bitonic sort
+ int col = tpitg[0];
+ int row = tgpig[1];
+
if (col >= ncols) return;
device const float * x_row = x + row * ncols;
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
+kernel void kernel_leaky_relu_f32(
+ device const float * src0,
+ device float * dst,
+ constant float & slope,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
+}
+
kernel void kernel_cpy_f16_f16(
- device const half * src0,
- device half * dst,
+ device const half * src0,
+ device half * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
}
}
+kernel void kernel_cpy_f16_f32(
+ device const half * src0,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig[2];
+ const int64_t i02 = tgpig[1];
+ const int64_t i01 = tgpig[0];
+
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+ const int64_t i3 = n / (ne2*ne1*ne0);
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+ device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
+ device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+ dst_data[i00] = src[0];
+ }
+}
+
kernel void kernel_cpy_f32_f16(
device const float * src0,
device half * dst,
}
kernel void kernel_concat(
- device const char * src0,
- device const char * src1,
- device char * dst,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
const int64_t i12 = i02 % ne12;
const int64_t i11 = i01 % ne11;
- device const char * src0_ptr = src0 + i03 * nb03 + i02 * nb02 + i01 * nb01 + tpitg.x*nb00;
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00;
device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10;
device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0;
//====================================== dot products =========================
-kernel void kernel_mul_mv_q2_K_f32(
+void kernel_mul_mv_q2_K_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne01[[buffer(4)]],
- constant int64_t & ne02[[buffer(5)]],
- constant int64_t & ne10[[buffer(9)]],
- constant int64_t & ne12[[buffer(11)]],
- constant int64_t & ne0 [[buffer(15)]],
- constant int64_t & ne1 [[buffer(16)]],
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
}
}
-#if QK_K == 256
-kernel void kernel_mul_mv_q3_K_f32(
+[[host_name("kernel_mul_mv_q2_K_f32")]]
+kernel void kernel_mul_mv_q2_K_f32(
device const void * src0,
device const float * src1,
device float * dst,
constant uint & r2 [[buffer(17)]],
constant uint & r3 [[buffer(18)]],
uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
+#if QK_K == 256
+void kernel_mul_mv_q3_K_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
const int nb = ne00/QK_K;
}
}
#else
-kernel void kernel_mul_mv_q3_K_f32(
+void kernel_mul_mv_q3_K_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne01[[buffer(4)]],
- constant int64_t & ne02[[buffer(5)]],
- constant int64_t & ne10[[buffer(9)]],
- constant int64_t & ne12[[buffer(11)]],
- constant int64_t & ne0 [[buffer(15)]],
- constant int64_t & ne1 [[buffer(16)]],
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
}
#endif
+[[host_name("kernel_mul_mv_q3_K_f32")]]
+kernel void kernel_mul_mv_q3_K_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01[[buffer(4)]],
+ constant int64_t & ne02[[buffer(5)]],
+ constant int64_t & ne10[[buffer(9)]],
+ constant int64_t & ne12[[buffer(11)]],
+ constant int64_t & ne0 [[buffer(15)]],
+ constant int64_t & ne1 [[buffer(16)]],
+ constant uint & r2 [[buffer(17)]],
+ constant uint & r3 [[buffer(18)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
#if QK_K == 256
-kernel void kernel_mul_mv_q4_K_f32(
+void kernel_mul_mv_q4_K_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne01 [[buffer(4)]],
- constant int64_t & ne02 [[buffer(5)]],
- constant int64_t & ne10 [[buffer(9)]],
- constant int64_t & ne12 [[buffer(11)]],
- constant int64_t & ne0 [[buffer(15)]],
- constant int64_t & ne1 [[buffer(16)]],
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
}
}
#else
-kernel void kernel_mul_mv_q4_K_f32(
+void kernel_mul_mv_q4_K_f32_impl(
device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne01[[buffer(4)]],
- constant int64_t & ne02[[buffer(5)]],
- constant int64_t & ne10[[buffer(9)]],
- constant int64_t & ne12[[buffer(11)]],
- constant int64_t & ne0 [[buffer(15)]],
- constant int64_t & ne1 [[buffer(16)]],
- constant uint & r2 [[buffer(17)]],
- constant uint & r3 [[buffer(18)]],
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
}
#endif
-kernel void kernel_mul_mv_q5_K_f32(
+[[host_name("kernel_mul_mv_q4_K_f32")]]
+kernel void kernel_mul_mv_q4_K_f32(
device const void * src0,
device const float * src1,
device float * dst,
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
+void kernel_mul_mv_q5_K_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
const int nb = ne00/QK_K;
const int64_t r0 = tgpig.x;
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
}
}
-
}
-kernel void kernel_mul_mv_q6_K_f32(
+[[host_name("kernel_mul_mv_q5_K_f32")]]
+kernel void kernel_mul_mv_q5_K_f32(
device const void * src0,
device const float * src1,
device float * dst,
constant uint & r2 [[buffer(17)]],
constant uint & r3 [[buffer(18)]],
uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
-
- const uint8_t kmask1 = 0x03;
- const uint8_t kmask2 = 0x0C;
- const uint8_t kmask3 = 0x30;
- const uint8_t kmask4 = 0xC0;
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int nb = ne00/QK_K;
+ kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
- const int64_t r0 = tgpig.x;
- const int64_t r1 = tgpig.y;
+void kernel_mul_mv_q6_K_f32_impl(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne10,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & r2,
+ constant uint & r3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ const uint8_t kmask1 = 0x03;
+ const uint8_t kmask2 = 0x0C;
+ const uint8_t kmask3 = 0x30;
+ const uint8_t kmask4 = 0xC0;
+
+ const int nb = ne00/QK_K;
+
+ const int64_t r0 = tgpig.x;
+ const int64_t r1 = tgpig.y;
const int im = tgpig.z;
const int row = 2 * r0 + sgitg;
}
}
+[[host_name("kernel_mul_mv_q6_K_f32")]]
+kernel void kernel_mul_mv_q6_K_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01[[buffer(4)]],
+ constant int64_t & ne02[[buffer(5)]],
+ constant int64_t & ne10[[buffer(9)]],
+ constant int64_t & ne12[[buffer(11)]],
+ constant int64_t & ne0 [[buffer(15)]],
+ constant int64_t & ne1 [[buffer(16)]],
+ constant uint & r2 [[buffer(17)]],
+ constant uint & r3 [[buffer(18)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
+}
+
//============================= templates and their specializations =============================
// NOTE: this is not dequantizing - we are simply fitting the template
template <typename type4x4>
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
- const half d = xb->d;
- const half min = xb->dmin;
+ const float d = xb->d;
+ const float min = xb->dmin;
device const uint8_t * q = (device const uint8_t *)xb->qs;
- half dl, ml;
+ float dl, ml;
uint8_t sc = xb->scales[il];
#if QK_K == 256
q = q + (il/4) * 32 + 16 * (il&1);
il = il & 3;
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
- const half d = il < 2 ? xb->d : xb->d / 16.h;
- const half min = xb->dmin;
- const half dl = d * sc[0];
- const half ml = min * sc[1];
+ const float d = il < 2 ? xb->d : xb->d / 16.h;
+ const float min = xb->dmin;
+ const float dl = d * sc[0];
+ const float ml = min * sc[1];
#else
q = q + 16 * (il&1);
device const uint8_t * s = xb->scales;
uint8_t ul = 1 << (il/2);
il = il & 3;
const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
- const half d = il < 2 ? xb->d : xb->d / 16.h;
- const half min = xb->dmin;
- const half dl = d * sc[0];
- const half ml = min * sc[1];
+ const float d = il < 2 ? xb->d : xb->d / 16.h;
+ const float min = xb->dmin;
+ const float dl = d * sc[0];
+ const float ml = min * sc[1];
- const ushort mask = il<2 ? 0x0F : 0xF0;
- const half qh_val = il<2 ? 16.h : 256.h;
+ const ushort mask = il<2 ? 0x0F : 0xF0;
+ const float qh_val = il<2 ? 16.f : 256.f;
for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
}
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
kernel void kernel_get_rows(
device const void * src0,
- device const int * src1,
+ device const char * src1,
device float * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
constant uint64_t & nb1,
- uint tgpig[[threadgroup_position_in_grid]],
+ constant uint64_t & nb2,
+ uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]],
- uint tptg[[threads_per_threadgroup]]) {
- const int i = tgpig;
- const int r = ((device int32_t *) src1)[i];
+ uint3 tptg [[threads_per_threadgroup]]) {
+ //const int64_t i = tgpig;
+ //const int64_t r = ((device int32_t *) src1)[i];
+
+ const int64_t i10 = tgpig.x;
+ const int64_t i11 = tgpig.y;
- for (int ind = tiitg; ind < ne00/16; ind += tptg) {
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+ const int64_t i02 = i11;
+
+ for (int64_t ind = tiitg; ind < ne00/16; ind += tptg.x) {
float4x4 temp;
dequantize_func(
- ((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
- *(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp;
+ ((device const block_q *) ((device char *) src0 + r*nb01 + i02*nb02)) + ind/nl, ind%nl, temp);
+ *(((device float4x4 *) ((device char *) dst + i11*nb2 + i10*nb1)) + ind) = temp;
+ }
+}
+
+kernel void kernel_get_rows_f32(
+ device const void * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint3 tptg [[threads_per_threadgroup]]) {
+ const int64_t i10 = tgpig.x;
+ const int64_t i11 = tgpig.y;
+
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+ const int64_t i02 = i11;
+
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
+ ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
+ ((device float *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
+ }
+}
+
+kernel void kernel_get_rows_f16(
+ device const void * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint3 tptg [[threads_per_threadgroup]]) {
+ const int64_t i10 = tgpig.x;
+ const int64_t i11 = tgpig.y;
+
+ const int64_t r = ((device int32_t *) ((device char *) src1 + i11*nb11 + i10*nb10))[0];
+
+ const int64_t i02 = i11;
+
+ for (int ind = tiitg; ind < ne00; ind += tptg.x) {
+ ((device float *) ((device char *) dst + i11*nb2 + i10*nb1))[ind] =
+ ((device half *) ((device char *) src0 + r*nb01 + i02*nb02))[ind];
}
}
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
kernel void kernel_mul_mm_id(
- device const int32_t * ids,
+ device const uchar * ids,
device const uchar * src1,
- device float * dst,
+ device uchar * dst,
+ constant int64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne02,
constant int64_t & nb01,
constant int64_t & nb02,
constant int64_t & ne12,
+ constant int64_t & ne13,
constant int64_t & nb10,
constant int64_t & nb11,
constant int64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
+ constant int64_t & nb1,
constant uint & r2,
constant uint & r3,
constant int & idx,
uint sgitg[[simdgroup_index_in_threadgroup]]) {
device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
kernel_mul_mm_impl<block_q, nl, dequantize_func>(
- src0[ids[idx]],
- src1,
- dst,
+ src0[id],
+ src1 + bid*nb11,
+ (device float *) (dst + bid*nb1),
ne00,
ne02,
nb01,
#define QK_NL 4
#endif
+//
+// get rows
+//
+
typedef void (get_rows_t)(
device const void * src0,
- device const int * src1,
+ device const char * src1,
device float * dst,
constant int64_t & ne00,
constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
constant uint64_t & nb1,
- uint, uint, uint);
+ constant uint64_t & nb2,
+ uint3, uint, uint3);
-template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
-template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
+//template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
+//template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
+//
+// matrix-matrix multiplication
+//
+
typedef void (mat_mm_t)(
device const uchar * src0,
device const uchar * src1,
template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
+//
+// indirect matrix-matrix multiplication
+//
+
typedef void (mat_mm_id_t)(
- device const int32_t * ids,
+ device const uchar * ids,
device const uchar * src1,
- device float * dst,
+ device uchar * dst,
+ constant int64_t & nbi1,
constant int64_t & ne00,
constant int64_t & ne02,
constant int64_t & nb01,
constant int64_t & nb02,
constant int64_t & ne12,
+ constant int64_t & ne13,
constant int64_t & nb10,
constant int64_t & nb11,
constant int64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
+ constant int64_t & nb1,
constant uint & r2,
constant uint & r3,
constant int & idx,
template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q4_K, QK_NL, dequantize_q4_K>;
template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q5_K, QK_NL, dequantize_q5_K>;
template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_q6_K, QK_NL, dequantize_q6_K>;
+
+//
+// matrix-vector multiplication
+//
+
+[[host_name("kernel_mul_mv_id_f32_f32")]]
+kernel void kernel_mul_mv_id_f32_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_f32_f32_impl(
+ src0[id],
+ src1 + bid*nb11,
+ (device float *) (dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ nb00,
+ nb01,
+ nb02,
+ ne10,
+ ne11,
+ ne12,
+ nb10,
+ nb11,
+ nb12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg);
+}
+
+[[host_name("kernel_mul_mv_id_f16_f32")]]
+kernel void kernel_mul_mv_id_f16_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_f16_f32_impl(
+ src0[id],
+ src1 + bid*nb11,
+ (device float *) (dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ nb00,
+ nb01,
+ nb02,
+ ne10,
+ ne11,
+ ne12,
+ nb10,
+ nb11,
+ nb12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg);
+}
+
+[[host_name("kernel_mul_mv_id_q8_0_f32")]]
+kernel void kernel_mul_mv_id_q8_0_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_q8_0_f32_impl(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q4_0_f32")]]
+kernel void kernel_mul_mv_id_q4_0_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q4_1_f32")]]
+kernel void kernel_mul_mv_id_q4_1_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q5_0_f32")]]
+kernel void kernel_mul_mv_id_q5_0_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q5_1_f32")]]
+kernel void kernel_mul_mv_id_q5_1_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q2_K_f32")]]
+kernel void kernel_mul_mv_id_q2_K_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_q2_K_f32_impl(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q3_K_f32")]]
+kernel void kernel_mul_mv_id_q3_K_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_q3_K_f32_impl(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q4_K_f32")]]
+kernel void kernel_mul_mv_id_q4_K_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_q4_K_f32_impl(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q5_K_f32")]]
+kernel void kernel_mul_mv_id_q5_K_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_q5_K_f32_impl(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
+
+[[host_name("kernel_mul_mv_id_q6_K_f32")]]
+kernel void kernel_mul_mv_id_q6_K_f32(
+ device const char * ids,
+ device const char * src1,
+ device uchar * dst,
+ constant int64_t & nbi1,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & nb1,
+ constant uint & r2,
+ constant uint & r3,
+ constant int & idx,
+ device const char * src00,
+ device const char * src01,
+ device const char * src02,
+ device const char * src03,
+ device const char * src04,
+ device const char * src05,
+ device const char * src06,
+ device const char * src07,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
+
+ const int64_t bid = tgpig.z/(ne12*ne13);
+
+ tgpig.z = tgpig.z%(ne12*ne13);
+
+ const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
+
+ kernel_mul_mv_q6_K_f32_impl(
+ src0[id],
+ (device const float *) (src1 + bid*nb11),
+ (device float *) ( dst + bid*nb1),
+ ne00,
+ ne01,
+ ne02,
+ ne10,
+ ne12,
+ ne0,
+ ne1,
+ r2,
+ r3,
+ tgpig,
+ tiisg,
+ sgitg);
+}
size_t vl = __riscv_vsetvl_e8m1(qk/2);
- // These tempory registers are for masking and shift operations
+ // These temporary registers are for masking and shift operations
vuint32m2_t vt_1 = __riscv_vid_v_u32m2(vl);
vuint32m2_t vt_2 = __riscv_vsll_vv_u32m2(__riscv_vmv_v_x_u32m2(1, vl), vt_1, vl);
vl = 16;
- // retreive lane to multiply with scale
+ // retrieve lane to multiply with scale
vint32m2_t aux0_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 0), (scale[0]), vl);
vint32m2_t aux0_1 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a0, 1), (scale[1]), vl);
vint32m2_t aux1_0 = __riscv_vwmul_vx_i32m2(__riscv_vget_v_i16m2_i16m1(a1, 0), (scale[2]), vl);
-#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
+#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows
#define _USE_MATH_DEFINES // For M_PI on MSVC
#include "ggml-impl.h"
// we should just be careful :)
#pragma warning(disable: 4244 4267)
-// disable POSIX deprecation warnigns
+// disable POSIX deprecation warnings
// these functions are never going away, anyway
#pragma warning(disable: 4996)
#endif
inline static void ggml_vec_tanh_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = tanhf(x[i]); }
inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expf(x[i])-1; }
inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; }
-inline static void ggml_vec_leaky_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.1f*x[i]; }
+inline static void ggml_vec_leaky_relu_f32 (const int n, float * y, const float * x, const float ns) { for (int i = 0; i < n; ++i) y[i] = ((x[i] > 0.f) ? x[i] : 0.f) + ns * ((x[i] < 0.0f) ? x[i] : 0.f); }
static const float GELU_COEF_A = 0.044715f;
static const float GELU_QUICK_COEF = -1.702f;
"POOL_1D",
"POOL_2D",
"UPSCALE",
+ "PAD",
"ARGSORT",
+ "LEAKY_RELU",
"FLASH_ATTN",
"FLASH_FF",
"CROSS_ENTROPY_LOSS_BACK",
};
-static_assert(GGML_OP_COUNT == 70, "GGML_OP_COUNT != 70");
+static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none",
"pool_1d(x)",
"pool_2d(x)",
"upscale(x)",
+ "pad(x)",
"argsort(x)",
+ "leaky_relu(x)",
"flash_attn(x)",
"flash_ff(x)",
"cross_entropy_loss_back(x,y)",
};
-static_assert(GGML_OP_COUNT == 70, "GGML_OP_COUNT != 70");
+static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
"GELU",
"GELU_QUICK",
"SILU",
- "LEAKY",
};
-static_assert(GGML_UNARY_OP_COUNT == 11, "GGML_UNARY_OP_COUNT != 11");
+static_assert(GGML_UNARY_OP_COUNT == 10, "GGML_UNARY_OP_COUNT != 10");
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
// WARN:
-// Mis-confguration can lead to problem that's hard to reason about:
+// Mis-configuration can lead to problem that's hard to reason about:
// * At best it crash or talks nosense.
// * At worst it talks slightly difference but hard to perceive.
//
return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_RELU);
}
-// ggml_leaky
+// ggml_leaky_relu
-struct ggml_tensor * ggml_leaky(
+struct ggml_tensor * ggml_leaky_relu(
struct ggml_context * ctx,
- struct ggml_tensor * a) {
- return ggml_unary(ctx, a, GGML_UNARY_OP_LEAKY);
+ struct ggml_tensor * a, float negative_slope, bool inplace) {
+ bool is_node = false;
+
+ if (!inplace && (a->grad)) {
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
+ ggml_set_op_params(result, &negative_slope, sizeof(negative_slope));
+
+ result->op = GGML_OP_LEAKY_RELU;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
}
// ggml_gelu
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
- result->op = GGML_OP_GROUP_NORM;
result->op_params[0] = n_groups;
+
+ result->op = GGML_OP_GROUP_NORM;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = NULL; // TODO: maybe store epsilon here?
struct ggml_tensor * ggml_mul_mat_id(
struct ggml_context * ctx,
- struct ggml_tensor * as[],
+ struct ggml_tensor * const as[],
+ int n_as,
struct ggml_tensor * ids,
int id,
struct ggml_tensor * b) {
- int64_t n_as = ids->ne[0];
-
GGML_ASSERT(ids->type == GGML_TYPE_I32);
- GGML_ASSERT(ggml_is_vector(ids));
+ GGML_ASSERT(ids->ne[2] == 1 && ids->ne[3] == 1);
+ GGML_ASSERT(ids->ne[1] == b->ne[1]);
+ GGML_ASSERT(ids->ne[2] == b->ne[2] && ids->ne[3] == b->ne[3]);
GGML_ASSERT(n_as > 0 && n_as <= GGML_MAX_SRC - 2);
- GGML_ASSERT(id >= 0 && id < n_as);
+ GGML_ASSERT(id >= 0 && id < ids->ne[0]);
bool is_node = false;
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(as[0]->n_dims, b->n_dims), ne);
ggml_set_op_params_i32(result, 0, id);
+ ggml_set_op_params_i32(result, 1, n_as);
result->op = GGML_OP_MUL_MAT_ID;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = ids;
result->src[1] = b;
- for (int64_t i = 0; i < n_as; i++) {
+ for (int i = 0; i < n_as; i++) {
struct ggml_tensor * a = as[i];
GGML_ASSERT(ggml_are_same_shape(as[0], a));
GGML_ASSERT(ggml_can_mul_mat(a, b));
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b) {
- GGML_ASSERT(ggml_is_matrix(a) && ggml_is_vector(b) && b->type == GGML_TYPE_I32);
+ GGML_ASSERT(a->ne[2] == b->ne[1]);
+ GGML_ASSERT(b->ne[3] == 1);
+ GGML_ASSERT(b->type == GGML_TYPE_I32);
bool is_node = false;
// TODO: implement non F32 return
//struct ggml_tensor * result = ggml_new_tensor_2d(ctx, a->type, a->ne[0], b->ne[0]);
- struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0]);
+ struct ggml_tensor * result = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, a->ne[0], b->ne[0], b->ne[1], b->ne[2]);
result->op = GGML_OP_GET_ROWS;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
return result;
}
+struct ggml_tensor * ggml_pad(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int p0, int p1, int p2, int p3) {
+ bool is_node = false;
+
+ if (a->grad) {
+ GGML_ASSERT(false); // TODO: implement backward
+ is_node = true;
+ }
+
+ struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
+ a->ne[0] + p0,
+ a->ne[1] + p1,
+ a->ne[2] + p2,
+ a->ne[3] + p3);
+
+ result->op = GGML_OP_PAD;
+ result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
+ result->src[0] = a;
+
+ return result;
+}
+
struct ggml_tensor * ggml_upscale(
struct ggml_context * ctx,
struct ggml_tensor * a,
GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
// view src0 and dst with these strides and data offset inbytes during acc
- // nb0 is implicitely element_size because src0 and dst are contiguous
+ // nb0 is implicitly element_size because src0 and dst are contiguous
size_t nb1 = ((int32_t *) dst->op_params)[0];
size_t nb2 = ((int32_t *) dst->op_params)[1];
size_t nb3 = ((int32_t *) dst->op_params)[2];
const int ith = params->ith;
const int nth = params->nth;
+// TODO: OpenCL kernel support broadcast
#ifdef GGML_USE_CLBLAST
if (src1->backend == GGML_BACKEND_GPU) {
+ GGML_ASSERT(ggml_are_same_shape(src0, src1));
if (ith == 0) {
ggml_cl_mul(src0, src1, dst);
}
} break;
}
}
+// ggml_compute_forward_leaky_relu
-// ggml_compute_forward_leaky
-
-static void ggml_compute_forward_leaky_f32(
+static void ggml_compute_forward_leaky_relu_f32(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
struct ggml_tensor * dst) {
const int n = ggml_nrows(src0);
const int nc = src0->ne[0];
+ float negative_slope;
+ memcpy(&negative_slope, dst->op_params, sizeof(float));
+
assert(dst->nb[0] == sizeof(float));
assert(src0->nb[0] == sizeof(float));
for (int i = 0; i < n; i++) {
- ggml_vec_leaky_f32(nc,
+ ggml_vec_leaky_relu_f32(nc,
(float *) ((char *) dst->data + i*( dst->nb[1])),
- (float *) ((char *) src0->data + i*(src0->nb[1])));
+ (float *) ((char *) src0->data + i*(src0->nb[1])), negative_slope);
}
}
-static void ggml_compute_forward_leaky(
+static void ggml_compute_forward_leaky_relu(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
struct ggml_tensor * dst) {
switch (src0->type) {
case GGML_TYPE_F32:
{
- ggml_compute_forward_leaky_f32(params, src0, dst);
+ ggml_compute_forward_leaky_relu_f32(params, src0, dst);
} break;
default:
{
const int64_t ne0 = dst->ne[0];
const int64_t ne1 = dst->ne[1];
+ // NOTE: with GGML_OP_MUL_MAT_ID we don't want to go through the BLAS branch because it will dequantize (to_float)
+ // all the experts for each batch element and the processing would become incredibly slow
// TODO: find the optimal values for these
- if (ggml_is_contiguous(src0) &&
+ if (dst->op != GGML_OP_MUL_MAT_ID &&
+ ggml_is_contiguous(src0) &&
ggml_is_contiguous(src1) &&
//src0->type == GGML_TYPE_F32 &&
src1->type == GGML_TYPE_F32 &&
}
#endif
+// off1 = offset in i11 and i1
+// cne1 = ne11 and ne1
+// in a normal matrix multiplication, off1 = 0 and cne1 = ne1
+// during GGML_TASK_INIT, the full src1 is converted regardless of off1 and cne1
static void ggml_compute_forward_mul_mat(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * src1,
- struct ggml_tensor * dst) {
+ struct ggml_tensor * dst,
+ int64_t off1, int64_t cne1) {
int64_t t0 = ggml_perf_time_us();
UNUSED(t0);
const int64_t i03 = i13/r3;
const int64_t i02 = i12/r2;
- const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
- const float * y = (float *) ((char *) src1->data + i12*nb12 + i13*nb13);
-
- float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
+ const void * x = (char *) src0->data + i02*nb02 + i03*nb03;
+ const float * y = (float *) ((char *) src1->data + off1*nb11 + i12*nb12 + i13*nb13);
+ float * d = (float *) ((char *) dst->data + off1*nb1 + i12*nb2 + i13*nb3);
if (type != GGML_TYPE_F32) {
float * const wdata = params->wdata;
}
cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans,
- ne11, ne01, ne10,
- 1.0f, y, ne10,
- x, ne00,
- 0.0f, d, ne01);
+ cne1, ne01, ne10,
+ 1.0f, y, ne10,
+ x, ne00,
+ 0.0f, d, ne01);
}
}
const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type);
assert(params->wsize >= ne11*ne12*ne13*row_size);
+ assert(src1->type == GGML_TYPE_F32);
for (int64_t i13 = 0; i13 < ne13; ++i13) {
for (int64_t i12 = 0; i12 < ne12; ++i12) {
const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type);
const int64_t nr0 = ne01; // src0 rows
- const int64_t nr1 = ne11*ne12*ne13; // src1 rows
+ const int64_t nr1 = cne1*ne12*ne13; // src1 rows
//printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
- const int64_t i13 = (ir1/(ne12*ne11));
- const int64_t i12 = (ir1 - i13*ne12*ne11)/ne11;
- const int64_t i11 = (ir1 - i13*ne12*ne11 - i12*ne11);
+ const int64_t i13 = (ir1/(ne12*cne1));
+ const int64_t i12 = (ir1 - i13*ne12*cne1)/cne1;
+ const int64_t i11 = (ir1 - i13*ne12*cne1 - i12*cne1) + off1;
// broadcast src0 into src1
const int64_t i03 = i13/r3;
static void ggml_compute_forward_mul_mat_id(
const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
- const struct ggml_tensor * ids = dst->src[0];
- const struct ggml_tensor * src1 = dst->src[1];
-
- const int id = ggml_get_op_params_i32(dst, 0);
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ // during GGML_TASK_INIT the entire src1 is converted to vec_dot_type
+ ggml_compute_forward_mul_mat(params, dst->src[2], src1, dst, 0, dst->ne[1]);
+ return;
+ }
- const int a_id = ((int32_t *)ids->data)[id];
+ const struct ggml_tensor * ids = src0;
+ const int id = ggml_get_op_params_i32(dst, 0);
+ const int n_as = ggml_get_op_params_i32(dst, 1);
- GGML_ASSERT(a_id >= 0 && a_id < ids->ne[0]);
+ for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
+ const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]);
- const struct ggml_tensor * src0 = dst->src[a_id + 2];
+ GGML_ASSERT(row_id >= 0 && row_id < n_as);
- ggml_compute_forward_mul_mat(params, src0, src1, dst);
+ const struct ggml_tensor * src0_row = dst->src[row_id + 2];
+ ggml_compute_forward_mul_mat(params, src0_row, src1, dst, i01, 1);
+ }
}
// ggml_compute_forward_out_prod
GGML_ASSERT(ggml_is_contiguous(dst) && ggml_is_contiguous(src0));
// view src0 and dst with these strides and data offset inbytes during set
- // nb0 is implicitely element_size because src0 and dst are contiguous
+ // nb0 is implicitly element_size because src0 and dst are contiguous
size_t nb1 = ((int32_t *) dst->op_params)[0];
size_t nb2 = ((int32_t *) dst->op_params)[1];
size_t nb3 = ((int32_t *) dst->op_params)[2];
return;
}
- const int nc = src0->ne[0];
- const int nr = ggml_nelements(src1);
+ GGML_TENSOR_BINARY_OP_LOCALS
+
+ const int64_t nc = ne00;
+ const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
+
const enum ggml_type type = src0->type;
ggml_to_float_t const dequantize_row_q = type_traits[type].to_float;
- assert( dst->ne[0] == nc);
- assert( dst->ne[1] == nr);
- assert(src0->nb[0] == ggml_type_size(type));
+ assert(ne0 == nc);
+ assert(ne02 == ne11);
+ assert(nb00 == ggml_type_size(type));
+ assert(ggml_nrows(dst) == nr);
- for (int i = 0; i < nr; ++i) {
- const int r = ((int32_t *) src1->data)[i];
+ // TODO: multi-thread
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
+ for (int64_t i11 = 0; i11 < ne11; ++i11) {
+ for (int64_t i10 = 0; i10 < ne10; ++i10) {
+ const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
- dequantize_row_q(
- (const void *) ((char *) src0->data + r*src0->nb[1]),
- (float *) ((char *) dst->data + i*dst->nb[1]), nc);
+ dequantize_row_q(
+ (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
+ (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
+ }
+ }
}
}
return;
}
- const int nc = src0->ne[0];
- const int nr = ggml_nelements(src1);
+ GGML_TENSOR_BINARY_OP_LOCALS
- assert( dst->ne[0] == nc);
- assert( dst->ne[1] == nr);
- assert(src0->nb[0] == sizeof(ggml_fp16_t));
+ const int64_t nc = ne00;
+ const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
- for (int i = 0; i < nr; ++i) {
- const int r = ((int32_t *) src1->data)[i];
+ assert(ne0 == nc);
+ assert(ne02 == ne11);
+ assert(nb00 == sizeof(ggml_fp16_t));
+ assert(ggml_nrows(dst) == nr);
- for (int j = 0; j < nc; ++j) {
- ggml_fp16_t v = ((ggml_fp16_t *) ((char *) src0->data + r*src0->nb[1]))[j];
- ((float *) ((char *) dst->data + i*dst->nb[1]))[j] = GGML_FP16_TO_FP32(v);
+ // TODO: multi-thread
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
+ for (int64_t i11 = 0; i11 < ne11; ++i11) {
+ for (int64_t i10 = 0; i10 < ne10; ++i10) {
+ const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+
+ ggml_fp16_to_fp32_row(
+ (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03),
+ (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc);
+ }
}
}
}
return;
}
- const int nc = src0->ne[0];
- const int nr = ggml_nelements(src1);
+ GGML_TENSOR_BINARY_OP_LOCALS
- assert( dst->ne[0] == nc);
- assert( dst->ne[1] == nr);
- assert(src0->nb[0] == sizeof(float));
+ const int64_t nc = ne00;
+ const int64_t nr = ggml_nelements(src1); GGML_UNUSED(nr);
- for (int i = 0; i < nr; ++i) {
- const int r = ((int32_t *) src1->data)[i];
+ assert(ne0 == nc);
+ assert(ne02 == ne11);
+ assert(nb00 == sizeof(float));
+ assert(ggml_nrows(dst) == nr);
- ggml_vec_cpy_f32(nc,
- (float *) ((char *) dst->data + i*dst->nb[1]),
- (float *) ((char *) src0->data + r*src0->nb[1]));
+ // TODO: multi-thread
+ for (int64_t i12 = 0; i12 < ne12; ++i12) {
+ for (int64_t i11 = 0; i11 < ne11; ++i11) {
+ for (int64_t i10 = 0; i10 < ne10; ++i10) {
+ const int64_t i01 = *(int32_t *) ((char *) src1->data + i10*nb10 + i11*nb11 + i12*nb12);
+
+ ggml_vec_cpy_f32(nc,
+ (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3),
+ (float *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03));
+ }
+ }
}
}
GGML_ASSERT(src0->nb[0] == sizeof(float));
const int ith = params->ith;
+ const int nth = params->nth;
GGML_TENSOR_UNARY_OP_LOCALS
// TODO: optimize
- for (int i03 = 0; i03 < ne03; i03++) {
- for (int i02 = ith; i02 < ne02; i02++) {
- for (int m = 0; m < dst->ne[1]; m++) {
- int i01 = m / scale_factor;
- for (int n = 0; n < dst->ne[0]; n++) {
- int i00 = n / scale_factor;
-
- const float * x = (float *)((char *) src0->data + i00 * nb00 +i01 * nb01 + i02 * nb02 + i03 * nb03);
+ for (int64_t i3 = 0; i3 < ne3; i3++) {
+ const int64_t i03 = i3;
+ for (int64_t i2 = ith; i2 < ne2; i2 += nth) {
+ const int64_t i02 = i2;
+ for (int64_t i1 = 0; i1 < ne1; i1++) {
+ const int64_t i01 = i1 / scale_factor;
+ for (int64_t i0 = 0; i0 < ne0; i0++) {
+ const int64_t i00 = i0 / scale_factor;
- float * y = (float *)((char *) dst->data + n * dst->nb[0] + m * dst->nb[1] + i02 * dst->nb[2] + i03 * dst->nb[3]);
+ const float * x = (float *)((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
+ float * y = (float *)((char *) dst->data + i0*nb0 + i1*nb1 + i2*nb2 + i3*nb3);
*y = *x;
}
}
}
+// ggml_compute_forward_pad
+
+static void ggml_compute_forward_pad_f32(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+
+ if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
+ return;
+ }
+
+ GGML_ASSERT(src0->nb[0] == sizeof(float));
+ GGML_ASSERT( dst->nb[0] == sizeof(float));
+
+ const int ith = params->ith;
+ const int nth = params->nth;
+
+ GGML_TENSOR_UNARY_OP_LOCALS
+
+ float * dst_ptr = (float *) dst->data;
+
+ // TODO: optimize
+
+ for (int64_t i2 = 0; i2 < ne2; ++i2) {
+ for (int64_t i1 = ith; i1 < ne1; i1 += nth) {
+ for (int64_t i0 = 0; i0 < ne0; ++i0) {
+ for (int64_t i3 = 0; i3 < ne3; ++i3) {
+ const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
+
+ const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
+
+ if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
+ dst_ptr[dst_idx] = *src_ptr;
+ } else {
+ dst_ptr[dst_idx] = 0;
+ }
+ }
+ }
+ }
+ }
+}
+
+static void ggml_compute_forward_pad(
+ const struct ggml_compute_params * params,
+ const struct ggml_tensor * src0,
+ struct ggml_tensor * dst) {
+ switch (src0->type) {
+ case GGML_TYPE_F32:
+ {
+ ggml_compute_forward_pad_f32(params, src0, dst);
+ } break;
+ default:
+ {
+ GGML_ASSERT(false);
+ } break;
+ }
+}
+
// ggml_compute_forward_argsort
static void ggml_compute_forward_argsort_f32(
{
ggml_compute_forward_silu(params, src0, dst);
} break;
- case GGML_UNARY_OP_LEAKY:
- {
- ggml_compute_forward_leaky(params, src0, dst);
- } break;
default:
{
GGML_ASSERT(false);
} break;
case GGML_OP_MUL_MAT:
{
- ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor);
+ ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor, 0, tensor->ne[1]);
} break;
case GGML_OP_MUL_MAT_ID:
{
- ggml_compute_forward_mul_mat_id(params, tensor);
+ ggml_compute_forward_mul_mat_id(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_OP_OUT_PROD:
{
{
ggml_compute_forward_upscale(params, tensor->src[0], tensor);
} break;
+ case GGML_OP_PAD:
+ {
+ ggml_compute_forward_pad(params, tensor->src[0], tensor);
+ } break;
case GGML_OP_ARGSORT:
{
ggml_compute_forward_argsort(params, tensor->src[0], tensor);
} break;
+ case GGML_OP_LEAKY_RELU:
+ {
+ ggml_compute_forward_leaky_relu(params, tensor->src[0], tensor);
+ } break;
case GGML_OP_FLASH_ATTN:
{
const int32_t t = ggml_get_op_params_i32(tensor, 0);
// insert new tensors recomputing src, reusing already made replacements,
// remember replacements: remember new tensors with mapping from corresponding gf nodes
// recurse for input tensors,
- // unless (i.e. terminating when) input tensors are replacments (like checkpoints)
+ // unless (i.e. terminating when) input tensors are replacements (like checkpoints)
node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]);
}
// insert rewritten backward node with replacements made into resulting backward graph gb
{
GGML_ASSERT(false); // TODO: not implemented
} break;
+ case GGML_OP_PAD:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
case GGML_OP_ARGSORT:
{
GGML_ASSERT(false); // TODO: not implemented
} break;
+ case GGML_OP_LEAKY_RELU:
+ {
+ GGML_ASSERT(false); // TODO: not implemented
+ } break;
case GGML_OP_FLASH_ATTN:
{
struct ggml_tensor * flash_grad = NULL;
case GGML_OP_ARGMAX:
case GGML_OP_REPEAT:
case GGML_OP_REPEAT_BACK:
+ case GGML_OP_LEAKY_RELU:
{
n_tasks = 1;
} break;
case GGML_UNARY_OP_TANH:
case GGML_UNARY_OP_ELU:
case GGML_UNARY_OP_RELU:
- case GGML_UNARY_OP_LEAKY:
{
n_tasks = 1;
} break;
{
n_tasks = n_threads;
} break;
+ case GGML_OP_PAD:
+ {
+ n_tasks = n_threads;
+ } break;
case GGML_OP_ARGSORT:
{
n_tasks = n_threads;
#define GGML_QNT_VERSION_FACTOR 1000 // do not change this
#define GGML_MAX_DIMS 4
-#define GGML_MAX_PARAMS 1024
+#define GGML_MAX_PARAMS 2048
#define GGML_MAX_CONTEXTS 64
-#define GGML_MAX_SRC 6
+#define GGML_MAX_SRC 10
#define GGML_MAX_NAME 64
#define GGML_MAX_OP_PARAMS 64
#define GGML_DEFAULT_N_THREADS 4
GGML_OP_POOL_1D,
GGML_OP_POOL_2D,
GGML_OP_UPSCALE, // nearest interpolate
+ GGML_OP_PAD,
GGML_OP_ARGSORT,
+ GGML_OP_LEAKY_RELU,
GGML_OP_FLASH_ATTN,
GGML_OP_FLASH_FF,
GGML_UNARY_OP_GELU,
GGML_UNARY_OP_GELU_QUICK,
GGML_UNARY_OP_SILU,
- GGML_UNARY_OP_LEAKY,
GGML_UNARY_OP_COUNT,
};
struct ggml_tensor * a,
struct ggml_tensor * b);
+ // dst = a
+ // view(dst, nb1, nb2, nb3, offset) += b
+ // return dst
GGML_API struct ggml_tensor * ggml_acc(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_context * ctx,
struct ggml_tensor * a);
- GGML_API struct ggml_tensor * ggml_leaky(
+ GGML_API struct ggml_tensor * ggml_leaky_relu(
struct ggml_context * ctx,
- struct ggml_tensor * a);
+ struct ggml_tensor * a, float negative_slope, bool inplace);
GGML_API struct ggml_tensor * ggml_relu_inplace(
struct ggml_context * ctx,
struct ggml_tensor * a);
- // TODO: double-check this computation is correct
GGML_API struct ggml_tensor * ggml_gelu(
struct ggml_context * ctx,
struct ggml_tensor * a);
// ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b)
GGML_API struct ggml_tensor * ggml_mul_mat_id(
struct ggml_context * ctx,
- struct ggml_tensor * as[],
+ struct ggml_tensor * const as[],
+ int n_as,
struct ggml_tensor * ids,
int id,
struct ggml_tensor * b);
struct ggml_context * ctx,
struct ggml_tensor * a);
+ // supports 3D: a->ne[2] == b->ne[1]
GGML_API struct ggml_tensor * ggml_get_rows(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * a,
int scale_factor);
+ // pad each dimension with zeros: [x, ..., x] -> [x, ..., x, 0, ..., 0]
+ GGML_API struct ggml_tensor * ggml_pad(
+ struct ggml_context * ctx,
+ struct ggml_tensor * a,
+ int p0,
+ int p1,
+ int p2,
+ int p3);
+
// sort rows
enum ggml_sort_order {
GGML_SORT_ASC,