GGML_METAL_DECL_KERNEL(relu);
GGML_METAL_DECL_KERNEL(gelu);
GGML_METAL_DECL_KERNEL(soft_max);
+ GGML_METAL_DECL_KERNEL(soft_max_4);
GGML_METAL_DECL_KERNEL(diag_mask_inf);
+ GGML_METAL_DECL_KERNEL(diag_mask_inf_8);
GGML_METAL_DECL_KERNEL(get_rows_f16);
GGML_METAL_DECL_KERNEL(get_rows_q4_0);
GGML_METAL_DECL_KERNEL(get_rows_q4_1);
GGML_METAL_DECL_KERNEL(norm);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
+ GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
GGML_METAL_ADD_KERNEL(relu);
GGML_METAL_ADD_KERNEL(gelu);
GGML_METAL_ADD_KERNEL(soft_max);
+ GGML_METAL_ADD_KERNEL(soft_max_4);
GGML_METAL_ADD_KERNEL(diag_mask_inf);
+ GGML_METAL_ADD_KERNEL(diag_mask_inf_8);
GGML_METAL_ADD_KERNEL(get_rows_f16);
GGML_METAL_ADD_KERNEL(get_rows_q4_0);
GGML_METAL_ADD_KERNEL(get_rows_q4_1);
GGML_METAL_ADD_KERNEL(norm);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
+ GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
GGML_METAL_DEL_KERNEL(relu);
GGML_METAL_DEL_KERNEL(gelu);
GGML_METAL_DEL_KERNEL(soft_max);
- GGML_METAL_DEL_KERNEL(diag_mask_inf);
+ GGML_METAL_DEL_KERNEL(soft_max_4);
+ GGML_METAL_DEL_KERNEL(diag_mask_inf_8);
GGML_METAL_DEL_KERNEL(get_rows_f16);
GGML_METAL_DEL_KERNEL(get_rows_q4_0);
GGML_METAL_DEL_KERNEL(get_rows_q4_1);
GGML_METAL_DEL_KERNEL(norm);
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
+ GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
- const int64_t n = ggml_nelements(dst);
+ const int64_t n = ggml_nelements(dst)/4;
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- const int64_t n = ggml_nelements(dst);
+ const int64_t n = ggml_nelements(dst)/4;
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
- const int64_t n = ggml_nelements(dst);
+ const int64_t n = ggml_nelements(dst)/4;
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
{
const int nth = 32;
- [encoder setComputePipelineState:ctx->pipeline_soft_max];
+ if (ne00%4 == 0) {
+ [encoder setComputePipelineState:ctx->pipeline_soft_max_4];
+ } else {
+ [encoder setComputePipelineState:ctx->pipeline_soft_max];
+ }
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
- [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
{
const int n_past = ((int32_t *)(dst->op_params))[0];
- [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
+ if (ne00%8 == 0) {
+ [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf_8];
+ } else {
+ [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
+ }
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
- [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ if (ne00%8 == 0) {
+ [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ }
+ else {
+ [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+ }
} break;
case GGML_OP_MUL_MAT:
{
} else {
int nth0 = 32;
int nth1 = 1;
+ int nrows = 1;
// use custom matrix x vector kernel
switch (src0t) {
nth1 = 1;
if (ne11 * ne12 < 4) {
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
+ } else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
+ nrows = ne11;
} else {
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
+ nrows = 4;
}
} break;
case GGML_TYPE_Q4_0:
else if (src0t == GGML_TYPE_Q6_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {
- int64_t ny = (ne11 + 3)/4;
+ int64_t ny = (ne11 + nrows - 1)/nrows;
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
}
}
kernel void kernel_scale(
- device const float * src0,
- device float * dst,
+ device const float4 * src0,
+ device float4 * dst,
constant float & scale,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * scale;
}
kernel void kernel_silu(
- device const float * src0,
- device float * dst,
+ device const float4 * src0,
+ device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
- float x = src0[tpig];
+ device const float4 & x = src0[tpig];
dst[tpig] = x / (1.0f + exp(-x));
}
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
kernel void kernel_gelu(
- device const float * src0,
- device float * dst,
+ device const float4 * src0,
+ device float4 * dst,
uint tpig[[thread_position_in_grid]]) {
- float x = src0[tpig];
+ device const float4 & x = src0[tpig];
// BEWARE !!!
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
- threadgroup float * buf [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
// parallel max
- buf[tpitg[0]] = -INFINITY;
- for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
- buf[tpitg[0]] = MAX(buf[tpitg[0]], psrc0[i00]);
- }
-
- // reduce
- threadgroup_barrier(mem_flags::mem_threadgroup);
- for (uint i = ntg[0]/2; i > 0; i /= 2) {
- if (tpitg[0] < i) {
- buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]);
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ float lmax = psrc0[tpitg[0]];
+ for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
+ lmax = MAX(lmax, psrc0[i00]);
}
-
- //// broadcast - not needed. There is a threadgroup barrier above in the last iteration of
- // the loop, and when that is done, buf[0] has the correct (synchronized) value
- //if (tpitg[0] == 0) {
- // buf[0] = buf[0];
- //}
-
- //threadgroup_barrier(mem_flags::mem_threadgroup);
-
- const float max = buf[0];
+ const float max = simd_max(lmax);
// parallel sum
- buf[tpitg[0]] = 0.0f;
+ float lsum = 0.0f;
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
const float exp_psrc0 = exp(psrc0[i00] - max);
- buf[tpitg[0]] += exp_psrc0;
+ lsum += exp_psrc0;
// Remember the result of exp here. exp is expensive, so we really do not
// whish to compute it twice.
pdst[i00] = exp_psrc0;
}
- // reduce
- threadgroup_barrier(mem_flags::mem_threadgroup);
- for (uint i = ntg[0]/2; i > 0; i /= 2) {
- if (tpitg[0] < i) {
- buf[tpitg[0]] += buf[tpitg[0] + i];
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ const float sum = simd_sum(lsum);
+
+ for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
+ pdst[i00] /= sum;
}
+}
- // broadcast - not needed, see above
- //// broadcast
- //if (tpitg[0] == 0) {
- // buf[0] = buf[0];
- //}
+kernel void kernel_soft_max_4(
+ device const float * src0,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig[2];
+ const int64_t i02 = tgpig[1];
+ const int64_t i01 = tgpig[0];
- //threadgroup_barrier(mem_flags::mem_threadgroup);
+ device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+ device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
- const float sum = buf[0];
+ // parallel max
+ float4 lmax4 = psrc4[tpitg[0]];
+ for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
+ lmax4 = fmax(lmax4, psrc4[i00]);
+ }
+ float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
- for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
- pdst[i00] /= sum;
+ const float max = simd_max(lmax);
+
+ // parallel sum
+ float4 lsum4 = 0.0f;
+ for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
+ const float4 exp_psrc4 = exp(psrc4[i00] - max);
+ lsum4 += exp_psrc4;
+ pdst4[i00] = exp_psrc4;
+ }
+ float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
+
+ const float sum = simd_sum(lsum);
+
+ for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
+ pdst4[i00] /= sum;
}
}
dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
} else {
dst[i02*ne01*ne00 + i01*ne00 + i00] = src0[i02*ne01*ne00 + i01*ne00 + i00];
+ }
+}
+
+kernel void kernel_diag_mask_inf_8(
+ device const float4 * src0,
+ device float4 * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int & n_past,
+ uint3 tpig[[thread_position_in_grid]]) {
+
+ const int64_t i = 2*tpig[0];
+
+ dst[i+0] = src0[i+0];
+ dst[i+1] = src0[i+1];
+ int64_t i4 = 4*i;
+ const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
+ const int64_t i01 = i4/(ne00); i4 -= i01*ne00;
+ const int64_t i00 = i4;
+ for (int k = 3; k >= 0; --k) {
+ if (i00 + 4 + k <= n_past + i01) {
+ break;
+ }
+ dst[i+1][k] = -INFINITY;
+ if (i00 + k > n_past + i01) {
+ dst[i][k] = -INFINITY;
+ }
}
}
}
}
+// Assumes row size (ne00) is a multiple of 4
+kernel void kernel_mul_mat_f16_f32_l4(
+ device const char * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]]) {
+
+ const int nrows = ne11;
+ const int64_t r0 = tgpig.x;
+ const int64_t im = tgpig.z;
+
+ device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
+
+ for (int r1 = 0; r1 < nrows; ++r1) {
+ device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
+
+ float sumf = 0;
+ for (int i = tiisg; i < ne00/4; i += 32) {
+ for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
+ }
+
+ float all_sum = simd_sum(sumf);
+ if (tiisg == 0) {
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ }
+ }
+}
+
kernel void kernel_alibi_f32(
device const float * src0,
device float * dst,
template <typename type4x4>
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
+
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
- const half d = il ? (xb->d / 16.h) : xb->d;
- const half m = il ? ( -8.h * 16.h) : -8.h;
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
+ const float d2 = d1 / 256.f;
+ const float md = -8.h * xb->d;
const ushort mask0 = il ? 0x00F0 : 0x000F;
- const ushort mask1 = il ? 0xF000 : 0x0F00;
+ const ushort mask1 = mask0 << 8;
for (int i=0;i<8;i++) {
- reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d;
- reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
+ reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
+ reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
}
+
}
template <typename type4x4>
void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
+
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
- const half d = il ? (xb->d / 16.h) : xb->d;
- const half m = xb->m;
+ const float d1 = il ? (xb->d / 16.h) : xb->d;
+ const float d2 = d1 / 256.f;
+ const float m = xb->m;
const ushort mask0 = il ? 0x00F0 : 0x000F;
- const ushort mask1 = il ? 0xF000 : 0x0F00;
+ const ushort mask1 = mask0 << 8;
for (int i=0;i<8;i++) {
- reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) * d) + m;
- reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
+ reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
+ reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
}
}
template <typename type4x4>
void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
- const float d_all = (float)(xb->d);
+ const half d_all = xb->d;
device const uint8_t * q = (device const uint8_t *)xb->qs;
device const uint8_t * h = (device const uint8_t *)xb->hmask;
device const int8_t * scales = (device const int8_t *)xb->scales;
((il/4)>0 ? 12 : 3);
uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
- int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) : \
- (scale_2&kmask2) | ((scale_1&kmask1) << 4);
- float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
+ int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
+ : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
+ half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h);
+ const half ml = 4.h * dl;
- il = (il/2)%4;
- float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
- uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
+ il = (il/2) & 3;
+ const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
+ const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
+ dl *= coef;
for (int i = 0; i < 16; ++i) {
- reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i] & m) ? 0 : 4.f/coef));
+ reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
}
+
#else
float kcoef = il&1 ? 1.f/16.f : 1.f;
uint16_t kmask = il&1 ? 0xF0 : 0x0F;
#endif
}
+static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
+ return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
+ : uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
+}
+
template <typename type4x4>
void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
- device const uint8_t * q = xb->qs;
+ device const uchar * q = xb->qs;
#if QK_K == 256
- const float d = (float)(xb->d);
- const float min = (float)(xb->dmin);
short is = (il/4) * 2;
q = q + (il/4) * 32 + 16 * (il&1);
- il = il%4;
- const uchar4 sc = get_scale_min_k4(is, xb->scales);
- const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
- const float ml = il<2 ? min * sc[1] : min * sc[3];
+ il = il & 3;
+ const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
+ const half d = il < 2 ? xb->d : xb->d / 16.h;
+ const half min = xb->dmin;
+ const half dl = d * sc[0];
+ const half ml = min * sc[1];
#else
q = q + 16 * (il&1);
device const uint8_t * s = xb->scales;
device const half2 * dh = (device const half2 *)xb->d;
const float2 d = (float2)dh[0];
const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
- const float ml = il<2 ? d[1] * (s[0]>>4) : d[1 ]* (s[1]>>4);
+ const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4);
#endif
const ushort mask = il<2 ? 0x0F : 0xF0;
for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = dl * (q[i] & mask) - ml;
}
+
}
template <typename type4x4>
device const uint8_t * qh = xb->qh;
#if QK_K == 256
- const float d = (float)(xb->d);
- const float min = (float)(xb->dmin);
short is = (il/4) * 2;
q = q + 32 * (il/4) + 16 * (il&1);
qh = qh + 16 * (il&1);
uint8_t ul = 1 << (il/2);
- il = il%4;
- const uchar4 sc = get_scale_min_k4(is, xb->scales);
- const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
- const float ml = il<2 ? min * sc[1] : min * sc[3];
+ il = il & 3;
+ const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
+ const half d = il < 2 ? xb->d : xb->d / 16.h;
+ const half min = xb->dmin;
+ const half dl = d * sc[0];
+ const half ml = min * sc[1];
- const ushort mask = il<2 ? 0x0F : 0xF0;
- const float qh_val = il<2 ? 16.f : 256.f;
+ const ushort mask = il<2 ? 0x0F : 0xF0;
+ const half qh_val = il<2 ? 16.h : 256.h;
for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
}
template <typename type4x4>
void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
- const float d_all = (float)(xb->d);
+ const half d_all = xb->d;
device const uint8_t * ql = (device const uint8_t *)xb->ql;
device const uint8_t * qh = (device const uint8_t *)xb->qh;
device const int8_t * scales = (device const int8_t *)xb->scales;
#if QK_K == 256
ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
qh = qh + 32*(il/8) + 16*(il&1);
- float sc = scales[(il%2) + 2 * ((il/2))];
- il = (il/2)%4;
+ half sc = scales[(il%2) + 2 * ((il/2))];
+ il = (il/2) & 3;
#else
ql = ql + 16 * (il&1);
- float sc = scales[il];
+ half sc = scales[il];
#endif
+ const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
+ const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
+ const half coef = il>1 ? 1.f/16.h : 1.h;
+ const half ml = d_all * sc * 32.h;
+ const half dl = d_all * sc * coef;
for (int i = 0; i < 16; ++i) {
- uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
- uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
- const float coef = il>1 ? 1.f/16.f : 1.f;
- float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \
- ((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef;
- reg[i/4][i%4] = d_all * sc * q * coef;
+ const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
+ : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
+ reg[i/4][i%4] = dl * q - ml;
}
}