#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
-enum ggml_sort_order {
- GGML_SORT_ORDER_ASC,
- GGML_SORT_ORDER_DESC,
+constexpr constant static float kvalues_iq4nl_f[16] = {
+ -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
};
-// general-purpose kernel for addition, subtraction, multiplication and division of two tensors
-// pros: works for non-contiguous tensors, supports broadcast across all dims
-// cons: not very efficient
-kernel void kernel_add(
- device const char * src0,
- device const char * src1,
- device char * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- constant int64_t & offs,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig.z;
- const int64_t i02 = tgpig.y;
- const int64_t i01 = tgpig.x;
-
- const int64_t i13 = i03 % ne13;
- const int64_t i12 = i02 % ne12;
- const int64_t i11 = i01 % ne11;
-
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
- device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
+// NOTE: this is not dequantizing - we are simply fitting the template
+template <typename type4x4>
+void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
+ float4x4 temp = *(((device float4x4 *)src));
+ for (int i = 0; i < 16; i++){
+ reg[i/4][i%4] = temp[i/4][i%4];
+ }
+}
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
- const int i10 = i0 % ne10;
- *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10));
+template <typename type4x4>
+void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
+ half4x4 temp = *(((device half4x4 *)src));
+ for (int i = 0; i < 16; i++){
+ reg[i/4][i%4] = temp[i/4][i%4];
}
}
-kernel void kernel_sub(
- device const char * src0,
- device const char * src1,
- device char * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- constant int64_t & offs,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig.z;
- const int64_t i02 = tgpig.y;
- const int64_t i01 = tgpig.x;
+template <typename type4x4>
+void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
+ device const uint16_t * qs = ((device const uint16_t *)xb + 1);
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
+ const float d2 = d1 / 256.f;
+ const float md = -8.h * xb->d;
+ const ushort mask0 = il ? 0x00F0 : 0x000F;
+ const ushort mask1 = mask0 << 8;
- const int64_t i13 = i03 % ne13;
- const int64_t i12 = i02 % ne12;
- const int64_t i11 = i01 % ne11;
+ for (int i=0;i<8;i++) {
+ reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
+ reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
+ }
+}
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
- device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
+template <typename type4x4>
+void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
+ device const uint16_t * qs = ((device const uint16_t *)xb + 2);
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
+ const float d2 = d1 / 256.f;
+ const float m = xb->m;
+ const ushort mask0 = il ? 0x00F0 : 0x000F;
+ const ushort mask1 = mask0 << 8;
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
- const int i10 = i0 % ne10;
- *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) - *((device float *)(src1_ptr + i10*nb10));
+ for (int i=0;i<8;i++) {
+ reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
+ reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
}
}
-kernel void kernel_mul(
- device const char * src0,
- device const char * src1,
- device char * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig.z;
- const int64_t i02 = tgpig.y;
- const int64_t i01 = tgpig.x;
+template <typename type4x4>
+void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
+ device const uint16_t * qs = ((device const uint16_t *)xb + 3);
+ const float d = xb->d;
+ const float md = -16.h * xb->d;
+ const ushort mask = il ? 0x00F0 : 0x000F;
- const int64_t i13 = i03 % ne13;
- const int64_t i12 = i02 % ne12;
- const int64_t i11 = i01 % ne11;
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
- device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
+ const int x_mv = il ? 4 : 0;
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
- const int i10 = i0 % ne10;
- *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10));
+ const int gh_mv = il ? 12 : 0;
+ const int gh_bk = il ? 0 : 4;
+
+ for (int i = 0; i < 8; i++) {
+ // extract the 5-th bits for x0 and x1
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
+
+ // combine the 4-bits from qs with the 5th bit
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
+
+ reg[i/2][2*(i%2)+0] = d * x0 + md;
+ reg[i/2][2*(i%2)+1] = d * x1 + md;
}
}
-kernel void kernel_div(
- device const char * src0,
- device const char * src1,
- device char * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne10,
- constant int64_t & ne11,
- constant int64_t & ne12,
- constant int64_t & ne13,
- constant uint64_t & nb10,
- constant uint64_t & nb11,
- constant uint64_t & nb12,
- constant uint64_t & nb13,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = tgpig.z;
- const int64_t i02 = tgpig.y;
- const int64_t i01 = tgpig.x;
+template <typename type4x4>
+void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
+ device const uint16_t * qs = ((device const uint16_t *)xb + 4);
+ const float d = xb->d;
+ const float m = xb->m;
+ const ushort mask = il ? 0x00F0 : 0x000F;
- const int64_t i13 = i03 % ne13;
- const int64_t i12 = i02 % ne12;
- const int64_t i11 = i01 % ne11;
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
- device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
- device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
+ const int x_mv = il ? 4 : 0;
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
- const int i10 = i0 % ne10;
- *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10));
- }
-}
+ const int gh_mv = il ? 12 : 0;
+ const int gh_bk = il ? 0 : 4;
-template<typename T>
-kernel void kernel_repeat(
- device const char * src0,
- device char * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant int64_t & ne03,
- constant uint64_t & nb00,
- constant uint64_t & nb01,
- constant uint64_t & nb02,
- constant uint64_t & nb03,
- constant int64_t & ne0,
- constant int64_t & ne1,
- constant int64_t & ne2,
- constant int64_t & ne3,
- constant uint64_t & nb0,
- constant uint64_t & nb1,
- constant uint64_t & nb2,
- constant uint64_t & nb3,
- uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 ntg[[threads_per_threadgroup]]) {
- const int64_t i3 = tgpig.z;
- const int64_t i2 = tgpig.y;
- const int64_t i1 = tgpig.x;
-
- const int64_t i03 = i3 % ne03;
- const int64_t i02 = i2 % ne02;
- const int64_t i01 = i1 % ne01;
+ for (int i = 0; i < 8; i++) {
+ // extract the 5-th bits for x0 and x1
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
- device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
- device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ;
+ // combine the 4-bits from qs with the 5th bit
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
- for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
- const int i00 = i0 % ne00;
- *((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00));
+ reg[i/2][2*(i%2)+0] = d * x0 + m;
+ reg[i/2][2*(i%2)+1] = d * x1 + m;
}
}
-typedef decltype(kernel_repeat<float>) kernel_repeat_t;
-
-template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
-template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
-template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
-template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
-
-// assumption: src1 is a row
-// broadcast src1 into src0
-kernel void kernel_add_row(
- device const float4 * src0,
- device const float4 * src1,
- device float4 * dst,
- constant uint64_t & nb [[buffer(28)]],
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = src0[tpig] + src1[tpig % nb];
-}
+template <typename type4x4>
+void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
+ device const int8_t * qs = ((device const int8_t *)xb->qs);
+ const half d = xb->d;
-kernel void kernel_sub_row(
- device const float4 * src0,
- device const float4 * src1,
- device float4 * dst,
- constant uint64_t & nb [[buffer(28)]],
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = src0[tpig] - src1[tpig % nb];
+ for (int i = 0; i < 16; i++) {
+ reg[i/4][i%4] = (qs[i + 16*il] * d);
+ }
}
-kernel void kernel_mul_row(
- device const float4 * src0,
- device const float4 * src1,
- device float4 * dst,
- constant uint64_t & nb [[buffer(28)]],
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = src0[tpig] * src1[tpig % nb];
-}
+template <typename type4x4>
+void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
+ const float d = xb->d;
+ const float min = xb->dmin;
+ device const uint8_t * q = (device const uint8_t *)xb->qs;
+ float dl, ml;
+ uint8_t sc = xb->scales[il];
-kernel void kernel_div_row(
- device const float4 * src0,
- device const float4 * src1,
- device float4 * dst,
- constant uint64_t & nb [[buffer(28)]],
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = src0[tpig] / src1[tpig % nb];
-}
+ q = q + 32*(il/8) + 16*(il&1);
+ il = (il/2)%4;
-kernel void kernel_scale(
- device const float * src0,
- device float * dst,
- constant float & scale,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = src0[tpig] * scale;
+ half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
+ uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
+ dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
+ for (int i = 0; i < 16; ++i) {
+ reg[i/4][i%4] = dl * (q[i] & mask) - ml;
+ }
}
-kernel void kernel_scale_4(
- device const float4 * src0,
- device float4 * dst,
- constant float & scale,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = src0[tpig] * scale;
-}
+template <typename type4x4>
+void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
+ const half d_all = xb->d;
+ device const uint8_t * q = (device const uint8_t *)xb->qs;
+ device const uint8_t * h = (device const uint8_t *)xb->hmask;
+ device const int8_t * scales = (device const int8_t *)xb->scales;
-kernel void kernel_clamp(
- device const float * src0,
- device float * dst,
- constant float & min,
- constant float & max,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]);
-}
+ q = q + 32 * (il/8) + 16 * (il&1);
+ h = h + 16 * (il&1);
+ uint8_t m = 1 << (il/2);
+ uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
+ ((il/4)>0 ? 12 : 3);
+ uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
+ uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
+ int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
+ : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
+ float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
+ const float ml = 4.f * dl;
-kernel void kernel_relu(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = max(0.0f, src0[tpig]);
-}
+ il = (il/2) & 3;
+ const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
+ const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
+ dl *= coef;
-kernel void kernel_sigmoid(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
+ for (int i = 0; i < 16; ++i) {
+ reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
+ }
}
-kernel void kernel_tanh(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- device const float & x = src0[tpig];
- dst[tpig] = precise::tanh(x);
+static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
+ return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
+ : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
}
-constant float GELU_COEF_A = 0.044715f;
-constant float GELU_QUICK_COEF = -1.702f;
-constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
+template <typename type4x4>
+void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
+ device const uchar * q = xb->qs;
-kernel void kernel_gelu(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- device const float & x = src0[tpig];
+ short is = (il/4) * 2;
+ q = q + (il/4) * 32 + 16 * (il&1);
+ il = il & 3;
+ const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
+ const float d = il < 2 ? xb->d : xb->d / 16.h;
+ const float min = xb->dmin;
+ const float dl = d * sc[0];
+ const float ml = min * sc[1];
- dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+ const ushort mask = il<2 ? 0x0F : 0xF0;
+ for (int i = 0; i < 16; ++i) {
+ reg[i/4][i%4] = dl * (q[i] & mask) - ml;
+ }
}
-kernel void kernel_gelu_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- device const float4 & x = src0[tpig];
-
- // BEWARE !!!
- // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
- // This was observed with Falcon 7B and 40B models
- //
- dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
-}
+template <typename type4x4>
+void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
+ device const uint8_t * q = xb->qs;
+ device const uint8_t * qh = xb->qh;
-kernel void kernel_gelu_quick(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- device const float & x = src0[tpig];
+ short is = (il/4) * 2;
+ q = q + 32 * (il/4) + 16 * (il&1);
+ qh = qh + 16 * (il&1);
+ uint8_t ul = 1 << (il/2);
+ il = il & 3;
+ const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
+ const float d = il < 2 ? xb->d : xb->d / 16.f;
+ const float min = xb->dmin;
+ const float dl = d * sc[0];
+ const float ml = min * sc[1];
- dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
+ const ushort mask = il<2 ? 0x0F : 0xF0;
+ const float qh_val = il<2 ? 16.f : 256.f;
+ for (int i = 0; i < 16; ++i) {
+ reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
+ }
}
-kernel void kernel_gelu_quick_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- device const float4 & x = src0[tpig];
-
- dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
-}
+template <typename type4x4>
+void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
+ const half d_all = xb->d;
+ device const uint8_t * ql = (device const uint8_t *)xb->ql;
+ device const uint8_t * qh = (device const uint8_t *)xb->qh;
+ device const int8_t * scales = (device const int8_t *)xb->scales;
-kernel void kernel_silu(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- device const float & x = src0[tpig];
- dst[tpig] = x / (1.0f + exp(-x));
-}
+ ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
+ qh = qh + 32*(il/8) + 16*(il&1);
+ float sc = scales[(il%2) + 2 * ((il/2))];
+ il = (il/2) & 3;
-kernel void kernel_silu_4(
- device const float4 * src0,
- device float4 * dst,
- uint tpig[[thread_position_in_grid]]) {
- device const float4 & x = src0[tpig];
- dst[tpig] = x / (1.0f + exp(-x));
+ const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
+ const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
+ const float coef = il>1 ? 1.f/16.f : 1.f;
+ const float ml = d_all * sc * 32.f;
+ const float dl = d_all * sc * coef;
+ for (int i = 0; i < 16; ++i) {
+ const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
+ : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
+ reg[i/4][i%4] = dl * q - ml;
+ }
}
-kernel void kernel_sqr(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = src0[tpig] * src0[tpig];
+template <typename type4x4>
+void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+ const float d = xb->d;
+ const int ib32 = il/2;
+ il = il%2;
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+ // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
+ device const uint16_t * q2 = xb->qs + 4*ib32;
+ const uint32_t aux32_g = q2[0] | (q2[1] << 16);
+ const uint32_t aux32_s = q2[2] | (q2[3] << 16);
+ thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
+ const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
+ constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
+ uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
+ for (int i = 0; i < 8; ++i) {
+ reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
+ }
+ grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
+ signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
+ for (int i = 0; i < 8; ++i) {
+ reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
+ }
}
-kernel void kernel_sqrt(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = sqrt(src0[tpig]);
+template <typename type4x4>
+void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+ const float d = xb->d;
+ const int ib32 = il/2;
+ il = il%2;
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+ device const uint16_t * q2 = xb->qs + 4*ib32;
+ const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
+ constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
+ uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
+ for (int i = 0; i < 8; ++i) {
+ reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
+ }
+ grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
+ signs = ksigns_iq2xs[q2[2*il+1] >> 9];
+ for (int i = 0; i < 8; ++i) {
+ reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
+ }
}
-kernel void kernel_sin(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = sin(src0[tpig]);
+template <typename type4x4>
+void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) {
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+ const float d = xb->d;
+ const int ib32 = il/2;
+ il = il%2;
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+ device const uint8_t * q3 = xb->qs + 8*ib32;
+ device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32;
+ const uint32_t aux32 = gas[0] | (gas[1] << 16);
+ const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f;
+ constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]);
+ constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]);
+ uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127];
+ for (int i = 0; i < 4; ++i) {
+ reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
+ reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
+ }
+ grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]);
+ grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]);
+ signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127];
+ for (int i = 0; i < 4; ++i) {
+ reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
+ reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
+ }
}
-kernel void kernel_cos(
- device const float * src0,
- device float * dst,
- uint tpig[[thread_position_in_grid]]) {
- dst[tpig] = cos(src0[tpig]);
+template <typename type4x4>
+void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) {
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+ const float d = xb->d;
+ const int ib32 = il/2;
+ il = il%2;
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+ device const uint8_t * qs = xb->qs + 8*ib32;
+ device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
+ const uint8_t qh = xb->qh[ib32] >> 4*il;
+ const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
+ constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256)));
+ constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256)));
+ for (int i = 0; i < 4; ++i) {
+ reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
+ reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
+ }
+ grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256)));
+ grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256)));
+ for (int i = 0; i < 4; ++i) {
+ reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
+ reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
+ }
}
-kernel void kernel_sum_rows(
- device const float * src0,
- device float * dst,
+template <typename type4x4>
+void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) {
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+ const float d = xb->d;
+ const int ib32 = il/2;
+ il = il%2;
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+ device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
+ device const uint8_t * signs = qs + QK_K/8;
+ const uint8_t qh = xb->qh[ib32] >> 4*il;
+ const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
+ constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300)));
+ constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300)));
+ for (int i = 0; i < 8; ++i) {
+ reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]);
+ reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]);
+ }
+}
+
+template <typename type4x4>
+void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+ const int ib32 = il/2;
+ il = il%2;
+ const float d = xb->d;
+ device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
+ device const uint16_t * qh = xb->qh;
+ const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1);
+ const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA);
+ const uint16_t h = qh[ib32] >> 6*il;
+ constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700)));
+ constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700)));
+ for (int i = 0; i < 4; ++i) {
+ reg[0][i] = dl * (grid1[i] & 0xf) + ml;
+ reg[1][i] = dl * (grid1[i] >> 4) + ml;
+ reg[2][i] = dl * (grid2[i] & 0xf) + ml;
+ reg[3][i] = dl * (grid2[i] >> 4) + ml;
+ }
+}
+
+template <typename type4x4>
+void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) {
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+ const int ib32 = il/2;
+ il = il%2;
+ device const uint16_t * sc = (device const uint16_t *)xb->scales;
+
+ iq1m_scale_t scale;
+ scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
+ const float d = scale.f16;
+
+ device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
+ device const uint8_t * qh = xb->qh + 2*ib32 + il;
+
+ const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
+ const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
+ const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
+ constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
+ constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
+ for (int i = 0; i < 4; ++i) {
+ reg[0][i] = dl * (grid1[i] & 0xf) + ml1;
+ reg[1][i] = dl * (grid1[i] >> 4) + ml1;
+ reg[2][i] = dl * (grid2[i] & 0xf) + ml2;
+ reg[3][i] = dl * (grid2[i] >> 4) + ml2;
+ }
+}
+
+template <typename type4x4>
+void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
+ device const uint16_t * q4 = (device const uint16_t *)xb->qs;
+ const float d = xb->d;
+ uint32_t aux32;
+ thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
+ for (int i = 0; i < 4; ++i) {
+ aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f;
+ reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
+ reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
+ reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
+ reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
+ }
+}
+
+template <typename type4x4>
+void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
+ // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
+ const int ib32 = il/2;
+ il = il%2;
+ // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
+ device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32;
+ const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4);
+ const float d = (float)xb->d * (ls - 32);
+ uint32_t aux32;
+ thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
+ for (int i = 0; i < 4; ++i) {
+ aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f;
+ reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
+ reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
+ reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
+ reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
+ }
+}
+
+enum ggml_sort_order {
+ GGML_SORT_ORDER_ASC,
+ GGML_SORT_ORDER_DESC,
+};
+
+// general-purpose kernel for addition, subtraction, multiplication and division of two tensors
+// pros: works for non-contiguous tensors, supports broadcast across all dims
+// cons: not very efficient
+kernel void kernel_add(
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
- uint3 tpig[[thread_position_in_grid]]) {
- int64_t i3 = tpig.z;
- int64_t i2 = tpig.y;
- int64_t i1 = tpig.x;
+ constant int64_t & offs,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig.z;
+ const int64_t i02 = tgpig.y;
+ const int64_t i01 = tgpig.x;
- if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
- return;
- }
-
- device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
- device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
+ const int64_t i13 = i03 % ne13;
+ const int64_t i12 = i02 % ne12;
+ const int64_t i11 = i01 % ne11;
- float row_sum = 0;
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
- for (int64_t i0 = 0; i0 < ne00; i0++) {
- row_sum += src_row[i0];
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+ const int i10 = i0 % ne10;
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10));
}
-
- dst_row[0] = row_sum;
}
-template<typename T>
-kernel void kernel_soft_max(
- device const char * src0,
- device const char * src1,
- device char * dst,
- constant int64_t & ne00,
- constant int64_t & ne01,
- constant int64_t & ne02,
- constant float & scale,
- constant float & max_bias,
- constant float & m0,
- constant float & m1,
- constant uint32_t & n_head_log2,
- threadgroup float * buf [[threadgroup(0)]],
- uint tgpig[[threadgroup_position_in_grid]],
- uint tpitg[[thread_position_in_threadgroup]],
- uint sgitg[[simdgroup_index_in_threadgroup]],
- uint tiisg[[thread_index_in_simdgroup]],
- uint ntg[[threads_per_threadgroup]]) {
- const int64_t i03 = (tgpig) / (ne02*ne01);
- const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
- const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
+kernel void kernel_sub(
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant uint64_t & nb13,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ constant int64_t & offs,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig.z;
+ const int64_t i02 = tgpig.y;
+ const int64_t i01 = tgpig.x;
- device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
- device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
- device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+ const int64_t i13 = i03 % ne13;
+ const int64_t i12 = i02 % ne12;
+ const int64_t i11 = i01 % ne11;
- float slope = 1.0f;
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs;
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs;
- // ALiBi
- if (max_bias > 0.0f) {
- const int64_t h = i02;
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+ const int i10 = i0 % ne10;
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) - *((device float *)(src1_ptr + i10*nb10));
+ }
+}
- const float base = h < n_head_log2 ? m0 : m1;
- const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+kernel void kernel_mul(
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant uint64_t & nb13,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig.z;
+ const int64_t i02 = tgpig.y;
+ const int64_t i01 = tgpig.x;
- slope = pow(base, exp);
- }
+ const int64_t i13 = i03 % ne13;
+ const int64_t i12 = i02 % ne12;
+ const int64_t i11 = i01 % ne11;
- // parallel max
- float lmax = -INFINITY;
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
- lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+ const int i10 = i0 % ne10;
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10));
}
+}
- // find the max value in the block
- float max_val = simd_max(lmax);
- if (ntg > N_SIMDWIDTH) {
- if (sgitg == 0) {
- buf[tiisg] = -INFINITY;
- }
-
- threadgroup_barrier(mem_flags::mem_threadgroup);
+kernel void kernel_div(
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant uint64_t & nb13,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig.z;
+ const int64_t i02 = tgpig.y;
+ const int64_t i01 = tgpig.x;
- if (tiisg == 0) {
- buf[sgitg] = max_val;
- }
+ const int64_t i13 = i03 % ne13;
+ const int64_t i12 = i02 % ne12;
+ const int64_t i11 = i01 % ne11;
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
+ device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11;
+ device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1;
- max_val = buf[tiisg];
- max_val = simd_max(max_val);
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+ const int i10 = i0 % ne10;
+ *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10));
}
+}
- // parallel sum
- float lsum = 0.0f;
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
- const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
- lsum += exp_psrc0;
- pdst[i00] = exp_psrc0;
- }
+template<typename T>
+kernel void kernel_repeat(
+ device const char * src0,
+ device char * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i3 = tgpig.z;
+ const int64_t i2 = tgpig.y;
+ const int64_t i1 = tgpig.x;
- // This barrier fixes a failing test
- // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
- threadgroup_barrier(mem_flags::mem_none);
+ const int64_t i03 = i3 % ne03;
+ const int64_t i02 = i2 % ne02;
+ const int64_t i01 = i1 % ne01;
- float sum = simd_sum(lsum);
+ device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01;
+ device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ;
- if (ntg > N_SIMDWIDTH) {
- if (sgitg == 0) {
- buf[tiisg] = 0.0f;
- }
+ for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) {
+ const int i00 = i0 % ne00;
+ *((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00));
+ }
+}
- threadgroup_barrier(mem_flags::mem_threadgroup);
+typedef decltype(kernel_repeat<float>) kernel_repeat_t;
- if (tiisg == 0) {
- buf[sgitg] = sum;
- }
+template [[host_name("kernel_repeat_f32")]] kernel kernel_repeat_t kernel_repeat<float>;
+template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat<half>;
+template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat<int>;
+template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat<short>;
- threadgroup_barrier(mem_flags::mem_threadgroup);
+// assumption: src1 is a row
+// broadcast src1 into src0
+kernel void kernel_add_row(
+ device const float4 * src0,
+ device const float4 * src1,
+ device float4 * dst,
+ constant uint64_t & nb [[buffer(28)]],
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] + src1[tpig % nb];
+}
- sum = buf[tiisg];
- sum = simd_sum(sum);
+kernel void kernel_sub_row(
+ device const float4 * src0,
+ device const float4 * src1,
+ device float4 * dst,
+ constant uint64_t & nb [[buffer(28)]],
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] - src1[tpig % nb];
+}
+
+kernel void kernel_mul_row(
+ device const float4 * src0,
+ device const float4 * src1,
+ device float4 * dst,
+ constant uint64_t & nb [[buffer(28)]],
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] * src1[tpig % nb];
+}
+
+kernel void kernel_div_row(
+ device const float4 * src0,
+ device const float4 * src1,
+ device float4 * dst,
+ constant uint64_t & nb [[buffer(28)]],
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] / src1[tpig % nb];
+}
+
+kernel void kernel_scale(
+ device const float * src0,
+ device float * dst,
+ constant float & scale,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] * scale;
+}
+
+kernel void kernel_scale_4(
+ device const float4 * src0,
+ device float4 * dst,
+ constant float & scale,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] * scale;
+}
+
+kernel void kernel_clamp(
+ device const float * src0,
+ device float * dst,
+ constant float & min,
+ constant float & max,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] < min ? min : (src0[tpig] > max ? max : src0[tpig]);
+}
+
+kernel void kernel_relu(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = max(0.0f, src0[tpig]);
+}
+
+kernel void kernel_sigmoid(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig]));
+}
+
+kernel void kernel_tanh(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float & x = src0[tpig];
+ dst[tpig] = precise::tanh(x);
+}
+
+constant float GELU_COEF_A = 0.044715f;
+constant float GELU_QUICK_COEF = -1.702f;
+constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
+
+kernel void kernel_gelu(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float & x = src0[tpig];
+
+ dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+}
+
+kernel void kernel_gelu_4(
+ device const float4 * src0,
+ device float4 * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float4 & x = src0[tpig];
+
+ // BEWARE !!!
+ // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
+ // This was observed with Falcon 7B and 40B models
+ //
+ dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
+}
+
+kernel void kernel_gelu_quick(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float & x = src0[tpig];
+
+ dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
+}
+
+kernel void kernel_gelu_quick_4(
+ device const float4 * src0,
+ device float4 * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float4 & x = src0[tpig];
+
+ dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x)));
+}
+
+kernel void kernel_silu(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float & x = src0[tpig];
+ dst[tpig] = x / (1.0f + exp(-x));
+}
+
+kernel void kernel_silu_4(
+ device const float4 * src0,
+ device float4 * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ device const float4 & x = src0[tpig];
+ dst[tpig] = x / (1.0f + exp(-x));
+}
+
+kernel void kernel_sqr(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] * src0[tpig];
+}
+
+kernel void kernel_sqrt(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = sqrt(src0[tpig]);
+}
+
+kernel void kernel_sin(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = sin(src0[tpig]);
+}
+
+kernel void kernel_cos(
+ device const float * src0,
+ device float * dst,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = cos(src0[tpig]);
+}
+
+kernel void kernel_sum_rows(
+ device const float * src0,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant int64_t & ne13,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant uint64_t & nb13,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tpig[[thread_position_in_grid]]) {
+ int64_t i3 = tpig.z;
+ int64_t i2 = tpig.y;
+ int64_t i1 = tpig.x;
+
+ if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
+ return;
}
- const float inv_sum = 1.0f/sum;
+ device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
+ device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
- pdst[i00] *= inv_sum;
+ float row_sum = 0;
+
+ for (int64_t i0 = 0; i0 < ne00; i0++) {
+ row_sum += src_row[i0];
}
+
+ dst_row[0] = row_sum;
}
template<typename T>
-kernel void kernel_soft_max_4(
+kernel void kernel_soft_max(
device const char * src0,
device const char * src1,
device char * dst,
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
- device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
- device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
- device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
+ device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+ device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
+ device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
float slope = 1.0f;
+ // ALiBi
if (max_bias > 0.0f) {
const int64_t h = i02;
}
// parallel max
- float4 lmax4 = -INFINITY;
+ float lmax = -INFINITY;
- for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
- lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+ lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
}
- const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
-
+ // find the max value in the block
float max_val = simd_max(lmax);
if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
}
// parallel sum
- float4 lsum4 = 0.0f;
- for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
- const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
- lsum4 += exp_psrc4;
- pdst4[i00] = exp_psrc4;
+ float lsum = 0.0f;
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+ const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
+ lsum += exp_psrc0;
+ pdst[i00] = exp_psrc0;
}
- const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
+ // This barrier fixes a failing test
+ // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
+ threadgroup_barrier(mem_flags::mem_none);
+
+ float sum = simd_sum(lsum);
+
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = 0.0f;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ buf[sgitg] = sum;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ sum = buf[tiisg];
+ sum = simd_sum(sum);
+ }
+
+ const float inv_sum = 1.0f/sum;
+
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+ pdst[i00] *= inv_sum;
+ }
+}
+
+template<typename T>
+kernel void kernel_soft_max_4(
+ device const char * src0,
+ device const char * src1,
+ device char * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant float & scale,
+ constant float & max_bias,
+ constant float & m0,
+ constant float & m1,
+ constant uint32_t & n_head_log2,
+ threadgroup float * buf [[threadgroup(0)]],
+ uint tgpig[[threadgroup_position_in_grid]],
+ uint tpitg[[thread_position_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = (tgpig) / (ne02*ne01);
+ const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
+ const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
+
+ device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
+ device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00/4 : nullptr;
+ device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
+
+ float slope = 1.0f;
+
+ if (max_bias > 0.0f) {
+ const int64_t h = i02;
+
+ const float base = h < n_head_log2 ? m0 : m1;
+ const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+ slope = pow(base, exp);
+ }
+
+ // parallel max
+ float4 lmax4 = -INFINITY;
+
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+ lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
+ }
+
+ const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
+
+ float max_val = simd_max(lmax);
+ if (ntg > N_SIMDWIDTH) {
+ if (sgitg == 0) {
+ buf[tiisg] = -INFINITY;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ if (tiisg == 0) {
+ buf[sgitg] = max_val;
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+
+ max_val = buf[tiisg];
+ max_val = simd_max(max_val);
+ }
+
+ // parallel sum
+ float4 lsum4 = 0.0f;
+ for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+ const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
+ lsum4 += exp_psrc4;
+ pdst4[i00] = exp_psrc4;
+ }
+
+ const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
// This barrier fixes a failing test
// ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
}
-constexpr constant static float kvalues_iq4nl_f[16] = {
- -127.f, -104.f, -83.f, -65.f, -49.f, -35.f, -22.f, -10.f, 1.f, 13.f, 25.f, 38.f, 53.f, 69.f, 89.f, 113.f
-};
-
kernel void kernel_cpy_f32_iq4_nl(
device const float * src0,
device void * dst,
kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
}
-//============================= templates and their specializations =============================
-
-// NOTE: this is not dequantizing - we are simply fitting the template
-template <typename type4x4>
-void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
- float4x4 temp = *(((device float4x4 *)src));
- for (int i = 0; i < 16; i++){
- reg[i/4][i%4] = temp[i/4][i%4];
- }
-}
-
-template <typename type4x4>
-void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
- half4x4 temp = *(((device half4x4 *)src));
- for (int i = 0; i < 16; i++){
- reg[i/4][i%4] = temp[i/4][i%4];
- }
-}
-
-template <typename type4x4>
-void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
- device const uint16_t * qs = ((device const uint16_t *)xb + 1);
- const float d1 = il ? (xb->d / 16.h) : xb->d;
- const float d2 = d1 / 256.f;
- const float md = -8.h * xb->d;
- const ushort mask0 = il ? 0x00F0 : 0x000F;
- const ushort mask1 = mask0 << 8;
-
- for (int i=0;i<8;i++) {
- reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
- reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
- }
-}
-
-template <typename type4x4>
-void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
- device const uint16_t * qs = ((device const uint16_t *)xb + 2);
- const float d1 = il ? (xb->d / 16.h) : xb->d;
- const float d2 = d1 / 256.f;
- const float m = xb->m;
- const ushort mask0 = il ? 0x00F0 : 0x000F;
- const ushort mask1 = mask0 << 8;
-
- for (int i=0;i<8;i++) {
- reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
- reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
- }
-}
-
-template <typename type4x4>
-void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
- device const uint16_t * qs = ((device const uint16_t *)xb + 3);
- const float d = xb->d;
- const float md = -16.h * xb->d;
- const ushort mask = il ? 0x00F0 : 0x000F;
-
- const uint32_t qh = *((device const uint32_t *)xb->qh);
-
- const int x_mv = il ? 4 : 0;
-
- const int gh_mv = il ? 12 : 0;
- const int gh_bk = il ? 0 : 4;
-
- for (int i = 0; i < 8; i++) {
- // extract the 5-th bits for x0 and x1
- const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
- const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
-
- // combine the 4-bits from qs with the 5th bit
- const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
- const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
-
- reg[i/2][2*(i%2)+0] = d * x0 + md;
- reg[i/2][2*(i%2)+1] = d * x1 + md;
- }
-}
-
-template <typename type4x4>
-void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
- device const uint16_t * qs = ((device const uint16_t *)xb + 4);
- const float d = xb->d;
- const float m = xb->m;
- const ushort mask = il ? 0x00F0 : 0x000F;
-
- const uint32_t qh = *((device const uint32_t *)xb->qh);
-
- const int x_mv = il ? 4 : 0;
-
- const int gh_mv = il ? 12 : 0;
- const int gh_bk = il ? 0 : 4;
-
- for (int i = 0; i < 8; i++) {
- // extract the 5-th bits for x0 and x1
- const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
- const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
-
- // combine the 4-bits from qs with the 5th bit
- const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
- const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
-
- reg[i/2][2*(i%2)+0] = d * x0 + m;
- reg[i/2][2*(i%2)+1] = d * x1 + m;
- }
-}
-
-template <typename type4x4>
-void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
- device const int8_t * qs = ((device const int8_t *)xb->qs);
- const half d = xb->d;
-
- for (int i = 0; i < 16; i++) {
- reg[i/4][i%4] = (qs[i + 16*il] * d);
- }
-}
-
-template <typename type4x4>
-void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
- const float d = xb->d;
- const float min = xb->dmin;
- device const uint8_t * q = (device const uint8_t *)xb->qs;
- float dl, ml;
- uint8_t sc = xb->scales[il];
-
- q = q + 32*(il/8) + 16*(il&1);
- il = (il/2)%4;
-
- half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
- uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
- dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
- for (int i = 0; i < 16; ++i) {
- reg[i/4][i%4] = dl * (q[i] & mask) - ml;
- }
-}
-
-template <typename type4x4>
-void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
- const half d_all = xb->d;
- device const uint8_t * q = (device const uint8_t *)xb->qs;
- device const uint8_t * h = (device const uint8_t *)xb->hmask;
- device const int8_t * scales = (device const int8_t *)xb->scales;
-
- q = q + 32 * (il/8) + 16 * (il&1);
- h = h + 16 * (il&1);
- uint8_t m = 1 << (il/2);
- uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
- ((il/4)>0 ? 12 : 3);
- uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
- uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
- int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
- : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
- float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
- const float ml = 4.f * dl;
-
- il = (il/2) & 3;
- const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
- const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
- dl *= coef;
-
- for (int i = 0; i < 16; ++i) {
- reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
- }
-}
-
-static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
- return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
- : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
-}
-
-template <typename type4x4>
-void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
- device const uchar * q = xb->qs;
-
- short is = (il/4) * 2;
- q = q + (il/4) * 32 + 16 * (il&1);
- il = il & 3;
- const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
- const float d = il < 2 ? xb->d : xb->d / 16.h;
- const float min = xb->dmin;
- const float dl = d * sc[0];
- const float ml = min * sc[1];
-
- const ushort mask = il<2 ? 0x0F : 0xF0;
- for (int i = 0; i < 16; ++i) {
- reg[i/4][i%4] = dl * (q[i] & mask) - ml;
- }
-}
-
-template <typename type4x4>
-void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
- device const uint8_t * q = xb->qs;
- device const uint8_t * qh = xb->qh;
-
- short is = (il/4) * 2;
- q = q + 32 * (il/4) + 16 * (il&1);
- qh = qh + 16 * (il&1);
- uint8_t ul = 1 << (il/2);
- il = il & 3;
- const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
- const float d = il < 2 ? xb->d : xb->d / 16.f;
- const float min = xb->dmin;
- const float dl = d * sc[0];
- const float ml = min * sc[1];
-
- const ushort mask = il<2 ? 0x0F : 0xF0;
- const float qh_val = il<2 ? 16.f : 256.f;
- for (int i = 0; i < 16; ++i) {
- reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
- }
-}
-
-template <typename type4x4>
-void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
- const half d_all = xb->d;
- device const uint8_t * ql = (device const uint8_t *)xb->ql;
- device const uint8_t * qh = (device const uint8_t *)xb->qh;
- device const int8_t * scales = (device const int8_t *)xb->scales;
-
- ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
- qh = qh + 32*(il/8) + 16*(il&1);
- float sc = scales[(il%2) + 2 * ((il/2))];
- il = (il/2) & 3;
-
- const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
- const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
- const float coef = il>1 ? 1.f/16.f : 1.f;
- const float ml = d_all * sc * 32.f;
- const float dl = d_all * sc * coef;
- for (int i = 0; i < 16; ++i) {
- const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
- : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
- reg[i/4][i%4] = dl * q - ml;
- }
-}
-
-template <typename type4x4>
-void dequantize_iq2_xxs(device const block_iq2_xxs * xb, short il, thread type4x4 & reg) {
- // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
- const float d = xb->d;
- const int ib32 = il/2;
- il = il%2;
- // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
- // each block of 32 needs 2 uint32_t's for the quants & scale, so 4 uint16_t's.
- device const uint16_t * q2 = xb->qs + 4*ib32;
- const uint32_t aux32_g = q2[0] | (q2[1] << 16);
- const uint32_t aux32_s = q2[2] | (q2[3] << 16);
- thread const uint8_t * aux8 = (thread const uint8_t *)&aux32_g;
- const float dl = d * (0.5f + (aux32_s >> 28)) * 0.25f;
- constant uint8_t * grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+0]);
- uint8_t signs = ksigns_iq2xs[(aux32_s >> 14*il) & 127];
- for (int i = 0; i < 8; ++i) {
- reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
- }
- grid = (constant uint8_t *)(iq2xxs_grid + aux8[2*il+1]);
- signs = ksigns_iq2xs[(aux32_s >> (14*il+7)) & 127];
- for (int i = 0; i < 8; ++i) {
- reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
- }
-}
-
-template <typename type4x4>
-void dequantize_iq2_xs(device const block_iq2_xs * xb, short il, thread type4x4 & reg) {
- // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
- const float d = xb->d;
- const int ib32 = il/2;
- il = il%2;
- // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
- device const uint16_t * q2 = xb->qs + 4*ib32;
- const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
- constant uint8_t * grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+0] & 511));
- uint8_t signs = ksigns_iq2xs[q2[2*il+0] >> 9];
- for (int i = 0; i < 8; ++i) {
- reg[i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
- }
- grid = (constant uint8_t *)(iq2xs_grid + (q2[2*il+1] & 511));
- signs = ksigns_iq2xs[q2[2*il+1] >> 9];
- for (int i = 0; i < 8; ++i) {
- reg[2+i/4][i%4] = dl * grid[i] * (signs & kmask_iq2xs[i] ? -1.f : 1.f);
- }
-}
-
-template <typename type4x4>
-void dequantize_iq3_xxs(device const block_iq3_xxs * xb, short il, thread type4x4 & reg) {
- // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
- const float d = xb->d;
- const int ib32 = il/2;
- il = il%2;
- // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
- device const uint8_t * q3 = xb->qs + 8*ib32;
- device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32;
- const uint32_t aux32 = gas[0] | (gas[1] << 16);
- const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f;
- constant uint8_t * grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+0]);
- constant uint8_t * grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+1]);
- uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127];
- for (int i = 0; i < 4; ++i) {
- reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
- reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
- }
- grid1 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+2]);
- grid2 = (constant uint8_t *)(iq3xxs_grid + q3[4*il+3]);
- signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127];
- for (int i = 0; i < 4; ++i) {
- reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
- reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
- }
-}
-
-template <typename type4x4>
-void dequantize_iq3_s(device const block_iq3_s * xb, short il, thread type4x4 & reg) {
- // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
- const float d = xb->d;
- const int ib32 = il/2;
- il = il%2;
- // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
- device const uint8_t * qs = xb->qs + 8*ib32;
- device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
- const uint8_t qh = xb->qh[ib32] >> 4*il;
- const float dl = d * (1 + 2*((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf));
- constant uint8_t * grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+0] | ((qh << 8) & 256)));
- constant uint8_t * grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+1] | ((qh << 7) & 256)));
- for (int i = 0; i < 4; ++i) {
- reg[0][i] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i+0]);
- reg[1][i] = dl * grid2[i] * select(1, -1, signs[0] & kmask_iq2xs[i+4]);
- }
- grid1 = (constant uint8_t *)(iq3s_grid + (qs[4*il+2] | ((qh << 6) & 256)));
- grid2 = (constant uint8_t *)(iq3s_grid + (qs[4*il+3] | ((qh << 5) & 256)));
- for (int i = 0; i < 4; ++i) {
- reg[2][i] = dl * grid1[i] * select(1, -1, signs[1] & kmask_iq2xs[i+0]);
- reg[3][i] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i+4]);
- }
-}
-
-template <typename type4x4>
-void dequantize_iq2_s(device const block_iq2_s * xb, short il, thread type4x4 & reg) {
- // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
- const float d = xb->d;
- const int ib32 = il/2;
- il = il%2;
- // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
- device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
- device const uint8_t * signs = qs + QK_K/8;
- const uint8_t qh = xb->qh[ib32] >> 4*il;
- const float dl = d * (0.5f + ((xb->scales[ib32] >> 4*il) & 0xf)) * 0.25f;
- constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[0] | ((qh << 8) & 0x300)));
- constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[1] | ((qh << 6) & 0x300)));
- for (int i = 0; i < 8; ++i) {
- reg[i/4+0][i%4] = dl * grid1[i] * select(1, -1, signs[0] & kmask_iq2xs[i]);
- reg[i/4+2][i%4] = dl * grid2[i] * select(1, -1, signs[1] & kmask_iq2xs[i]);
- }
-}
-
-template <typename type4x4>
-void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 & reg) {
- // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
- const int ib32 = il/2;
- il = il%2;
- const float d = xb->d;
- device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
- device const uint16_t * qh = xb->qh;
- const float dl = d * (2*((qh[ib32] >> 12) & 7) + 1);
- const float ml = dl * (qh[ib32] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA);
- const uint16_t h = qh[ib32] >> 6*il;
- constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((h << 8) & 0x700)));
- constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((h << 5) & 0x700)));
- for (int i = 0; i < 4; ++i) {
- reg[0][i] = dl * (grid1[i] & 0xf) + ml;
- reg[1][i] = dl * (grid1[i] >> 4) + ml;
- reg[2][i] = dl * (grid2[i] & 0xf) + ml;
- reg[3][i] = dl * (grid2[i] >> 4) + ml;
- }
-}
-
-template <typename type4x4>
-void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) {
- // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
- const int ib32 = il/2;
- il = il%2;
- device const uint16_t * sc = (device const uint16_t *)xb->scales;
-
- iq1m_scale_t scale;
- scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
- const float d = scale.f16;
-
- device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
- device const uint8_t * qh = xb->qh + 2*ib32 + il;
-
- const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
- const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
- const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
- constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
- constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
- for (int i = 0; i < 4; ++i) {
- reg[0][i] = dl * (grid1[i] & 0xf) + ml1;
- reg[1][i] = dl * (grid1[i] >> 4) + ml1;
- reg[2][i] = dl * (grid2[i] & 0xf) + ml2;
- reg[3][i] = dl * (grid2[i] >> 4) + ml2;
- }
-}
-
-template <typename type4x4>
-void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
- device const uint16_t * q4 = (device const uint16_t *)xb->qs;
- const float d = xb->d;
- uint32_t aux32;
- thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
- for (int i = 0; i < 4; ++i) {
- aux32 = ((q4[2*i] | (q4[2*i+1] << 16)) >> 4*il) & 0x0f0f0f0f;
- reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
- reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
- reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
- reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
- }
-}
-
-template <typename type4x4>
-void dequantize_iq4_xs(device const block_iq4_xs * xb, short il, thread type4x4 & reg) {
- // il is 0...15 for QK_K = 256 => index of block of 32 is il/2
- const int ib32 = il/2;
- il = il%2;
- // il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
- device const uint32_t * q4 = (device const uint32_t *)xb->qs + 4*ib32;
- const int ls = ((xb->scales_l[ib32/2] >> 4*(ib32%2)) & 0xf) | (((xb->scales_h >> 2*ib32) & 3) << 4);
- const float d = (float)xb->d * (ls - 32);
- uint32_t aux32;
- thread const uint8_t * q8 = (thread const uint8_t *)&aux32;
- for (int i = 0; i < 4; ++i) {
- aux32 = (q4[i] >> 4*il) & 0x0f0f0f0f;
- reg[i][0] = d * kvalues_iq4nl_f[q8[0]];
- reg[i][1] = d * kvalues_iq4nl_f[q8[1]];
- reg[i][2] = d * kvalues_iq4nl_f[q8[2]];
- reg[i][3] = d * kvalues_iq4nl_f[q8[3]];
- }
-}
-
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
kernel void kernel_get_rows_q(
device const void * src0,