if (op->src[0]->ne[0] == 256) {
return false;
}
- {
- float logit_softcap;
-
- memcpy(&logit_softcap, ((const float *) op->op_params) + 2, sizeof(logit_softcap));
-
- if (logit_softcap != 0.0f) {
- return false;
- }
- }
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
float scale;
float max_bias;
+ float logit_softcap;
+ memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
+ memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
+ memcpy(&logit_softcap, ((int32_t *) dst->op_params) + 2, sizeof(logit_softcap));
- memcpy(&scale, ((int32_t *) dst->op_params) + 0, sizeof(scale));
- memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
+ if (logit_softcap != 0.0f) {
+ scale /= logit_softcap;
+ }
const uint32_t n_head = src0->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
} else {
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
}
- [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
- [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
- [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
- [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
- [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
- [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
- [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
- [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
- [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
- [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
- [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
- [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
- [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
- [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
- [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
- [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
- [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
- [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
- [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
- [encoder setBytes:&scale length:sizeof( float) atIndex:23];
- [encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
- [encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
- [encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
- [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
+ [encoder setBuffer:id_dst offset:offs_dst atIndex:4];
+ [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5];
+ [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6];
+ [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7];
+ [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8];
+ [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9];
+ [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10];
+ [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11];
+ [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12];
+ [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13];
+ [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14];
+ [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15];
+ [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16];
+ [encoder setBytes:&nb21 length:sizeof(uint64_t) atIndex:17];
+ [encoder setBytes:&nb22 length:sizeof(uint64_t) atIndex:18];
+ [encoder setBytes:&nb23 length:sizeof(uint64_t) atIndex:19];
+ [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:20];
+ [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:21];
+ [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:22];
+ [encoder setBytes:&scale length:sizeof( float) atIndex:23];
+ [encoder setBytes:&max_bias length:sizeof( float) atIndex:24];
+ [encoder setBytes:&m0 length:sizeof(m0) atIndex:25];
+ [encoder setBytes:&m1 length:sizeof(m1) atIndex:26];
+ [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:27];
+ [encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:28];
if (!use_vec_kernel) {
// half8x8 kernel
constant float & m0,
constant float & m1,
constant uint32_t & n_head_log2,
+ constant float & logit_softcap,
threadgroup half * shared,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
constant float & m0,
constant float & m1,
constant uint32_t & n_head_log2,
+ constant float & logit_softcap,
threadgroup half * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
const short tx = tiisg%4;
const short ty = tiisg/4;
+ // mqk = mqk*scale
+ ss[8*cc + ty*TF + 2*tx + 0] *= scale;
+ ss[8*cc + ty*TF + 2*tx + 1] *= scale;
+
+ if (logit_softcap != 0.0f) {
+ ss[8*cc + ty*TF + 2*tx + 0] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 0]);
+ ss[8*cc + ty*TF + 2*tx + 1] = logit_softcap*precise::tanh(ss[8*cc + ty*TF + 2*tx + 1]);
+ }
+
if (mask != q) {
- // mqk = mqk*scale + mask*slope
- ss[8*cc + ty*TF + 2*tx + 0] = scale*ss[8*cc + ty*TF + 2*tx + 0] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
- ss[8*cc + ty*TF + 2*tx + 1] = scale*ss[8*cc + ty*TF + 2*tx + 1] + slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
- } else {
- // mqk = mqk*scale
- ss[8*cc + ty*TF + 2*tx + 0] *= scale;
- ss[8*cc + ty*TF + 2*tx + 1] *= scale;
+ // mqk = mqk + mask*slope
+ ss[8*cc + ty*TF + 2*tx + 0] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 0];
+ ss[8*cc + ty*TF + 2*tx + 1] += slope*mp[ic + 8*cc + ty*nb31/sizeof(half) + 2*tx + 1];
}
}
}
constant float & m0,
constant float & m1,
constant uint32_t & n_head_log2,
+ constant float & logit_softcap,
threadgroup half * shared [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
// mqk = mqk*scale + mask*slope
if (tiisg == 0) {
- mqk = mqk*scale + ((mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f);
+ mqk *= scale;
+
+ if (logit_softcap != 0.0f) {
+ mqk = logit_softcap*precise::tanh(mqk);
+ }
+
+ mqk += (mask != q) ? ((float4) mp4[ic/4 + cc])*slope : (float4) 0.0f;
ss4[cc] = mqk;
}