GGML_METAL_DECL_KERNEL(get_rows_q5_k);
GGML_METAL_DECL_KERNEL(get_rows_q6_k);
GGML_METAL_DECL_KERNEL(rms_norm);
+ GGML_METAL_DECL_KERNEL(norm);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q5_k_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q6_k_f32);
GGML_METAL_DECL_KERNEL(rope);
+ GGML_METAL_DECL_KERNEL(alibi_f32);
GGML_METAL_DECL_KERNEL(cpy_f32_f16);
GGML_METAL_DECL_KERNEL(cpy_f32_f32);
+ GGML_METAL_DECL_KERNEL(cpy_f16_f16);
#undef GGML_METAL_DECL_KERNEL
};
GGML_METAL_ADD_KERNEL(get_rows_q5_k);
GGML_METAL_ADD_KERNEL(get_rows_q6_k);
GGML_METAL_ADD_KERNEL(rms_norm);
+ GGML_METAL_ADD_KERNEL(norm);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q5_k_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q6_k_f32);
GGML_METAL_ADD_KERNEL(rope);
+ GGML_METAL_ADD_KERNEL(alibi_f32);
GGML_METAL_ADD_KERNEL(cpy_f32_f16);
GGML_METAL_ADD_KERNEL(cpy_f32_f32);
+ GGML_METAL_ADD_KERNEL(cpy_f16_f16);
#undef GGML_METAL_ADD_KERNEL
}
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
+ case GGML_OP_NORM:
+ {
+ if (encoder == nil) {
+ encoder = [command_buffer computeCommandEncoder];
+ }
+
+ const float eps = 1e-5f;
+
+ const int nth = 256;
+
+ [encoder setComputePipelineState:ctx->pipeline_norm];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
+ [encoder setBytes:&eps length:sizeof( float) atIndex:4];
+ [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
+
+ const int64_t nrows = ggml_nrows(src0);
+
+ [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
+ case GGML_OP_ALIBI:
+ {
+ GGML_ASSERT((src0t == GGML_TYPE_F32));
+ const int n_past = ((int32_t *) src1->data)[0];
+ const int n_head = ((int32_t *) src1->data)[1];
+ const float max_bias = ((float *) src1->data)[2];
+ if (__builtin_popcount(n_head) != 1) {
+ GGML_ASSERT(false && "only power-of-two n_head implemented");
+ }
+ const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
+ const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
+ if (encoder == nil) {
+ encoder = [command_buffer computeCommandEncoder];
+ }
+ [encoder setComputePipelineState:ctx->pipeline_alibi_f32];
+ [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
+ [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
+ [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
+ [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
+ [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
+ [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
+ [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
+ [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
+ [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
+ [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
+ const int nth = 32;
+ [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
+ } break;
case GGML_OP_ROPE:
{
if (encoder == nil) {
default: GGML_ASSERT(false && "not implemented");
};
} break;
+ case GGML_TYPE_F16:
+ {
+ switch (dstt) {
+ case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f16_f16]; break;
+ case GGML_TYPE_F32: GGML_ASSERT(false && "cpy_f16_f32 not implemented"); break;
+ default: GGML_ASSERT(false && "not implemented");
+ };
+ } break;
default: GGML_ASSERT(false && "not implemented");
}
(device float *) ((device char *) dst + i*nb1), ne00);
}
+kernel void kernel_norm(
+ device const void * src0,
+ device float * dst,
+ constant int64_t & ne00,
+ constant uint64_t & nb01,
+ constant float & eps,
+ threadgroup float * sum [[threadgroup(0)]],
+ uint tgpig[[threadgroup_position_in_grid]],
+ uint tpitg[[thread_position_in_threadgroup]],
+ uint ntg[[threads_per_threadgroup]]) {
+ device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01);
+ // MEAN
+ // parallel sum
+ sum[tpitg] = 0.0f;
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+ sum[tpitg] += x[i00];
+ }
+ // reduce
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ for (uint i = ntg/2; i > 0; i /= 2) {
+ if (tpitg < i) {
+ sum[tpitg] += sum[tpitg + i];
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+ // broadcast
+ if (tpitg == 0) {
+ sum[0] /= ne00;
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ const float mean = sum[0];
+
+ // recenter
+ device float * y = dst + tgpig*ne00;
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+ y[i00] = x[i00] - mean;
+ }
+
+ // VARIANCE
+ // parallel sum
+ sum[tpitg] = 0.0f;
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+ sum[tpitg] += y[i00] * y[i00];
+ }
+ // reduce
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ for (uint i = ntg/2; i > 0; i /= 2) {
+ if (tpitg < i) {
+ sum[tpitg] += sum[tpitg + i];
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ }
+ // broadcast
+ if (tpitg == 0) {
+ sum[0] /= ne00;
+ }
+ threadgroup_barrier(mem_flags::mem_threadgroup);
+ const float variance = sum[0];
+
+ const float scale = 1.0f/sqrt(variance + eps);
+ for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+ y[i00] = y[i00] * scale;
+ }
+}
+
+
kernel void kernel_rms_norm(
device const void * src0,
device float * dst,
}
}
+kernel void kernel_alibi_f32(
+ device const float * src0,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ constant float & m0,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig[2];
+ const int64_t i02 = tgpig[1];
+ const int64_t i01 = tgpig[0];
+
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+ const int64_t i3 = n / (ne2*ne1*ne0);
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+ device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+ float m_k = pow(m0, i2 + 1);
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
+ device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+ dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);
+ }
+}
+
kernel void kernel_rope(
device const void * src0,
device float * dst,
}
}
+kernel void kernel_cpy_f16_f16(
+ device const half * src0,
+ device half * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01,
+ constant int64_t & ne02,
+ constant int64_t & ne03,
+ constant uint64_t & nb00,
+ constant uint64_t & nb01,
+ constant uint64_t & nb02,
+ constant uint64_t & nb03,
+ constant int64_t & ne0,
+ constant int64_t & ne1,
+ constant int64_t & ne2,
+ constant int64_t & ne3,
+ constant uint64_t & nb0,
+ constant uint64_t & nb1,
+ constant uint64_t & nb2,
+ constant uint64_t & nb3,
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint3 tpitg[[thread_position_in_threadgroup]],
+ uint3 ntg[[threads_per_threadgroup]]) {
+ const int64_t i03 = tgpig[2];
+ const int64_t i02 = tgpig[1];
+ const int64_t i01 = tgpig[0];
+
+ const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+
+ const int64_t i3 = n / (ne2*ne1*ne0);
+ const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
+ const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
+ const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
+
+ device half * dst_data = (device half *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
+
+ for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
+ device const half * src = (device half *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
+ dst_data[i00] = src[0];
+ }
+}
+
kernel void kernel_cpy_f32_f16(
device const float * src0,
device half * dst,