]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : full ALiBi support (#7192)
authorGeorgi Gerganov <redacted>
Sat, 11 May 2024 07:32:41 +0000 (10:32 +0300)
committerGitHub <redacted>
Sat, 11 May 2024 07:32:41 +0000 (10:32 +0300)
* ggml : full ALiBi support

* ggml : update ggml_soft_max_ext() CUDA, SYCL

* ggml : ggml_flash_attn_ext() support ALiBi (CPU)

* ggml : ggml_flash_attn_ext() support ALiBi (Metal)

* ggml : fix warning

* ggml : ggml_flash_attn_ext() support ALiBi (CUDA)

ggml-ci

* ggml : fix assert message

* vulkan : add dev notes

* ggml : require mask when using ALiBi

ggml-ci

* convert : fix convert for refact models

16 files changed:
convert-hf-to-gguf.py
ggml-cuda.cu
ggml-cuda/alibi.cu [deleted file]
ggml-cuda/alibi.cuh [deleted file]
ggml-cuda/fattn.cu
ggml-cuda/softmax.cu
ggml-kompute.cpp
ggml-metal.m
ggml-metal.metal
ggml-sycl.cpp
ggml-vulkan.cpp
ggml.c
ggml.h
gguf-py/gguf/tensor_mapping.py
llama.cpp
tests/test-backend-ops.cpp

index 1dc18b2a5772189314f5bafd01f8e044f9b3c227..3315ca74b044be3a491b3f3e40f5194eed1fe5a7 100755 (executable)
@@ -1013,6 +1013,18 @@ class StarCoderModel(Model):
 class RefactModel(Model):
     model_arch = gguf.MODEL_ARCH.REFACT
 
+    def set_vocab(self):
+        super().set_vocab()
+
+        # TODO: how to determine special FIM tokens automatically?
+        special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=False,
+                                          special_token_types = ['prefix', 'suffix', 'middle', 'fsep', 'eot'])
+        special_vocab._set_special_token("prefix", 1)
+        special_vocab._set_special_token("suffix", 3)
+        special_vocab._set_special_token("middle", 2)
+        special_vocab._set_special_token("fsep",   4) # is this correct?
+        special_vocab.add_to_gguf(self.gguf_writer)
+
     def set_gguf_parameters(self):
         hidden_dim = self.hparams["n_embd"]
         inner_dim = 4 * hidden_dim
index 6f89a7cc3e9005250a11f2e0ab32115337001c9a..c5c7787969f5fa3ec1cc0fa5fba64e1714253551 100644 (file)
@@ -4,7 +4,6 @@
 
 #include "ggml-cuda/common.cuh"
 #include "ggml-cuda/acc.cuh"
-#include "ggml-cuda/alibi.cuh"
 #include "ggml-cuda/arange.cuh"
 #include "ggml-cuda/argsort.cuh"
 #include "ggml-cuda/binbcast.cuh"
@@ -2277,9 +2276,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
         case GGML_OP_ROPE:
             ggml_cuda_op_rope(ctx, dst);
             break;
-        case GGML_OP_ALIBI:
-            ggml_cuda_op_alibi(ctx, dst);
-            break;
         case GGML_OP_IM2COL:
             ggml_cuda_op_im2col(ctx, dst);
             break;
@@ -2829,7 +2825,6 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
         case GGML_OP_DIAG_MASK_INF:
         case GGML_OP_SOFT_MAX:
         case GGML_OP_ROPE:
-        case GGML_OP_ALIBI:
         case GGML_OP_IM2COL:
         case GGML_OP_POOL_2D:
         case GGML_OP_SUM_ROWS:
diff --git a/ggml-cuda/alibi.cu b/ggml-cuda/alibi.cu
deleted file mode 100644 (file)
index 6c7f1fd..0000000
+++ /dev/null
@@ -1,63 +0,0 @@
-#include "alibi.cuh"
-
-static __global__ void alibi_f32(const float * x, float * dst, const int ncols, const int k_rows,
-                                 const int n_heads_log2_floor, const float m0, const float m1) {
-    const int col = blockDim.x*blockIdx.x + threadIdx.x;
-
-    if (col >= ncols) {
-        return;
-    }
-
-    const int row = blockDim.y*blockIdx.y + threadIdx.y;
-    const int i = row*ncols + col;
-
-    const int k = row/k_rows;
-
-    float m_k;
-    if (k < n_heads_log2_floor) {
-        m_k = powf(m0, k + 1);
-    } else {
-        m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
-    }
-
-    dst[i] = col * m_k + x[i];
-}
-
-static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows,
-                           const int k_rows, const int n_heads_log2_floor, const float m0,
-                           const float m1, cudaStream_t stream) {
-    const dim3 block_dims(CUDA_ALIBI_BLOCK_SIZE, 1, 1);
-    const int num_blocks_x = (ncols + CUDA_ALIBI_BLOCK_SIZE - 1) / (CUDA_ALIBI_BLOCK_SIZE);
-    const dim3 block_nums(num_blocks_x, nrows, 1);
-    alibi_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, k_rows, n_heads_log2_floor, m0, m1);
-}
-
-void ggml_cuda_op_alibi(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
-    const ggml_tensor * src0 = dst->src[0];
-    const float * src0_d = (const float *)src0->data;
-    float * dst_d = (float *)dst->data;
-    cudaStream_t stream = ctx.stream();
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    const int64_t ne00 = src0->ne[0];
-    const int64_t ne01 = src0->ne[1];
-    const int64_t ne02 = src0->ne[2];
-    const int64_t nrows = ggml_nrows(src0);
-
-    //const int n_past = ((int32_t *) dst->op_params)[0];
-    const int n_head = ((int32_t *) dst->op_params)[1];
-    float max_bias;
-    memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
-
-    //GGML_ASSERT(ne01 + n_past == ne00);
-    GGML_ASSERT(n_head == ne02);
-
-    const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
-
-    const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
-    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
-
-    alibi_f32_cuda(src0_d, dst_d, ne00, nrows, ne01, n_heads_log2_floor, m0, m1, stream);
-}
diff --git a/ggml-cuda/alibi.cuh b/ggml-cuda/alibi.cuh
deleted file mode 100644 (file)
index 630adfc..0000000
+++ /dev/null
@@ -1,5 +0,0 @@
-#include "common.cuh"
-
-#define CUDA_ALIBI_BLOCK_SIZE 32
-
-void ggml_cuda_op_alibi(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
index 7c486f4829bdd7868a99d028bbea2bbc0f4c8b2f..ac5d6672b30d342912d087b566b3784bd3bba710 100644 (file)
@@ -23,6 +23,10 @@ static __global__ void flash_attn_vec_ext_f16(
         float      * __restrict__ dst,
         float2     * __restrict__ dst_meta,
         const float scale,
+        const float max_bias,
+        const float m0,
+        const float m1,
+        const uint32_t n_head_log2,
         const int ne00,
         const int ne01,
         const int ne02,
@@ -58,6 +62,18 @@ static __global__ void flash_attn_vec_ext_f16(
     const int stride_KV  = nb11 / sizeof(half);
     const int stride_KV2 = nb11 / sizeof(half2);
 
+    half slopeh = __float2half(1.0f);
+
+    // ALiBi
+    if (max_bias > 0.0f) {
+        const int h = blockIdx.y;
+
+        const float base = h < n_head_log2 ? m0 : m1;
+        const int   exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+        slopeh = __float2half(powf(base, exph));
+    }
+
     static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
     constexpr int nwarps = D / WARP_SIZE;
     const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
@@ -141,7 +157,7 @@ static __global__ void flash_attn_vec_ext_f16(
             for (int j = 0; j < ncols; ++j) {
                 sum2[j] = warp_reduce_sum(sum2[j]);
                 half sum = __low2half(sum2[j]) + __high2half(sum2[j]);
-                sum += mask ? maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
+                sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
 
                 if (ncols == 1) {
                     kqmax_new        = ggml_cuda_hmax(kqmax_new,        sum);
@@ -249,6 +265,10 @@ static __global__ void flash_attn_ext_f16(
         float      * __restrict__ dst,
         float2     * __restrict__ dst_meta,
         const float scale,
+        const float max_bias,
+        const float m0,
+        const float m1,
+        const uint32_t n_head_log2,
         const int ne00,
         const int ne01,
         const int ne02,
@@ -305,6 +325,20 @@ static __global__ void flash_attn_ext_f16(
     const int stride_Q  = nb01 / sizeof(float);
     const int stride_KV = nb11 / sizeof(half);
 
+    half  slopeh = __float2half(1.0f);
+    half2 slope2 = make_half2(1.0f, 1.0f);
+
+    // ALiBi
+    if (max_bias > 0.0f) {
+        const int h = blockIdx.y;
+
+        const float base = h < n_head_log2 ? m0 : m1;
+        const int   exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+        slopeh = __float2half(powf(base, exph));
+        slope2 = make_half2(slopeh, slopeh);
+    }
+
     frag_b Q_b[D/16][ncols/frag_n];
 
     // A single buffer for temporarily holding tiles of KQ and VKQ parts:
@@ -421,7 +455,7 @@ static __global__ void flash_attn_ext_f16(
                 for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
                     const int k = k0 + threadIdx.x;
 
-                    KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
+                    KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
                     KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
                 }
                 KQ_max_new = warp_reduce_max(KQ_max_new);
@@ -464,7 +498,7 @@ static __global__ void flash_attn_ext_f16(
                 for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
                     const int k = k0 + threadIdx.x;
 
-                    KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
+                    KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
                     KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
                 }
                 KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
@@ -710,8 +744,17 @@ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_
     const     dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
     const     int  shmem = 0;
 
-    float scale;
-    memcpy(&scale, KQV->op_params, sizeof(float));
+    float scale    = 1.0f;
+    float max_bias = 0.0f;
+
+    memcpy(&scale,    (float *) KQV->op_params + 0, sizeof(float));
+    memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
+
+    const uint32_t n_head      = Q->ne[2];
+    const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
+
+    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
+    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
     flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks>
         <<<blocks_num, block_dim, shmem, main_stream>>> (
@@ -720,7 +763,7 @@ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_
                 (const char *) V->data,
                 mask ? ((const char *) mask->data) : nullptr,
                 parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
-                scale,
+                scale, max_bias, m0, m1, n_head_log2,
                 Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
                 K->ne[0], K->ne[1], K->ne[2], K->ne[3],
                 mask ? mask->ne[1] : 0, mask ?  mask->nb[1] : 0,
@@ -761,8 +804,17 @@ template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename K
     const     dim3 blocks_num(parallel_blocks*(Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]);
     const     int  shmem = 0;
 
-    float scale;
-    memcpy(&scale, KQV->op_params, sizeof(float));
+    float scale    = 1.0f;
+    float max_bias = 0.0f;
+
+    memcpy(&scale,    (float *) KQV->op_params + 0, sizeof(float));
+    memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
+
+    const uint32_t n_head      = Q->ne[2];
+    const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
+
+    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
+    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
     flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t>
         <<<blocks_num, block_dim, shmem, main_stream>>> (
@@ -771,7 +823,7 @@ template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename K
                 (const char *) V->data,
                 mask ? ((const char *) mask->data) : nullptr,
                 (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
-                scale,
+                scale, max_bias, m0, m1, n_head_log2,
                 Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
                 K->ne[0], K->ne[1], K->ne[2], K->ne[3],
                 mask ? mask->ne[1] : 0, mask ?  mask->nb[1] : 0,
@@ -837,7 +889,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
     const int cc  = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
     const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
 
-    const int32_t precision = KQV->op_params[1];
+    const int32_t precision = KQV->op_params[2];
 
     if (!fp16_mma_available(cc)) {
         GGML_ASSERT(precision == GGML_PREC_DEFAULT);
index 6ed225999bddfc914118a6fda3536222047bf6aa..ca85285a3f4583e80d7b487a4d32bf191c332b43 100644 (file)
@@ -11,7 +11,7 @@ __device__ float __forceinline__ t2f32<half>(half val) {
 }
 
 template <bool vals_smem, int ncols_template, int block_size_template, typename T>
-static __global__ void soft_max_f32(const float * x, const T * mask, const T * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
+static __global__ void soft_max_f32(const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
     const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
 
     const int tid  = threadIdx.x;
@@ -23,16 +23,16 @@ static __global__ void soft_max_f32(const float * x, const T * mask, const T * p
     const int warp_id = threadIdx.x / WARP_SIZE;
     const int lane_id = threadIdx.x % WARP_SIZE;
 
-    float slope = 0.0f;
+    float slope = 1.0f;
 
     // ALiBi
     if (max_bias > 0.0f) {
         const int h = rowx/nrows_y; // head index
 
         const float base = h < n_head_log2 ? m0 : m1;
-        const int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+        const int   exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
 
-        slope = powf(base, exp);
+        slope = powf(base, exph);
     }
 
     extern __shared__ float data_soft_max_f32[];
@@ -53,7 +53,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, const T * p
         const int64_t ix = (int64_t)rowx*ncols + col;
         const int64_t iy = (int64_t)rowy*ncols + col;
 
-        const float val = x[ix]*scale + (mask ? t2f32(mask[iy]) : 0.0f) + (pos ? slope*t2f32(pos[col]) : 0.0f);
+        const float val = x[ix]*scale + (mask ? slope*t2f32(mask[iy]) : 0.0f);
 
         vals[col] = val;
         max_val = max(max_val, val);
@@ -125,7 +125,7 @@ static __global__ void soft_max_f32(const float * x, const T * mask, const T * p
 }
 
 template<typename T>
-static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
+static void soft_max_f32_cuda(const float * x, const T * mask, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
     int nth = WARP_SIZE;
     while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
     const dim3 block_dims(nth,     1, 1);
@@ -133,8 +133,8 @@ static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, fl
     const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
     static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
 
-    const uint32_t n_head_kv   = nrows_x/nrows_y;
-    const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
+    const uint32_t n_head      = nrows_x/nrows_y;
+    const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
 
     const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
@@ -142,43 +142,42 @@ static void soft_max_f32_cuda(const float * x, const T * mask, const T * pos, fl
     if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
         switch (ncols_x) {
             case 32:
-                soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 64:
-                soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 128:
-                soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 256:
-                soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 512:
-                soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 1024:
-                soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 2048:
-                soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 4096:
-                soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             default:
-                soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+                soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
         }
     } else {
         const size_t shmem_low = WARP_SIZE*sizeof(float);
-        soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
+        soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
     }
 }
 
 void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     const ggml_tensor * src0 = dst->src[0];
     const ggml_tensor * src1 = dst->src[1];
-    const ggml_tensor * src2 = dst->src[2];
 
     const float * src0_d = (const float *)src0->data;
     const void  * src1_d = src1 ? (const void *)src1->data : nullptr;
@@ -190,7 +189,6 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
     GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
-    GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional
 
     const int64_t ne00    = src0->ne[0];
     const int64_t nrows_x = ggml_nrows(src0);
@@ -202,26 +200,15 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
     memcpy(&scale,    (float *) dst->op_params + 0, sizeof(float));
     memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
 
-    // positions tensor
-    void * src2_d = nullptr;
-
-    const bool use_src2 = src2 != nullptr;
-
-    if (use_src2) {
-        src2_d = (void *)src2->data;
-    }
-
-    const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
+    const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
 
     if (use_f16) {
         const half * src1_dd = (const half *)src1_d;
-        const half * src2_dd = (const half *)src2_d;
 
-        soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
+        soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
     } else {
         const float * src1_dd = (const float *)src1_d;
-        const float * src2_dd = (const float *)src2_d;
 
-        soft_max_f32_cuda(src0_d, src1_dd, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
+        soft_max_f32_cuda(src0_d, src1_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
     }
 }
index 9a469821d804214afb3d99bb83a7cfa31e093349..3f033d58be481544d3da127a8a23f6164cbfb465 100644 (file)
@@ -1559,12 +1559,18 @@ static void ggml_vk_graph_compute(struct ggml_kompute_context * ctx, struct ggml
                 case GGML_OP_SOFT_MAX:
                     {
                         float scale;
-                        memcpy(&scale, dst->op_params, sizeof(float));
+                        float max_bias;
 
-#pragma message("TODO: add ggml_vk_soft_max() F16/F32 src1 and src2 support")
+                        memcpy(&scale,    (float *)dst->op_params + 0, sizeof(float));
+                        memcpy(&max_bias, (float *)dst->op_params + 1, sizeof(float));
+
+#pragma message("TODO: add ggml_vk_soft_max() F16 src1 support")
 #pragma message("ref:  https://github.com/ggerganov/llama.cpp/pull/5021")
                         GGML_ASSERT(!src1 || src1t == GGML_TYPE_F32);
-                        GGML_ASSERT(src2 == nullptr);
+
+#pragma message("TODO: add ALiBi support")
+#pragma message("ref:  https://github.com/ggerganov/llama.cpp/pull/7192")
+                        GGML_ASSERT(max_bias == 0.0f);
 
                         ggml_vk_soft_max(seq, id_src0, id_src1, id_dst, off_src0, off_src1, off_dst, ne00, ne01, ne02, ne03, scale);
                     } break;
index 18ce5b88a1c01748c2acb9c0bb2a959c3242f62f..1bbb8fb4fefa190937dc3b267a35ef8765ce45da 100644 (file)
@@ -169,7 +169,6 @@ enum ggml_metal_kernel_type {
     GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,
     GGML_METAL_KERNEL_TYPE_ROPE_F32,
     GGML_METAL_KERNEL_TYPE_ROPE_F16,
-    GGML_METAL_KERNEL_TYPE_ALIBI_F32,
     GGML_METAL_KERNEL_TYPE_IM2COL_F16,
     GGML_METAL_KERNEL_TYPE_IM2COL_F32,
     GGML_METAL_KERNEL_TYPE_UPSCALE_F32,
@@ -623,7 +622,6 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ4_XS_F32,          mul_mm_id_iq4_xs_f32,           ctx->support_simdgroup_mm);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32,                      rope_f32,                       true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16,                      rope_f16,                       true);
-        GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32,                     alibi_f32,                      true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16,                    im2col_f16,                     true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32,                    im2col_f32,                     true);
         GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32,                   upscale_f32,                    true);
@@ -759,7 +757,6 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
         case GGML_OP_GROUP_NORM:
             return ctx->support_simdgroup_reduction;
         case GGML_OP_NORM:
-        case GGML_OP_ALIBI:
         case GGML_OP_ROPE:
         case GGML_OP_IM2COL:
             return true;
@@ -1358,13 +1355,12 @@ static enum ggml_status ggml_metal_graph_compute(
                 case GGML_OP_SOFT_MAX:
                     {
                         GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16 || src1->type == GGML_TYPE_F32);
-                        GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F16 || src2->type == GGML_TYPE_F32);
 
                         int nth = 32; // SIMD width
 
                         id<MTLComputePipelineState> pipeline = nil;
 
-                        const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
+                        const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
 
                         if (ne00%4 == 0) {
                             while (nth < ne00/4 && nth < 256) {
@@ -1395,8 +1391,8 @@ static enum ggml_status ggml_metal_graph_compute(
                         const int64_t nrows_x = ggml_nrows(src0);
                         const int64_t nrows_y = src0->ne[1];
 
-                        const uint32_t n_head_kv   = nrows_x/nrows_y;
-                        const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
+                        const uint32_t n_head      = nrows_x/nrows_y;
+                        const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
 
                         const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
                         const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
@@ -1408,20 +1404,15 @@ static enum ggml_status ggml_metal_graph_compute(
                         } else {
                             [encoder setBuffer:id_src0 offset:offs_src0   atIndex:1];
                         }
-                        if (id_src2) {
-                            [encoder setBuffer:id_src2 offset:offs_src2   atIndex:2];
-                        } else {
-                            [encoder setBuffer:id_src0 offset:offs_src0   atIndex:2];
-                        }
-                        [encoder setBuffer:id_dst   offset:offs_dst          atIndex:3];
-                        [encoder setBytes:&ne00     length:sizeof(ne00)      atIndex:4];
-                        [encoder setBytes:&ne01     length:sizeof(ne01)      atIndex:5];
-                        [encoder setBytes:&ne02     length:sizeof(ne02)      atIndex:6];
-                        [encoder setBytes:&scale    length:sizeof(scale)     atIndex:7];
-                        [encoder setBytes:&max_bias length:sizeof(max_bias)  atIndex:8];
-                        [encoder setBytes:&m0       length:sizeof(m0)        atIndex:9];
-                        [encoder setBytes:&m1       length:sizeof(m1)        atIndex:10];
-                        [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11];
+                        [encoder setBuffer:id_dst      offset:offs_dst            atIndex:2];
+                        [encoder setBytes:&ne00        length:sizeof(ne00)        atIndex:3];
+                        [encoder setBytes:&ne01        length:sizeof(ne01)        atIndex:4];
+                        [encoder setBytes:&ne02        length:sizeof(ne02)        atIndex:5];
+                        [encoder setBytes:&scale       length:sizeof(scale)       atIndex:6];
+                        [encoder setBytes:&max_bias    length:sizeof(max_bias)    atIndex:7];
+                        [encoder setBytes:&m0          length:sizeof(m0)          atIndex:8];
+                        [encoder setBytes:&m1          length:sizeof(m1)          atIndex:9];
+                        [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10];
                         [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
 
                         [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
@@ -2226,49 +2217,6 @@ static enum ggml_status ggml_metal_graph_compute(
 
                         [encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                     } break;
-                case GGML_OP_ALIBI:
-                    {
-                        GGML_ASSERT((src0t == GGML_TYPE_F32));
-
-                        const int nth = MIN(1024, ne00);
-
-                        //const int n_past = ((int32_t *) dst->op_params)[0];
-                        const int n_head = ((int32_t *) dst->op_params)[1];
-
-                        float max_bias;
-                        memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
-
-                        const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
-                        const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
-                        const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
-
-                        id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ALIBI_F32].pipeline;
-
-                        [encoder setComputePipelineState:pipeline];
-                        [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                        [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                        [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
-                        [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
-                        [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
-                        [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
-                        [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
-                        [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
-                        [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
-                        [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
-                        [encoder setBytes:&ne0  length:sizeof( int64_t) atIndex:10];
-                        [encoder setBytes:&ne1  length:sizeof( int64_t) atIndex:11];
-                        [encoder setBytes:&ne2  length:sizeof( int64_t) atIndex:12];
-                        [encoder setBytes:&ne3  length:sizeof( int64_t) atIndex:13];
-                        [encoder setBytes:&nb0  length:sizeof(uint64_t) atIndex:14];
-                        [encoder setBytes:&nb1  length:sizeof(uint64_t) atIndex:15];
-                        [encoder setBytes:&nb2  length:sizeof(uint64_t) atIndex:16];
-                        [encoder setBytes:&nb3  length:sizeof(uint64_t) atIndex:17];
-                        [encoder setBytes:&m0   length:sizeof(   float) atIndex:18];
-                        [encoder setBytes:&m1   length:sizeof(   float) atIndex:19];
-                        [encoder setBytes:&n_heads_log2_floor   length:sizeof(int) atIndex:20];
-
-                        [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
-                    } break;
                 case GGML_OP_ROPE:
                     {
                         GGML_ASSERT(ne10 == ne02);
@@ -2566,7 +2514,7 @@ static enum ggml_status ggml_metal_graph_compute(
                                 "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
 
                         const int64_t  ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
-                        const int64_t  ne31 = src3 ? src3->ne[1] : 0;
+                      //const int64_t  ne31 = src3 ? src3->ne[1] : 0;
                         const int64_t  ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
                         const int64_t  ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
 
@@ -2578,7 +2526,16 @@ static enum ggml_status ggml_metal_graph_compute(
                         const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
 
                         float scale;
-                        memcpy(&scale, dst->op_params, sizeof(float));
+                        float max_bias;
+
+                        memcpy(&scale,    ((int32_t *) dst->op_params) + 0, sizeof(scale));
+                        memcpy(&max_bias, ((int32_t *) dst->op_params) + 1, sizeof(max_bias));
+
+                        const uint32_t n_head      = src0->ne[2];
+                        const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
+
+                        const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
+                        const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
                         id<MTLComputePipelineState> pipeline = nil;
 
@@ -2615,34 +2572,37 @@ static enum ggml_status ggml_metal_graph_compute(
                         }
 
                         [encoder setComputePipelineState:pipeline];
-                        [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
-                        [encoder setBuffer:id_src1 offset:offs_src1        atIndex:1];
-                        [encoder setBuffer:id_src2 offset:offs_src2        atIndex:2];
-                        [encoder setBuffer:id_src3 offset:offs_src3        atIndex:3];
-                        [encoder setBuffer:id_dst  offset:offs_dst         atIndex:4];
-                        [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:5];
-                        [encoder setBytes:&ne01    length:sizeof( int64_t) atIndex:6];
-                        [encoder setBytes:&ne02    length:sizeof( int64_t) atIndex:7];
-                        [encoder setBytes:&ne03    length:sizeof( int64_t) atIndex:8];
-                        [encoder setBytes:&nb00    length:sizeof(uint64_t) atIndex:9];
-                        [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:10];
-                        [encoder setBytes:&nb02    length:sizeof(uint64_t) atIndex:11];
-                        [encoder setBytes:&nb03    length:sizeof(uint64_t) atIndex:12];
-                        [encoder setBytes:&ne10    length:sizeof( int64_t) atIndex:13];
-                        [encoder setBytes:&ne11    length:sizeof( int64_t) atIndex:14];
-                        [encoder setBytes:&ne12    length:sizeof( int64_t) atIndex:15];
-                        [encoder setBytes:&ne13    length:sizeof( int64_t) atIndex:16];
-                        [encoder setBytes:&nb10    length:sizeof(uint64_t) atIndex:17];
-                        [encoder setBytes:&nb11    length:sizeof(uint64_t) atIndex:18];
-                        [encoder setBytes:&nb12    length:sizeof(uint64_t) atIndex:19];
-                        [encoder setBytes:&nb13    length:sizeof(uint64_t) atIndex:20];
-                        [encoder setBytes:&ne31    length:sizeof( int64_t) atIndex:21];
-                        [encoder setBytes:&nb31    length:sizeof(uint64_t) atIndex:22];
-                        [encoder setBytes:&ne0     length:sizeof( int64_t) atIndex:23];
-                        [encoder setBytes:&ne1     length:sizeof( int64_t) atIndex:24];
-                        [encoder setBytes:&ne2     length:sizeof( int64_t) atIndex:25];
-                        [encoder setBytes:&ne3     length:sizeof( int64_t) atIndex:26];
-                        [encoder setBytes:&scale   length:sizeof(   float) atIndex:27];
+                        [encoder setBuffer:id_src0     offset:offs_src0           atIndex:0];
+                        [encoder setBuffer:id_src1     offset:offs_src1           atIndex:1];
+                        [encoder setBuffer:id_src2     offset:offs_src2           atIndex:2];
+                        [encoder setBuffer:id_src3     offset:offs_src3           atIndex:3];
+                        [encoder setBuffer:id_dst      offset:offs_dst            atIndex:4];
+                        [encoder setBytes:&ne00        length:sizeof( int64_t)    atIndex:5];
+                        [encoder setBytes:&ne01        length:sizeof( int64_t)    atIndex:6];
+                        [encoder setBytes:&ne02        length:sizeof( int64_t)    atIndex:7];
+                        [encoder setBytes:&ne03        length:sizeof( int64_t)    atIndex:8];
+                        [encoder setBytes:&nb00        length:sizeof(uint64_t)    atIndex:9];
+                        [encoder setBytes:&nb01        length:sizeof(uint64_t)    atIndex:10];
+                        [encoder setBytes:&nb02        length:sizeof(uint64_t)    atIndex:11];
+                        [encoder setBytes:&nb03        length:sizeof(uint64_t)    atIndex:12];
+                        [encoder setBytes:&ne10        length:sizeof( int64_t)    atIndex:13];
+                        [encoder setBytes:&ne11        length:sizeof( int64_t)    atIndex:14];
+                        [encoder setBytes:&ne12        length:sizeof( int64_t)    atIndex:15];
+                        [encoder setBytes:&ne13        length:sizeof( int64_t)    atIndex:16];
+                        [encoder setBytes:&nb10        length:sizeof(uint64_t)    atIndex:17];
+                        [encoder setBytes:&nb11        length:sizeof(uint64_t)    atIndex:18];
+                        [encoder setBytes:&nb12        length:sizeof(uint64_t)    atIndex:19];
+                        [encoder setBytes:&nb13        length:sizeof(uint64_t)    atIndex:20];
+                        [encoder setBytes:&nb31        length:sizeof(uint64_t)    atIndex:21];
+                        [encoder setBytes:&ne0         length:sizeof( int64_t)    atIndex:22];
+                        [encoder setBytes:&ne1         length:sizeof( int64_t)    atIndex:23];
+                        [encoder setBytes:&ne2         length:sizeof( int64_t)    atIndex:24];
+                        [encoder setBytes:&ne3         length:sizeof( int64_t)    atIndex:25];
+                        [encoder setBytes:&scale       length:sizeof(   float)    atIndex:26];
+                        [encoder setBytes:&max_bias    length:sizeof(   float)    atIndex:27];
+                        [encoder setBytes:&m0          length:sizeof(m0)          atIndex:28];
+                        [encoder setBytes:&m1          length:sizeof(m1)          atIndex:29];
+                        [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:30];
 
                         if (!use_vec_kernel) {
                             // half8x8 kernel
index 46c7d503930a4adea2f0d1eea3cc8305835342cc..ee9de57a3a6f1b144965063e162f34104cad67b2 100644 (file)
@@ -356,7 +356,6 @@ template<typename T>
 kernel void kernel_soft_max(
         device const  char * src0,
         device const  char * src1,
-        device const  char * src2,
         device        char * dst,
         constant   int64_t & ne00,
         constant   int64_t & ne01,
@@ -378,10 +377,9 @@ kernel void kernel_soft_max(
 
     device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
     device const     T * pmask = src1 != src0 ? (device const    T *) src1         + i01*ne00 : nullptr;
-    device const     T * ppos  = src2 != src0 ? (device const    T *) src2                    : nullptr;
     device       float * pdst  = (device       float *) dst  + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
 
-    float slope = 0.0f;
+    float slope = 1.0f;
 
     // ALiBi
     if (max_bias > 0.0f) {
@@ -397,7 +395,7 @@ kernel void kernel_soft_max(
     float lmax = -INFINITY;
 
     for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
-        lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f));
+        lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
     }
 
     // find the max value in the block
@@ -422,7 +420,7 @@ kernel void kernel_soft_max(
     // parallel sum
     float lsum = 0.0f;
     for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
-        const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val);
+        const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
         lsum += exp_psrc0;
         pdst[i00] = exp_psrc0;
     }
@@ -461,7 +459,6 @@ template<typename T>
 kernel void kernel_soft_max_4(
         device const  char * src0,
         device const  char * src1,
-        device const  char * src2,
         device        char * dst,
         constant   int64_t & ne00,
         constant   int64_t & ne01,
@@ -483,10 +480,9 @@ kernel void kernel_soft_max_4(
 
     device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
     device const      T * pmask = src1 != src0 ? (device const     T *) src1         + i01*ne00/4 : nullptr;
-    device const      T * ppos  = src2 != src0 ? (device const     T *) src2                      : nullptr;
     device       float4 * pdst4 = (device       float4 *) dst  + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
 
-    float slope = 0.0f;
+    float slope = 1.0f;
 
     if (max_bias > 0.0f) {
         const int64_t h = i02;
@@ -501,7 +497,7 @@ kernel void kernel_soft_max_4(
     float4 lmax4 = -INFINITY;
 
     for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
-        lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)));
+        lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f)));
     }
 
     const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
@@ -527,7 +523,7 @@ kernel void kernel_soft_max_4(
     // parallel sum
     float4 lsum4 = 0.0f;
     for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
-        const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f))) - max_val);
+        const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? slope*pmask[i00] : 0.0f))) - max_val);
         lsum4 += exp_psrc4;
         pdst4[i00] = exp_psrc4;
     }
@@ -1595,60 +1591,6 @@ kernel void kernel_mul_mv_f16_f32_l4(
     }
 }
 
-kernel void kernel_alibi_f32(
-        device const float * src0,
-        device       float * dst,
-        constant   int64_t & ne00,
-        constant   int64_t & ne01,
-        constant   int64_t & ne02,
-        constant   int64_t & ne03,
-        constant  uint64_t & nb00,
-        constant  uint64_t & nb01,
-        constant  uint64_t & nb02,
-        constant  uint64_t & nb03,
-        constant   int64_t & ne0,
-        constant   int64_t & ne1,
-        constant   int64_t & ne2,
-        constant   int64_t & ne3,
-        constant  uint64_t & nb0,
-        constant  uint64_t & nb1,
-        constant  uint64_t & nb2,
-        constant  uint64_t & nb3,
-        constant     float & m0,
-        constant     float & m1,
-        constant       int & n_heads_log2_floor,
-        uint3 tgpig[[threadgroup_position_in_grid]],
-        uint3 tpitg[[thread_position_in_threadgroup]],
-        uint3   ntg[[threads_per_threadgroup]]) {
-    const int64_t i03 = tgpig[2];
-    const int64_t i02 = tgpig[1];
-    const int64_t i01 = tgpig[0];
-
-    const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
-
-    const int64_t i3 = n / (ne2*ne1*ne0);
-    const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0);
-    const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0;
-  //const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0);
-
-    const int64_t k = i3*ne3 + i2;
-
-    float m_k;
-    if (k < n_heads_log2_floor) {
-        m_k = pow(m0, k + 1);
-    } else {
-        m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1);
-    }
-
-    device       char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1;
-    device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01;
-    for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
-        const  float   src_v = *(device float *)(src_row + i00*nb00);
-        device float * dst_v =  (device float *)(dst_row + i00*nb0);
-        *dst_v = i00 * m_k + src_v;
-    }
-}
-
 static float rope_yarn_ramp(const float low, const float high, const int i0) {
     const float y = (i0 / 2 - low) / max(0.001f, high - low);
     return 1.0f - min(1.0f, max(0.0f, y));
@@ -2116,13 +2058,16 @@ typedef void (flash_attn_ext_f16_t)(
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
         constant  uint64_t & nb13,
-        constant   int64_t & ne31,
         constant  uint64_t & nb31,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   int64_t & ne2,
         constant   int64_t & ne3,
         constant     float & scale,
+        constant     float & max_bias,
+        constant     float & m0,
+        constant     float & m1,
+        constant  uint32_t & n_head_log2,
         threadgroup   half * shared,
         uint3  tgpig[[threadgroup_position_in_grid]],
         uint3  tpitg[[thread_position_in_threadgroup]],
@@ -2154,13 +2099,16 @@ kernel void kernel_flash_attn_ext_f16(
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
         constant  uint64_t & nb13,
-        constant   int64_t & ne31,
         constant  uint64_t & nb31,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   int64_t & ne2,
         constant   int64_t & ne3,
         constant     float & scale,
+        constant     float & max_bias,
+        constant     float & m0,
+        constant     float & m1,
+        constant  uint32_t & n_head_log2,
         threadgroup   half * shared [[threadgroup(0)]],
         uint3  tgpig[[threadgroup_position_in_grid]],
         uint3  tpitg[[thread_position_in_threadgroup]],
@@ -2257,6 +2205,19 @@ kernel void kernel_flash_attn_ext_f16(
         // prepare diagonal scale matrix
         simdgroup_float8x8 mscale(scale);
 
+        // prepare diagonal slope matrix
+        simdgroup_float8x8 mslope(1.0f);
+
+        // ALiBi
+        if (max_bias > 0.0f) {
+            const short h = iq2;
+
+            const float base = h < n_head_log2 ? m0 : m1;
+            const int   exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+            mslope = simdgroup_float8x8(pow(base, exph));
+        }
+
         // loop over the KV cache
         // each simdgroup handles blocks of Q rows and C columns
         for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) {
@@ -2279,9 +2240,10 @@ kernel void kernel_flash_attn_ext_f16(
                         simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk);
                     }
 
-                    // mqk = mqk*scale + mask
+                    // mqk = mqk*scale + mask*slope
                     simdgroup_half8x8 mm;
                     simdgroup_load(mm, mp + ic + 8*cc, nb31/sizeof(half), 0, false);
+                    simdgroup_multiply(mm, mslope, mm);
                     simdgroup_multiply_accumulate(mqk, mqk, mscale, mm);
 
                     simdgroup_store(mqk, ss + 8*cc, TF, 0, false);
@@ -2472,13 +2434,16 @@ kernel void kernel_flash_attn_ext_vec_f16(
         constant  uint64_t & nb11,
         constant  uint64_t & nb12,
         constant  uint64_t & nb13,
-        constant   int64_t & ne31,
         constant  uint64_t & nb31,
         constant   int64_t & ne0,
         constant   int64_t & ne1,
         constant   int64_t & ne2,
         constant   int64_t & ne3,
         constant     float & scale,
+        constant     float & max_bias,
+        constant     float & m0,
+        constant     float & m1,
+        constant  uint32_t & n_head_log2,
         threadgroup   half * shared [[threadgroup(0)]],
         uint3  tgpig[[threadgroup_position_in_grid]],
         uint3  tpitg[[thread_position_in_threadgroup]],
@@ -2497,6 +2462,18 @@ kernel void kernel_flash_attn_ext_vec_f16(
 
     const short T  = D + 2*nsg*SH; // shared memory size per query in (half)
 
+    float slope = 1.0f;
+
+    // ALiBi
+    if (max_bias > 0.0f) {
+        const short h = iq2;
+
+        const float base = h < n_head_log2 ? m0 : m1;
+        const int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+        slope = pow(base, exp);
+    }
+
   //threadgroup half   * sq  = (threadgroup half   *) (shared +              0*D); // holds the query data
     threadgroup half4  * sq4 = (threadgroup half4  *) (shared +              0*D); // same as above but in half4
     threadgroup float  * ss  = (threadgroup float  *) (shared + 2*sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
@@ -2603,10 +2580,10 @@ kernel void kernel_flash_attn_ext_vec_f16(
                     mqk += simd_shuffle_down(mqk,  2);
                     mqk += simd_shuffle_down(mqk,  1);
 
-                    // mqk = mqk*scale + mask
+                    // mqk = mqk*scale + mask*slope
                     if (tiisg == 0) {
                         float4 mm = (float4) mp4[ic/4 + cc];
-                        mqk = mqk*scale + mm;
+                        mqk = mqk*scale + mm*slope;
 
                         ss4[cc] = mqk;
                     }
@@ -2840,7 +2817,8 @@ kernel void kernel_cpy_f32_f16(
     for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) {
         device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00);
 
-        dst_data[i00] = src[0];
+        // TODO: is there a better way to handle -INFINITY?
+        dst_data[i00] = src[0] == -INFINITY ? -MAXHALF : src[0];
     }
 }
 
index 79aec4d9f02e1ba48bfdf775c3acf9bfdaaac1a1..e93d2af631cc0809345609f6134860b860500f9d 100644 (file)
@@ -3154,7 +3154,6 @@ typedef float (*vec_dot_q_mul_mat_sycl_t)(
 #define SYCL_SCALE_BLOCK_SIZE 256
 #define SYCL_CLAMP_BLOCK_SIZE 256
 #define SYCL_ROPE_BLOCK_SIZE 256
-#define SYCL_ALIBI_BLOCK_SIZE 32
 #define SYCL_DIAG_MASK_INF_BLOCK_SIZE 32
 #define SYCL_QUANTIZE_BLOCK_SIZE 256
 #define SYCL_DEQUANTIZE_BLOCK_SIZE 256
@@ -9316,32 +9315,6 @@ static void rope_glm_f32(
     dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
 }
 
-static void alibi_f32(const float * x, float * dst, const int ncols, const int k_rows,
-                                 const int n_heads_log2_floor, const float m0, const float m1,
-                                 const sycl::nd_item<3> &item_ct1) {
-    const int col = item_ct1.get_local_range(2) * item_ct1.get_group(2) +
-                    item_ct1.get_local_id(2);
-
-    if (col >= ncols) {
-        return;
-    }
-
-    const int row = item_ct1.get_local_range(1) * item_ct1.get_group(1) +
-                    item_ct1.get_local_id(1);
-    const int i = row*ncols + col;
-
-    const int k = row/k_rows;
-
-    float m_k;
-    if (k < n_heads_log2_floor) {
-        m_k = dpct::pow(m0, k + 1);
-    } else {
-        m_k = dpct::pow(m1, 2 * (k - n_heads_log2_floor) + 1);
-    }
-
-    dst[i] = col * m_k + x[i];
-}
-
 static void k_sum_rows_f32(const float * x, float * dst, const int ncols,
                            const sycl::nd_item<3> &item_ct1) {
     const int row = item_ct1.get_group(1);
@@ -9443,7 +9416,7 @@ static void diag_mask_inf_f32(const float * x, float * dst, const int ncols, con
 
 
 template <bool vals_smem, int ncols_template, int block_size_template>
-static void soft_max_f32(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par,
+static void soft_max_f32(const float * x, const float * mask, float * dst, const int ncols_par,
                          const int nrows_y, const float scale, const float max_bias, const float m0,
                          const float m1, uint32_t n_head_log2, const sycl::nd_item<3> &item_ct1, float *buf) {
     const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
@@ -9457,7 +9430,7 @@ static void soft_max_f32(const float * x, const float * mask, const float *pos,
     const int warp_id = item_ct1.get_local_id(2) / WARP_SIZE;
     const int lane_id = item_ct1.get_local_id(2) % WARP_SIZE;
 
-    float slope = 0.0f;
+    float slope = 1.0f;
 
     // ALiBi
     if (max_bias > 0.0f) {
@@ -9482,7 +9455,7 @@ static void soft_max_f32(const float * x, const float * mask, const float *pos,
         const int ix = rowx*ncols + col;
         const int iy = rowy*ncols + col;
 
-        const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f);
+        const float val = x[ix]*scale + (mask ? slope*mask[iy] : 0.0f);
 
         vals[col] = val;
         max_val = sycl::max(max_val, val);
@@ -12964,20 +12937,6 @@ static void rope_glm_f32_sycl(const float *x, float *dst, int ncols, int nrows,
                          });
 }
 
-static void alibi_f32_sycl(const float *x, float *dst, const int ncols,
-                           const int nrows, const int k_rows,
-                           const int n_heads_log2_floor, const float m0,
-                           const float m1, dpct::queue_ptr stream) {
-    const sycl::range<3> block_dims(1, 1, SYCL_ALIBI_BLOCK_SIZE);
-    const int num_blocks_x = (ncols + SYCL_ALIBI_BLOCK_SIZE - 1) / (SYCL_ALIBI_BLOCK_SIZE);
-    const sycl::range<3> block_nums(1, nrows, num_blocks_x);
-    stream->parallel_for(sycl::nd_range<3>(block_nums * block_dims, block_dims),
-                         [=](sycl::nd_item<3> item_ct1) {
-                             alibi_f32(x, dst, ncols, k_rows,
-                                       n_heads_log2_floor, m0, m1, item_ct1);
-                         });
-}
-
 static void sum_rows_f32_sycl(const float *x, float *dst, const int ncols,
                               const int nrows, dpct::queue_ptr stream) {
     const sycl::range<3> block_dims(1, 1, WARP_SIZE);
@@ -13058,7 +13017,7 @@ static void diag_mask_inf_f32_sycl(const float *x, float *dst,
 }
 
 template <bool vals_smem, int ncols_template, int block_size_template>
-static void soft_max_f32_submitter(const float * x, const float * mask, const float *pos, float * dst, const int ncols_par,
+static void soft_max_f32_submitter(const float * x, const float * mask, float * dst, const int ncols_par,
                                    const int nrows_y, const float scale, const float max_bias, const float m0,
                                    const float m1, uint32_t n_head_log2, sycl::range<3> block_nums, sycl::range<3> block_dims,
                                    const size_t n_local_scratch, dpct::queue_ptr stream) {
@@ -13068,7 +13027,7 @@ static void soft_max_f32_submitter(const float * x, const float * mask, const fl
         cgh.parallel_for(
             sycl::nd_range<3>(block_nums * block_dims, block_dims),
             [=](sycl::nd_item<3> item_ct1) [[intel::reqd_sub_group_size(32)]] {
-                soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, pos, dst, ncols_par,
+                soft_max_f32<vals_smem, ncols_template, block_size_template>(x, mask, dst, ncols_par,
                                                                              nrows_y, scale, max_bias, m0,
                                                                              m1, n_head_log2, item_ct1,
                                                                              local_buf_acc.get_pointer());
@@ -13076,7 +13035,7 @@ static void soft_max_f32_submitter(const float * x, const float * mask, const fl
     });
 }
 
-static void soft_max_f32_sycl(const float * x, const float * mask, const float * pos,
+static void soft_max_f32_sycl(const float * x, const float * mask,
                               float * dst, const int ncols_x, const int nrows_x,
                               const int nrows_y, const float scale, const float max_bias,
                               dpct::queue_ptr stream) {
@@ -13098,60 +13057,60 @@ static void soft_max_f32_sycl(const float * x, const float * mask, const float *
     const size_t local_mem_size = stream->get_device().get_info<sycl::info::device::local_mem_size>();
     if (n_local_scratch*sizeof(float) < local_mem_size) {
         if (ncols_x > max_block_size) {
-            soft_max_f32_submitter<true, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+            soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
                                                max_bias, m0, m1, n_head_log2, block_nums,
                                                block_dims, n_local_scratch, stream);
             return;
         }
         switch (ncols_x) {
             case 32:
-                soft_max_f32_submitter<true, 32, 32>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+                soft_max_f32_submitter<true, 32, 32>(x, mask, dst, ncols_x, nrows_y, scale,
                                                      max_bias, m0, m1, n_head_log2, block_nums,
                                                      block_dims, n_local_scratch, stream);
                 break;
             case 64:
-                soft_max_f32_submitter<true, 64, 64>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+                soft_max_f32_submitter<true, 64, 64>(x, mask, dst, ncols_x, nrows_y, scale,
                                                      max_bias, m0, m1, n_head_log2, block_nums,
                                                      block_dims, n_local_scratch, stream);
                 break;
             case 128:
-                soft_max_f32_submitter<true, 128, 128>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+                soft_max_f32_submitter<true, 128, 128>(x, mask, dst, ncols_x, nrows_y, scale,
                                                        max_bias, m0, m1, n_head_log2, block_nums,
                                                        block_dims, n_local_scratch, stream);
                 break;
             case 256:
-                soft_max_f32_submitter<true, 256, 256>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+                soft_max_f32_submitter<true, 256, 256>(x, mask, dst, ncols_x, nrows_y, scale,
                                                        max_bias, m0, m1, n_head_log2, block_nums,
                                                        block_dims, n_local_scratch, stream);
                 break;
             case 512:
-                soft_max_f32_submitter<true, 512, 512>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+                soft_max_f32_submitter<true, 512, 512>(x, mask, dst, ncols_x, nrows_y, scale,
                                                        max_bias, m0, m1, n_head_log2, block_nums,
                                                        block_dims, n_local_scratch, stream);
                 break;
             case 1024:
-                soft_max_f32_submitter<true, 1024, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+                soft_max_f32_submitter<true, 1024, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
                                                          max_bias, m0, m1, n_head_log2, block_nums,
                                                          block_dims, n_local_scratch, stream);
                 break;
             case 2048:
-                soft_max_f32_submitter<true, 2048, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+                soft_max_f32_submitter<true, 2048, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
                                                          max_bias, m0, m1, n_head_log2, block_nums,
                                                          block_dims, n_local_scratch, stream);
                 break;
             case 4096:
-                soft_max_f32_submitter<true, 4096, 1024>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+                soft_max_f32_submitter<true, 4096, 1024>(x, mask, dst, ncols_x, nrows_y, scale,
                                                          max_bias, m0, m1, n_head_log2, block_nums,
                                                          block_dims, n_local_scratch, stream);
                 break;
             default:
-                soft_max_f32_submitter<true, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+                soft_max_f32_submitter<true, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
                                                    max_bias, m0, m1, n_head_log2, block_nums,
                                                    block_dims, n_local_scratch, stream);
                 break;
         }
     } else {
-        soft_max_f32_submitter<false, 0, 0>(x, mask, pos, dst, ncols_x, nrows_y, scale,
+        soft_max_f32_submitter<false, 0, 0>(x, mask, dst, ncols_x, nrows_y, scale,
                                             max_bias, m0, m1, n_head_log2, block_nums,
                                             block_dims, WARP_SIZE, stream);
     }
@@ -14562,36 +14521,6 @@ inline void ggml_sycl_op_rope(const ggml_tensor *src0, const ggml_tensor *src1,
     (void) src1_dd;
 }
 
-inline void ggml_sycl_op_alibi(const ggml_tensor *src0, const ggml_tensor *src1,
-                               ggml_tensor *dst, const float *src0_dd,
-                               const float *src1_dd, float *dst_dd,
-                               const dpct::queue_ptr &main_stream) {
-
-    GGML_ASSERT(src0->type == GGML_TYPE_F32);
-    GGML_ASSERT( dst->type == GGML_TYPE_F32);
-
-    GGML_TENSOR_LOCALS_3(int64_t, ne0, src0, ne);
-    const int64_t nrows = ggml_nrows(src0);
-
-    //const int n_past = ((int32_t *) dst->op_params)[0];
-    const int n_head = ((int32_t *) dst->op_params)[1];
-    float max_bias;
-    memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
-
-    //GGML_ASSERT(ne01 + n_past == ne00);
-    GGML_ASSERT(n_head == ne02);
-
-    const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
-
-    const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
-    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
-
-    alibi_f32_sycl(src0_dd, dst_dd, ne00, nrows, ne01, n_heads_log2_floor, m0, m1, main_stream);
-
-    (void) src1;
-    (void) src1_dd;
-}
-
 static void ggml_sycl_op_pool2d(const ggml_tensor *src0,
                                 const ggml_tensor *src1, ggml_tensor *dst,
                                 const float *src0_dd, const float *src1_dd,
@@ -14746,12 +14675,9 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
-    const ggml_tensor * src2 = dst->src[2];
-
-#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 and src2 support")
+#pragma message("TODO: add ggml_sycl_op_soft_max() F16 src1 support")
 #pragma message("ref:  https://github.com/ggerganov/llama.cpp/pull/5021")
     GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
-    GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); // src2 contains positions and it is optional
 
     const int64_t ne00 = src0->ne[0];
     const int64_t nrows_x = ggml_nrows(src0);
@@ -14763,25 +14689,7 @@ inline void ggml_sycl_op_soft_max(const ggml_tensor *src0,
     memcpy(&scale, dst->op_params + 0, sizeof(float));
     memcpy(&max_bias, dst->op_params + 1, sizeof(float));
 
-    // positions tensor
-    float * src2_dd = nullptr;
-    sycl_pool_alloc<float> src2_f;
-
-    const bool use_src2 = src2 != nullptr;
-
-    if (use_src2) {
-        const bool src2_on_device = src2->backend == GGML_BACKEND_TYPE_GPU;
-
-        if (src2_on_device) {
-            ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) src2->extra;
-            src2_dd = (float *) src2_extra->data_device[g_main_device];
-        } else {
-            src2_dd = src2_f.alloc(ggml_nelements(src2));
-            SYCL_CHECK(ggml_sycl_cpy_tensor_2d(src2_dd, src2, 0, 0, 0, 1, main_stream));
-        }
-    }
-
-    soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, src2_dd, dst_dd, ne00,
+    soft_max_f32_sycl(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00,
                       nrows_x, nrows_y, scale, max_bias, main_stream);
 }
 
@@ -16232,10 +16140,6 @@ static void ggml_sycl_rope(const ggml_tensor * src0, const ggml_tensor * src1, g
     ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_rope);
 }
 
-static void ggml_sycl_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
-    ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_alibi);
-}
-
 static void ggml_sycl_pool2d(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     ggml_sycl_op_flatten(src0, src1, dst, ggml_sycl_op_pool2d);
 }
@@ -16612,9 +16516,6 @@ bool ggml_sycl_compute_forward(struct ggml_compute_params * params, struct ggml_
         case GGML_OP_ROPE:
             func = ggml_sycl_rope;
             break;
-        case GGML_OP_ALIBI:
-            func = ggml_sycl_alibi;
-            break;
         case GGML_OP_IM2COL:
             func = ggml_sycl_im2col;
             break;
@@ -17744,7 +17645,6 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
         case GGML_OP_DIAG_MASK_INF:
         case GGML_OP_SOFT_MAX:
         case GGML_OP_ROPE:
-        case GGML_OP_ALIBI:
         case GGML_OP_IM2COL:
         case GGML_OP_POOL_2D:
         case GGML_OP_SUM_ROWS:
index 95f71897405397e0f4757c88325993d9913bf33f..b9449be0357b511a2776d9eac05d32cd5abb272e 100644 (file)
@@ -3830,9 +3830,8 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
         return nullptr;
     case GGML_OP_SOFT_MAX:
         GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16);
-        GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32 || src2->type == GGML_TYPE_F16);
 
-        if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && (src2 == nullptr || src2->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
+        if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
             return ctx->device->pipeline_soft_max_f32;
         }
         if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && src2->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
@@ -4286,6 +4285,9 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context * subctx,
     const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
+#pragma message("TODO: src2 is no longer used in soft_max - should be removed and ALiBi calculation should be updated")
+#pragma message("ref:  https://github.com/ggerganov/llama.cpp/pull/7192")
+
     ggml_vk_op_f32<vk_op_soft_max_push_constants>(ctx, subctx, src0, src1, src2, dst, GGML_OP_SOFT_MAX, {
         ncols,
         src1 != nullptr ? nrows_y : (uint32_t)0,
diff --git a/ggml.c b/ggml.c
index 093d38d00ae3528d7305353e9e4d98ea5f5a7fd6..4ee5d24af75d001ab5a0b47369f07a9acaa2d13c 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -2185,7 +2185,6 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "SOFT_MAX_BACK",
     "ROPE",
     "ROPE_BACK",
-    "ALIBI",
     "CLAMP",
     "CONV_TRANSPOSE_1D",
     "IM2COL",
@@ -2227,7 +2226,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
     "CROSS_ENTROPY_LOSS_BACK",
 };
 
-static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
+static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
 
 static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "none",
@@ -2276,7 +2275,6 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "soft_max_back(x)",
     "rope(x)",
     "rope_back(x)",
-    "alibi(x)",
     "clamp(x)",
     "conv_transpose_1d(x)",
     "im2col(x)",
@@ -2318,7 +2316,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
     "cross_entropy_loss_back(x,y)",
 };
 
-static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77");
+static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76");
 
 static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
 
@@ -5646,7 +5644,6 @@ static struct ggml_tensor * ggml_soft_max_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * mask,
-        struct ggml_tensor  * pos,
         float                 scale,
         float                 max_bias,
         bool                  inplace) {
@@ -5660,18 +5657,8 @@ static struct ggml_tensor * ggml_soft_max_impl(
         GGML_ASSERT(mask->ne[1] >= a->ne[1]);
     }
 
-    if (pos) {
-        GGML_ASSERT(ggml_is_vector(pos));
-        GGML_ASSERT(pos->type == GGML_TYPE_F16 || pos->type == GGML_TYPE_F32);
-        GGML_ASSERT(pos->ne[0] == a->ne[0]);
-    }
-
-    if (pos && mask) {
-        GGML_ASSERT(pos->type == mask->type);
-    }
-
     if (max_bias > 0.0f) {
-        GGML_ASSERT(pos);
+        GGML_ASSERT(mask);
     }
 
     bool is_node = false;
@@ -5689,7 +5676,6 @@ static struct ggml_tensor * ggml_soft_max_impl(
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
     result->src[1] = mask;
-    result->src[2] = pos;
 
     return result;
 }
@@ -5697,23 +5683,22 @@ static struct ggml_tensor * ggml_soft_max_impl(
 struct ggml_tensor * ggml_soft_max(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, 0.0f, false);
+    return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, false);
 }
 
 struct ggml_tensor * ggml_soft_max_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, 0.0f, true);
+    return ggml_soft_max_impl(ctx, a, NULL, 1.0f, 0.0f, true);
 }
 
 struct ggml_tensor * ggml_soft_max_ext(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * mask,
-        struct ggml_tensor  * pos,
         float                 scale,
         float                 max_bias) {
-    return ggml_soft_max_impl(ctx, a, mask, pos, scale, max_bias, false);
+    return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false);
 }
 
 // ggml_soft_max_back
@@ -5928,37 +5913,6 @@ struct ggml_tensor * ggml_rope_back(
     return result;
 }
 
-// ggml_alibi
-
-struct ggml_tensor * ggml_alibi(
-        struct ggml_context * ctx,
-        struct ggml_tensor  * a,
-        int                   n_past,
-        int                   n_head,
-        float                 bias_max) {
-    GGML_ASSERT(n_past >= 0);
-    bool is_node = false;
-
-    if (a->grad) {
-        GGML_ASSERT(false); // TODO: implement backward
-        is_node = true;
-    }
-
-    // TODO: when implement backward, fix this:
-    //struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
-    struct ggml_tensor * result = ggml_view_tensor(ctx, a);
-
-    int32_t op_params[3] = { n_past, n_head };
-    memcpy(op_params + 2, &bias_max, sizeof(float));
-    ggml_set_op_params(result, op_params, sizeof(op_params));
-
-    result->op   = GGML_OP_ALIBI;
-    result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
-    result->src[0] = a;
-
-    return result;
-}
-
 // ggml_clamp
 
 struct ggml_tensor * ggml_clamp(
@@ -6486,9 +6440,11 @@ struct ggml_tensor * ggml_flash_attn_ext(
         struct ggml_tensor  * k,
         struct ggml_tensor  * v,
         struct ggml_tensor  * mask,
-        float                 scale) {
+        float                 scale,
+        float                 max_bias) {
     GGML_ASSERT(ggml_can_mul_mat(k, q));
     // TODO: check if vT can be multiplied by (k*qT)
+
     if (mask) {
         GGML_ASSERT(ggml_is_contiguous(mask));
         GGML_ASSERT(mask->ne[2] == 1);
@@ -6498,6 +6454,10 @@ struct ggml_tensor * ggml_flash_attn_ext(
         //GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
     }
 
+    if (max_bias > 0.0f) {
+        GGML_ASSERT(mask);
+    }
+
     bool is_node = false;
 
     if (q->grad || k->grad || v->grad) {
@@ -6508,7 +6468,7 @@ struct ggml_tensor * ggml_flash_attn_ext(
     int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
     struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
 
-    float params[] = { scale };
+    float params[] = { scale, max_bias };
     ggml_set_op_params(result, params, sizeof(params));
 
     result->op   = GGML_OP_FLASH_ATTN_EXT;
@@ -6528,7 +6488,7 @@ void ggml_flash_attn_ext_set_prec(
 
     const int32_t prec_i32 = (int32_t) prec;
 
-    ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos
+    ggml_set_op_params_i32(a, 2, prec_i32); // scale is on first pos, max_bias on second
 }
 
 // ggml_flash_ff
@@ -13333,7 +13293,6 @@ static void ggml_compute_forward_soft_max_f32(
 
     const struct ggml_tensor * src0 = dst->src[0];
     const struct ggml_tensor * src1 = dst->src[1];
-    const struct ggml_tensor * src2 = dst->src[2];
 
     assert(ggml_is_contiguous(dst));
     assert(ggml_are_same_shape(src0, dst));
@@ -13359,8 +13318,8 @@ static void ggml_compute_forward_soft_max_f32(
 
     // TODO: is this supposed to be ceil instead of floor?
     //       https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
-    const uint32_t n_head_kv   = ne02;
-    const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head_kv));
+    const uint32_t n_head      = ne02;
+    const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
 
     const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
@@ -13377,13 +13336,13 @@ static void ggml_compute_forward_soft_max_f32(
 
     float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
 
-    // when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching
-    ggml_fp16_t * pos_f16 = src2 ? (ggml_fp16_t *) src2->data : src0->data;
-    float       * pos_f32 = src2 ? (float       *) src2->data : src0->data;
-
-    const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16) || (src2 && src2->type == GGML_TYPE_F16);
+    const bool use_f16 = (src1 && src1->type == GGML_TYPE_F16);
 
     for (int i1 = ir0; i1 < ir1; i1++) {
+        // ALiBi
+        const uint32_t h = (i1/ne01)%ne02; // head
+        const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
+
         float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
         float * dp = (float *)((char *)  dst->data +  i1*dst->nb[1]);
 
@@ -13396,27 +13355,11 @@ static void ggml_compute_forward_soft_max_f32(
         if (mp_f32) {
             if (use_f16) {
                 for (int i = 0; i < nc; ++i) {
-                    wp[i] += GGML_FP16_TO_FP32(mp_f16[i]);
+                    wp[i] += slope*GGML_FP16_TO_FP32(mp_f16[i]);
                 }
             } else {
                 for (int i = 0; i < nc; ++i) {
-                    wp[i] += mp_f32[i];
-                }
-            }
-        }
-
-        // ALiBi bias
-        if (max_bias > 0.0f) {
-            const uint32_t h  = (i1/ne01)%ne02; // head
-            const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);
-
-            if (use_f16) {
-                for (int i = 0; i < nc; ++i) {
-                    wp[i] += slope*GGML_FP16_TO_FP32(pos_f16[i]);
-                }
-            } else {
-                for (int i = 0; i < nc; ++i) {
-                    wp[i] += slope*pos_f32[i];
+                    wp[i] += slope*mp_f32[i];
                 }
             }
         }
@@ -13578,178 +13521,6 @@ static void ggml_compute_forward_soft_max_back(
     }
 }
 
-// ggml_compute_forward_alibi
-
-static void ggml_compute_forward_alibi_f32(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    assert(params->ith == 0);
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
-    //const int n_past = ((int32_t *) dst->op_params)[0];
-    const int n_head = ((int32_t *) dst->op_params)[1];
-    float max_bias;
-    memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
-
-    const int64_t ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
-    const int64_t ne1 = src0->ne[1]; // seq_len_without_past
-    const int64_t ne2 = src0->ne[2]; // n_head -> this is k
-    //const int64_t ne3 = src0->ne[3]; // 1 -> bsz
-
-    const int64_t n  = ggml_nrows(src0);
-    const int64_t ne2_ne3 = n/ne1; // ne2*ne3
-
-    const size_t nb0 = src0->nb[0];
-    const size_t nb1 = src0->nb[1];
-    const size_t nb2 = src0->nb[2];
-    //const int nb3 = src0->nb[3];
-
-    GGML_ASSERT(nb0 == sizeof(float));
-    GGML_ASSERT(n_head == ne2);
-
-    // add alibi to src0 (KQ_scaled)
-    const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
-
-    const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
-    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
-
-    for (int64_t k = 0; k < ne2_ne3; k++) {
-        // TODO: k*nb2 or k*nb3
-        float m_k;
-
-        if (k < n_heads_log2_floor) {
-            m_k = powf(m0, k + 1);
-        } else {
-            m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
-        }
-
-        for (int64_t i = 0; i < ne0; i++) {
-            for (int64_t j = 0; j < ne1; j++) {
-                float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
-                float *      pdst = (float *)((char *)  dst->data + i*nb0 + j*nb1 + k*nb2);
-                pdst[0] = i * m_k + src[0];
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_alibi_f16(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    assert(params->ith == 0);
-
-    if (params->type == GGML_TASK_TYPE_INIT || params->type == GGML_TASK_TYPE_FINALIZE) {
-        return;
-    }
-
-    //const int n_past = ((int32_t *) dst->op_params)[0];
-    const int n_head = ((int32_t *) dst->op_params)[1];
-    float max_bias;
-    memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
-
-    const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
-    const int ne1 = src0->ne[1]; // seq_len_without_past
-    const int ne2 = src0->ne[2]; // n_head -> this is k
-    //const int ne3 = src0->ne[3]; // 1 -> bsz
-
-    const int n  = ggml_nrows(src0);
-    const int ne2_ne3 = n/ne1; // ne2*ne3
-
-    const int nb0 = src0->nb[0];
-    const int nb1 = src0->nb[1];
-    const int nb2 = src0->nb[2];
-    //const int nb3 = src0->nb[3];
-
-    GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
-    //GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
-    GGML_ASSERT(n_head == ne2);
-
-    // add alibi to src0 (KQ_scaled)
-    const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
-
-    const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
-    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
-
-    for (int k = 0; k < ne2_ne3; k++) {
-        // TODO: k*nb2 or k*nb3
-        float m_k;
-
-        if (k < n_heads_log2_floor) {
-            m_k = powf(m0, k + 1);
-        } else {
-            m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
-        }
-
-        for (int i = 0; i < ne0; i++) {
-            for (int j = 0; j < ne1; j++) {
-                ggml_fp16_t * const src  = (ggml_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
-                float       *      pdst  =       (float *)((char *)  dst->data + i*nb0 + j*nb1 + k*nb2);
-
-                // we return F32
-                pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]);
-            }
-        }
-    }
-}
-
-static void ggml_compute_forward_alibi(
-        const struct ggml_compute_params * params,
-        struct ggml_tensor * dst) {
-
-    const struct ggml_tensor * src0 = dst->src[0];
-
-    switch (src0->type) {
-        case GGML_TYPE_F16:
-            {
-                ggml_compute_forward_alibi_f16(params, dst);
-            } break;
-        case GGML_TYPE_F32:
-            {
-                ggml_compute_forward_alibi_f32(params, dst);
-            } break;
-        case GGML_TYPE_BF16:
-        case GGML_TYPE_Q4_0:
-        case GGML_TYPE_Q4_1:
-        case GGML_TYPE_Q5_0:
-        case GGML_TYPE_Q5_1:
-        case GGML_TYPE_Q8_0:
-        case GGML_TYPE_Q8_1:
-        case GGML_TYPE_Q2_K:
-        case GGML_TYPE_Q3_K:
-        case GGML_TYPE_Q4_K:
-        case GGML_TYPE_Q5_K:
-        case GGML_TYPE_Q6_K:
-        case GGML_TYPE_IQ2_XXS:
-        case GGML_TYPE_IQ2_XS:
-        case GGML_TYPE_IQ3_XXS:
-        case GGML_TYPE_IQ1_S:
-        case GGML_TYPE_IQ1_M:
-        case GGML_TYPE_IQ4_NL:
-        case GGML_TYPE_IQ4_XS:
-        case GGML_TYPE_IQ3_S:
-        case GGML_TYPE_IQ2_S:
-        case GGML_TYPE_Q8_K:
-        case GGML_TYPE_I8:
-        case GGML_TYPE_I16:
-        case GGML_TYPE_I32:
-        case GGML_TYPE_I64:
-        case GGML_TYPE_F64:
-        case GGML_TYPE_COUNT:
-            {
-                GGML_ASSERT(false);
-            } break;
-    }
-}
-
 // ggml_compute_forward_clamp
 
 static void ggml_compute_forward_clamp_f32(
@@ -15763,8 +15534,17 @@ static void ggml_compute_forward_flash_attn_ext_f16(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
-    float scale = 1.0f;
-    memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
+    float scale    = 1.0f;
+    float max_bias = 0.0f;
+
+    memcpy(&scale,    (float *) dst->op_params + 0, sizeof(float));
+    memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
+
+    const uint32_t n_head      = neq2;
+    const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
+
+    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
+    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
     // loop over n_batch and n_head
     for (int ir = ir0; ir < ir1; ++ir) {
@@ -15773,6 +15553,9 @@ static void ggml_compute_forward_flash_attn_ext_f16(
         const int iq2 = (ir - iq3*neq2*neq1)/neq1;
         const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
 
+        const uint32_t h = iq2; // head
+        const float slope = (max_bias > 0.0f) ? h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1) : 1.0f;
+
         float S = 0.0f;
         float M = -INFINITY;
 
@@ -15796,7 +15579,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
         // loop over n_kv and n_head_kv
         // ref: https://arxiv.org/pdf/2112.05682.pdf
         for (int64_t ic = 0; ic < nek1; ++ic) {
-            const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
+            const float mv = mp ? slope*GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
             if (mv == -INFINITY) {
                 continue;
             }
@@ -15867,7 +15650,7 @@ static void ggml_compute_forward_flash_attn_ext(
         const struct ggml_tensor * v,
         const struct ggml_tensor * mask,
         struct ggml_tensor * dst) {
-    switch (dst->op_params[1]) {
+    switch (dst->op_params[2]) {
         case GGML_PREC_DEFAULT:
         case GGML_PREC_F32:
             {
@@ -17630,10 +17413,6 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             {
                 ggml_compute_forward_rope_back(params, tensor);
             } break;
-        case GGML_OP_ALIBI:
-            {
-                ggml_compute_forward_alibi(params, tensor);
-            } break;
         case GGML_OP_CLAMP:
             {
                 ggml_compute_forward_clamp(params, tensor);
@@ -18652,10 +18431,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
                             zero_table);
                 }
             } break;
-        case GGML_OP_ALIBI:
-            {
-                GGML_ASSERT(false); // TODO: not implemented
-            } break;
         case GGML_OP_CLAMP:
             {
                 GGML_ASSERT(false); // TODO: not implemented
@@ -19428,10 +19203,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_
             {
                 n_tasks = n_threads;
             } break;
-        case GGML_OP_ALIBI:
-            {
-                n_tasks = 1; //TODO
-            } break;
         case GGML_OP_CLAMP:
             {
                 n_tasks = 1; //TODO
diff --git a/ggml.h b/ggml.h
index fe605382206807ed3483f95f1c0e882123bec414..76c3328315964c73589f2f5295466b5155549eda 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -468,7 +468,6 @@ extern "C" {
         GGML_OP_SOFT_MAX_BACK,
         GGML_OP_ROPE,
         GGML_OP_ROPE_BACK,
-        GGML_OP_ALIBI,
         GGML_OP_CLAMP,
         GGML_OP_CONV_TRANSPOSE_1D,
         GGML_OP_IM2COL,
@@ -1428,15 +1427,13 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
-    // fused soft_max(a*scale + mask + pos[i]*(ALiBi slope))
+    // fused soft_max(a*scale + mask*(ALiBi slope))
     // mask is optional
-    // pos is required when max_bias > 0.0f
     // max_bias = 0.0f for no ALiBi
     GGML_API struct ggml_tensor * ggml_soft_max_ext(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * mask,
-            struct ggml_tensor  * pos,
             float                 scale,
             float                 max_bias);
 
@@ -1538,16 +1535,6 @@ extern "C" {
             float                 xpos_base,
             bool                  xpos_down);
 
-    // alibi position embedding
-    // in-place, returns view(a)
-    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_alibi(
-            struct ggml_context * ctx,
-            struct ggml_tensor  * a,
-            int                   n_past,
-            int                   n_head,
-            float                 bias_max),
-        "use ggml_soft_max_ext instead (will be removed in Mar 2024)");
-
     // clamp
     // in-place, returns view(a)
     GGML_API struct ggml_tensor * ggml_clamp(
@@ -1744,7 +1731,8 @@ extern "C" {
             struct ggml_tensor  * k,
             struct ggml_tensor  * v,
             struct ggml_tensor  * mask,
-            float                 scale);
+            float                 scale,
+            float                 max_bias);
 
     GGML_API void ggml_flash_attn_ext_set_prec(
             struct ggml_tensor * a,
index e5750d4191f6b1aa686bd573b82e929994254280..990fe63c2acd19cddd5c4cff6fa60b62df7c0153 100644 (file)
@@ -137,6 +137,7 @@ class TensorNameMap:
             "layers.{bid}.attention.wk",                               # llama-pth
             "encoder.layer.{bid}.attention.self.key",                  # bert
             "transformer.h.{bid}.attn.k_proj",                         # gpt-j
+            "transformer.h.{bid}.attn.k",                              # refact
             "model.layers.layers.{bid}.self_attn.k_proj",              # plamo
             "model.layers.{bid}.attention.wk",                         # internlm2
             "transformer.decoder_layer.{bid}.multi_head_attention.key" # Grok
@@ -148,6 +149,7 @@ class TensorNameMap:
             "layers.{bid}.attention.wv",                                 # llama-pth
             "encoder.layer.{bid}.attention.self.value",                  # bert
             "transformer.h.{bid}.attn.v_proj",                           # gpt-j
+            "transformer.h.{bid}.attn.v",                                # refact
             "model.layers.layers.{bid}.self_attn.v_proj",                # plamo
             "model.layers.{bid}.attention.wv",                           # internlm2
             "transformer.decoder_layer.{bid}.multi_head_attention.value" # Grok
@@ -229,6 +231,7 @@ class TensorNameMap:
             "layers.{bid}.feed_forward.w3",                           # llama-pth
             "encoder.layer.{bid}.intermediate.dense",                 # bert
             "transformer.h.{bid}.mlp.fc_in",                          # gpt-j
+            "transformer.h.{bid}.mlp.linear_3",                       # refact
             "language_model.encoder.layers.{bid}.mlp.dense_h_to_4h",  # persimmon
             "model.layers.{bid}.mlp.dense_h_to_4h",                   # persimmon
             "transformer.h.{bid}.mlp.w1",                             # qwen
@@ -266,6 +269,7 @@ class TensorNameMap:
             "model.layers.layers.{bid}.mlp.gate_proj",    # plamo
             "model.layers.{bid}.feed_forward.w1",         # internlm2
             "encoder.layers.{bid}.mlp.fc12",              # nomic-bert
+            "transformer.h.{bid}.mlp.linear_1",           # refact
         ),
 
         MODEL_TENSOR.FFN_GATE_EXP: (
index 2f1123d4e16788ccebeaf1cf73a49eca9e74f765..dede68cb5a3f9bf4c55897b9fb21ccd9fba15c11 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1845,7 +1845,7 @@ struct llama_hparams {
     float f_logit_scale    = 0.0f;
 
     bool causal_attn = true;
-    bool use_alibi   = false; // currently, we need KQ_pos data for ALiBi-based models
+    bool use_alibi   = false;
 
     enum llama_pooling_type      pooling_type            = LLAMA_POOLING_TYPE_NONE;
     enum llama_rope_type         rope_type               = LLAMA_ROPE_TYPE_NONE;
@@ -2317,7 +2317,6 @@ struct llama_context {
     struct ggml_tensor * inp_pos;       // I32 [n_batch]
     struct ggml_tensor * inp_out_ids;   // I32 [n_outputs]
     struct ggml_tensor * inp_KQ_mask;   // F32 [kv_size, n_batch]
-    struct ggml_tensor * inp_KQ_pos;    // F32 [n_kv]
     struct ggml_tensor * inp_K_shift;   // I32 [kv_size]
     struct ggml_tensor * inp_mean;      // F32 [n_batch, n_batch]
     struct ggml_tensor * inp_cls;       // I32 [n_batch]
@@ -6500,7 +6499,6 @@ static struct ggml_tensor * llm_build_kqv(
          struct ggml_tensor * wo_b,
          struct ggml_tensor * q_cur,
          struct ggml_tensor * kq_mask,
-         struct ggml_tensor * kq_pos,
                     int32_t   n_tokens,
                     int32_t   n_kv,
                     float     kq_scale,
@@ -6530,10 +6528,6 @@ static struct ggml_tensor * llm_build_kqv(
         GGML_UNUSED(model);
         GGML_UNUSED(n_ctx);
 
-        // note: if this assert triggers, then some check has failed earlier
-        //       the idea is to detect during context creation that ALiBi would be used and disable Flash Attention
-        GGML_ASSERT(kq_pos == nullptr && "ALiBi is not yet supported with Flash Attention");
-
         // split cached v into n_head heads (not transposed)
         struct ggml_tensor * v =
             ggml_view_3d(ctx, kv.v_l[il],
@@ -6543,7 +6537,7 @@ static struct ggml_tensor * llm_build_kqv(
                     0);
         cb(v, "v", il);
 
-        cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale);
+        cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale, hparams.f_max_alibi_bias);
 
         if (model.arch == LLM_ARCH_PHI2 || model.arch == LLM_ARCH_PHI3) {
             ggml_flash_attn_ext_set_prec(cur, GGML_PREC_F32);
@@ -6574,28 +6568,8 @@ static struct ggml_tensor * llm_build_kqv(
             kq = ggml_scale(ctx, kq, 30);
         }
 
-#if defined(GGML_USE_KOMPUTE)
-#pragma message("TODO: ALiBi support in ggml_soft_max_ext is not implemented for Kompute")
-#pragma message("      Falling back to ggml_alibi(). Will become an error in Mar 2024")
-#pragma message("ref:  https://github.com/ggerganov/llama.cpp/pull/5488")
-        if (hparams.use_alibi) {
-            kq = ggml_scale(ctx, kq, kq_scale);
-            cb(kq, "kq_scaled", il);
-
-            kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, hparams.f_max_alibi_bias);
-            cb(kq, "kq_scaled_alibi", il);
-
-            kq = ggml_add(ctx, kq, kq_mask);
-            cb(kq, "kq_masked", il);
-
-            kq = ggml_soft_max(ctx, kq);
-            cb(kq, "kq_soft_max", il);
-        } else
-#endif
-        {
-            kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_pos, kq_scale, hparams.f_max_alibi_bias);
-            cb(kq, "kq_soft_max_ext", il);
-        }
+        kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, hparams.f_max_alibi_bias);
+        cb(kq, "kq_soft_max_ext", il);
 
         GGML_ASSERT(kv.size == n_ctx);
 
@@ -6645,7 +6619,6 @@ static struct ggml_tensor * llm_build_kv(
          struct ggml_tensor * v_cur,
          struct ggml_tensor * q_cur,
          struct ggml_tensor * kq_mask,
-         struct ggml_tensor * kq_pos,
                     int32_t   n_tokens,
                     int32_t   kv_head,
                     int32_t   n_kv,
@@ -6664,7 +6637,7 @@ static struct ggml_tensor * llm_build_kv(
     struct ggml_tensor * cur;
 
     cur  = llm_build_kqv(ctx, model, hparams, cparams, kv, graph, wo, wo_b,
-            q_cur, kq_mask, kq_pos, n_tokens, n_kv, kq_scale, cb, il);
+            q_cur, kq_mask, n_tokens, n_kv, kq_scale, cb, il);
     cb(cur, "kqv_out", il);
 
     return cur;
@@ -6771,18 +6744,17 @@ struct llm_build_context {
 
         ctx0 = ggml_init(params);
 
-        lctx.inp_tokens = nullptr;
-        lctx.inp_embd = nullptr;
-        lctx.inp_pos = nullptr;
+        lctx.inp_tokens  = nullptr;
+        lctx.inp_embd    = nullptr;
+        lctx.inp_pos     = nullptr;
         lctx.inp_out_ids = nullptr;
         lctx.inp_KQ_mask = nullptr;
-        lctx.inp_KQ_pos = nullptr;
         lctx.inp_K_shift = nullptr;
-        lctx.inp_mean = nullptr;
-        lctx.inp_cls = nullptr;
-        lctx.inp_s_copy = nullptr;
-        lctx.inp_s_mask = nullptr;
-        lctx.inp_s_seq = nullptr;
+        lctx.inp_mean    = nullptr;
+        lctx.inp_cls     = nullptr;
+        lctx.inp_s_copy  = nullptr;
+        lctx.inp_s_mask  = nullptr;
+        lctx.inp_s_seq   = nullptr;
     }
 
     void free() {
@@ -6932,19 +6904,6 @@ struct llm_build_context {
         return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16) : lctx.inp_KQ_mask;
     }
 
-    struct ggml_tensor * build_inp_KQ_pos(bool causal = true) {
-        if (causal) {
-            lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_kv);
-        } else {
-            // TODO: this will be needed for ALiBi-based BERT models
-            //       https://github.com/ggerganov/llama.cpp/pull/6826
-            lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_tokens);
-        }
-        cb(lctx.inp_KQ_pos, "KQ_pos", -1);
-        ggml_set_input(lctx.inp_KQ_pos);
-        return flash_attn ? ggml_cast(ctx0, lctx.inp_KQ_pos, GGML_TYPE_F16) : lctx.inp_KQ_pos;
-    }
-
     struct ggml_tensor * build_inp_mean() {
         lctx.inp_mean = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens);
         cb(lctx.inp_mean, "inp_mean", -1);
@@ -7050,7 +7009,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -7143,9 +7102,6 @@ struct llm_build_context {
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
         struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
 
-        // positions of the tokens in the KV cache
-        struct ggml_tensor * KQ_pos = build_inp_KQ_pos();
-
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * inpSA = inpL;
 
@@ -7190,7 +7146,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -7260,9 +7216,6 @@ struct llm_build_context {
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
         struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
 
-        // positions of the tokens in the KV cache
-        struct ggml_tensor * KQ_pos = build_inp_KQ_pos();
-
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * inpSA = inpL;
 
@@ -7297,7 +7250,7 @@ struct llm_build_context {
                 cb(Kcur, "Kcur", il);
                 cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -7417,7 +7370,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -7542,7 +7495,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -7694,7 +7647,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -7806,7 +7759,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -8010,7 +7963,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Q, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Q, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -8076,9 +8029,6 @@ struct llm_build_context {
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
         struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
 
-        // positions of the tokens in the KV cache
-        struct ggml_tensor * KQ_pos = build_inp_KQ_pos();
-
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * inpSA = inpL;
 
@@ -8106,7 +8056,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -8246,7 +8196,7 @@ struct llm_build_context {
             struct ggml_tensor * kq = ggml_mul_mat(ctx0, k, q);
             cb(kq, "kq", il);
 
-            kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, nullptr, 1.0f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias);
+            kq = ggml_soft_max_ext(ctx0, kq, KQ_mask, 1.0f/sqrtf(float(n_embd_head)), hparams.f_max_alibi_bias);
             cb(kq, "kq_soft_max_ext", il);
 
             struct ggml_tensor * v = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_2d(ctx0, Vcur, n_embd_gqa, n_tokens)));
@@ -8363,9 +8313,6 @@ struct llm_build_context {
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
         struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
 
-        // positions of the tokens in the KV cache
-        struct ggml_tensor * KQ_pos = build_inp_KQ_pos();
-
         inpL = llm_build_norm(ctx0, inpL, hparams,
                 model.tok_norm,
                 model.tok_norm_b,
@@ -8399,7 +8346,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -8464,9 +8411,6 @@ struct llm_build_context {
         // KQ_mask (mask for 1 head, it will be broadcasted to all heads)
         struct ggml_tensor * KQ_mask = build_inp_KQ_mask();
 
-        // positions of the tokens in the KV cache
-        struct ggml_tensor * KQ_pos = build_inp_KQ_pos();
-
         if (model.pos_embd) {
             // inp_pos - contains the positions
             struct ggml_tensor * inp_pos = build_inp_pos();
@@ -8530,13 +8474,13 @@ struct llm_build_context {
 
                     cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                             model.layers[il].wo, model.layers[il].bo,
-                            Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                            Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 } else {
                     Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
 
                     cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                             model.layers[il].wo, model.layers[il].bo,
-                            Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                            Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 }
             }
 
@@ -8680,7 +8624,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -8798,7 +8742,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -8911,7 +8855,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -9025,7 +8969,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -9180,7 +9124,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -9297,7 +9241,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -9410,7 +9354,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
             struct ggml_tensor * sa_out = cur;
 
@@ -9513,7 +9457,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -9620,7 +9564,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -9736,7 +9680,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -9853,7 +9797,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -9983,7 +9927,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -10104,7 +10048,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f, cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f, cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -10223,7 +10167,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -10513,7 +10457,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -10644,7 +10588,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, cparams, kv_self, gf,
                         model.layers[il].wo, nullptr,
-                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
             }
 
             if (il == n_layer - 1) {
@@ -11032,11 +10976,21 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
                         if (!lctx.kv_self.cells[i].has_seq_id(seq_id) || lctx.kv_self.cells[i].pos > pos) {
                             f = -INFINITY;
                         } else {
-                            f = 0.0f;
+                            if (hparams.use_alibi) {
+                                f = -fabs(lctx.kv_self.cells[i].pos - pos);
+                            } else {
+                                f = 0.0f;
+                            }
                         }
                         data[h*(n_kv*n_tokens) + j*n_kv + i] = f;
                     }
                 }
+
+                for (int i = n_tokens; i < GGML_PAD(n_tokens, GGML_KQ_MASK_PAD); ++i) {
+                    for (int j = 0; j < n_kv; ++j) {
+                        data[h*(n_kv*n_tokens) + i*n_kv + j] = -INFINITY;
+                    }
+                }
             }
         } else {
             // when using kv cache, the mask needs to match the kv cache size
@@ -11055,7 +11009,11 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
                         float f = -INFINITY;
                         for (int s = 0; s < batch.n_seq_id[i]; ++s) {
                             if (batch.seq_id[i][s] == seq_id) {
-                                f = 0.0f;
+                                if (hparams.use_alibi) {
+                                    f = -fabs(batch.pos[i] - batch.pos[j]);
+                                } else {
+                                    f = 0.0f;
+                                }
                                 break;
                             }
                         }
@@ -11071,21 +11029,6 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         }
     }
 
-    // ALiBi requires the KQ_pos tensor to provide the sequence position of each token in the batch
-    // this allows to process multiple sequences in parallel with ALiBi-based models
-    if (hparams.use_alibi) {
-        const int64_t n_kv = kv_self.n;
-
-        GGML_ASSERT(lctx.inp_KQ_pos);
-        GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_KQ_pos->buffer));
-
-        float * data = (float *) lctx.inp_KQ_pos->data;
-
-        for (int i = 0; i < n_kv; ++i) {
-            data[i] = float(lctx.kv_self.cells[i].pos);
-        }
-    }
-
     if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) {
         const int64_t n_tokens = batch.n_tokens;
 
@@ -15509,11 +15452,6 @@ struct llama_context * llama_new_context_with_model(
         }
     }
 
-    if (cparams.flash_attn && hparams.use_alibi) {
-        LLAMA_LOG_WARN("%s: flash_attn is not yet compatible with ALiBi - forcing off\n", __func__);
-        cparams.flash_attn = false;
-    }
-
     if (cparams.flash_attn && model->arch == LLM_ARCH_GROK) {
         LLAMA_LOG_WARN("%s: flash_attn is not compatible with Grok - forcing off\n", __func__);
         cparams.flash_attn = false;
index 0d66de5d9d389e385f94520c5f8fb35ef470ed03..731788b958fca7b8c493e32047fd82da8d30dbb8 100644 (file)
@@ -1111,11 +1111,7 @@ struct test_soft_max : public test_case {
         if (this->mask) {
             mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, ne[0], ne[1]);
         }
-        ggml_tensor * pos = nullptr;
-        if (max_bias > 0.0f) {
-            pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[0]);
-        }
-        ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, pos, scale, max_bias);
+        ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, scale, max_bias);
         return out;
     }
 };
@@ -1490,23 +1486,25 @@ struct test_flash_attn_ext : public test_case {
     const int64_t kv; // kv size
     const int64_t nb; // batch size
 
+    const float max_bias; // ALiBi
+
     std::string vars() override {
-        return VARS_TO_STR4(hs, nh, kv, nb);
+        return VARS_TO_STR5(hs, nh, kv, nb, max_bias);
     }
 
     double max_nmse_err() override {
         return 5e-4;
     }
 
-    test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8)
-        : hs(hs), nh(nh), kv(kv), nb(nb) {}
+    test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8, float max_bias = 0.0f)
+        : hs(hs), nh(nh), kv(kv), nb(nb), max_bias(max_bias) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1);
         ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
         ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
         ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1);
-        ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs));
+        ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs), max_bias);
         return out;
     }
 };
@@ -1611,7 +1609,7 @@ public:
 
         struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
 
-        kq = ggml_soft_max_ext(ctx, kq, kq_mask, nullptr, kq_scale, 0.0f);
+        kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale, 0.0f);
 
         // split cached v into n_head heads
         struct ggml_tensor * v =
@@ -2128,6 +2126,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
 #endif
     for (bool mask : {false, true}) {
         for (float max_bias : {0.0f, 8.0f}) {
+            if (!mask && max_bias > 0.0f) continue;
             for (float scale : {1.0f, 0.1f}) {
                 for (int64_t ne0 : {16, 1024}) {
                     for (int64_t ne1 : {16, 1024}) {
@@ -2141,7 +2140,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
 
     test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 0.0f));
     test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  0.1f, 0.0f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 8.0f));
     test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  0.1f, 8.0f));
 
     for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
@@ -2180,10 +2178,12 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
 #else
     for (int hs : { 64, 80, 128, 256, }) {
 #endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
-        for (int nh : { 32, }) {
-            for (int kv : { 512, 1024, }) {
-                for (int nb : { 1, 2, 4, 8, }) {
-                    test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb));
+        for (float max_bias : {0.0f, 8.0f}) {
+            for (int nh : { 32, }) {
+                for (int kv : { 512, 1024, }) {
+                    for (int nb : { 1, 2, 4, 8, }) {
+                        test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb, max_bias));
+                    }
                 }
             }
         }