GGML_METAL_DECL_KERNEL(rms_norm);
GGML_METAL_DECL_KERNEL(norm);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
+ GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
GGML_METAL_ADD_KERNEL(rms_norm);
GGML_METAL_ADD_KERNEL(norm);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
+ GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
GGML_METAL_DEL_KERNEL(rms_norm);
GGML_METAL_DEL_KERNEL(norm);
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
+ GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
{
nth0 = 32;
nth1 = 1;
- [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
+ if (ne11 * ne12 < 4) {
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
+ } else {
+ [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
+ }
} break;
case GGML_TYPE_Q4_0:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
- nth0 = 2;
- nth1 = 32;
+ nth0 = 4; //1;
+ nth1 = 8; //32;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32];
} break;
case GGML_TYPE_Q5_K:
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
- src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {
+ src0t == GGML_TYPE_Q2_K) {// || src0t == GGML_TYPE_Q4_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
+ else if (src0t == GGML_TYPE_Q4_K) {
+ [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ }
else if (src0t == GGML_TYPE_Q3_K) {
#ifdef GGML_QKK_64
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
else if (src0t == GGML_TYPE_Q6_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {
- [encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
- [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
+ int64_t ny = (ne11 + 3)/4;
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
}
} break;
threadgroup_barrier(mem_flags::mem_threadgroup);
}
- // broadcast
- if (tpitg[0] == 0) {
- buf[0] = buf[0];
- }
+ //// broadcast - not needed. There is a threadgroup barrier above in the last iteration of
+ // the loop, and when that is done, buf[0] has the correct (synchronized) value
+ //if (tpitg[0] == 0) {
+ // buf[0] = buf[0];
+ //}
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ //threadgroup_barrier(mem_flags::mem_threadgroup);
const float max = buf[0];
// parallel sum
buf[tpitg[0]] = 0.0f;
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
- buf[tpitg[0]] += exp(psrc0[i00] - max);
+ const float exp_psrc0 = exp(psrc0[i00] - max);
+ buf[tpitg[0]] += exp_psrc0;
+ // Remember the result of exp here. exp is expensive, so we really do not
+ // whish to compute it twice.
+ pdst[i00] = exp_psrc0;
}
// reduce
threadgroup_barrier(mem_flags::mem_threadgroup);
}
- // broadcast
- if (tpitg[0] == 0) {
- buf[0] = buf[0];
- }
+ // broadcast - not needed, see above
+ //// broadcast
+ //if (tpitg[0] == 0) {
+ // buf[0] = buf[0];
+ //}
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ //threadgroup_barrier(mem_flags::mem_threadgroup);
const float sum = buf[0];
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
- pdst[i00] = exp(psrc0[i00] - max) / sum;
+ pdst[i00] /= sum;
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
- // broadcast
- if (tpitg == 0) {
- sum[0] /= ne00;
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ //// broadcast
+ //if (tpitg == 0) {
+ // sum[0] /= ne00;
+ //}
+ //threadgroup_barrier(mem_flags::mem_threadgroup);
const float mean = sum[0];
- // recenter
+ // recenter and VARIANCE
device float * y = dst + tgpig*ne00;
- for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
- y[i00] = x[i00] - mean;
- }
-
- // VARIANCE
- // parallel sum
sum[tpitg] = 0.0f;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+ y[i00] = x[i00] - mean;
sum[tpitg] += y[i00] * y[i00];
}
+
+ //// VARIANCE
+ //// parallel sum
+ //sum[tpitg] = 0.0f;
+ //for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+ // sum[tpitg] += y[i00] * y[i00];
+ //}
// reduce
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint i = ntg/2; i > 0; i /= 2) {
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
- // broadcast
- if (tpitg == 0) {
- sum[0] /= ne00;
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
+ //// broadcast
+ //if (tpitg == 0) {
+ // sum[0] /= ne00;
+ //}
+ //threadgroup_barrier(mem_flags::mem_threadgroup);
const float variance = sum[0];
const float scale = 1.0f/sqrt(variance + eps);
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
}
+#define NB_Q8_0 8
+
kernel void kernel_mul_mat_q8_0_f32(
device const void * src0,
device const float * src1,
device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0;
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
- float yl[16];
+ float yl[NB_Q8_0];
float sumf[nr]={0.f};
- const int ix = tiisg/2;
- const int il = tiisg%2;
+ const int ix = tiisg/4;
+ const int il = tiisg%4;
- device const float * yb = y + ix * QK8_0 + 16*il;
+ device const float * yb = y + ix * QK8_0 + NB_Q8_0*il;
- // each thread in a SIMD group deals with half a block.
- for (int ib = ix; ib < nb; ib += nw/2) {
- for (int i = 0; i < 16; ++i) {
+ // each thread in a SIMD group deals with NB_Q8_0 quants at a time
+ for (int ib = ix; ib < nb; ib += nw/4) {
+ for (int i = 0; i < NB_Q8_0; ++i) {
yl[i] = yb[i];
}
for (int row = 0; row < nr; row++) {
- device const int8_t * qs = x[ib+row*nb].qs + 16*il;
+ device const int8_t * qs = x[ib+row*nb].qs + NB_Q8_0*il;
float sumq = 0.f;
- for (int iq = 0; iq < 16; ++iq) {
+ for (int iq = 0; iq < NB_Q8_0; ++iq) {
sumq += qs[iq] * yl[iq];
}
sumf[row] += sumq*x[ib+row*nb].d;
}
- yb += QK8_0 * 16;
+ yb += NB_Q8_0 * nw;
}
for (int row = 0; row < nr; ++row) {
}
}
-kernel void kernel_mul_mat_f16_f32(
+kernel void kernel_mul_mat_f16_f32_1row(
device const char * src0,
device const char * src1,
device float * dst,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
- threadgroup float * sum [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
- uint3 tpig[[thread_position_in_grid]],
- uint3 tpitg[[thread_position_in_threadgroup]],
- uint3 tptg[[threads_per_threadgroup]]) {
+ uint tiisg[[thread_index_in_simdgroup]]) {
const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
- uint ith = tpitg.x;
- uint nth = tptg.x;
+ float sumf = 0;
+ if (ne00 < 128) {
+ for (int i = tiisg; i < ne00; i += 32) {
+ sumf += (float) x[i] * (float) y[i];
+ }
+ float all_sum = simd_sum(sumf);
+ if (tiisg == 0) {
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ }
+ } else {
+ device const half4 * x4 = (device const half4 *) x;
+ device const float4 * y4 = (device const float4 *) y;
+ for (int i = tiisg; i < ne00/4; i += 32) {
+ for (int k = 0; k < 4; ++k) sumf += (float)x4[i][k] * y4[i][k];
+ }
+ float all_sum = simd_sum(sumf);
+ if (tiisg == 0) {
+ for (int i = 4*(ne00/4); i < ne00; ++i) sumf += (float) x[i] * y[i];
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ }
+ }
- sum[ith] = 0.0f;
+}
- for (int i = ith; i < ne00; i += nth) {
- sum[ith] += (float) x[i] * (float) y[i];
- }
+#define N_F16_F32 4
- // accumulate the sum from all threads in the threadgroup
- threadgroup_barrier(mem_flags::mem_threadgroup);
- if (ith%4 == 0) {
- for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i];
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- if (ith%16 == 0) {
- for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i];
- }
- threadgroup_barrier(mem_flags::mem_threadgroup);
- if (ith == 0) {
- for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
- dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
+kernel void kernel_mul_mat_f16_f32(
+ device const char * src0,
+ device const char * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant int64_t & ne10,
+ constant int64_t & ne11,
+ constant int64_t & ne12,
+ constant uint64_t & nb10,
+ constant uint64_t & nb11,
+ constant uint64_t & nb12,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]]) {
+
+ const int64_t r0 = tgpig.x;
+ const int64_t rb = N_F16_F32*tgpig.y;
+ const int64_t im = tgpig.z;
+
+ device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
+
+ if (ne00 < 128) {
+ for (int row = 0; row < N_F16_F32; ++row) {
+ int r1 = rb + row;
+ if (r1 >= ne11) {
+ break;
+ }
+
+ device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
+
+ float sumf = 0;
+ for (int i = tiisg; i < ne00; i += 32) {
+ sumf += (float) x[i] * (float) y[i];
+ }
+
+ float all_sum = simd_sum(sumf);
+ if (tiisg == 0) {
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ }
+ }
+ } else {
+ device const half4 * x4 = (device const half4 *)x;
+ for (int row = 0; row < N_F16_F32; ++row) {
+ int r1 = rb + row;
+ if (r1 >= ne11) {
+ break;
+ }
+
+ device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
+ device const float4 * y4 = (device const float4 *) y;
+
+ float sumf = 0;
+ for (int i = tiisg; i < ne00/4; i += 32) {
+ for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
+ }
+
+ float all_sum = simd_sum(sumf);
+ if (tiisg == 0) {
+ for (int i = 4*(ne00/4); i < ne00; ++i) sumf += (float) x[i] * y[i];
+ dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
+ }
+ }
}
- // Original implementation. Left behind commented out for now
- //threadgroup_barrier(mem_flags::mem_threadgroup);
- //for (uint i = tptg.x/2; i > 0; i /= 2) {
- // if (tpitg.x < i) {
- // sum[tpitg.x] += sum[tpitg.x + i];
- // }
- // threadgroup_barrier(mem_flags::mem_threadgroup);
- //}
- //
- //if (tpitg.x == 0) {
- // dst[im*ne1*ne0 + r1*ne0 + r0] = sum[0];
- //}
}
kernel void kernel_alibi_f32(
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int r2 = tgpig.z;
- const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+ //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
+ const int first_row = r0 * N_DST;
const int ib_row = first_row * nb;
const uint offset0 = r2/gqa*(nb*ne0);
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;