// TODO: this should not be a relative path, but can't figure out how to set Metal include paths in Package.swift
#include "../ggml-common.h"
#endif
+#include "ggml-metal-impl.h"
#include <metal_stdlib>
// pros: works for non-contiguous tensors, supports broadcast across all dims
// cons: not very efficient
kernel void kernel_add(
+ constant ggml_metal_kargs_bin & args,
device const char * src0,
device const char * src1,
device char * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- constant int64_t & offs,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig.z;
- const int64_t i02 = tgpig.y;
- const int64_t i01 = tgpig.x;
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
+ const int i03 = tgpig.z;
+ const int i02 = tgpig.y;
+ const int i01 = tgpig.x;
- const int64_t i13 = i03 % ne13;
- const int64_t i12 = i02 % ne12;
- const int64_t i11 = i01 % ne11;
+ const int i13 = i03%args.ne13;
+ const int i12 = i02%args.ne12;
+ const int i11 = i01%args.ne11;
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
- device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
+ device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
+ device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
+ device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
- const int i10 = i0 % ne10;
- *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10));
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+ const int i10 = i0%args.ne10;
+ *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) + *((device float *)(src1_ptr + i10*args.nb10));
}
}
kernel void kernel_sub(
+ constant ggml_metal_kargs_bin & args,
device const char * src0,
device const char * src1,
device char * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- constant int64_t & offs,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig.z;
- const int64_t i02 = tgpig.y;
- const int64_t i01 = tgpig.x;
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
+ const int i03 = tgpig.z;
+ const int i02 = tgpig.y;
+ const int i01 = tgpig.x;
- const int64_t i13 = i03 % ne13;
- const int64_t i12 = i02 % ne12;
- const int64_t i11 = i01 % ne11;
+ const int i13 = i03%args.ne13;
+ const int i12 = i02%args.ne12;
+ const int i11 = i01%args.ne11;
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
- device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
+ device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs;
+ device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
+ device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs;
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
- const int i10 = i0 % ne10;
- *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) - *((device float *)(src1_ptr + i10*nb10));
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+ const int i10 = i0%args.ne10;
+ *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10));
}
}
kernel void kernel_mul(
+ constant ggml_metal_kargs_bin & args,
device const char * src0,
device const char * src1,
device char * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig.z;
- const int64_t i02 = tgpig.y;
- const int64_t i01 = tgpig.x;
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
+ const int i03 = tgpig.z;
+ const int i02 = tgpig.y;
+ const int i01 = tgpig.x;
- const int64_t i13 = i03 % ne13;
- const int64_t i12 = i02 % ne12;
- const int64_t i11 = i01 % ne11;
+ const int i13 = i03%args.ne13;
+ const int i12 = i02%args.ne12;
+ const int i11 = i01%args.ne11;
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
- device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
+ device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
+ device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
+ device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1;
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
- const int i10 = i0 % ne10;
- *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10));
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+ const int i10 = i0%args.ne10;
+ *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10));
}
}
kernel void kernel_div(
+ constant ggml_metal_kargs_bin & args,
device const char * src0,
device const char * src1,
device char * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig.z;
- const int64_t i02 = tgpig.y;
- const int64_t i01 = tgpig.x;
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
+ const int i03 = tgpig.z;
+ const int i02 = tgpig.y;
+ const int i01 = tgpig.x;
- const int64_t i13 = i03 % ne13;
- const int64_t i12 = i02 % ne12;
- const int64_t i11 = i01 % ne11;
+ const int i13 = i03%args.ne13;
+ const int i12 = i02%args.ne12;
+ const int i11 = i01%args.ne11;
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
- device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
+ device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
+ device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11;
+ device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1;
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
- const int i10 = i0 % ne10;
- *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10));
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+ const int i10 = i0%args.ne10;
+ *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10));
}
}
template<typename T>
kernel void kernel_repeat(
+ constant ggml_metal_kargs_repeat & args,
device const char * src0,
device char * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i3 = tgpig.z;
- const int64_t i2 = tgpig.y;
- const int64_t i1 = tgpig.x;
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
+ const int i3 = tgpig.z;
+ const int i2 = tgpig.y;
+ const int i1 = tgpig.x;
- const int64_t i03 = i3 % ne03;
- const int64_t i02 = i2 % ne02;
- const int64_t i01 = i1 % ne01;
+ const int i03 = i3%args.ne03;
+ const int i02 = i2%args.ne02;
+ const int i01 = i1%args.ne01;
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
- device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ;
+ device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01;
+ device char * dst_ptr = dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1;
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
- const int i00 = i0 % ne00;
- *((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00));
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+ const int i00 = i0%args.ne00;
+ *((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00));
}
}
// assumption: src1 is a row
// broadcast src1 into src0
kernel void kernel_add_row(
+ constant ggml_metal_kargs_bin & args,
device const float4 * src0,
device const float4 * src1,
device float4 * dst,
- constant uint64_t & nb [[buffer(28)]],
uint tpig[[thread_position_in_grid]]) {
+ const uint nb = args.ne00/4;
dst[tpig] = src0[tpig] + src1[tpig % nb];
}
kernel void kernel_sub_row(
+ constant ggml_metal_kargs_bin & args,
device const float4 * src0,
device const float4 * src1,
device float4 * dst,
- constant uint64_t & nb [[buffer(28)]],
uint tpig[[thread_position_in_grid]]) {
+ const uint nb = args.ne00/4;
dst[tpig] = src0[tpig] - src1[tpig % nb];
}
kernel void kernel_mul_row(
+ constant ggml_metal_kargs_bin & args,
device const float4 * src0,
device const float4 * src1,
device float4 * dst,
- constant uint64_t & nb [[buffer(28)]],
uint tpig[[thread_position_in_grid]]) {
+ const uint nb = args.ne00/4;
dst[tpig] = src0[tpig] * src1[tpig % nb];
}
kernel void kernel_div_row(
+ constant ggml_metal_kargs_bin & args,
device const float4 * src0,
device const float4 * src1,
device float4 * dst,
- constant uint64_t & nb [[buffer(28)]],
uint tpig[[thread_position_in_grid]]) {
+ const uint nb = args.ne00/4;
dst[tpig] = src0[tpig] / src1[tpig % nb];
}
}
kernel void kernel_norm(
- device const void * src0,
- device float * dst,
- constant int64_t & ne00,
- constant uint64_t & nb01,
- constant float & eps,
- threadgroup float * sum [[threadgroup(0)]],
- uint tgpig[[threadgroup_position_in_grid]],
- uint tpitg[[thread_position_in_threadgroup]],
- uint ntg[[threads_per_threadgroup]]) {
- device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
- // MEAN
- // parallel sum
- sum[tpitg] = 0.0f;
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
- sum[tpitg] += x[i00];
+ constant ggml_metal_kargs_norm & args,
+ device const char * src0,
+ device char * dst,
+ threadgroup float * shmem_f32 [[threadgroup(0)]],
+ uint tgpig[[threadgroup_position_in_grid]],
+ ushort tpitg[[thread_position_in_threadgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort ntg[[threads_per_threadgroup]]) {
+ if (sgitg == 0) {
+ shmem_f32[tiisg] = 0.0f;
+ }
+
+ device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
+
+ float4 sumf4(0.0f);
+
+ float sumf = 0.0f;
+
+ for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+ sumf4 += x[i00];
}
- // reduce
+ sumf = sumf4[0] + sumf4[1] + sumf4[2] + sumf4[3];
+ sumf = simd_sum(sumf);
+
threadgroup_barrier(mem_flags::mem_threadgroup);
- for (uint i = ntg/2; i > 0; i /= 2) {
- if (tpitg < i) {
- sum[tpitg] += sum[tpitg + i];
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ shmem_f32[sgitg] = sumf;
}
- const float mean = sum[0] / ne00;
- // recenter and VARIANCE
threadgroup_barrier(mem_flags::mem_threadgroup);
- device float * y = dst + tgpig*ne00;
- sum[tpitg] = 0.0f;
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+
+ sumf = shmem_f32[tiisg];
+ sumf = simd_sum(sumf);
+
+ const float mean = sumf/args.ne00;
+
+ device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
+
+ sumf = 0.0f;
+ for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
y[i00] = x[i00] - mean;
- sum[tpitg] += y[i00] * y[i00];
+ sumf += dot(y[i00], y[i00]);
}
+ sumf = simd_sum(sumf);
- // reduce
threadgroup_barrier(mem_flags::mem_threadgroup);
- for (uint i = ntg/2; i > 0; i /= 2) {
- if (tpitg < i) {
- sum[tpitg] += sum[tpitg + i];
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ shmem_f32[sgitg] = sumf;
}
- const float variance = sum[0] / ne00;
- const float scale = 1.0f/sqrt(variance + eps);
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ sumf = shmem_f32[tiisg];
+ sumf = simd_sum(sumf);
+
+ const float variance = sumf/args.ne00;
+
+ const float scale = 1.0f/sqrt(variance + args.eps);
+ for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
y[i00] = y[i00] * scale;
}
}
kernel void kernel_rms_norm(
- device const void * src0,
- device float * dst,
- constant int64_t & ne00,
- constant uint64_t & nb01,
- constant float & eps,
- threadgroup float * buf [[threadgroup(0)]],
- uint tgpig[[threadgroup_position_in_grid]],
- uint tpitg[[thread_position_in_threadgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint ntg[[threads_per_threadgroup]]) {
- device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
+ constant ggml_metal_kargs_rms_norm & args,
+ device const char * src0,
+ device char * dst,
+ threadgroup float * shmem_f32 [[threadgroup(0)]],
+ uint tgpig[[threadgroup_position_in_grid]],
+ ushort tpitg[[thread_position_in_threadgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort ntg[[threads_per_threadgroup]]) {
+ if (sgitg == 0) {
+ shmem_f32[tiisg] = 0.0f;
+ }
+
+ device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01);
- float4 sumf = 0;
- float all_sum = 0;
+ float sumf = 0.0f;
// parallel sum
- for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
- sumf += x[i00] * x[i00];
+ for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
+ sumf += dot(x[i00], x[i00]);
}
- all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
- all_sum = simd_sum(all_sum);
- if (ntg > N_SIMDWIDTH) {
- if (sgitg == 0) {
- buf[tiisg] = 0.0f;
- }
+ sumf = simd_sum(sumf);
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ threadgroup_barrier(mem_flags::mem_threadgroup);
- if (tiisg == 0) {
- buf[sgitg] = all_sum;
- }
+ if (tiisg == 0) {
+ shmem_f32[sgitg] = sumf;
+ }
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ threadgroup_barrier(mem_flags::mem_threadgroup);
- all_sum = buf[tiisg];
- all_sum = simd_sum(all_sum);
- }
+ sumf = shmem_f32[tiisg];
+ sumf = simd_sum(sumf);
- const float mean = all_sum/ne00;
- const float scale = 1.0f/sqrt(mean + eps);
+ const float mean = sumf/args.ne00;
+ const float scale = 1.0f/sqrt(mean + args.eps);
- device float4 * y = (device float4 *) (dst + tgpig*ne00);
- for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+ device float4 * y = (device float4 *) dst + tgpig*args.ne00_4;
+ for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) {
y[i00] = x[i00] * scale;
}
}
// quantizations where the block size is 32. It also does not
// guard against the number of rows not being divisible by
// N_DST, so this is another explicit assumption of the implementation.
-template<typename block_q_type, int nr, int nsg, int nw>
+template<typename block_q_type, int nr, int nsg, int nw, typename args_t>
void mul_vec_q_n_f32_impl(
- device const void * src0,
- device const float * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne12,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- threadgroup int8_t * shared_values,
- uint3 tgpig,
- uint tiisg,
- uint sgitg) {
- const int nb = ne00/QK4_0;
+ args_t args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg) {
+ const int nb = args.ne00/QK4_0;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int first_row = (r0 * nsg + sgitg) * nr;
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
- //const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
+ //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- //device const block_q_type * x = (device const block_q_type *) ((device char *) src0 + offset0);
- device const float * y = (device const float *) ((device char *) src1 + offset1);
+ //device const block_q_type * x = (device const block_q_type *) (src0 + offset0);
+ device const float * y = (device const float *) (src1 + offset1);
// pointers to src0 rows
device const block_q_type * ax[nr];
for (int row = 0; row < nr; ++row) {
- const uint offset0 = (first_row + row)*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+ const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
ax[row] = (device const block_q_type *) ((device char *) src0 + offset0);
}
float yl[16]; // src1 vector cache
float sumf[nr] = {0.f};
- const int ix = (tiisg/2);
- const int il = (tiisg%2)*8;
+ const short ix = (tiisg/2);
+ const short il = (tiisg%2)*8;
- device const float * yb = y + ix * QK4_0 + il;
+ device const float * yb = y + ix*QK4_0 + il;
// each thread in a SIMD group deals with half a block.
for (int ib = ix; ib < nb; ib += nw/2) {
yb += QK4_0 * 16;
}
+ device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
+
for (int row = 0; row < nr; ++row) {
const float tot = simd_sum(sumf[row]);
- if (tiisg == 0 && first_row + row < ne01) {
- dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot;
+
+ if (tiisg == 0 && first_row + row < args.ne01) {
+ dst_f32[first_row + row] = tot;
}
}
}
kernel void kernel_mul_mv_q4_0_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
kernel void kernel_mul_mv_q4_1_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
kernel void kernel_mul_mv_q5_0_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
kernel void kernel_mul_mv_q5_1_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
-
#define NB_Q8_0 8
+template<typename args_t>
void kernel_mul_mv_q8_0_f32_impl(
- device const void * src0,
- device const float * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne12,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- threadgroup int8_t * shared_values,
- uint3 tgpig,
- uint tiisg,
- uint sgitg) {
+ args_t args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg) {
const int nr = N_DST;
const int nsg = N_SIMDGROUP;
const int nw = N_SIMDWIDTH;
- const int nb = ne00/QK8_0;
+ const int nb = args.ne00/QK8_0;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
- const int first_row = (r0 * nsg + sgitg) * nr;
+ const int first_row = (r0*nsg + sgitg)*nr;
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
- //const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
+ //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- //device const block_q8_0 * x = (device const block_q8_0 *) ((device char *) src0 + offset0);
- device const float * y = (device const float *) ((device char *) src1 + offset1);
+ //device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0);
+ device const float * y = (device const float *) (src1 + offset1);
// pointers to src0 rows
device const block_q8_0 * ax[nr];
for (int row = 0; row < nr; ++row) {
- const uint offset0 = (first_row + row)*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+ const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0);
}
float yl[NB_Q8_0];
- float sumf[nr]={0.f};
+ float sumf[nr] = { 0.f };
- const int ix = tiisg/4;
- const int il = tiisg%4;
+ const short ix = tiisg/4;
+ const short il = tiisg%4;
- device const float * yb = y + ix * QK8_0 + NB_Q8_0*il;
+ device const float * yb = y + ix*QK8_0 + il*NB_Q8_0;
// each thread in a SIMD group deals with NB_Q8_0 quants at a time
for (int ib = ix; ib < nb; ib += nw/4) {
- for (int i = 0; i < NB_Q8_0; ++i) {
+ for (short i = 0; i < NB_Q8_0; ++i) {
yl[i] = yb[i];
}
for (int row = 0; row < nr; row++) {
- device const int8_t * qs = ax[row][ib].qs + NB_Q8_0*il;
+ device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0;
float sumq = 0.f;
- for (int iq = 0; iq < NB_Q8_0; ++iq) {
+ for (short iq = 0; iq < NB_Q8_0; ++iq) {
sumq += qs[iq] * yl[iq];
}
sumf[row] += sumq*ax[row][ib].d;
}
- yb += NB_Q8_0 * nw;
+ yb += nw*NB_Q8_0;
}
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
for (int row = 0; row < nr; ++row) {
const float tot = simd_sum(sumf[row]);
- if (tiisg == 0 && first_row + row < ne01) {
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
+
+ if (tiisg == 0 && first_row + row < args.ne01) {
+ dst_f32[first_row + row] = tot;
}
}
}
[[host_name("kernel_mul_mv_q8_0_f32")]]
kernel void kernel_mul_mv_q8_0_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg);
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
#define N_MV_T_T 4
-template<typename T0, typename T04, typename T1, typename T14>
+template<typename T0, typename T04, typename T1, typename T14, typename args_t>
void kernel_mul_mv_impl(
- device const char * src0,
- device const char * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb00,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne11,
- int64_t ne12,
- uint64_t nb10,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- uint3 tgpig,
- uint tiisg) {
- const int64_t r0 = tgpig.x;
- const int64_t rb = tgpig.y*N_MV_T_T;
- const int64_t im = tgpig.z;
-
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
-
- const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+ args_t args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig,
+ ushort tiisg) {
+ const int r0 = tgpig.x;
+ const int rb = tgpig.y*N_MV_T_T;
+ const int im = tgpig.z;
+
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
+
+ const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
device const T0 * x = (device const T0 *) (src0 + offset0);
- if (ne00 < 128) {
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
+
+ if (args.ne00 < 128) {
for (int row = 0; row < N_MV_T_T; ++row) {
int r1 = rb + row;
- if (r1 >= ne11) {
+ if (r1 >= args.ne11) {
break;
}
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const T1 * y = (device const T1 *) (src1 + offset1);
float sumf = 0;
- for (int i = tiisg; i < ne00; i += 32) {
+ for (int i = tiisg; i < args.ne00; i += 32) {
sumf += (T0) x[i] * (T1) y[i];
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
}
}
} else {
device const T04 * x4 = (device const T04 *) x;
for (int row = 0; row < N_MV_T_T; ++row) {
int r1 = rb + row;
- if (r1 >= ne11) {
+ if (r1 >= args.ne11) {
break;
}
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const T1 * y = (device const T1 *) (src1 + offset1);
device const T14 * y4 = (device const T14 *) y;
float sumf = 0;
- for (int i = tiisg; i < ne00/4; i += 32) {
- for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
+ for (int i = tiisg; i < args.ne00/4; i += 32) {
+ sumf += dot((float4) x4[i], (float4) y4[i]);
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
- for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]);
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]);
+ dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
}
}
}
template<typename T0, typename T04, typename T1, typename T14>
kernel void kernel_mul_mv(
- device const char * src0,
- device const char * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]]) {
- kernel_mul_mv_impl<T0, T04, T1, T14>(
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]]) {
+ kernel_mul_mv_impl<T0, T04, T1, T14, constant ggml_metal_kargs_mul_mv &>(
+ args,
src0,
src1,
dst,
- ne00,
- ne01,
- ne02,
- nb00,
- nb01,
- nb02,
- nb03,
- ne10,
- ne11,
- ne12,
- nb10,
- nb11,
- nb12,
- nb13,
- ne0,
- ne1,
- r2,
- r3,
tgpig,
tiisg);
}
template<typename T, typename T4>
kernel void kernel_mul_mv_1row(
- device const char * src0,
- device const char * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]]) {
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]]) {
- const int64_t r0 = tgpig.x;
- const int64_t r1 = tgpig.y;
- const int64_t im = tgpig.z;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
- const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
+ const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const T * x = (device const T *) (src0 + offset0);
device const float * y = (device const float *) (src1 + offset1);
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
float sumf = 0;
- if (ne00 < 128) {
- for (int i = tiisg; i < ne00; i += 32) {
+ if (args.ne00 < 128) {
+ for (int i = tiisg; i < args.ne00; i += 32) {
sumf += (float) x[i] * (float) y[i];
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ dst_f32[r0] = all_sum;
}
} else {
device const T4 * x4 = (device const T4 *) x;
device const float4 * y4 = (device const float4 *) y;
- for (int i = tiisg; i < ne00/4; i += 32) {
- for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
+ for (int i = tiisg; i < args.ne00/4; i += 32) {
+ sumf += dot((float4) x4[i], y4[i]);
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
- for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]);
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]);
+ dst_f32[r0] = all_sum;
}
}
}
// Assumes row size (ne00) is a multiple of 4
template<typename T, typename T4>
kernel void kernel_mul_mv_l4(
- device const char * src0,
- device const char * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]]) {
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]]) {
- const int nrows = ne11;
- const int64_t r0 = tgpig.x;
- const int64_t im = tgpig.z;
+ const int nrows = args.ne11;
+ const int r0 = tgpig.x;
+ const int im = tgpig.z;
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
- const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
+ const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
device const T4 * x4 = (device const T4 *) (src0 + offset0);
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1;
+
for (int r1 = 0; r1 < nrows; ++r1) {
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const float4 * y4 = (device const float4 *) (src1 + offset1);
float sumf = 0;
- for (int i = tiisg; i < ne00/4; i += 32) {
- for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]);
+ for (int i = tiisg; i < args.ne00/4; i += 32) {
+ sumf += dot((float4) x4[i], y4[i]);
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
- dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum;
}
}
}
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
static void rope_yarn(
- float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale,
+ float theta_extrap, float freq_scale, float corr_dims[2], int i0, float ext_factor, float mscale,
thread float * cos_theta, thread float * sin_theta) {
// Get n-d rotational scaling corrected for extrapolation
float theta_interp = freq_scale * theta_extrap;
template<typename T>
kernel void kernel_rope_norm(
- device const void * src0,
- device const int32_t * src1,
- device const float * src2,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- constant int & n_past,
- constant int & n_dims,
- constant int & n_ctx_orig,
- constant float & freq_base,
- constant float & freq_scale,
- constant float & ext_factor,
- constant float & attn_factor,
- constant float & beta_fast,
- constant float & beta_slow,
- uint tiitg[[thread_index_in_threadgroup]],
- uint3 tptg[[threads_per_threadgroup]],
- uint3 tgpig[[threadgroup_position_in_grid]]) {
- const int64_t i3 = tgpig[2];
- const int64_t i2 = tgpig[1];
- const int64_t i1 = tgpig[0];
+ constant ggml_metal_kargs_rope & args,
+ device const char * src0,
+ device const char * src1,
+ device const char * src2,
+ device char * dst,
+ ushort tiitg[[thread_index_in_threadgroup]],
+ ushort3 tptg [[threads_per_threadgroup]],
+ uint3 tgpig[[threadgroup_position_in_grid]]) {
+ const int i3 = tgpig[2];
+ const int i2 = tgpig[1];
+ const int i1 = tgpig[0];
float corr_dims[2];
- rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
+ rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
- device const int32_t * pos = src1;
+ device const int32_t * pos = (device const int32_t *) src1;
const float theta_base = (float) pos[i2];
- const float inv_ndims = -1.f/n_dims;
+ const float inv_ndims = -1.f/args.n_dims;
float cos_theta;
float sin_theta;
- for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
- if (i0 < n_dims) {
- const int64_t ic = i0/2;
+ for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
+ if (i0 < args.n_dims) {
+ const int ic = i0/2;
- const float theta = theta_base * pow(freq_base, inv_ndims*i0);
+ const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
- const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
+ const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
- rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+ rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+ device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
+ device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
const float x0 = src[0];
const float x1 = src[1];
dst_data[0] = x0*cos_theta - x1*sin_theta;
dst_data[1] = x0*sin_theta + x1*cos_theta;
} else {
- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+ device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
+ device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
template<typename T>
kernel void kernel_rope_neox(
- device const void * src0,
- device const int32_t * src1,
- device const float * src2,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- constant int & n_past,
- constant int & n_dims,
- constant int & n_ctx_orig,
- constant float & freq_base,
- constant float & freq_scale,
- constant float & ext_factor,
- constant float & attn_factor,
- constant float & beta_fast,
- constant float & beta_slow,
- uint tiitg[[thread_index_in_threadgroup]],
- uint3 tptg[[threads_per_threadgroup]],
- uint3 tgpig[[threadgroup_position_in_grid]]) {
- const int64_t i3 = tgpig[2];
- const int64_t i2 = tgpig[1];
- const int64_t i1 = tgpig[0];
+ constant ggml_metal_kargs_rope & args,
+ device const char * src0,
+ device const char * src1,
+ device const char * src2,
+ device char * dst,
+ ushort tiitg[[thread_index_in_threadgroup]],
+ ushort3 tptg [[threads_per_threadgroup]],
+ uint3 tgpig[[threadgroup_position_in_grid]]) {
+ const int i3 = tgpig[2];
+ const int i2 = tgpig[1];
+ const int i1 = tgpig[0];
float corr_dims[2];
- rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims);
+ rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
- device const int32_t * pos = src1;
+ device const int32_t * pos = (device const int32_t *) src1;
const float theta_base = (float) pos[i2];
- const float inv_ndims = -1.f/n_dims;
+ const float inv_ndims = -1.f/args.n_dims;
float cos_theta;
float sin_theta;
- for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) {
- if (i0 < n_dims) {
- const int64_t ic = i0/2;
+ for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
+ if (i0 < args.n_dims) {
+ const int ic = i0/2;
- const float theta = theta_base * pow(freq_base, inv_ndims*i0);
+ const float theta = theta_base * pow(args.freq_base, inv_ndims*i0);
- const float freq_factor = src2 != src0 ? src2[ic] : 1.0f;
+ const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
- rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta);
+ rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00);
- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0);
+ device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
+ device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
const float x0 = src[0];
- const float x1 = src[n_dims/2];
+ const float x1 = src[args.n_dims/2];
- dst_data[0] = x0*cos_theta - x1*sin_theta;
- dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta;
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
+ dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
} else {
- device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
- device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+ device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
+ device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
dst_data[0] = src[0];
dst_data[1] = src[1];
short KV = 8, // key/value processed per each simdgroup
short C = 32> // cache items per threadgroup
kernel void kernel_flash_attn_ext(
- device const char * q,
- device const char * k,
- device const char * v,
- device const char * mask,
- device float * dst,
- constant int32_t & ne01,
- constant int32_t & ne02,
- constant int32_t & ne03,
- constant uint32_t & nb01,
- constant uint32_t & nb02,
- constant uint32_t & nb03,
- constant int32_t & ne11,
- constant int32_t & ne_12_2, // assume K and V are same shape
- constant int32_t & ne_12_3,
- constant uint32_t & nb_12_1,
- constant uint32_t & nb_12_2,
- constant uint32_t & nb_12_3,
- constant uint32_t & nb31,
- constant int32_t & ne1,
- constant int32_t & ne2,
- constant float & scale,
- constant float & max_bias,
- constant float & m0,
- constant float & m1,
- constant uint16_t & n_head_log2,
- constant float & logit_softcap,
- threadgroup half * shared [[threadgroup(0)]],
- ushort3 tgpig[[threadgroup_position_in_grid]],
- ushort3 ntg[[threads_per_threadgroup]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ constant ggml_metal_kargs_flash_attn_ext & args,
+ device const char * q,
+ device const char * k,
+ device const char * v,
+ device const char * mask,
+ device char * dst,
+ threadgroup half * shmem_f16 [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 ntg[[threads_per_threadgroup]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
const short nsg = ntg.y; // number of simdgroups
const int iq3 = tgpig[2];
const short TS = nsg*SH; // shared memory size per query in (s_t == float)
const short T = D + 2*TS; // shared memory size per query in (half)
- threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
- threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t
- threadgroup o_t * so = (threadgroup o_t *) (shared + 0*D); // reuse query data for accumulation
- threadgroup o4_t * so4 = (threadgroup o4_t *) (shared + 0*D); // same as above but in o4_t
- threadgroup s_t * ss = (threadgroup s_t *) (shared + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix
+ threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*D); // holds the query data
+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*D); // same as above but in q4_t
+ threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*D); // reuse query data for accumulation
+ threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*D); // same as above but in o4_t
+ threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix
- threadgroup k_t * sk = (threadgroup k_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
- threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
+ threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory
+ threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t
- threadgroup v_t * sv = (threadgroup v_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory
- threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t
+ threadgroup v_t * sv = (threadgroup v_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory
+ threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
o8x8_t lo[D8];
// load heads from Q to shared memory
for (short j = sgitg; j < Q; j += nsg) {
- device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
+ device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
for (short i = tiisg; i < D4; i += NW) {
- if (iq1 + j < ne01) {
+ if (iq1 + j < args.ne01) {
sq4[j*D4 + i] = (q4_t) q4[i];
} else {
sq4[j*D4 + i] = (q4_t) 0.0f;
const short ty = tiisg/4;
// broadcast kv
- //const short rk2 = ne02/ne12;
- //const short rk3 = ne03/ne13;
+ //const short rk2 = args.ne02/args.ne12;
+ //const short rk3 = args.ne03/args.ne13;
- const short ikv2 = iq2/(ne02/ne_12_2);
- const short ikv3 = iq3/(ne03/ne_12_3);
+ const short ikv2 = iq2/(args.ne02/args.ne_12_2);
+ const short ikv3 = iq3/(args.ne03/args.ne_12_3);
// load the queries from shared memory into local memory
q8x8_t mq[D8];
half slope = 1.0f;
// ALiBi
- if (max_bias > 0.0f) {
+ if (args.max_bias > 0.0f) {
const short h = iq2;
- const half base = h < n_head_log2 ? m0 : m1;
- const short exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+ const half base = h < args.n_head_log2 ? args.m0 : args.m1;
+ const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
slope = pow(base, exph);
}
// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
- for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
+ for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) {
const int ic = ic0 + C*sgitg;
- if (ic >= ne11) {
+ if (ic >= args.ne11) {
break;
}
// load the mask in shared memory
#pragma unroll(Q)
for (short j = 0; j < Q; ++j) {
- device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*nb31);
+ device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31);
const half m = pm[ic + tiisg];
// this is compile-time check, so it does not have runtime overhead
if (is_same<kd4x4_t, k4x4_t>::value) {
// we can read directly from global memory
- device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
+ device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
#pragma unroll(D8)
for (short i = 0; i < D8; ++i) {
k8x8_t mk;
- simdgroup_load(mk, pk + i*8, nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10
+ simdgroup_load(mk, pk + i*8, args.nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10
simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
}
} else {
for (short ii = 0; ii < D16; ii += 4) {
- device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
+ device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
if (D16%4 == 0) {
// the head is evenly divisible by 4*16 = 64, so no need for bound checks
const half m = M[j];
// scale and apply the logitcap / mask
- half s = ss[j*TS + tiisg]*scale;
+ half s = ss[j*TS + tiisg]*args.scale;
- if (logit_softcap != 0.0f) {
- s = logit_softcap*precise::tanh(s);
+ if (args.logit_softcap != 0.0f) {
+ s = args.logit_softcap*precise::tanh(s);
}
// mqk = mqk + mask*slope
if (is_same<vd4x4_t, v4x4_t>::value) {
// we can read directly from global memory
- device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
+ device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
#pragma unroll(D8)
for (short i = 0; i < D8; ++i) {
v8x8_t mv;
- simdgroup_load(mv, pv + i*8, nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20
+ simdgroup_load(mv, pv + i*8, args.nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20
simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]);
}
} else {
for (short ii = 0; ii < D16; ii += 4) {
- device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
+ device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
if (D16%4 == 0) {
// no need for bound checks
// final rescale with 1/S and store to global memory
if (sgitg == 0) {
- for (short j = 0; j < Q && iq1 + j < ne01; ++j) {
+ for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) {
const float S = ss[j*TS + 0];
for (short i = tiisg; i < D4; i += NW) {
- dst4[((int64_t)iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
+ dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S;
}
}
}
short Q = 1, // queries per threadgroup
short C = 32> // cache items per threadgroup
kernel void kernel_flash_attn_ext_vec(
- device const char * q,
- device const char * k,
- device const char * v,
- device const char * mask,
- device float * dst,
- constant int32_t & ne01,
- constant int32_t & ne02,
- constant int32_t & ne03,
- constant uint32_t & nb01,
- constant uint32_t & nb02,
- constant uint32_t & nb03,
- constant int32_t & ne11,
- constant int32_t & ne_12_2, // assume K and V are same shape
- constant int32_t & ne_12_3,
- constant uint32_t & nb_12_1,
- constant uint32_t & nb_12_2,
- constant uint32_t & nb_12_3,
- constant uint32_t & nb31,
- constant int32_t & ne1,
- constant int32_t & ne2,
- constant float & scale,
- constant float & max_bias,
- constant float & m0,
- constant float & m1,
- constant uint16_t & n_head_log2,
- constant float & logit_softcap,
- threadgroup half * shared [[threadgroup(0)]],
- ushort3 tgpig[[threadgroup_position_in_grid]],
- ushort3 tpitg[[thread_position_in_threadgroup]],
- ushort3 ntg[[threads_per_threadgroup]],
- ushort tiisg[[thread_index_in_simdgroup]],
- ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ constant ggml_metal_kargs_flash_attn_ext & args,
+ device const char * q,
+ device const char * k,
+ device const char * v,
+ device const char * mask,
+ device char * dst,
+ threadgroup half * shmem_f16 [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 ntg[[threads_per_threadgroup]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
const short nsg = ntg.y; // number of simdgroups
const int iq3 = tgpig[2];
const short T = D + nsg*SH; // shared memory size per query in (half)
- //threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data
- threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t
- threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shared + 0*D); // same as above but in q4x4_t
- threadgroup s_t * ss = (threadgroup s_t *) (shared + sgitg*SH + Q*D); // scratch buffer for attention
- threadgroup s4_t * ss4 = (threadgroup s4_t *) (shared + sgitg*SH + Q*D); // same as above but in s4_t
- threadgroup half * sm = (threadgroup half *) (shared + sgitg*SH + C + Q*D); // scratch buffer for mask
- threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results
+ //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*D); // holds the query data
+ threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*D); // same as above but in q4_t
+ threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shmem_f16 + 0*D); // same as above but in q4x4_t
+ threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*D); // scratch buffer for attention
+ threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*D); // same as above but in s4_t
+ threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*D); // scratch buffer for mask
+ threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shmem_f16 + sgitg*D + Q*T); // scratch buffer for the results
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
o4x4_t lo[D16/NL];
// load heads from Q to shared memory
- device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03));
+ device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03));
for (short i = tiisg; i < D4; i += NW) {
- if (iq1 < ne01) {
+ if (iq1 < args.ne01) {
sq4[i] = (q4_t) q4[i];
} else {
sq4[i] = (q4_t) 0.0f;
const short ty = tiisg/NL;
// broadcast kv
- //const short rk2 = ne02/ne12;
- //const short rk3 = ne03/ne13;
+ //const short rk2 = args.ne02/args.ne12;
+ //const short rk3 = args.ne03/args.ne13;
- const short ikv2 = iq2/(ne02/ne_12_2);
- const short ikv3 = iq3/(ne03/ne_12_3);
+ const short ikv2 = iq2/(args.ne02/args.ne_12_2);
+ const short ikv3 = iq3/(args.ne03/args.ne_12_3);
// load the queries from shared memory into local memory
q4x4_t mq[D16/NL];
const bool has_mask = mask != q;
// pointer to the mask
- device const half * pm = (device const half *) (mask + iq1*nb31);
+ device const half * pm = (device const half *) (mask + iq1*args.nb31);
half slope = 1.0f;
// ALiBi
- if (max_bias > 0.0f) {
+ if (args.max_bias > 0.0f) {
const short h = iq2;
- const half base = h < n_head_log2 ? m0 : m1;
- const short exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+ const half base = h < args.n_head_log2 ? args.m0 : args.m1;
+ const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1;
slope = pow(base, exph);
}
// loop over the KV cache
// each simdgroup handles blocks of Q rows and C columns
- for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
+ for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) {
const int ic = ic0 + C*sgitg;
- if (ic >= ne11) {
+ if (ic >= args.ne11) {
break;
}
for (short cc = 0; cc < C/4; ++cc) {
qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 };
- device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
+ device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
#pragma unroll(D16/NL)
for (short ii = 0; ii < D16; ii += NL) {
// mqk = mqk*scale + mask*slope
if (tx == 0) {
- mqk *= scale;
+ mqk *= args.scale;
- if (logit_softcap != 0.0f) {
- mqk = logit_softcap*precise::tanh(mqk);
+ if (args.logit_softcap != 0.0f) {
+ mqk = args.logit_softcap*precise::tanh(mqk);
}
mqk += sm[4*cc + ty]*slope;
// O = O + (Q*K^T)*V
{
for (short cc = 0; cc < C/4; ++cc) {
- device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3));
+ device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3));
const s4x4_t ms(ss[4*cc + ty]);
const float S = ss[0];
for (short i = tiisg; i < D16; i += NW) {
- dst44[((int64_t)iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D16 + i] = (float4x4) sr4x4[i]/S;
+ dst44[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S;
}
}
}
template<typename T0, typename T1>
kernel void kernel_cpy(
- device const void * src0,
- device void * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig[2];
- const int64_t i02 = tgpig[1];
- const int64_t i01 = tgpig[0];
+ constant ggml_metal_kargs_cpy & args,
+ device const char * src0,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
+ const int i03 = tgpig[2];
+ const int i02 = tgpig[1];
+ const int i01 = tgpig[0];
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+ const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
- const int64_t i3 = n / (ne2*ne1*ne0);
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+ const int64_t i3 = n/(args.ne2*args.ne1*args.ne0);
+ const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0);
+ const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0;
+ const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0);
- device T1 * dst_data = (device T1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+ device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
- for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
- device const T0 * src = (device T0 *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+ for (int64_t i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) {
+ device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
dst_data[i00] = (T1) src[0];
}
}
#endif
kernel void kernel_cpy_f32_q8_0(
- device const float * src0,
- device void * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig[2];
- const int64_t i02 = tgpig[1];
- const int64_t i01 = tgpig[0];
+ constant ggml_metal_kargs_cpy & args,
+ device const char * src0,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
+ const int i03 = tgpig[2];
+ const int i02 = tgpig[1];
+ const int i01 = tgpig[0];
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+ const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
- const int64_t i3 = n / (ne2*ne1*ne0);
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK8_0;
+ const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
+ const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
+ const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
+ const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK8_0;
- device block_q8_0 * dst_data = (device block_q8_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+ device block_q8_0 * dst_data = (device block_q8_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
- for (int64_t i00 = tpitg.x*QK8_0; i00 < ne00; i00 += ntg.x*QK8_0) {
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+ for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) {
+ device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
float amax = 0.0f; // absolute max
}
kernel void kernel_cpy_f32_q4_0(
- device const float * src0,
- device void * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig[2];
- const int64_t i02 = tgpig[1];
- const int64_t i01 = tgpig[0];
+ constant ggml_metal_kargs_cpy & args,
+ device const char * src0,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
+ const int i03 = tgpig[2];
+ const int i02 = tgpig[1];
+ const int i01 = tgpig[0];
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+ const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
- const int64_t i3 = n / (ne2*ne1*ne0);
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_0;
+ const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
+ const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
+ const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
+ const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_0;
- device block_q4_0 * dst_data = (device block_q4_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+ device block_q4_0 * dst_data = (device block_q4_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
- for (int64_t i00 = tpitg.x*QK4_0; i00 < ne00; i00 += ntg.x*QK4_0) {
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+ for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) {
+ device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
float amax = 0.0f; // absolute max
float max = 0.0f;
}
kernel void kernel_cpy_f32_q4_1(
- device const float * src0,
- device void * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig[2];
- const int64_t i02 = tgpig[1];
- const int64_t i01 = tgpig[0];
+ constant ggml_metal_kargs_cpy & args,
+ device const char * src0,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
+ const int i03 = tgpig[2];
+ const int i02 = tgpig[1];
+ const int i01 = tgpig[0];
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+ const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
- const int64_t i3 = n / (ne2*ne1*ne0);
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_1;
+ const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
+ const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
+ const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
+ const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_1;
- device block_q4_1 * dst_data = (device block_q4_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+ device block_q4_1 * dst_data = (device block_q4_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
- for (int64_t i00 = tpitg.x*QK4_1; i00 < ne00; i00 += ntg.x*QK4_1) {
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+ for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) {
+ device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
float min = FLT_MAX;
float max = -FLT_MAX;
}
kernel void kernel_cpy_f32_q5_0(
- device const float * src0,
- device void * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig[2];
- const int64_t i02 = tgpig[1];
- const int64_t i01 = tgpig[0];
+ constant ggml_metal_kargs_cpy & args,
+ device const char * src0,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
+ const int i03 = tgpig[2];
+ const int i02 = tgpig[1];
+ const int i01 = tgpig[0];
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+ const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
- const int64_t i3 = n / (ne2*ne1*ne0);
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_0;
+ const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
+ const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
+ const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
+ const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_0;
- device block_q5_0 * dst_data = (device block_q5_0 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+ device block_q5_0 * dst_data = (device block_q5_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
- for (int64_t i00 = tpitg.x*QK5_0; i00 < ne00; i00 += ntg.x*QK5_0) {
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+ for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) {
+ device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
float amax = 0.0f; // absolute max
float max = 0.0f;
}
kernel void kernel_cpy_f32_q5_1(
- device const float * src0,
- device void * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig[2];
- const int64_t i02 = tgpig[1];
- const int64_t i01 = tgpig[0];
+ constant ggml_metal_kargs_cpy & args,
+ device const char * src0,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
+ const int i03 = tgpig[2];
+ const int i02 = tgpig[1];
+ const int i01 = tgpig[0];
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+ const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
- const int64_t i3 = n / (ne2*ne1*ne0);
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK5_1;
+ const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
+ const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
+ const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
+ const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_1;
- device block_q5_1 * dst_data = (device block_q5_1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+ device block_q5_1 * dst_data = (device block_q5_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
- for (int64_t i00 = tpitg.x*QK5_1; i00 < ne00; i00 += ntg.x*QK5_1) {
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+ for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) {
+ device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
float max = src[0];
float min = src[0];
}
kernel void kernel_cpy_f32_iq4_nl(
- device const float * src0,
- device void * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig[2];
- const int64_t i02 = tgpig[1];
- const int64_t i01 = tgpig[0];
+ constant ggml_metal_kargs_cpy & args,
+ device const char * src0,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
+ const int i03 = tgpig[2];
+ const int i02 = tgpig[1];
+ const int i01 = tgpig[0];
- const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+ const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00;
- const int64_t i3 = n / (ne2*ne1*ne0);
- const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
- const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
- const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0)/QK4_NL;
+ const int64_t i3 = n / (args.ne2*args.ne1*args.ne0);
+ const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0);
+ const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0;
+ const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_NL;
- device block_iq4_nl * dst_data = (device block_iq4_nl *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+ device block_iq4_nl * dst_data = (device block_iq4_nl *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
- for (int64_t i00 = tpitg.x*QK4_NL; i00 < ne00; i00 += ntg.x*QK4_NL) {
- device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+ for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) {
+ device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00);
float amax = 0.0f; // absolute max
float max = 0.0f;
}
dst_data[i00/QK4_NL].d = sumq2 > 0 ? sumqx/sumq2 : d;
-
}
}
kernel void kernel_concat(
+ constant ggml_metal_kargs_concat & args,
device const char * src0,
device const char * src1,
device char * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- constant int32_t & dim,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort3 tpitg[[thread_position_in_threadgroup]],
+ ushort3 ntg[[threads_per_threadgroup]]) {
- const int64_t i3 = tgpig.z;
- const int64_t i2 = tgpig.y;
- const int64_t i1 = tgpig.x;
+ const int i3 = tgpig.z;
+ const int i2 = tgpig.y;
+ const int i1 = tgpig.x;
- int64_t o[4] = {0, 0, 0, 0};
- o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03));
+ int o[4] = {0, 0, 0, 0};
+ o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03));
device const float * x;
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
- if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
- x = (device const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00);
+ for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) {
+ if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) {
+ x = (device const float *)(src0 + (i3 )*args.nb03 + (i2 )*args.nb02 + (i1 )*args.nb01 + (i0 )*args.nb00);
} else {
- x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10);
+ x = (device const float *)(src1 + (i3 - o[3])*args.nb13 + (i2 - o[2])*args.nb12 + (i1 - o[1])*args.nb11 + (i0 - o[0])*args.nb10);
}
- device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+ device float * y = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
*y = *x;
}
}
+template<typename args_t>
void kernel_mul_mv_q2_K_f32_impl(
- device const void * src0,
- device const float * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne12,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- threadgroup int8_t * shared_values,
- uint3 tgpig,
- uint tiisg,
- uint sgitg) {
-
- const int nb = ne00/QK_K;
+ args_t args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg) {
+
+ const int nb = args.ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
- const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
+ const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- device const block_q2_K * x = (device const block_q2_K *) ((device char *) src0 + offset0);
- device const float * y = (device const float *) ((device char *) src1 + offset1);
+ device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0);
+ device const float * y = (device const float *) (src1 + offset1);
float yl[32];
float sumf[N_DST]={0.f}, all_sum;
(acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) -
dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0));
- qs += nb01/2;
- sc += nb01;
- dh += nb01/2;
+ qs += args.nb01/2;
+ sc += args.nb01;
+ dh += args.nb01/2;
}
y4 += 4 * QK_K;
}
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+ dst_f32[first_row + row] = all_sum;
}
}
}
[[host_name("kernel_mul_mv_q2_K_f32")]]
kernel void kernel_mul_mv_q2_K_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q2_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
+template<typename args_t>
void kernel_mul_mv_q3_K_f32_impl(
- device const void * src0,
- device const float * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne12,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- threadgroup int8_t * shared_values,
- uint3 tgpig,
- uint tiisg,
- uint sgitg) {
-
- const int nb = ne00/QK_K;
-
- const int64_t r0 = tgpig.x;
- const int64_t r1 = tgpig.y;
- const int64_t im = tgpig.z;
+ args_t args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg) {
+
+ const int nb = args.ne00/QK_K;
+
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
- const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
+ const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- device const block_q3_K * x = (device const block_q3_K *) ((device char *) src0 + offset0);
- device const float * yy = (device const float *) ((device char *) src1 + offset1);
+ device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0);
+ device const float * yy = (device const float *) (src1 + offset1);
float yl[32];
const ushort4 hm = mm[2*ip + il/2];
- const int shift = 2*il;
- const float v1 = il == 0 ? 4.f : 64.f;
- const float v2 = 4.f * v1;
+ const short shift = 2*il;
+
+ const float v1 = il == 0 ? 4.f : 64.f;
+ const float v2 = 4.f * v1;
const uint16_t s_shift1 = 4*ip;
const uint16_t s_shift2 = s_shift1 + il;
sumf1[row] += d1 * (scales[1] - 32);
sumf2[row] += d2 * (scales[3] - 32);
- q += nb01/2;
- h += nb01/2;
- a += nb01/2;
- dh += nb01/2;
+ q += args.nb01/2;
+ h += args.nb01/2;
+ a += args.nb01/2;
+ dh += args.nb01/2;
}
y1 += 4 * QK_K;
const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
sumf1[row] = simd_sum(sumf);
}
+
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
if (tiisg == 0) {
for (int row = 0; row < 2; ++row) {
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row];
+ dst_f32[first_row + row] = sumf1[row];
}
}
}
[[host_name("kernel_mul_mv_q3_K_f32")]]
kernel void kernel_mul_mv_q3_K_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q3_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
+template<typename args_t>
void kernel_mul_mv_q4_K_f32_impl(
- device const void * src0,
- device const float * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne12,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- threadgroup int8_t * shared_values,
- uint3 tgpig,
- uint tiisg,
- uint sgitg) {
+ args_t args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg) {
const uint16_t kmask1 = 0x3f3f;
const uint16_t kmask2 = 0x0f0f;
const int iq = it/4; // 0 or 1
const int ir = it%4; // 0...3
- const int nb = ne00/QK_K;
+ const int nb = args.ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
//const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
const int first_row = r0 * N_DST;
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
- const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
+ const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- device const block_q4_K * x = (device const block_q4_K *) ((device char *) src0 + offset0);
- device const float * y = (device const float *) ((device char *) src1 + offset1);
+ device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0);
+ device const float * y = (device const float *) (src1 + offset1);
float yl[16];
float yh[16];
(acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) -
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
- q1 += nb01/2;
- sc += nb01/2;
- dh += nb01/2;
+ q1 += args.nb01/2;
+ sc += args.nb01/2;
+ dh += args.nb01/2;
}
y4 += 4 * QK_K;
}
+ device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0;
+
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+ dst_f32[first_row + row] = all_sum;
}
}
}
[[host_name("kernel_mul_mv_q4_K_f32")]]
kernel void kernel_mul_mv_q4_K_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q4_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
+template<typename args_t>
void kernel_mul_mv_q5_K_f32_impl(
- device const void * src0,
- device const float * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne12,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- threadgroup int8_t * shared_values,
- uint3 tgpig,
- uint tiisg,
- uint sgitg) {
-
- const int nb = ne00/QK_K;
-
- const int64_t r0 = tgpig.x;
- const int64_t r1 = tgpig.y;
+ args_t args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg) {
+
+ const int nb = args.ne00/QK_K;
+
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
- const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
+ const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- device const block_q5_K * x = (device const block_q5_K *) ((device char *) src0 + offset0);
- device const float * yy = (device const float *) ((device char *) src1 + offset1);
+ device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0);
+ device const float * yy = (device const float *) (src1 + offset1);
float sumf[2]={0.f};
sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
- q1 += nb01;
- qh += nb01;
- dh += nb01/2;
- a += nb01/2;
+ q1 += args.nb01;
+ qh += args.nb01;
+ dh += args.nb01/2;
+ a += args.nb01/2;
}
y1 += 4 * QK_K;
}
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
for (int row = 0; row < 2; ++row) {
const float tot = simd_sum(sumf[row]);
if (tiisg == 0) {
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
+ dst_f32[first_row + row] = tot;
}
}
}
[[host_name("kernel_mul_mv_q5_K_f32")]]
kernel void kernel_mul_mv_q5_K_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q5_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
+template <typename args_t>
void kernel_mul_mv_q6_K_f32_impl(
- device const void * src0,
- device const float * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne12,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- threadgroup int8_t * shared_values,
- uint3 tgpig,
- uint tiisg,
- uint sgitg) {
+ args_t args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg) {
const uint8_t kmask1 = 0x03;
const uint8_t kmask2 = 0x0C;
const uint8_t kmask3 = 0x30;
const uint8_t kmask4 = 0xC0;
- const int nb = ne00/QK_K;
+ const int nb = args.ne00/QK_K;
- const int64_t r0 = tgpig.x;
- const int64_t r1 = tgpig.y;
- const int im = tgpig.z;
+ const int r0 = tgpig.x;
+ const int r1 = tgpig.y;
+ const int im = tgpig.z;
- const int row = 2 * r0 + sgitg;
+ const int row = 2*r0 + sgitg;
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
- const uint offset0 = row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
+ const uint64_t offset0 = row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- device const block_q6_K * x = (device const block_q6_K *) ((device char *) src0 + offset0);
- device const float * yy = (device const float *) ((device char *) src1 + offset1);
+ device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0);
+ device const float * yy = (device const float *) (src1 + offset1);
float sumf = 0;
const int q_offset_h = 32*ip + l0;
for (int i = ix; i < nb; i += 2) {
-
device const uint8_t * q1 = x[i].ql + q_offset_l;
device const uint8_t * q2 = q1 + 32;
device const uint8_t * qh = x[i].qh + q_offset_h;
}
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
const float tot = simd_sum(sumf);
if (tiisg == 0) {
- dst[r1*ne0 + im*ne0*ne1 + row] = tot;
+ dst_f32[row] = tot;
}
}
[[host_name("kernel_mul_mv_q6_K_f32")]]
kernel void kernel_mul_mv_q6_K_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+ kernel_mul_mv_q6_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
// ======================= "True" 2-bit
+template<typename args_t>
void kernel_mul_mv_iq2_xxs_f32_impl(
- device const void * src0,
- device const float * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne12,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- threadgroup int8_t * shared_values,
- uint3 tgpig,
- uint tiisg,
- uint sgitg) {
-
- const int nb = ne00/QK_K;
+ args_t args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg) {
+
+ const int nb = args.ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
- const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
+ const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- device const block_iq2_xxs * x = (device const block_iq2_xxs *) ((device char *) src0 + offset0);
- device const float * y = (device const float *) ((device char *) src1 + offset1);
+ device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0);
+ device const float * y = (device const float *) (src1 + offset1);
float yl[32];
float sumf[N_DST]={0.f}, all_sum;
const int nb32 = nb * (QK_K / 32);
- threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
- threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256);
+ threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem);
+ threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 256);
{
int nval = 4;
int pos = (32*sgitg + tiisg)*nval;
- for (int i = 0; i < nval; ++i) values[pos + i] = iq2xxs_grid[pos + i];
+ for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xxs_grid[pos + i];
nval = 2;
pos = (32*sgitg + tiisg)*nval;
- for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
+ for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i];
threadgroup_barrier(mem_flags::mem_threadgroup);
}
float sum = 0;
for (int l = 0; l < 4; ++l) {
- const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]);
- const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
+ const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]);
+ const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
for (int j = 0; j < 8; ++j) {
sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
}
}
sumf[row] += d * sum;
- dh += nb01/2;
- q2 += nb01/2;
+ dh += args.nb01/2;
+ q2 += args.nb01/2;
}
y4 += 32 * 32;
}
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
+ dst_f32[first_row + row] = all_sum * 0.25f;
}
}
}
[[host_name("kernel_mul_mv_iq2_xxs_f32")]]
kernel void kernel_mul_mv_iq2_xxs_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
-
- kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ kernel_mul_mv_iq2_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
+template<typename args_t>
void kernel_mul_mv_iq2_xs_f32_impl(
- device const void * src0,
- device const float * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne12,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- threadgroup int8_t * shared_values,
- uint3 tgpig,
- uint tiisg,
- uint sgitg) {
-
- const int nb = ne00/QK_K;
+ args_t args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg) {
+
+ const int nb = args.ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
- const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
+ const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- device const block_iq2_xs * x = (device const block_iq2_xs *) ((device char *) src0 + offset0);
- device const float * y = (device const float *) ((device char *) src1 + offset1);
+ device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0);
+ device const float * y = (device const float *) (src1 + offset1);
float yl[32];
float sumf[N_DST]={0.f}, all_sum;
const int nb32 = nb * (QK_K / 32);
- threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
- threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 512);
+ threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem);
+ threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 512);
{
int nval = 8;
int pos = (32*sgitg + tiisg)*nval;
- for (int i = 0; i < nval; ++i) values[pos + i] = iq2xs_grid[pos + i];
+ for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xs_grid[pos + i];
nval = 2;
pos = (32*sgitg + tiisg)*nval;
- for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
+ for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i];
threadgroup_barrier(mem_flags::mem_threadgroup);
}
float sum1 = 0, sum2 = 0;
for (int l = 0; l < 2; ++l) {
- const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
- const uint8_t signs = shared_signs[(q2[l] >> 9)];
+ const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
+ const uint8_t signs = ssigns[(q2[l] >> 9)];
for (int j = 0; j < 8; ++j) {
sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
}
}
for (int l = 2; l < 4; ++l) {
- const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511));
- const uint8_t signs = shared_signs[(q2[l] >> 9)];
+ const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511));
+ const uint8_t signs = ssigns[(q2[l] >> 9)];
for (int j = 0; j < 8; ++j) {
sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
}
}
sumf[row] += d1 * sum1 + d2 * sum2;
- dh += nb01/2;
- q2 += nb01/2;
- sc += nb01;
+ dh += args.nb01/2;
+ q2 += args.nb01/2;
+ sc += args.nb01;
}
y4 += 32 * 32;
}
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
+ dst_f32[first_row + row] = all_sum * 0.25f;
}
- }
-}
-
-[[host_name("kernel_mul_mv_iq2_xs_f32")]]
-kernel void kernel_mul_mv_iq2_xs_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ }
+}
+
+[[host_name("kernel_mul_mv_iq2_xs_f32")]]
+kernel void kernel_mul_mv_iq2_xs_f32(
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq2_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
+template <typename args_t>
void kernel_mul_mv_iq3_xxs_f32_impl(
- device const void * src0,
- device const float * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne12,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- threadgroup int8_t * shared_values,
- uint3 tgpig,
- uint tiisg,
- uint sgitg) {
-
- const int nb = ne00/QK_K;
+ args_t args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg) {
+
+ const int nb = args.ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
- const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
+ const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- device const block_iq3_xxs * x = (device const block_iq3_xxs *) ((device char *) src0 + offset0);
- device const float * y = (device const float *) ((device char *) src1 + offset1);
+ device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0);
+ device const float * y = (device const float *) (src1 + offset1);
float yl[32];
float sumf[N_DST]={0.f}, all_sum;
const int nb32 = nb * (QK_K / 32);
- threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
- threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256);
+ threadgroup uint32_t * svalues = (threadgroup uint32_t *)(shmem);
+ threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 256);
{
int nval = 4;
int pos = (32*sgitg + tiisg)*nval;
- for (int i = 0; i < nval; ++i) values[pos + i] = iq3xxs_grid[pos + i];
+ for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3xxs_grid[pos + i];
nval = 2;
pos = (32*sgitg + tiisg)*nval;
- for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
+ for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i];
threadgroup_barrier(mem_flags::mem_threadgroup);
}
device const float * y4 = y + 32 * ix;
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
-
for (int i = 0; i < 32; ++i) {
yl[i] = y4[i];
}
device const half * dh = &xr->d;
for (int row = 0; row < N_DST; row++) {
-
const float db = dh[0];
const uint32_t aux32 = gas[0] | (gas[1] << 16);
const float d = db * (0.5f + (aux32 >> 28));
float2 sum = {0};
for (int l = 0; l < 4; ++l) {
- const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + q3[2*l+0]);
- const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + q3[2*l+1]);
- const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
+ const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + q3[2*l+0]);
+ const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + q3[2*l+1]);
+ const uint8_t signs = ssigns[(aux32 >> 7*l) & 127];
for (int j = 0; j < 4; ++j) {
sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
}
sumf[row] += d * (sum[0] + sum[1]);
- dh += nb01/2;
- q3 += nb01;
- gas += nb01/2;
+ dh += args.nb01/2;
+ q3 += args.nb01;
+ gas += args.nb01/2;
}
y4 += 32 * 32;
}
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f;
+ dst_f32[first_row + row] = all_sum * 0.5f;
}
}
}
[[host_name("kernel_mul_mv_iq3_xxs_f32")]]
kernel void kernel_mul_mv_iq3_xxs_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq3_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
+template<typename args_t>
void kernel_mul_mv_iq3_s_f32_impl(
- device const void * src0,
- device const float * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne12,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- threadgroup int8_t * shared_values,
- uint3 tgpig,
- uint tiisg,
- uint sgitg) {
-
- const int nb = ne00/QK_K;
+ args_t args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg) {
+
+ const int nb = args.ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
- const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
+ const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- device const block_iq3_s * x = (device const block_iq3_s *) ((device char *) src0 + offset0);
- device const float * y = (device const float *) ((device char *) src1 + offset1);
+ device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0);
+ device const float * y = (device const float *) (src1 + offset1);
float yl[32];
float sumf[N_DST]={0.f}, all_sum;
const int nb32 = nb * (QK_K / 32);
- threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
+ threadgroup uint32_t * svalues = (threadgroup uint32_t *) shmem;
{
int nval = 8;
int pos = (32*sgitg + tiisg)*nval;
- for (int i = 0; i < nval; ++i) values[pos + i] = iq3s_grid[pos + i];
+ for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3s_grid[pos + i];
threadgroup_barrier(mem_flags::mem_threadgroup);
}
float2 sum = {0};
for (int l = 0; l < 4; ++l) {
- const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? values + 256 : values;
- const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? values + 256 : values;
+ const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? svalues + 256 : svalues;
+ const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? svalues + 256 : svalues;
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]);
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]);
for (int j = 0; j < 4; ++j) {
}
sumf[row] += d * (sum[0] + sum[1]);
- dh += nb01/2;
- qs += nb01;
- qh += nb01;
- sc += nb01;
- signs += nb01;
+ dh += args.nb01/2;
+ qs += args.nb01;
+ qh += args.nb01;
+ sc += args.nb01;
+ signs += args.nb01;
}
y4 += 32 * 32;
}
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+ dst_f32[first_row + row] = all_sum;
}
}
}
[[host_name("kernel_mul_mv_iq3_s_f32")]]
kernel void kernel_mul_mv_iq3_s_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq3_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
+template <typename args_t>
void kernel_mul_mv_iq2_s_f32_impl(
- device const void * src0,
- device const float * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne12,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- threadgroup int8_t * shared_values,
- uint3 tgpig,
- uint tiisg,
- uint sgitg) {
-
- const int nb = ne00/QK_K;
+ args_t args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg) {
+
+ const int nb = args.ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
- const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
+ const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- device const block_iq2_s * x = (device const block_iq2_s *) ((device char *) src0 + offset0);
- device const float * y = (device const float *) ((device char *) src1 + offset1);
+ device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0);
+ device const float * y = (device const float *) (src1 + offset1);
float yl[32];
float sumf[N_DST]={0.f}, all_sum;
const int nb32 = nb * (QK_K / 32);
- //threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
+ //threadgroup uint64_t * svalues = (threadgroup uint64_t *) shmem;
//{
// int nval = 32;
// int pos = (32*sgitg + tiisg)*nval;
- // for (int i = 0; i < nval; ++i) values[pos + i] = iq2s_grid[pos + i];
+ // for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2s_grid[pos + i];
// threadgroup_barrier(mem_flags::mem_threadgroup);
//}
float2 sum = {0};
for (int l = 0; l < 2; ++l) {
- //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
- //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
+ //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
+ //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300)));
constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300)));
for (int j = 0; j < 8; ++j) {
}
sumf[row] += d1 * sum[0] + d2 * sum[1];
- dh += nb01/2;
- qs += nb01;
- qh += nb01;
- sc += nb01;
- signs += nb01;
+ dh += args.nb01/2;
+ qs += args.nb01;
+ qh += args.nb01;
+ sc += args.nb01;
+ signs += args.nb01;
}
y4 += 32 * 32;
}
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f;
+ dst_f32[first_row + row] = all_sum * 0.25f;
}
}
}
[[host_name("kernel_mul_mv_iq2_s_f32")]]
kernel void kernel_mul_mv_iq2_s_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq2_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
+template<typename args_t>
void kernel_mul_mv_iq1_s_f32_impl(
- device const void * src0,
- device const float * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne12,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- threadgroup int8_t * shared_value,
- uint3 tgpig,
- uint tiisg,
- uint sgitg) {
-
- const int nb = ne00/QK_K;
+ args_t args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg) {
+
+ const int nb = args.ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
- const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
+ const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- device const block_iq1_s * x = (device const block_iq1_s *) ((device char *) src0 + offset0);
- device const float * y = (device const float *) ((device char *) src1 + offset1);
+ device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0);
+ device const float * y = (device const float *) (src1 + offset1);
float yl[32];
float sumf[N_DST]={0.f}, all_sum;
}
sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1);
- dh += nb01/2;
- qs += nb01;
- qh += nb01/2;
+ dh += args.nb01/2;
+ qs += args.nb01;
+ qh += args.nb01/2;
}
y4 += 32 * 32;
}
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+ dst_f32[first_row + row] = all_sum;
}
}
}
+template <typename args_t>
void kernel_mul_mv_iq1_m_f32_impl(
- device const void * src0,
- device const float * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne12,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- threadgroup int8_t * shared_value,
- uint3 tgpig,
- uint tiisg,
- uint sgitg) {
-
- const int nb = ne00/QK_K;
+ args_t args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg) {
+
+ const int nb = args.ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
- const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
+ const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- device const block_iq1_m * x = (device const block_iq1_m *) ((device char *) src0 + offset0);
- device const float * y = (device const float *) ((device char *) src1 + offset1);
+ device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0);
+ device const float * y = (device const float *) (src1 + offset1);
float yl[32];
float sumf[N_DST]={0.f}, all_sum;
sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
(sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
- sc += nb01/2;
- qs += nb01;
- qh += nb01;
+ sc += args.nb01/2;
+ qs += args.nb01;
+ qh += args.nb01;
}
y4 += 32 * 32;
}
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+ dst_f32[first_row + row] = all_sum;
}
}
}
+template<typename args_t>
void kernel_mul_mv_iq4_nl_f32_impl(
- device const void * src0,
- device const float * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne12,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- threadgroup int8_t * shared_values_i8,
- uint3 tgpig,
- uint tiisg,
- uint sgitg) {
-
- threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
- const int nb = ne00/QK4_NL;
+ args_t args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg) {
+
+ threadgroup float * shmem_f32 = (threadgroup float *) shmem;
+ const int nb = args.ne00/QK4_NL;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * 2 + sgitg) * 2;
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
- const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
+ const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- device const block_iq4_nl * x = (device const block_iq4_nl *) ((device char *) src0 + offset0);
- device const float * y = (device const float *) ((device char *) src1 + offset1);
+ device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
+ device const float * y = (device const float *) (src1 + offset1);
const int ix = tiisg/2; // 0...15
const int it = tiisg%2; // 0 or 1
- shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
+ shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
threadgroup_barrier(mem_flags::mem_threadgroup);
float4 yl[4];
device const float4 * y4 = (device const float4 *)yb;
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
- for (int row = 0; row < 2 && first_row + row < ne01; ++row) {
+ for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) {
device const block_iq4_nl & xb = x[row*nb + ib];
device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it);
aux32[0] = q4[0] | (q4[1] << 16);
aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
aux32[0] &= 0x0f0f0f0f;
- qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
- qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
+ qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
+ qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
acc1 += yl[0] * qf1;
acc2 += yl[1] * qf2;
aux32[0] = q4[2] | (q4[3] << 16);
aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f;
aux32[0] &= 0x0f0f0f0f;
- qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
- qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
+ qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
+ qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
acc1 += yl[2] * qf1;
acc2 += yl[3] * qf2;
yb += 16 * QK4_NL;
}
- for (int row = 0; row < 2 && first_row + row < ne01; ++row) {
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
+ for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+ dst_f32[first_row + row] = all_sum;
}
}
}
+template<typename args_t>
void kernel_mul_mv_iq4_xs_f32_impl(
- device const void * src0,
- device const float * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne12,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- threadgroup int8_t * shared_values_i8,
- uint3 tgpig,
- uint tiisg,
- uint sgitg) {
-
- threadgroup float * shared_values = (threadgroup float *)shared_values_i8;
- const int nb = ne00/QK_K;
+ args_t args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg) {
+
+ threadgroup float * shmem_f32 = (threadgroup float *) shmem;
+ const int nb = args.ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int first_row = (r0 * 2 + sgitg) * 2;
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
+ const uint i12 = im%args.ne12;
+ const uint i13 = im/args.ne12;
- const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
- const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
+ const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
- device const block_iq4_xs * x = (device const block_iq4_xs *) ((device char *) src0 + offset0);
- device const float * y = (device const float *) ((device char *) src1 + offset1);
+ device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
+ device const float * y = (device const float *) (src1 + offset1);
const int ix = tiisg/16; // 0 or 1
const int it = tiisg%16; // 0...15
const int ib = it/2;
const int il = it%2;
- shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16];
+ shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16];
threadgroup_barrier(mem_flags::mem_threadgroup);
float4 yl[4];
float4 qf1, qf2;
for (int ibl = ix; ibl < nb; ibl += 2) {
-
device const float4 * y4 = (device const float4 *)yb;
yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5];
for (int row = 0; row < 2; ++row) {
-
device const block_iq4_xs & xb = x[row*nb + ibl];
device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il);
float4 acc1 = {0.f}, acc2 = {0.f};
- aux32[0] = q4[0] & 0x0f0f0f0f;
+ aux32[0] = (q4[0] ) & 0x0f0f0f0f;
aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f;
- qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
- qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
+ qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
+ qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
acc1 += yl[0] * qf1;
acc2 += yl[1] * qf2;
- aux32[0] = q4[1] & 0x0f0f0f0f;
+ aux32[0] = (q4[1] ) & 0x0f0f0f0f;
aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f;
- qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]};
- qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]};
+ qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]};
+ qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]};
acc1 += yl[2] * qf1;
acc2 += yl[3] * qf2;
yb += 2 * QK_K;
}
+ device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0;
+
for (int row = 0; row < 2; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
- dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
+ dst_f32[first_row + row] = all_sum;
}
}
}
[[host_name("kernel_mul_mv_iq1_s_f32")]]
kernel void kernel_mul_mv_iq1_s_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq1_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
[[host_name("kernel_mul_mv_iq1_m_f32")]]
kernel void kernel_mul_mv_iq1_m_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq1_m_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
[[host_name("kernel_mul_mv_iq4_nl_f32")]]
kernel void kernel_mul_mv_iq4_nl_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq4_nl_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
[[host_name("kernel_mul_mv_iq4_xs_f32")]]
kernel void kernel_mul_mv_iq4_xs_f32(
- device const void * src0,
- device const float * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ constant ggml_metal_kargs_mul_mv & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
- kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
+ kernel_mul_mv_iq4_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
}
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
// each block_q contains 16*nl weights
template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
-kernel void kernel_mul_mm(device const uchar * src0,
- device const uchar * src1,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne02,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint & r2,
- constant uint & r3,
- threadgroup uchar * shared_memory [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
-
- threadgroup T * sa = (threadgroup T *)(shared_memory);
- threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
-
- const uint r0 = tgpig.y;
- const uint r1 = tgpig.x;
- const uint im = tgpig.z;
+kernel void kernel_mul_mm(
+ constant ggml_metal_kargs_mul_mm & args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiitg[[thread_index_in_threadgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ threadgroup T * sa = (threadgroup T *)(shmem);
+ threadgroup float * sb = (threadgroup float *)(shmem + 4096);
+
+ const int r0 = tgpig.y;
+ const int r1 = tgpig.x;
+ const int im = tgpig.z;
// if this block is of 64x32 shape or smaller
- short n_rows = (ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
- short n_cols = (ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
+ short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M;
+ short n_cols = (args.ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (args.ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N;
// a thread shouldn't load data outside of the matrix
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
short il = (tiitg % THREAD_PER_ROW);
- const uint i12 = im%ne12;
- const uint i13 = im/ne12;
+ const int i12 = im%args.ne12;
+ const int i13 = im/args.ne12;
- uint offset0 = (i12/r2)*nb02 + (i13/r3)*nb03;
- ushort offset1 = il/nl;
+ uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
+ short offset1 = il/nl;
- device const block_q * x = (device const block_q *)(src0 + (r0*BLOCK_SIZE_M + thread_row)*nb01 + offset0) + offset1;
+ device const block_q * x = (device const block_q *)(src0 + (r0*BLOCK_SIZE_M + thread_row)*args.nb01 + offset0) + offset1;
device const float * y = (device const float *)(src1
- + nb13 * i13
- + nb12 * i12
- + nb11 * (r1 * BLOCK_SIZE_N + thread_col)
- + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
+ + args.nb13*i13
+ + args.nb12*i12
+ + args.nb11*(r1 * BLOCK_SIZE_N + thread_col)
+ + args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
- for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
+ for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) {
// load data and store to threadgroup memory
T4x4 temp_a;
dequantize_func(x, il, temp_a);
}
}
- if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
- device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \
- + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0;
+ if ((r0 + 1) * BLOCK_SIZE_M <= args.ne0 && (r1 + 1) * BLOCK_SIZE_N <= args.ne1) {
+ device float * C = (device float *) dst +
+ (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) + \
+ (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0;
+
for (short i = 0; i < 8; i++) {
- simdgroup_store(mc[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
+ simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.ne0 * (i/4), args.ne0);
}
} else {
// block is smaller than 64x32, we should avoid writing data outside of the matrix
threadgroup_barrier(mem_flags::mem_threadgroup);
- threadgroup float * temp_str = ((threadgroup float *) shared_memory) \
+ threadgroup float * temp_str = ((threadgroup float *) shmem) \
+ 32 * (sgitg&1) + (16 * (sgitg>>1))*BLOCK_SIZE_M;
for (short i = 0; i < 8; i++) {
simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M);
if (sgitg == 0) {
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
- device float * D = dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*ne0 + im*ne1*ne0;
+ device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.ne0 + im*args.ne1*args.ne0;
device float4 * D4 = (device float4 *) D;
threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
}
// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids
+// TODO: this kernel needs to be reimplemented from scratch for better performance
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
void kernel_mul_mm_id_impl(
- device const uchar * src0,
- device const uchar * src1,
+ int32_t ne00,
+ int32_t ne02,
+ uint64_t nb01,
+ uint64_t nb02,
+ int32_t ne11,
+ int32_t ne12,
+ uint64_t nb10,
+ uint64_t nb11,
+ uint64_t nb12,
+ int32_t ne0,
+ int32_t ne1,
+ int64_t ne0ne1,
+ device const char * src0,
+ device const char * src1,
threadgroup ushort2 * rowids,
- device float * dst,
- constant int64_t & ne00,
- constant int64_t & ne02,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- int64_t ne1,
- int64_t ne0ne1,
- threadgroup uchar * shared_memory,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
-
- threadgroup half * sa = (threadgroup half *)(shared_memory);
- threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
-
- const uint r0 = tgpig.y;
- const uint r1 = tgpig.x;
-
- if (r1 * BLOCK_SIZE_N >= ne1) return;
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiitg[[thread_index_in_threadgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ threadgroup half * sa = (threadgroup half *)(shmem);
+ threadgroup float * sb = (threadgroup float *)(shmem + 4096);
+
+ const int r0 = tgpig.y;
+ const int r1 = tgpig.x;
+
+ if (r1*BLOCK_SIZE_N >= ne1) return;
// if this block is of 64x32 shape or smaller
short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
simdgroup_half8x8 ma[4];
simdgroup_float8x8 mb[2];
- simdgroup_float8x8 c_res[8];
+ simdgroup_float8x8 mc[8];
for (int i = 0; i < 8; i++){
- c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
+ mc[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
}
short il = (tiitg % THREAD_PER_ROW);
threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
+ #pragma unroll(BLOCK_SIZE_K/8)
for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
+ #pragma unroll(4)
for (int i = 0; i < 4; i++) {
simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i);
}
simdgroup_barrier(mem_flags::mem_none);
+ #pragma unroll(2)
for (int i = 0; i < 2; i++) {
simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i);
}
lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
+ #pragma unroll(8)
for (int i = 0; i < 8; i++){
- simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
+ simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]);
}
}
}
{
threadgroup_barrier(mem_flags::mem_threadgroup);
- threadgroup float * temp_str = ((threadgroup float *)shared_memory) \
+ threadgroup float * temp_str = ((threadgroup float *) shmem) \
+ 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
for (int i = 0; i < 8; i++) {
- simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
+ simdgroup_store(mc[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
- device float * C = dst + (BLOCK_SIZE_M * r0);
if (sgitg == 0) {
for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) {
threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j];
- int joff = jid[0] * ne0 + jid[1] * ne0ne1;
- for (int i = 0; i < n_rows; i++) {
- *(C + i + joff) = *(temp_str + i + j * BLOCK_SIZE_M);
+ int64_t joff = jid[0]*ne0 + jid[1]*ne0ne1;
+
+ device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + joff;
+ device float4 * D4 = (device float4 *) D;
+
+ threadgroup float * C = temp_str + (j*BLOCK_SIZE_M);
+ threadgroup float4 * C4 = (threadgroup float4 *) C;
+
+ int i = 0;
+ for (; i < n_rows/4; i++) {
+ *(D4 + i) = *(C4 + i);
+ }
+
+ i *= 4;
+ for (; i < n_rows; i++) {
+ *(D + i) = *(C + i);
}
}
}
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
kernel void kernel_mul_mm_id(
- device const uchar * src0s,
- device const uchar * src1,
- device float * dst,
- device const uchar * ids,
- constant int64_t & nei0,
- constant int64_t & nei1,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne02,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- threadgroup uchar * shared_memory [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ constant ggml_metal_kargs_mul_mm_id & args,
+ device const char * src0s,
+ device const char * src1,
+ device char * dst,
+ device const char * ids,
+ threadgroup char * shmem [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiitg[[thread_index_in_threadgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
const int32_t i02 = tgpig.z;
+
tgpig.z = 0;
- device const uchar * src0 = src0s + i02*nb02;
+ device const char * src0 = src0s + i02*args.nb02;
// row indices
- threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192);
+ threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192);
// TODO: parallelize this loop
- int64_t _ne1 = 0;
- for (ushort ii1 = 0; ii1 < nei1; ii1++) {
- for (ushort ii0 = 0; ii0 < nei0; ii0++) {
- int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0];
+ int32_t _ne1 = 0;
+ for (ushort ii1 = 0; ii1 < args.nei1; ii1++) {
+ for (ushort ii0 = 0; ii0 < args.nei0; ii0++) {
+ int32_t id = ((device int32_t *) (ids + ii1*args.nbi1))[ii0];
if (id == i02) {
- //if (tiitg == 0) {
+ if (tiitg == 0) {
rowids[_ne1] = ushort2(ii0, ii1);
- //}
+ }
_ne1++;
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
+ args.ne00,
+ args.ne02,
+ args.nb01,
+ args.nb02,
+ args.ne11,
+ args.ne12,
+ args.nb10,
+ args.nb11,
+ args.nb12,
+ args.ne0,
+ _ne1,
+ (int64_t)args.ne0*args.ne1,
src0,
src1,
rowids,
dst,
- ne00,
- ne02,
- nb01,
- nb02,
- ne11,
- ne12,
- nb10,
- nb11,
- nb12,
- ne0,
- _ne1,
- ne0*ne1,
- shared_memory,
+ shmem,
tgpig,
tiitg,
sgitg);
//
typedef void (kernel_mul_mv_impl_t)(
- device const char * src0,
- device const char * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb00,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne11,
- int64_t ne12,
- uint64_t nb10,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- uint3 tgpig,
- uint tiisg);
+ ggml_metal_kargs_mul_mv args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ uint3 tgpig,
+ ushort tiisg);
typedef void (kernel_mul_mv2_impl_t)(
- device const void * src0,
- device const float * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne12,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint r2,
- uint r3,
- threadgroup int8_t * shared_values,
- uint3 tgpig,
- uint tiisg,
- uint sgitg);
+ ggml_metal_kargs_mul_mv args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiisg,
+ ushort sgitg);
template<kernel_mul_mv_impl_t impl_fn>
void mmv_fn(
- device const char * src0,
- device const char * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb00,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne11,
- int64_t ne12,
- int64_t ne13,
- uint64_t nb10,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint64_t nb1,
- uint r2,
- uint r3,
- threadgroup int8_t * shared_values,
- uint3 tgpig,
- uint tiitg,
- uint tiisg,
- uint sgitg) {
- impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,nb03,ne10,ne11,ne12,nb10,nb11,nb12,nb13,ne0,ne1,r2,r3,tgpig,tiisg);
+ ggml_metal_kargs_mul_mv args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiitg,
+ ushort tiisg,
+ ushort sgitg) {
+ impl_fn(args, src0, src1, dst, tgpig, tiisg);
}
template<kernel_mul_mv2_impl_t impl_fn>
void mmv_fn(
- device const char * src0,
- device const char * src1,
- device float * dst,
- int64_t ne00,
- int64_t ne01,
- int64_t ne02,
- uint64_t nb00,
- uint64_t nb01,
- uint64_t nb02,
- uint64_t nb03,
- int64_t ne10,
- int64_t ne11,
- int64_t ne12,
- int64_t ne13,
- uint64_t nb10,
- uint64_t nb11,
- uint64_t nb12,
- uint64_t nb13,
- int64_t ne0,
- int64_t ne1,
- uint64_t nb1,
- uint r2,
- uint r3,
- threadgroup int8_t * shared_values,
- uint3 tgpig,
- uint tiitg,
- uint tiisg,
- uint sgitg) {
- impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg);
-}
-
-typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4>>) mul_mv_impl_fn_t;
+ ggml_metal_kargs_mul_mv args,
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ threadgroup char * shmem,
+ uint3 tgpig,
+ ushort tiitg,
+ ushort tiisg,
+ ushort sgitg) {
+ impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg);
+}
+
+typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4, ggml_metal_kargs_mul_mv>>) mul_mv_impl_fn_t;
template<mul_mv_impl_fn_t impl_fn>
kernel void kernel_mul_mv_id(
- device const char * src0s,
- device const char * src1,
- device float * dst,
- device const char * ids,
- constant int64_t & nei0,
- constant int64_t & nei1,
- constant uint64_t & nbi1,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant uint64_t & nb1,
- threadgroup int8_t * shared_values [[threadgroup(0)]],
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint tiitg[[thread_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]]) {
- const int iid1 = tgpig.z/nei0;
- const int idx = tgpig.z%nei0;
+ constant ggml_metal_kargs_mul_mv_id & args,
+ device const char * src0s,
+ device const char * src1,
+ device char * dst,
+ device const char * ids,
+ threadgroup char * shmem [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ ushort tiitg[[thread_index_in_threadgroup]],
+ ushort tiisg[[thread_index_in_simdgroup]],
+ ushort sgitg[[simdgroup_index_in_threadgroup]]) {
+ const int iid1 = tgpig.z/args.nei0;
+ const int idx = tgpig.z%args.nei0;
tgpig.z = 0;
- const int32_t i02 = ((device const int32_t *) (ids + iid1*nbi1))[idx];
+ const int32_t i02 = ((device const int32_t *) (ids + iid1*args.nbi1))[idx];
- const int64_t i11 = idx % ne11;
+ const int64_t i11 = idx % args.ne11;
const int64_t i12 = iid1;
const int64_t i1 = idx;
const int64_t i2 = i12;
- device const char * src0_cur = src0s + i02*nb02;
- device const char * src1_cur = src1 + i11*nb11 + i12*nb12;
- device float * dst_cur = dst + i1*ne0 + i2*ne1*ne0;
+ device const char * src0_cur = src0s + i02*args.nb02;
+ device const char * src1_cur = src1 + i11*args.nb11 + i12*args.nb12;
+
+ device char * dst_cur = dst + (i1*args.ne0 + i2*args.ne1*args.ne0)*sizeof(float);
+
+ ggml_metal_kargs_mul_mv args0 = {
+ /*.ne00 =*/ args.ne00,
+ /*.ne01 =*/ args.ne01,
+ /*.ne02 =*/ 1, // args.ne02,
+ /*.nb00 =*/ args.nb00,
+ /*.nb01 =*/ args.nb01,
+ /*.nb02 =*/ args.nb02,
+ /*.nb03 =*/ args.nb02, // args.ne02 == 1
+ /*.ne10 =*/ args.ne10,
+ /*.ne11 =*/ 1, // args.ne11,
+ /*.ne12 =*/ 1, // args.ne12,
+ /*.nb10 =*/ args.nb10,
+ /*.nb11 =*/ args.nb11,
+ /*.nb12 =*/ args.nb12,
+ /*.nb13 =*/ args.nb12, // ne12 == 1
+ /*.ne0 =*/ args.ne0,
+ /*.ne1 =*/ 1, // args.ne1,
+ /*.r2 =*/ 1,
+ /*.r3 =*/ 1,
+ };
impl_fn(
+ args0,
/* src0 */ src0_cur,
/* src1 */ src1_cur,
/* dst */ dst_cur,
- /* ne00 */ ne00,
- /* ne01 */ ne01,
- /* ne02 */ 1, // ne02,
- /* nb00 */ nb00,
- /* nb01 */ nb01,
- /* nb02 */ nb02,
- /* nb03 */ nb02, // ne02 == 1
- /* ne10 */ ne10,
- /* ne11 */ 1, // ne11,
- /* ne12 */ 1, // ne12,
- /* ne13 */ 1, // ne13,
- /* nb10 */ nb10,
- /* nb11 */ nb11,
- /* nb12 */ nb12,
- /* ne13 */ nb12, // ne12 == 1
- /* ne0 */ ne0,
- /* ne1 */ 1, // ne1,
- /* nb1 */ nb1,
- /* r2 */ 1,
- /* r3 */ 1,
- shared_values,
+ shmem,
tgpig,
tiitg,
tiisg,