float max_bias;
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
- if (__builtin_popcount(n_head) != 1) {
- GGML_ASSERT(false && "only power-of-two n_head implemented");
- }
-
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
+ const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
[encoder setComputePipelineState:ctx->pipeline_alibi_f32];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
- [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
+ [encoder setBytes:&m0 length:sizeof( float) atIndex:18];
+ [encoder setBytes:&m1 length:sizeof( float) atIndex:19];
+ [encoder setBytes:&n_heads_log2_floor length:sizeof(int) atIndex:20];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
constant uint64_t & nb1,
constant uint64_t & nb2,
constant uint64_t & nb3,
- constant float & m0,
+ constant float & m0,
+ constant float & m1,
+ constant int & n_heads_log2_floor,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0);
- float m_k = pow(m0, i2 + 1);
+ float m_k;
+ if (i2 < n_heads_log2_floor) {
+ m_k = pow(m0, i2 + 1);
+ } else {
+ m_k = pow(m1, 2 * (i2 - n_heads_log2_floor) + 1);
+ }
for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1);