#import <Foundation/Foundation.h>
#import <Metal/Metal.h>
-#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
#undef MIN
#undef MAX
GGML_METAL_DECL_KERNEL(mul_mat_q4_K_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q5_K_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q6_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
+ GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
+ GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
+ GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32);
+ GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32);
GGML_METAL_DECL_KERNEL(rope);
GGML_METAL_DECL_KERNEL(alibi_f32);
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
ctx->n_buffers = 0;
ctx->concur_list_len = 0;
- // determine if we can use MPS
- if (MPSSupportsMTLDevice(ctx->device)) {
- fprintf(stderr, "%s: using MPS\n", __func__);
- } else {
- fprintf(stderr, "%s: not using MPS\n", __func__);
- GGML_ASSERT(false && "MPS not supported");
- }
#if 0
// compile from source string and show compile log
GGML_METAL_ADD_KERNEL(mul_mat_q4_K_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q5_K_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q6_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
+ GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32);
+ GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32);
GGML_METAL_ADD_KERNEL(rope);
GGML_METAL_ADD_KERNEL(alibi_f32);
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
- id<MTLComputeCommandEncoder> encoder = nil;
+ id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
const int node_start = (cb_idx + 0) * n_nodes_per_cb;
const int node_end = (cb_idx == n_cb - 1) ? n_nodes : (cb_idx + 1) * n_nodes_per_cb;
const int i = has_concur ? ctx->concur_list[ind] : ind;
if (i == -1) {
- if (encoder == nil) {
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
- continue;
- }
[encoder memoryBarrierWithScope:MTLBarrierScopeBuffers];
continue;
}
} break;
case GGML_OP_ADD:
{
- if (encoder == nil) {
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
- }
-
if (ggml_nelements(src1) == ne10) {
// src1 is a row
[encoder setComputePipelineState:ctx->pipeline_add_row];
} break;
case GGML_OP_MUL:
{
- if (encoder == nil) {
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
- }
-
if (ggml_nelements(src1) == ne10) {
// src1 is a row
[encoder setComputePipelineState:ctx->pipeline_mul_row];
} break;
case GGML_OP_SCALE:
{
- if (encoder == nil) {
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
- }
-
const float scale = *(const float *) src1->data;
[encoder setComputePipelineState:ctx->pipeline_scale];
switch (ggml_get_unary_op(gf->nodes[i])) {
case GGML_UNARY_OP_SILU:
{
- if (encoder == nil) {
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
- }
-
[encoder setComputePipelineState:ctx->pipeline_silu];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
} break;
case GGML_UNARY_OP_RELU:
{
- if (encoder == nil) {
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
- }
-
[encoder setComputePipelineState:ctx->pipeline_relu];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
} break;
case GGML_UNARY_OP_GELU:
{
- if (encoder == nil) {
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
- }
-
[encoder setComputePipelineState:ctx->pipeline_gelu];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
} break;
case GGML_OP_SOFT_MAX:
{
- if (encoder == nil) {
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
- }
-
const int nth = 32;
[encoder setComputePipelineState:ctx->pipeline_soft_max];
} break;
case GGML_OP_DIAG_MASK_INF:
{
- if (encoder == nil) {
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
- }
-
const int n_past = ((int32_t *)(dst->op_params))[0];
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
GGML_ASSERT(ne00 == ne10);
// GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
+ uint gqa = ne12/ne02;
GGML_ASSERT(ne03 == ne13);
+ // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
+ // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
if (ggml_is_contiguous(src0) &&
ggml_is_contiguous(src1) &&
- (src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) {
-
- if (encoder != nil) {
- [encoder endEncoding];
- encoder = nil;
- }
-
- MPSDataType src0dt = src0t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
- MPSDataType src1dt = src1t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
-
- // for F32 x F32 we use MPS
- MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor
- matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:src0->nb[1] dataType:src0dt];
-
- MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor
- matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt];
-
- MPSMatrixDescriptor * desc = [MPSMatrixDescriptor
- matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32];
-
- MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc]
- initWithDevice:ctx->device transposeLeft:false transposeRight:true
- resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
-
- // we need to do ne12 multiplications
- // TODO: is there a way to do this in parallel - currently very slow ..
- // TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
- for (int64_t i02 = 0; i02 < ne12; ++i02) {
- size_t offs_src0_cur = offs_src0 + i02/(ne12/ne02)*nb02; // gqa not used for now
- size_t offs_src1_cur = offs_src1 + i02*nb12;
- size_t offs_dst_cur = offs_dst + i02*nb2;
-
- MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0];
- MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1];
- MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst_cur descriptor:desc ];
-
- [mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
- }
- } else {
- if (encoder == nil) {
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
+ src1t == GGML_TYPE_F32 &&
+ [ctx->device supportsFamily:MTLGPUFamilyApple7] &&
+ ne00%32 == 0 &&
+ ne11 > 1) {
+ switch (src0->type) {
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break;
+ case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
+ case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
+ case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
+ case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
+ case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_K_f32]; break;
+ case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_K_f32]; break;
+ case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q6_K_f32]; break;
+ default: GGML_ASSERT(false && "MUL MAT-MAT not implemented");
+ }
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
+ [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
+ [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
+ [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5];
+ [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6];
+ [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
+ [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
+ [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
+ [encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
+ [encoder setThreadgroupMemoryLength:8192 atIndex:0];
+ [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
}
-
+ else {
int nth0 = 32;
int nth1 = 1;
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
+ [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_Q3_K) {
#ifdef GGML_QKK_64
- [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
#else
- [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01+3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
#endif
}
else if (src0t == GGML_TYPE_Q5_K) {
- [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_Q6_K) {
- [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} break;
case GGML_OP_GET_ROWS:
{
- if (encoder == nil) {
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
- }
-
switch (src0->type) {
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
} break;
case GGML_OP_RMS_NORM:
{
- if (encoder == nil) {
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
- }
-
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
} break;
case GGML_OP_NORM:
{
- if (encoder == nil) {
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
- }
-
const float eps = 1e-5f;
const int nth = 256;
} break;
case GGML_OP_ALIBI:
{
- if (encoder == nil) {
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
- }
-
GGML_ASSERT((src0t == GGML_TYPE_F32));
const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
} break;
case GGML_OP_ROPE:
{
- if (encoder == nil) {
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
- }
-
const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
case GGML_OP_CPY:
case GGML_OP_CONT:
{
- if (encoder == nil) {
- encoder = [command_buffer computeCommandEncoderWithDescriptor: edesc];
- }
-
const int nth = 32;
switch (src0t) {
uint8_t qs[QK4_1 / 2]; // nibbles / quants
} block_q4_1;
-static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) {
- const int qk = QK4_0;
-
- assert(k % qk == 0);
-
- const int nb = k / qk;
-
- for (int i = 0; i < nb; i++) {
- const half d = x[i].d;
-
- for (int j = 0; j < qk/2; ++j) {
- const int x0 = (x[i].qs[j] & 0x0F) - 8;
- const int x1 = (x[i].qs[j] >> 4) - 8;
-
- y[i*qk + j + 0 ] = x0*d;
- y[i*qk + j + qk/2] = x1*d;
- }
- }
-}
-
-static void dequantize_row_q4_1(device const block_q4_1 * x, device float * y, int k) {
- const int qk = QK4_1;
-
- assert(k % qk == 0);
-
- const int nb = k / qk;
-
- for (int i = 0; i < nb; i++) {
- const half d = x[i].d;
- const half m = x[i].m;
-
- for (int j = 0; j < qk/2; ++j) {
- const int x0 = (x[i].qs[j] & 0x0F);
- const int x1 = (x[i].qs[j] >> 4);
-
- y[i*qk + j + 0 ] = x0*d + m;
- y[i*qk + j + qk/2] = x1*d + m;
- }
- }
-}
-
kernel void kernel_add(
device const float * src0,
device const float * src1,
}
}
-kernel void kernel_get_rows_f16(
- device const void * src0,
- device const int * src1,
- device float * dst,
- constant int64_t & ne00,
- constant uint64_t & nb01,
- constant uint64_t & nb1,
- uint tpig[[thread_position_in_grid]]) {
- const int i = tpig;
- const int r = ((device int32_t *) src1)[i];
-
- for (int j = 0; j < ne00; j++) {
- dst[i*nb1 + j] = ((device half *) ((device char *) src0 + r*nb01))[j];
- }
-}
-
-kernel void kernel_get_rows_q4_0(
- device const void * src0,
- device const int * src1,
- device float * dst,
- constant int64_t & ne00,
- constant uint64_t & nb01,
- constant uint64_t & nb1,
- uint tpig[[thread_position_in_grid]]) {
- const int i = tpig;
- const int r = ((device int32_t *) src1)[i];
-
- dequantize_row_q4_0(
- (device const block_q4_0 *) ((device char *) src0 + r*nb01),
- (device float *) ((device char *) dst + i*nb1), ne00);
-}
-
-kernel void kernel_get_rows_q4_1(
- device const void * src0,
- device const int * src1,
- device float * dst,
- constant int64_t & ne00,
- constant uint64_t & nb01,
- constant uint64_t & nb1,
- uint tpig[[thread_position_in_grid]]) {
- const int i = tpig;
- const int r = ((device int32_t *) src1)[i];
-
- dequantize_row_q4_1(
- (device const block_q4_1 *) ((device char *) src0 + r*nb01),
- (device float *) ((device char *) dst + i*nb1), ne00);
-}
-
kernel void kernel_norm(
device const void * src0,
device float * dst,
// N_DST, so this is another explicit assumption of the implementation.
template<typename block_q_type, int nr, int nsg, int nw>
void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
- int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
- uint2 tgpig, uint tiisg, uint sgitg) {
+ int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa,
+ uint3 tgpig, uint tiisg, uint sgitg) {
const int nb = ne00/QK4_0;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
+ const int im = tgpig.z;
const int first_row = (r0 * nsg + sgitg) * nr;
- device const block_q_type * x = (device const block_q_type *) src0 + first_row * nb;
- device const float * y = (device const float *) src1 + r1*ne10;
+ const uint offset0 = first_row * nb + im/gqa*(nb*ne0);
+ device const block_q_type * x = (device const block_q_type *) src0 + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
float yl[16]; // src1 vector cache
float sumf[nr]={0.f};
for (int row = 0; row < nr; ++row) {
const float tot = simd_sum(sumf[row]);
if (tiisg == 0 && first_row + row < ne01) {
- dst[r1*ne0 + first_row + row] = tot;
+ dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot;
}
}
}
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne10,
- constant int64_t & ne0,
constant int64_t & ne01[[buffer(4)]],
- uint2 tgpig[[threadgroup_position_in_grid]],
+ constant int64_t & ne02[[buffer(5)]],
+ constant int64_t & ne10[[buffer(9)]],
+ constant int64_t & ne12[[buffer(11)]],
+ constant int64_t & ne0[[buffer(15)]],
+ constant int64_t & ne1[[buffer(16)]],
+ constant uint & gqa[[buffer(17)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
+ mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
}
kernel void kernel_mul_mat_q4_1_f32(
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne10,
- constant int64_t & ne0,
constant int64_t & ne01[[buffer(4)]],
- uint2 tgpig[[threadgroup_position_in_grid]],
+ constant int64_t & ne02[[buffer(5)]],
+ constant int64_t & ne10[[buffer(9)]],
+ constant int64_t & ne12[[buffer(11)]],
+ constant int64_t & ne0[[buffer(15)]],
+ constant int64_t & ne1[[buffer(16)]],
+ constant uint & gqa[[buffer(17)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
- mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
+ mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
}
kernel void kernel_mul_mat_f16_f32(
return r;
}
-//========================================== dequantization =============================
-
-static void dequantize_row_q2_K(device const block_q2_K * x, device float * y, int k) {
- assert(k % QK_K == 0);
- const int nb = k / QK_K;
-
- for (int i = 0; i < nb; i++) {
-
- const float d = x[i].d;
- const float min = x[i].dmin;
-
- device const uint8_t * q = x[i].qs;
-
-#if QK_K == 256
- int is = 0;
- float dl, ml;
- for (int n = 0; n < QK_K; n += 128) {
- int shift = 0;
- for (int j = 0; j < 4; ++j) {
-
- uint8_t sc = x[i].scales[is++];
- dl = d * (sc & 0xF); ml = min * (sc >> 4);
- for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml;
-
- sc = x[i].scales[is++];
- dl = d * (sc & 0xF); ml = min * (sc >> 4);
- for (int l = 0; l < 16; ++l) *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml;
-
- shift += 2;
- }
- q += 32;
- }
-#else
- float dl1 = d * (x[i].scales[0] & 0xF), ml1 = min * (x[i].scales[0] >> 4);
- float dl2 = d * (x[i].scales[1] & 0xF), ml2 = min * (x[i].scales[1] >> 4);
- float dl3 = d * (x[i].scales[2] & 0xF), ml3 = min * (x[i].scales[2] >> 4);
- float dl4 = d * (x[i].scales[3] & 0xF), ml4 = min * (x[i].scales[3] >> 4);
- for (int l = 0; l < 16; ++l) {
- y[l+ 0] = dl1 * ((q[l] >> 0) & 3) - ml1;
- y[l+16] = dl2 * ((q[l] >> 2) & 3) - ml2;
- y[l+32] = dl3 * ((q[l] >> 4) & 3) - ml3;
- y[l+48] = dl4 * ((q[l] >> 6) & 3) - ml4;
- }
- y += QK_K;
-#endif
-
- }
-}
-
-static void dequantize_row_q3_K(device const block_q3_K * x, device float * y, int k) {
- assert(k % QK_K == 0);
- const int nb = k / QK_K;
-
-#if QK_K == 256
-
- const uint16_t kmask1 = 0x0303;
- const uint16_t kmask2 = 0x0f0f;
-
- uint16_t aux[8];
- thread const int8_t * scales = (thread const int8_t*)aux;
-
- for (int i = 0; i < nb; i++) {
-
- const float d_all = (float)(x[i].d);
-
- device const uint8_t * q = x[i].qs;
- device const uint8_t * h = x[i].hmask;
- uint8_t m = 1;
-
- device const uint16_t * a = (device const uint16_t *)x[i].scales;
- aux[0] = (a[0] & kmask2) | (((a[4] >> 0) & kmask1) << 4);
- aux[1] = (a[1] & kmask2) | (((a[5] >> 0) & kmask1) << 4);
- aux[2] = (a[2] & kmask2) | (((a[4] >> 2) & kmask1) << 4);
- aux[3] = (a[3] & kmask2) | (((a[5] >> 2) & kmask1) << 4);
- aux[4] = ((a[0] >> 4) & kmask2) | (((a[4] >> 4) & kmask1) << 4);
- aux[5] = ((a[1] >> 4) & kmask2) | (((a[5] >> 4) & kmask1) << 4);
- aux[6] = ((a[2] >> 4) & kmask2) | (((a[4] >> 6) & kmask1) << 4);
- aux[7] = ((a[3] >> 4) & kmask2) | (((a[5] >> 6) & kmask1) << 4);
-
- int is = 0;
- float dl;
- for (int n = 0; n < QK_K; n += 128) {
- int shift = 0;
- for (int j = 0; j < 4; ++j) {
-
- dl = d_all * (scales[is++] - 32);
- for (int l = 0; l < 16; ++l) {
- *y++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((h[l+ 0] & m) ? 0 : 4));
- }
-
- dl = d_all * (scales[is++] - 32);
- for (int l = 0; l < 16; ++l) {
- *y++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((h[l+16] & m) ? 0 : 4));
- }
-
- shift += 2;
- m <<= 1;
- }
- q += 32;
- }
- }
-#else
- for (int i = 0; i < nb; i++) {
-
- const float d_all = (float)(x[i].d);
-
- device const uint8_t * q = x[i].qs;
- device const uint8_t * hm = x[i].hmask;
-
- const float d1 = d_all * ((x[i].scales[0] & 0xF) - 8);
- const float d2 = d_all * ((x[i].scales[0] >> 4) - 8);
- const float d3 = d_all * ((x[i].scales[1] & 0xF) - 8);
- const float d4 = d_all * ((x[i].scales[1] >> 4) - 8);
-
- for (int l = 0; l < 8; ++l) {
- uint8_t h = hm[l];
- y[l+ 0] = d1 * ((int8_t)((q[l+0] >> 0) & 3) - ((h & 0x01) ? 0 : 4));
- y[l+ 8] = d1 * ((int8_t)((q[l+8] >> 0) & 3) - ((h & 0x02) ? 0 : 4));
- y[l+16] = d2 * ((int8_t)((q[l+0] >> 2) & 3) - ((h & 0x04) ? 0 : 4));
- y[l+24] = d2 * ((int8_t)((q[l+8] >> 2) & 3) - ((h & 0x08) ? 0 : 4));
- y[l+32] = d3 * ((int8_t)((q[l+0] >> 4) & 3) - ((h & 0x10) ? 0 : 4));
- y[l+40] = d3 * ((int8_t)((q[l+8] >> 4) & 3) - ((h & 0x20) ? 0 : 4));
- y[l+48] = d4 * ((int8_t)((q[l+0] >> 6) & 3) - ((h & 0x40) ? 0 : 4));
- y[l+56] = d4 * ((int8_t)((q[l+8] >> 6) & 3) - ((h & 0x80) ? 0 : 4));
- }
- y += QK_K;
- }
-#endif
-
-}
-
-static void dequantize_row_q4_K(device const block_q4_K * x, device float * y, int k) {
- assert(k % QK_K == 0);
- const int nb = k / QK_K;
-
- for (int i = 0; i < nb; i++) {
-
- device const uint8_t * q = x[i].qs;
-
-#if QK_K == 256
- const float d = x[i].d;
- const float min = x[i].dmin;
-
- device const uint8_t * scales = x[i].scales;
-
- int is = 0;
- for (int j = 0; j < QK_K; j += 64) {
- const uchar4 sc = get_scale_min_k4(is, scales);
- const float d1 = d * sc[0]; const float m1 = min * sc[1];
- const float d2 = d * sc[2]; const float m2 = min * sc[3];
- for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
- for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
- q += 32; is += 2;
- }
-#else
- device const uint8_t * s = x[i].scales;
- device const half2 * dh = (device const half2 *)x[i].d;
- const float2 d = (float2)dh[0];
- const float d1 = d[0] * (s[0] & 0xF);
- const float d2 = d[0] * (s[1] & 0xF);
- const float m1 = d[1] * (s[0] >> 4);
- const float m2 = d[1] * (s[1] >> 4);
- for (int l = 0; l < 32; ++l) {
- y[l+ 0] = d1 * (q[l] & 0xF) - m1;
- y[l+32] = d2 * (q[l] >> 4) - m2;
- }
- y += QK_K;
-#endif
-
- }
-}
-
-static void dequantize_row_q5_K(device const block_q5_K * x, device float * y, int k) {
- assert(k % QK_K == 0);
- const int nb = k / QK_K;
-
-#if QK_K == 256
- for (int i = 0; i < nb; i++) {
-
- const float d = (float)(x[i].d);
- const float min = (float)(x[i].dmin);
-
- device const uint8_t * ql = x[i].qs;
- device const uint8_t * qh = x[i].qh;
-
- int is = 0;
- uint8_t u1 = 1, u2 = 2;
- for (int j = 0; j < QK_K; j += 64) {
- const uchar4 sc = get_scale_min_k4(is, x[i].scales);
- const float d1 = d * sc[0]; const float m1 = min * sc[1];
- const float d2 = d * sc[2]; const float m2 = min * sc[3];
- for (int l = 0; l < 32; ++l) *y++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;
- for (int l = 0; l < 32; ++l) *y++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;
- ql += 32; is += 2;
- u1 <<= 2; u2 <<= 2;
- }
- }
-#else
- for (int i = 0; i < nb; i++) {
-
- const float d = (float)x[i].d;
-
- device const uint8_t * ql = x[i].qs;
- device const uint8_t * qh = x[i].qh;
- device const int8_t * sc = x[i].scales;
-
- for (int l = 0; l < 8; ++l) {
- y[l+ 0] = d * sc[0] * ((ql[l+ 0] & 0xF) - (qh[l] & 0x01 ? 0 : 16));
- y[l+ 8] = d * sc[0] * ((ql[l+ 8] & 0xF) - (qh[l] & 0x02 ? 0 : 16));
- y[l+16] = d * sc[1] * ((ql[l+16] & 0xF) - (qh[l] & 0x04 ? 0 : 16));
- y[l+24] = d * sc[1] * ((ql[l+24] & 0xF) - (qh[l] & 0x08 ? 0 : 16));
- y[l+32] = d * sc[2] * ((ql[l+ 0] >> 4) - (qh[l] & 0x10 ? 0 : 16));
- y[l+40] = d * sc[2] * ((ql[l+ 8] >> 4) - (qh[l] & 0x20 ? 0 : 16));
- y[l+48] = d * sc[3] * ((ql[l+16] >> 4) - (qh[l] & 0x40 ? 0 : 16));
- y[l+56] = d * sc[3] * ((ql[l+24] >> 4) - (qh[l] & 0x80 ? 0 : 16));
- }
- y += QK_K;
- }
-#endif
-
-}
-
-static void dequantize_row_q6_K(device const block_q6_K * x, device float * y, int k) {
- assert(k % QK_K == 0);
- const int nb = k / QK_K;
-
- for (int i = 0; i < nb; i++) {
-
- device const uint8_t * ql = x[i].ql;
- device const uint8_t * qh = x[i].qh;
- device const int8_t * sc = x[i].scales;
-
- const float d = x[i].d;
-
-#if QK_K == 256
- for (int n = 0; n < QK_K; n += 128) {
- for (int l = 0; l < 32; ++l) {
- int is = l/16;
- const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
- const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
- const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
- const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
- y[l + 0] = d * sc[is + 0] * q1;
- y[l + 32] = d * sc[is + 2] * q2;
- y[l + 64] = d * sc[is + 4] * q3;
- y[l + 96] = d * sc[is + 6] * q4;
- }
- y += 128;
- ql += 64;
- qh += 32;
- sc += 8;
- }
-#else
- for (int l = 0; l < 16; ++l) {
- const int8_t q1 = (int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
- const int8_t q2 = (int8_t)((ql[l+16] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
- const int8_t q3 = (int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
- const int8_t q4 = (int8_t)((ql[l+16] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
- y[l+ 0] = d * sc[0] * q1;
- y[l+16] = d * sc[1] * q2;
- y[l+32] = d * sc[2] * q3;
- y[l+48] = d * sc[3] * q4;
- }
- y += 64;
-#endif
- }
-}
-
-kernel void kernel_get_rows_q2_K(
- device const void * src0,
- device const int * src1,
- device float * dst,
- constant int64_t & ne00,
- constant uint64_t & nb01,
- constant uint64_t & nb1,
- uint tpig[[thread_position_in_grid]]) {
- const int i = tpig;
- const int r = ((device int32_t *) src1)[i];
-
- dequantize_row_q2_K(
- (device const block_q2_K *) ((device char *) src0 + r*nb01),
- (device float *) ((device char *) dst + i*nb1), ne00);
-}
-
-kernel void kernel_get_rows_q3_K(
- device const void * src0,
- device const int * src1,
- device float * dst,
- constant int64_t & ne00,
- constant uint64_t & nb01,
- constant uint64_t & nb1,
- uint tpig[[thread_position_in_grid]]) {
- const int i = tpig;
- const int r = ((device int32_t *) src1)[i];
-
- dequantize_row_q3_K(
- (device const block_q3_K *) ((device char *) src0 + r*nb01),
- (device float *) ((device char *) dst + i*nb1), ne00);
-}
-
-kernel void kernel_get_rows_q4_K(
- device const void * src0,
- device const int * src1,
- device float * dst,
- constant int64_t & ne00,
- constant uint64_t & nb01,
- constant uint64_t & nb1,
- uint tpig[[thread_position_in_grid]]) {
- const int i = tpig;
- const int r = ((device int32_t *) src1)[i];
-
- dequantize_row_q4_K(
- (device const block_q4_K *) ((device char *) src0 + r*nb01),
- (device float *) ((device char *) dst + i*nb1), ne00);
-}
-
-kernel void kernel_get_rows_q5_K(
- device const void * src0,
- device const int * src1,
- device float * dst,
- constant int64_t & ne00,
- constant uint64_t & nb01,
- constant uint64_t & nb1,
- uint tpig[[thread_position_in_grid]]) {
- const int i = tpig;
- const int r = ((device int32_t *) src1)[i];
-
- dequantize_row_q5_K(
- (device const block_q5_K *) ((device char *) src0 + r*nb01),
- (device float *) ((device char *) dst + i*nb1), ne00);
-}
-
-kernel void kernel_get_rows_q6_K(
- device const void * src0,
- device const int * src1,
- device float * dst,
- constant int64_t & ne00,
- constant uint64_t & nb01,
- constant uint64_t & nb1,
- uint tpig[[thread_position_in_grid]]) {
- const int i = tpig;
- const int r = ((device int32_t *) src1)[i];
-
- dequantize_row_q6_K(
- (device const block_q6_K *) ((device char *) src0 + r*nb01),
- (device float *) ((device char *) dst + i*nb1), ne00);
-}
-
//====================================== dot products =========================
kernel void kernel_mul_mat_q2_K_f32(
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne10,
- constant int64_t & ne0,
constant int64_t & ne01[[buffer(4)]],
- uint2 tgpig[[threadgroup_position_in_grid]],
+ constant int64_t & ne02[[buffer(5)]],
+ constant int64_t & ne10[[buffer(9)]],
+ constant int64_t & ne12[[buffer(11)]],
+ constant int64_t & ne0[[buffer(15)]],
+ constant int64_t & ne1[[buffer(16)]],
+ constant uint & gqa[[buffer(17)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
const int nb = ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
+ const int r2 = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
const int ib_row = first_row * nb;
- device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row;
- device const float * y = (device const float *) src1 + r1*ne10;
+ const uint offset0 = r2/gqa*(nb*ne0);
+ device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
float yl[32];
float sumf[N_DST]={0.f}, all_sum;
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
- dst[r1*ne0 + first_row + row] = all_sum;
+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
}
}
}
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne10,
- constant int64_t & ne0,
- constant int64_t & ne1,
- uint2 tgpig[[threadgroup_position_in_grid]],
+ constant int64_t & ne01[[buffer(4)]],
+ constant int64_t & ne02[[buffer(5)]],
+ constant int64_t & ne10[[buffer(9)]],
+ constant int64_t & ne12[[buffer(11)]],
+ constant int64_t & ne0[[buffer(15)]],
+ constant int64_t & ne1[[buffer(16)]],
+ constant uint & gqa[[buffer(17)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
+ const int64_t r2 = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
-
- device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb;
- device const float * yy = (device const float *) src1 + r1*ne10;
+ const uint offset0 = r2/gqa*(nb*ne0);
+ device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
+ device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
float yl[16];
const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift);
const float tot = simd_sum(sumf);
if (tiisg == 0) {
- dst[r1*ne0 + first_row + row] = tot;
+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
}
}
}
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne10,
- constant int64_t & ne0,
- constant int64_t & ne1,
- uint2 tgpig[[threadgroup_position_in_grid]],
+ constant int64_t & ne01[[buffer(4)]],
+ constant int64_t & ne02[[buffer(5)]],
+ constant int64_t & ne10[[buffer(9)]],
+ constant int64_t & ne12[[buffer(11)]],
+ constant int64_t & ne0[[buffer(15)]],
+ constant int64_t & ne1[[buffer(16)]],
+ constant uint & gqa[[buffer(17)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
+ const int64_t r2 = tgpig.z;
const int row = 2 * r0 + sgitg;
-
- device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb;
- device const float * yy = (device const float *) src1 + r1*ne10;
+ const uint offset0 = r2/gqa*(nb*ne0);
+ device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
+ device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
const int ix = tiisg/4;
const int il = 4 * (tiisg%4);// 0, 4, 8, 12
const int im = il/8; // 0, 0, 1, 1
const float tot = simd_sum(sumf);
if (tiisg == 0) {
- dst[r1*ne0 + row] = tot;
+ dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
}
}
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne10,
- constant int64_t & ne0,
constant int64_t & ne01[[buffer(4)]],
- uint2 tgpig[[threadgroup_position_in_grid]],
+ constant int64_t & ne02[[buffer(5)]],
+ constant int64_t & ne10[[buffer(9)]],
+ constant int64_t & ne12[[buffer(11)]],
+ constant int64_t & ne0[[buffer(15)]],
+ constant int64_t & ne1[[buffer(16)]],
+ constant uint & gqa[[buffer(17)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
const int nb = ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
+ const int r2 = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
const int ib_row = first_row * nb;
- device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
- device const float * y = (device const float *) src1 + r1*ne10;
+ const uint offset0 = r2/gqa*(nb*ne0);
+ device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
float yl[16];
float yh[16];
float sumf[N_DST]={0.f}, all_sum;
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
- dst[r1*ne0 + first_row + row] = all_sum;
+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum;
}
}
}
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne10,
- constant int64_t & ne0,
constant int64_t & ne01[[buffer(4)]],
- uint2 tgpig[[threadgroup_position_in_grid]],
+ constant int64_t & ne02[[buffer(5)]],
+ constant int64_t & ne10[[buffer(9)]],
+ constant int64_t & ne12[[buffer(11)]],
+ constant int64_t & ne0[[buffer(15)]],
+ constant int64_t & ne1[[buffer(16)]],
+ constant uint & gqa[[buffer(17)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
const int nb = ne00/QK_K;
const int r0 = tgpig.x;
const int r1 = tgpig.y;
+ const int r2 = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
const int ib_row = first_row * nb;
- device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row;
- device const float * y = (device const float *) src1 + r1*ne10;
+ const uint offset0 = r2/gqa*(nb*ne0);
+ device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
+ device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
float yl[8];
float yh[8];
float sumf[N_DST]={0.f}, all_sum;
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0) {
- dst[r1*ne0 + first_row + row] = all_sum;
+ dst[r1*ne0+ r2*ne0*ne1 + first_row + row] = all_sum;
}
}
}
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne10,
- constant int64_t & ne0,
- uint2 tgpig[[threadgroup_position_in_grid]],
+ constant int64_t & ne01[[buffer(4)]],
+ constant int64_t & ne02[[buffer(5)]],
+ constant int64_t & ne10[[buffer(9)]],
+ constant int64_t & ne12[[buffer(11)]],
+ constant int64_t & ne0[[buffer(15)]],
+ constant int64_t & ne1[[buffer(16)]],
+ constant uint & gqa[[buffer(17)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
+ const int r2 = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
-
- device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb;
- device const float * yy = (device const float *) src1 + r1*ne10;
+ const uint offset0 = r2/gqa*(nb*ne0);
+ device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
+ device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
float sumf[2]={0.f};
for (int row = 0; row < 2; ++row) {
const float tot = simd_sum(sumf[row]);
if (tiisg == 0) {
- dst[r1*ne0 + first_row + row] = tot;
+ dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot;
}
}
device const float * src1,
device float * dst,
constant int64_t & ne00,
- constant int64_t & ne10,
- constant int64_t & ne0,
- uint2 tgpig[[threadgroup_position_in_grid]],
+ constant int64_t & ne01[[buffer(4)]],
+ constant int64_t & ne02[[buffer(5)]],
+ constant int64_t & ne10[[buffer(9)]],
+ constant int64_t & ne12[[buffer(11)]],
+ constant int64_t & ne0[[buffer(15)]],
+ constant int64_t & ne1[[buffer(16)]],
+ constant uint & gqa[[buffer(17)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
+ const int r2 = tgpig.z;
const int row = 2 * r0 + sgitg;
-
- device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb; //r0*nb;
- device const float * yy = (device const float *) src1 + r1*ne10;
+ const uint offset0 = r2/gqa*(nb*ne0);
+ device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
+ device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
float sumf = 0;
const float tot = simd_sum(sumf);
if (tiisg == 0) {
- dst[r1*ne0 + row] = tot;
+ dst[r1*ne0 + r2*ne0*ne1 + row] = tot;
+ }
+}
+
+//============================= templates and their specializations =============================
+
+template <typename type4x4>
+void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
+ half4x4 temp = *(((device half4x4 *)src));
+ for (int i = 0; i < 16; i++){
+ reg[i/4][i%4] = temp[i/4][i%4];
+ }
+}
+
+template <typename type4x4>
+void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
+ device const uint16_t * qs = ((device const uint16_t *)xb + 1);
+ const half d = il ? (xb->d / 16.h) : xb->d;
+ const half m = il ? (-8.h * 16.h) : -8.h;
+ const ushort mask0 = il ? 0x00F0 : 0x000F;
+ const ushort mask1 = il ? 0xF000 : 0x0F00;
+
+ for (int i=0;i<8;i++) {
+ reg[i/2][2*(i%2)] = (((qs[i] & mask0)) + m) * d;
+ reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d;
+ }
+}
+
+template <typename type4x4>
+void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
+ device const uint16_t * qs = ((device const uint16_t *)xb + 2);
+ const half d = il ? (xb->d / 16.h) : xb->d;
+ const half m = xb->m;
+ const ushort mask0 = il ? 0x00F0 : 0x000F;
+ const ushort mask1 = il ? 0xF000 : 0x0F00;
+
+ for (int i=0;i<8;i++) {
+ reg[i/2][2*(i%2)] = (((qs[i] & mask0)) * d) + m;
+ reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m;
+ }
+}
+
+template <typename type4x4>
+void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
+ const half d = xb->d;
+ const half min = xb->dmin;
+ device const uint8_t * q = (device const uint8_t *)xb->qs;
+ half dl, ml;
+ uint8_t sc = xb->scales[il];
+
+#if QK_K == 256
+ q = q + 32*(il/8) + 16*(il&1);
+ il = (il/2)%4;
+#endif
+ half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
+ uchar mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
+ dl = d * (sc & 0xF) * coef, ml = min * (sc >> 4);
+ for (int i = 0; i < 16; ++i) {
+ reg[i/4][i%4] = dl * (q[i] & mask) - ml;
+ }
+}
+
+template <typename type4x4>
+void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
+ const float d_all = (float)(xb->d);
+ device const uint8_t * q = (device const uint8_t *)xb->qs;
+ device const uint8_t * h = (device const uint8_t *)xb->hmask;
+ device const int8_t * scales = (device const int8_t *)xb->scales;
+
+#if QK_K == 256
+ q = q + 32 * (il/8) + 16 * (il&1);
+ h = h + 16 * (il&1);
+ uint8_t m = 1 << (il/2);
+ uint16_t kmask1 = (il/4)>1 ? ((il/4)>2 ? 192 : 48) : \
+ ((il/4)>0 ? 12 : 3);
+ uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
+ uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
+ int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) : \
+ (scale_2&kmask2) | ((scale_1&kmask1) << 4);
+ float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f);
+
+ il = (il/2)%4;
+ float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
+ uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
+
+ for (int i = 0; i < 16; ++i) {
+ reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i] & m) ? 0 : 4.f/coef));
+ }
+#else
+ float kcoef = il&1 ? 1.f/16.f : 1.f;
+ uint16_t kmask = il&1 ? 0xF0 : 0x0F;
+ float dl = d_all * ((scales[il/2] & kmask) * kcoef - 8);
+ float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
+ uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
+ uint8_t m = 1<<(il*2);
+ for (int i = 0; i < 16; ++i) {
+ reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i%8] & (m * (1 + i/8))) ? 0 : 4.f/coef));
+ }
+#endif
+}
+
+template <typename type4x4>
+void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
+ device const uint8_t * q = xb->qs;
+
+#if QK_K == 256
+ const float d = (float)(xb->d);
+ const float min = (float)(xb->dmin);
+ short is = (il/4) * 2;
+ q = q + (il/4) * 32 + 16 * (il&1);
+ il = il%4;
+ const uchar4 sc = get_scale_min_k4(is, xb->scales);
+ const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
+ const float ml = il<2 ? min * sc[1] : min * sc[3];
+#else
+ q = q + 16 * (il&1);
+ device const uint8_t * s = xb->scales;
+ device const half2 * dh = (device const half2 *)xb->d;
+ const float2 d = (float2)dh[0];
+ const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
+ const float ml = il<2 ? d[1] * (s[0]>>4) : d[1 ]* (s[1]>>4);
+#endif
+ const ushort mask = il<2 ? 0x0F : 0xF0;
+ for (int i = 0; i < 16; ++i) {
+ reg[i/4][i%4] = dl * (q[i] & mask) - ml;
+ }
+}
+
+template <typename type4x4>
+void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg) {
+ device const uint8_t * q = xb->qs;
+ device const uint8_t * qh = xb->qh;
+
+#if QK_K == 256
+ const float d = (float)(xb->d);
+ const float min = (float)(xb->dmin);
+ short is = (il/4) * 2;
+ q = q + 32 * (il/4) + 16 * (il&1);
+ qh = qh + 16 * (il&1);
+ uint8_t ul = 1 << (il/2);
+ il = il%4;
+ const uchar4 sc = get_scale_min_k4(is, xb->scales);
+ const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h;
+ const float ml = il<2 ? min * sc[1] : min * sc[3];
+
+ const ushort mask = il<2 ? 0x0F : 0xF0;
+ const float qh_val = il<2 ? 16.f : 256.f;
+ for (int i = 0; i < 16; ++i) {
+ reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
+ }
+#else
+ q = q + 16 * (il&1);
+ device const int8_t * s = xb->scales;
+ const float dl = xb->d * s[il];
+ uint8_t m = 1<<(il*2);
+ const float coef = il<2 ? 1.f : 1.f/16.f;
+ const ushort mask = il<2 ? 0x0F : 0xF0;
+ for (int i = 0; i < 16; ++i) {
+ reg[i/4][i%4] = coef * dl * ((q[i] & mask) - (qh[i%8] & (m*(1+i/8)) ? 0.f : 16.f/coef));
+ }
+#endif
+}
+
+template <typename type4x4>
+void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
+ const float d_all = (float)(xb->d);
+ device const uint8_t * ql = (device const uint8_t *)xb->ql;
+ device const uint8_t * qh = (device const uint8_t *)xb->qh;
+ device const int8_t * scales = (device const int8_t *)xb->scales;
+
+#if QK_K == 256
+ ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
+ qh = qh + 32*(il/8) + 16*(il&1);
+ float sc = scales[(il%2) + 2 * ((il/2))];
+ il = (il/2)%4;
+#else
+ ql = ql + 16 * (il&1);
+ float sc = scales[il];
+#endif
+ for (int i = 0; i < 16; ++i) {
+ uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
+ uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
+ const float coef = il>1 ? 1.f/16.f : 1.f;
+ float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \
+ ((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef;
+ reg[i/4][i%4] = d_all * sc * q * coef;
+ }
+}
+
+template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
+kernel void kernel_get_rows(
+ device const void * src0,
+ device const int * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb1,
+ uint tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint tptg[[threads_per_threadgroup]]) {
+ const int i = tgpig;
+ const int r = ((device int32_t *) src1)[i];
+
+ for (int ind = tiitg; ind < ne00/16; ind += tptg) {
+ float4x4 temp;
+ dequantize_func(
+ ((device const block_q *) ((device char *) src0 + r*nb01)) + ind/nl, ind%nl, temp);
+ *(((device float4x4 *) ((device char *) dst + i*nb1)) + ind) = temp;
+ }
+}
+
+#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
+#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix A
+#define BLOCK_SIZE_K 32
+#define THREAD_MAT_M 4 // each thread take 4 simdgroup matrices from matrix A
+#define THREAD_MAT_N 2 // each thread take 2 simdgroup matrices from matrix B
+#define THREAD_PER_BLOCK 128
+#define THREAD_PER_ROW 2 // 2 thread for each row in matrix A to load numbers
+#define THREAD_PER_COL 4 // 4 thread for each row in matrix B to load numbers
+#define SG_MAT_SIZE 64 // simdgroup matrix is of shape 8x8
+#define SG_MAT_ROW 8
+
+// each block_q contains 16*nl weights
+template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
+kernel void kernel_mul_mm(device const uchar * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne02,
+ constant int64_t & nb01,
+ constant int64_t & nb02,
+ constant int64_t & ne12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant uint & gqa,
+ threadgroup uchar * shared_memory [[threadgroup(0)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiitg[[thread_index_in_threadgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+
+ threadgroup half * sa = ((threadgroup half *)shared_memory);
+ threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
+
+ const uint r0 = tgpig.y;
+ const uint r1 = tgpig.x;
+ const uint im = tgpig.z;
+ // if this block is of 64x32 shape or smaller
+ short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M;
+ short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N;
+ // a thread shouldn't load data outside of the matrix
+ short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
+ short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
+
+ simdgroup_half8x8 ma[4];
+ simdgroup_float8x8 mb[2];
+ simdgroup_float8x8 c_res[8];
+ for (int i = 0; i < 8; i++){
+ c_res[i] = make_filled_simdgroup_matrix<float, 8>(0.f);
+ }
+
+ short il = (tiitg % THREAD_PER_ROW);
+ uint offset0 = im/gqa*nb02; ushort offset1 = il/nl;
+ device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
+ device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \
+ + BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1;
+
+ for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
+ //load data and store to threadgroup memory
+ half4x4 temp_a;
+ dequantize_func(x, il, temp_a);
+ #pragma unroll(16)
+ for (int i = 0; i < 16; i++) {
+ *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \
+ + 16 * (tiitg % THREAD_PER_ROW) + 8 * (i / 8)) \
+ + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4];
+ }
+ *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) \
+ = *((device float2x4 *)y);
+ il = (il + 2 < nl) ? il + 2 : il % 2;
+ x = (il < 2) ? x + (2+nl-1)/nl : x;
+ y += BLOCK_SIZE_K;
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ //load matrices from threadgroup memory and conduct outer products
+ threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2));
+ threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2));
+ #pragma unroll(4)
+ for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) {
+ #pragma unroll(4)
+ for (int i = 0; i < 4; i++) {
+ simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i);
+ }
+ simdgroup_barrier(mem_flags::mem_none);
+ #pragma unroll(2)
+ for (int i = 0; i < 2; i++) {
+ simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i);
+ }
+
+ lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE;
+ lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE;
+ #pragma unroll(8)
+ for (int i = 0; i < 8; i++){
+ simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]);
+ }
+ }
+ }
+
+ if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) {
+ device float *C = dst + BLOCK_SIZE_M * r0 + 32 * (sgitg&1) \
+ + (BLOCK_SIZE_N * r1 + 16 * (sgitg>>1)) * ne0 + im*ne1*ne0;
+ for (int i = 0; i < 8; i++) {
+ simdgroup_store(c_res[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0);
+ }
+ } else {
+ // block is smaller than 64x32, we should avoid writing data outside of the matrix
+ threadgroup float *temp_str = ((threadgroup float *)shared_memory) \
+ + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M;
+ for (int i = 0; i < 8; i++) {
+ simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M);
+ }
+
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ device float *C = dst + BLOCK_SIZE_M * r0 + (BLOCK_SIZE_N * r1) * ne0 + im*ne1*ne0;
+ if (sgitg==0) {
+ for (int i = 0; i < n_rows; i++) {
+ for (int j = tiitg; j< n_cols; j += BLOCK_SIZE_N) {
+ *(C + i + j * ne0) = *(temp_str + i + j * BLOCK_SIZE_M);
+ }
+ }
+ }
}
}
+
+#if QK_K == 256
+#define QK_NL 16
+#else
+#define QK_NL 4
+#endif
+
+typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
+ constant uint64_t &, constant uint64_t &, uint, uint, uint);
+
+template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
+template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
+template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
+template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
+template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
+template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows<block_q4_K, QK_NL, dequantize_q4_K>;
+template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
+template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
+
+typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\
+ constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
+ constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
+
+template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
+template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
+template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
+template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
+template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
+template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
+template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_K, QK_NL, dequantize_q5_K>;
+template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;