]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : add ALiBi support for ggml_soft_max_ext (#5488)
authorGeorgi Gerganov <redacted>
Sat, 17 Feb 2024 21:04:16 +0000 (23:04 +0200)
committerGitHub <redacted>
Sat, 17 Feb 2024 21:04:16 +0000 (23:04 +0200)
* ggml : avoid recomputing alibi slopes (CPU)

* llama : reuse hparams.f_max_alibi_bias in all cases

ggml-ci

* ggml : support alibi bias in ggml_soft_max_ext (CPU + Metal)

ggml-ci

* ggml : handle all SRCs (do not break on first null)

ggml-ci

* tests : do not use slope for large soft_max

accumulates too much error

ggml-ci

* ggml : alternative ALiBi without extra tensor

We compute the slopes in the kernel

ggml-ci

* cuda : add ALiBi support in ggml_soft_max_ext

ggml-ci

* ggml : deprecate ggml_alibi

* ggml : support multi-sequence ALiBi (Metal)

ggml-ci

* cuda : add multi-seq ALiBi + remote F16 soft_max

ggml-ci

* ggml : update deprecation message

* ggml : fix pos ptr when no ALiBi

ggml-ci

* cuda : fix performance (pow -> powf)

* cuda : precompute ALiBi constants

* metal : pre-compute ALiBi slopes

ggml-ci

* llama : init kq_pos only if needed

ggml-ci

* test-backend-ops : add null pos test to soft_max

test-backend-ops : replace soft_max tests

ggml-ci

---------

Co-authored-by: slaren <redacted>
ggml-alloc.c
ggml-backend.c
ggml-cuda.cu
ggml-metal.m
ggml-metal.metal
ggml.c
ggml.h
llama.cpp
tests/test-backend-ops.cpp

index c28c37c4fd9ffb52bdecd686f547fd43768773ba..d4123564ff0d1786555321581d4ce94ea3cc0232 100644 (file)
@@ -551,7 +551,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr
         }
         for (int j = 0; j < GGML_MAX_SRC; j++) {
             if (graph->nodes[i]->src[j] == NULL) {
-                break;
+                continue;
             }
             if (graph->nodes[i]->src[j]->flags & GGML_TENSOR_FLAG_INPUT) {
                 ggml_gallocr_allocate_node(galloc, graph->nodes[i]->src[j], get_node_buffer_id(node_buffer_ids, i));
@@ -787,7 +787,7 @@ static bool ggml_gallocr_needs_realloc(ggml_gallocr_t galloc, struct ggml_cgraph
         for (int j = 0; j < GGML_MAX_SRC; j++) {
             struct ggml_tensor * src = node->src[j];
             if (src == NULL) {
-                break;
+                continue;
             }
             if (!ggml_gallocr_node_needs_realloc(galloc, src, node_alloc, &node_alloc->src[j])) {
 #ifndef NDEBUG
@@ -833,7 +833,7 @@ bool ggml_gallocr_alloc_graph(ggml_gallocr_t galloc, struct ggml_cgraph * graph)
         for (int j = 0; j < GGML_MAX_SRC; j++) {
             struct ggml_tensor * src = node->src[j];
             if (src == NULL) {
-                break;
+                continue;
             }
             ggml_gallocr_init_tensor(galloc, src, node_alloc, &node_alloc->src[j]);
         }
index d019d813ad5f070f5745515fbf9d8c9172172cd4..66e8c293a9e3fc5bf454228c5577771a26336a88 100644 (file)
@@ -1041,7 +1041,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st
     for (int i = 0; i < GGML_MAX_SRC; i++) {
         const struct ggml_tensor * src = tensor->src[i];
         if (src == NULL) {
-            break;
+            continue;
         }
         if (src->buffer != NULL && src->buffer->usage == GGML_BACKEND_BUFFER_USAGE_WEIGHTS) {
             int src_backend = ggml_backend_sched_backend_from_buffer(sched, src->buffer);
@@ -1088,7 +1088,7 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str
         for (int j = 0; j < GGML_MAX_SRC; j++) {
             struct ggml_tensor * src = node->src[j];
             if (src == NULL) {
-                break;
+                continue;
             }
             ggml_backend_t src_backend = tensor_backend(src);
             fprintf(stderr, " %20.20s (%5.5s) [%5.5s %8.8s]", src->name,
@@ -1144,7 +1144,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
         for (int j = 0; j < GGML_MAX_SRC; j++) {
             struct ggml_tensor * src = node->src[j];
             if (src == NULL) {
-                break;
+                continue;
             }
             if (tensor_backend_id(src) == -1) {
                 tensor_backend_id(src) = ggml_backend_sched_backend_id_from_cur(sched, src);
@@ -1256,7 +1256,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
         for (int j = 0; j < GGML_MAX_SRC; j++) {
             struct ggml_tensor * src = node->src[j];
             if (src == NULL) {
-                break;
+                continue;
             }
             int src_backend_id = tensor_backend_id(src);
             if (src_backend_id == -1) {
@@ -1315,7 +1315,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
             for (int j = 0; j < GGML_MAX_SRC; j++) {
                 struct ggml_tensor * src = node->src[j];
                 if (src == NULL) {
-                    break;
+                    continue;
                 }
                 int src_backend_id = tensor_backend_id(src);
                 assert(src_backend_id != -1); // all inputs should be assigned by now
@@ -1362,7 +1362,7 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg
         for (int j = 0; j < GGML_MAX_SRC; j++) {
             struct ggml_tensor * src = node->src[j];
             if (src == NULL) {
-                break;
+                continue;
             }
             ggml_backend_t src_backend = tensor_backend(src);
             if (src_backend != tensor_backend /* && src_backend != NULL */) {
@@ -1668,7 +1668,7 @@ static struct ggml_tensor * graph_copy_dup_tensor(struct ggml_hash_set hash_set,
     for (int i = 0; i < GGML_MAX_SRC; i++) {
         struct ggml_tensor * s = src->src[i];
         if (s == NULL) {
-            break;
+            continue;
         }
         dst->src[i] = graph_copy_dup_tensor(hash_set, node_copies, ctx_allocated, ctx_unallocated, s);
     }
@@ -1697,7 +1697,7 @@ static void graph_copy_init_tensor(struct ggml_hash_set hash_set, struct ggml_te
     for (int i = 0; i < GGML_MAX_SRC; i++) {
         struct ggml_tensor * s = src->src[i];
         if (s == NULL) {
-            break;
+            continue;
         }
         graph_copy_init_tensor(hash_set, node_copies, node_init, s);
     }
index b35fcb7fdb5d2aa516ebad87e197e5a43225fea6..5fd8a87e4150f0f06bd23884e94a81b1ebc632e2 100644 (file)
@@ -5956,148 +5956,30 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
     dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
 }
 
-template <bool vals_smem, int ncols_template, int block_size_template, bool need_check>
-static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
-#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
-    const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template;
-    const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2;
+template <bool vals_smem, int ncols_template, int block_size_template>
+static __global__ void soft_max_f32(const float * x, const float * mask, const float * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
+    const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
 
     const int tid  = threadIdx.x;
     const int rowx = blockIdx.x;
-    const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
+    const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
 
     const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
 
     const int warp_id = threadIdx.x / WARP_SIZE;
     const int lane_id = threadIdx.x % WARP_SIZE;
 
-    extern __shared__ half data_soft_max_f16[];
-    half  * buf_iw = data_soft_max_f16 + 0; // shared memory buffer for inter-warp communication
-    // (shared memory) buffer to cache values between iterations:
-    half2 * vals = vals_smem ? (half2 *) (buf_iw + WARP_SIZE) : (half2 *) (dst + rowx*ncols_data);
-    // if the buffer is larger than max. shared memory per block, use dst as temp. buffer instead
-    // in that case col_smem == col_data must be enforced to avoid race conditions
-
-    half2 max_val = make_half2(-INFINITY, -INFINITY);
-
-#pragma unroll
-    for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
-        const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id;
-        const int col_smem = vals_smem ? col0 + tid : col_data;
-
-        const int ix = rowx*ncols_data + col_data;
-        const int iy = rowy*ncols_data + col_data;
-
-        half2 val;
-        if (need_check && col_data + 0 >= ncols_data) {
-            val.x = -INFINITY;
-        } else {
-            val.x = x[ix + 0]*scale + (y ? y[iy + 0] : 0.0f);
-        }
-        if (need_check && col_data + WARP_SIZE >= ncols_data) {
-            val.y = -INFINITY;
-        } else {
-            val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f);
-        }
-        if (!need_check || col_smem < (vals_smem ? ncols_smem : ncols_data)) {
-            vals[col_smem] = val;
-        }
-        max_val = __hmax2(max_val, val);
-    }
-
-    // find the max value in the block
-    max_val = warp_reduce_max(max_val);
-    if (block_size > WARP_SIZE) {
-        if (warp_id == 0) {
-            buf_iw[lane_id] = -INFINITY;
-        }
-        __syncthreads();
-
-        if (lane_id == 0) {
-            buf_iw[warp_id] = __hmax(max_val.x, max_val.y);
-        }
-        __syncthreads();
-
-        max_val = __half2half2(buf_iw[lane_id]);
-        max_val = warp_reduce_max(max_val);
-    } else {
-        max_val = __half2half2(__hmax(max_val.x, max_val.y));
-    }
-
-    half2 tmp = make_half2(0.0f, 0.0f); // partial sums
+    float slope = 0.0f;
 
-#pragma unroll
-    for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
-        const int col_smem = vals_smem ? col0 + tid : 2*col0 + 2*warp_id*WARP_SIZE + lane_id;
-
-        if (ncols_template == 0 && col_smem >= (vals_smem ? ncols_smem : ncols_data)) {
-            break;
-        }
-
-        const half2 val = h2exp(vals[col_smem] - max_val);
-
-        tmp += val;
-        vals[col_smem] = val;
-    }
-
-    // find the sum of exps in the block
-    tmp = warp_reduce_sum(tmp);
-    if (block_size > WARP_SIZE) {
-        if (warp_id == 0) {
-            buf_iw[lane_id] = 0.0f;
-        }
-        __syncthreads();
-
-        if (lane_id == 0) {
-            buf_iw[warp_id] = tmp.x + tmp.y;
-        }
-        __syncthreads();
-
-        tmp = __half2half2(buf_iw[lane_id]);
-        tmp = warp_reduce_sum(tmp);
-    } else {
-        tmp = __half2half2(tmp.x + tmp.y);
-    }
-
-    const half2 inv_sum = make_half2(1.0f, 1.0f) / tmp;
-
-#pragma unroll
-    for (int col0 = 0; col0 < ncols_smem; col0 += block_size) {
-        const int col_data = 2*col0 + 2*WARP_SIZE*warp_id + lane_id;
-        const int col_smem = vals_smem ? col0 + tid : col_data;
-
-        const int idst = rowx*ncols_data + col_data;
-        const half2 result = vals[col_smem] * inv_sum;
-
-        if (need_check && col_data + 0 >= ncols_data) {
-            return;
-        }
-        dst[idst] = result.x;
+    // ALiBi
+    if (max_bias > 0.0f) {
+        const int h = rowx/nrows_y; // head index
 
-        if (need_check && col_data + WARP_SIZE >= ncols_data) {
-            return;
-        }
+        const float base = h < n_head_log2 ? m0 : m1;
+        const int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
 
-        dst[idst + WARP_SIZE] = result.y;
+        slope = powf(base, exp);
     }
-#else
-    (void) x; (void) y; (void) dst; (void) ncols_par; (void) nrows_y; (void) scale;
-    NO_DEVICE_CODE;
-#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
-}
-
-template <bool vals_smem, int ncols_template, int block_size_template>
-static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
-    const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
-
-    const int tid  = threadIdx.x;
-    const int rowx = blockIdx.x;
-    const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
-
-    const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
-
-    const int warp_id = threadIdx.x / WARP_SIZE;
-    const int lane_id = threadIdx.x % WARP_SIZE;
 
     extern __shared__ float data_soft_max_f32[];
     float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
@@ -6117,7 +5999,8 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds
         const int ix = rowx*ncols + col;
         const int iy = rowy*ncols + col;
 
-        const float val = x[ix]*scale + (y ? y[iy] : 0.0f);
+        const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + slope*pos[col];
+
         vals[col] = val;
         max_val = max(max_val, val);
     }
@@ -7589,89 +7472,53 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
     diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
 }
 
-static void soft_max_f16_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
-    int nth = WARP_SIZE;
-    while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
-    const dim3 block_dims(nth,     1, 1);
-    const dim3 block_nums(nrows_x, 1, 1);
-    const size_t shmem = (GGML_PAD(ncols_x, 2*WARP_SIZE) + WARP_SIZE)*sizeof(half);
-    static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
-    if (shmem <= g_device_caps[g_main_device].smpb) {
-        switch (ncols_x) {
-            case 32:
-                soft_max_f16<true, 32, 32, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
-                break;
-            case 64:
-                soft_max_f16<true, 64, 32, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
-                break;
-            case 128:
-                soft_max_f16<true, 128, 64, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
-                break;
-            case 256:
-                soft_max_f16<true, 256, 128, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
-                break;
-            case 512:
-                soft_max_f16<true, 512, 256, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
-                break;
-            case 1024:
-                soft_max_f16<true, 1024, 512, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
-                break;
-            case 2048:
-                soft_max_f16<true, 2048, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
-                break;
-            case 4096:
-                soft_max_f16<true, 4096, 1024, false><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
-                break;
-            default:
-                soft_max_f16<true, 0, 0, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
-                break;
-        }
-    } else {
-        const size_t shmem_low = WARP_SIZE*sizeof(half);
-        soft_max_f16<false, 0, 0, true><<<block_nums, block_dims, shmem_low, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
-    }
-}
-
-static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
+static void soft_max_f32_cuda(const float * x, const float * mask, const float * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
     int nth = WARP_SIZE;
     while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
     const dim3 block_dims(nth,     1, 1);
     const dim3 block_nums(nrows_x, 1, 1);
     const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
     static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
+
+    const uint32_t n_head_kv   = nrows_x/nrows_y;
+    const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
+
+    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
+    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
     if (shmem < g_device_caps[g_main_device].smpb) {
         switch (ncols_x) {
             case 32:
-                soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
+                soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 64:
-                soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
+                soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 128:
-                soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
+                soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 256:
-                soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
+                soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 512:
-                soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
+                soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 1024:
-                soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
+                soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 2048:
-                soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
+                soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             case 4096:
-                soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
+                soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
             default:
-                soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
+                soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
                 break;
         }
     } else {
         const size_t shmem_low = WARP_SIZE*sizeof(float);
-        soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
+        soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
     }
 }
 
@@ -9090,30 +8937,36 @@ static void ggml_cuda_op_soft_max(
 
     GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
 
-    const int64_t ne00 = src0->ne[0];
+    const int64_t ne00    = src0->ne[0];
     const int64_t nrows_x = ggml_nrows(src0);
-    const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1;
+    const int64_t nrows_y = src0->ne[1];
 
-    float scale = 1.0f;
-    memcpy(&scale, dst->op_params, sizeof(float));
+    float scale    = 1.0f;
+    float max_bias = 0.0f;
 
-#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION >= CUDART_HMAX
-#ifdef GGML_CUDA_F16
-    const bool use_f16_soft_max = true;
-#else
-    const bool use_f16_soft_max = false;
-#endif // GGML_CUDA_F16
-#else
-    const bool use_f16_soft_max = false;
-#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && CUDART_VERSION >= CUDART_HMAX
+    memcpy(&scale,    (float *) dst->op_params + 0, sizeof(float));
+    memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
 
-    if (use_f16_soft_max) {
-        soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
-    } else {
-        soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
+    // positions tensor
+    float * src2_dd = dst_dd; // default to avoid null checks in the kernel
+    cuda_pool_alloc<float> src2_f;
+
+    ggml_tensor * src2 = dst->src[2];
+    const bool use_src2 = src2 != nullptr;
+
+    if (use_src2) {
+        const bool src2_on_device = use_src2 && src2->backend == GGML_BACKEND_GPU;
+        ggml_tensor_extra_gpu * src2_extra = use_src2 ? (ggml_tensor_extra_gpu *) src2->extra : nullptr;
+
+        if (src2_on_device) {
+            src2_dd = (float *) src2_extra->data_device[g_main_device];
+        } else {
+            src2_dd = src2_f.alloc(ggml_nelements(src2));
+            CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src2_dd, src2, 0, 0, 0, 1, main_stream));
+        }
     }
 
-    (void) dst;
+    soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, src2_dd, dst_dd, ne00, nrows_x, nrows_y, scale, max_bias, main_stream);
 }
 
 static void ggml_cuda_op_scale(
index 6e76f8bedb50b6e40841b727a44c3a6870694865..c0848a293e48f3616141169c9141846f26d93fcb 100644 (file)
@@ -728,6 +728,7 @@ static bool ggml_metal_graph_compute(
 
         size_t offs_src0 = 0;
         size_t offs_src1 = 0;
+        size_t offs_src2 = 0;
         size_t offs_dst  = 0;
 
         id<MTLCommandBuffer> command_buffer  = command_buffers[cb_idx];
@@ -746,6 +747,7 @@ static bool ggml_metal_graph_compute(
 
             struct ggml_tensor * src0 = gf->nodes[i]->src[0];
             struct ggml_tensor * src1 = gf->nodes[i]->src[1];
+            struct ggml_tensor * src2 = gf->nodes[i]->src[2];
             struct ggml_tensor * dst  = gf->nodes[i];
 
             switch (dst->op) {
@@ -807,6 +809,7 @@ static bool ggml_metal_graph_compute(
 
             id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(src0, &offs_src0) : nil;
             id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(src1, &offs_src1) : nil;
+            id<MTLBuffer> id_src2 = src2 ? ggml_metal_get_buffer(src2, &offs_src2) : nil;
             id<MTLBuffer> id_dst  = dst  ? ggml_metal_get_buffer(dst,  &offs_dst)  : nil;
 
             //GGML_METAL_LOG_INFO("%s: op - %s\n", __func__, ggml_op_name(dst->op));
@@ -1188,7 +1191,16 @@ static bool ggml_metal_graph_compute(
                             pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SOFT_MAX].pipeline;
                         }
 
-                        const float scale = ((float *) dst->op_params)[0];
+                        const float scale    = ((float *) dst->op_params)[0];
+                        const float max_bias = ((float *) dst->op_params)[1];
+
+                        const int64_t nrows_x = ggml_nrows(src0);
+                        const int64_t nrows_y = src0->ne[1];
+                        const uint32_t n_head_kv   = nrows_x/nrows_y;
+                        const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
+
+                        const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
+                        const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
 
                         [encoder setComputePipelineState:pipeline];
                         [encoder setBuffer:id_src0 offset:offs_src0   atIndex:0];
@@ -1197,11 +1209,20 @@ static bool ggml_metal_graph_compute(
                         } else {
                             [encoder setBuffer:id_src0 offset:offs_src0   atIndex:1];
                         }
-                        [encoder setBuffer:id_dst  offset:offs_dst    atIndex:2];
-                        [encoder setBytes:&ne00  length:sizeof(ne00)  atIndex:3];
-                        [encoder setBytes:&ne01  length:sizeof(ne01)  atIndex:4];
-                        [encoder setBytes:&ne02  length:sizeof(ne02)  atIndex:5];
-                        [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
+                        if (id_src2) {
+                            [encoder setBuffer:id_src2 offset:offs_src2   atIndex:2];
+                        } else {
+                            [encoder setBuffer:id_src0 offset:offs_src0   atIndex:2];
+                        }
+                        [encoder setBuffer:id_dst   offset:offs_dst          atIndex:3];
+                        [encoder setBytes:&ne00     length:sizeof(ne00)      atIndex:4];
+                        [encoder setBytes:&ne01     length:sizeof(ne01)      atIndex:5];
+                        [encoder setBytes:&ne02     length:sizeof(ne02)      atIndex:6];
+                        [encoder setBytes:&scale    length:sizeof(scale)     atIndex:7];
+                        [encoder setBytes:&max_bias length:sizeof(max_bias)  atIndex:8];
+                        [encoder setBytes:&m0       length:sizeof(m0)        atIndex:9];
+                        [encoder setBytes:&m1       length:sizeof(m1)        atIndex:10];
+                        [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:11];
                         [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
 
                         [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
@@ -1514,8 +1535,6 @@ static bool ggml_metal_graph_compute(
                         // max size of the src1ids array in the kernel stack
                         GGML_ASSERT(ne11 <= 512);
 
-                        struct ggml_tensor * src2 = gf->nodes[i]->src[2];
-
                         const int64_t  ne20 = src2 ? src2->ne[0] : 0;
                         const int64_t  ne21 = src2 ? src2->ne[1] : 0;
                         const int64_t  ne22 = src2 ? src2->ne[2] : 0;
index efed6ad465e78d4fba26c293214ac980b131e7b3..09ebcc9e3040fdaa3a46bdb582c950ac14a28ba4 100644 (file)
@@ -351,12 +351,17 @@ kernel void kernel_sum_rows(
 kernel void kernel_soft_max(
         device const float * src0,
         device const float * src1,
+        device const float * src2,
         device       float * dst,
         constant   int64_t & ne00,
         constant   int64_t & ne01,
         constant   int64_t & ne02,
         constant     float & scale,
-        threadgroup float  * buf [[threadgroup(0)]],
+        constant     float & max_bias,
+        constant     float & m0,
+        constant     float & m1,
+        constant  uint32_t & n_head_log2,
+        threadgroup  float * buf [[threadgroup(0)]],
         uint  tgpig[[threadgroup_position_in_grid]],
         uint  tpitg[[thread_position_in_threadgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]],
@@ -368,13 +373,26 @@ kernel void kernel_soft_max(
 
     device const float * psrc0 =         src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
     device const float * pmask = src1 != src0 ? src1                               + i01*ne00 : nullptr;
+    device const float * ppos  = src2 != src0 ? src2                                          : nullptr;
     device       float * pdst  =         dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
 
+    float slope = 0.0f;
+
+    // ALiBi
+    if (max_bias > 0.0f) {
+        const int64_t h = i02;
+
+        const float base = h < n_head_log2 ? m0 : m1;
+        const int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+        slope = pow(base, exp);
+    }
+
     // parallel max
     float lmax = -INFINITY;
 
     for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
-        lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
+        lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]);
     }
 
     // find the max value in the block
@@ -399,7 +417,7 @@ kernel void kernel_soft_max(
     // parallel sum
     float lsum = 0.0f;
     for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
-        const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
+        const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]) - max_val);
         lsum += exp_psrc0;
         pdst[i00] = exp_psrc0;
     }
@@ -437,12 +455,17 @@ kernel void kernel_soft_max(
 kernel void kernel_soft_max_4(
         device const float * src0,
         device const float * src1,
+        device const float * src2,
         device       float * dst,
         constant   int64_t & ne00,
         constant   int64_t & ne01,
         constant   int64_t & ne02,
         constant     float & scale,
-        threadgroup float  * buf [[threadgroup(0)]],
+        constant     float & max_bias,
+        constant     float & m0,
+        constant     float & m1,
+        constant  uint32_t & n_head_log2,
+        threadgroup  float * buf [[threadgroup(0)]],
         uint  tgpig[[threadgroup_position_in_grid]],
         uint  tpitg[[thread_position_in_threadgroup]],
         uint  sgitg[[simdgroup_index_in_threadgroup]],
@@ -454,13 +477,25 @@ kernel void kernel_soft_max_4(
 
     device const float4 * psrc4 =                (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
     device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 +                                      i01*ne00) : nullptr;
+    device const float4 * ppos  = src2 != src0 ? (device const float4 *)(src2)                                                 : nullptr;
     device       float4 * pdst4 =                (device       float4 *)(dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
 
+    float slope = 0.0f;
+
+    if (max_bias > 0.0f) {
+        const int64_t h = i02;
+
+        const float base = h < n_head_log2 ? m0 : m1;
+        const int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
+
+        slope = pow(base, exp);
+    }
+
     // parallel max
     float4 lmax4 = -INFINITY;
 
     for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
-        lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
+        lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]);
     }
 
     const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
@@ -486,7 +521,7 @@ kernel void kernel_soft_max_4(
     // parallel sum
     float4 lsum4 = 0.0f;
     for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
-        const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
+        const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + slope*ppos[i00]) - max_val);
         lsum4 += exp_psrc4;
         pdst4[i00] = exp_psrc4;
     }
diff --git a/ggml.c b/ggml.c
index 264cfd705cd378aac42cf65b09e451a9c80fad79..e94024c62a1238883e5f367d92dea8928a18e416 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -5096,16 +5096,28 @@ static struct ggml_tensor * ggml_soft_max_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * mask,
+        struct ggml_tensor  * pos,
         float                 scale,
+        float                 max_bias,
         bool                  inplace) {
     GGML_ASSERT(ggml_is_contiguous(a));
+
     if (mask) {
         GGML_ASSERT(ggml_is_contiguous(mask));
-        GGML_ASSERT(mask->ne[2] == 1);
-        GGML_ASSERT(mask->ne[3] == 1);
+        GGML_ASSERT(ggml_is_matrix(mask));
         GGML_ASSERT(ggml_can_repeat_rows(mask, a));
     }
 
+    if (pos) {
+        GGML_ASSERT(ggml_is_vector(pos));
+        GGML_ASSERT(pos->type == GGML_TYPE_F32);
+        GGML_ASSERT(pos->ne[0] == a->ne[0]);
+    }
+
+    if (max_bias > 0.0f) {
+        GGML_ASSERT(pos);
+    }
+
     bool is_node = false;
 
     if (a->grad) {
@@ -5114,13 +5126,14 @@ static struct ggml_tensor * ggml_soft_max_impl(
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
-    float params[] = { scale };
+    float params[] = { scale, max_bias };
     ggml_set_op_params(result, params, sizeof(params));
 
     result->op   = GGML_OP_SOFT_MAX;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
     result->src[1] = mask;
+    result->src[2] = pos;
 
     return result;
 }
@@ -5128,21 +5141,23 @@ static struct ggml_tensor * ggml_soft_max_impl(
 struct ggml_tensor * ggml_soft_max(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_soft_max_impl(ctx, a, NULL, 1.0f, false);
+    return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, 0.0f, false);
 }
 
 struct ggml_tensor * ggml_soft_max_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_soft_max_impl(ctx, a, NULL, 1.0f, true);
+    return ggml_soft_max_impl(ctx, a, NULL, NULL, 1.0f, 0.0f, true);
 }
 
 struct ggml_tensor * ggml_soft_max_ext(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
         struct ggml_tensor  * mask,
-        float                 scale) {
-    return ggml_soft_max_impl(ctx, a, mask, scale, false);
+        struct ggml_tensor  * pos,
+        float                 scale,
+        float                 max_bias) {
+    return ggml_soft_max_impl(ctx, a, mask, pos, scale, max_bias, false);
 }
 
 // ggml_soft_max_back
@@ -11495,6 +11510,7 @@ static void ggml_compute_forward_soft_max_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
+        const struct ggml_tensor * src2,
               struct ggml_tensor * dst) {
     assert(ggml_is_contiguous(dst));
     assert(ggml_are_same_shape(src0, dst));
@@ -11503,16 +11519,29 @@ static void ggml_compute_forward_soft_max_f32(
         return;
     }
 
-    float scale = 1.0f;
-    memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
+    float scale    = 1.0f;
+    float max_bias = 0.0f;
+
+    memcpy(&scale,    (float *) dst->op_params + 0, sizeof(float));
+    memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
 
     // TODO: handle transposed/permuted matrices
 
     const int ith = params->ith;
     const int nth = params->nth;
 
+    GGML_TENSOR_UNARY_OP_LOCALS
+
     const int64_t ne11 = src1 ? src1->ne[1] : 1;
 
+    // TODO: is this supposed to be ceil instead of floor?
+    //       https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370
+    const uint32_t n_head_kv   = ne02;
+    const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head_kv));
+
+    const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2);
+    const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
+
     const int nc = src0->ne[0];
     const int nr = ggml_nrows(src0);
 
@@ -11525,6 +11554,9 @@ static void ggml_compute_forward_soft_max_f32(
 
     float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
 
+    // when max_bias <= 0.0f, src2 is not used and we default it to src0 to avoid branching
+    float * pos = src2 ? (float *) src2->data : src0->data;
+
     for (int i1 = ir0; i1 < ir1; i1++) {
         float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
         float * dp = (float *)((char *)  dst->data +  i1*dst->nb[1]);
@@ -11538,6 +11570,16 @@ static void ggml_compute_forward_soft_max_f32(
             ggml_vec_acc_f32(nc, wp, mp);
         }
 
+        // ALiBi bias
+        if (max_bias > 0.0f) {
+            const uint32_t h  = (i1/ne01)%ne02; // head
+            const float slope = h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1);
+
+            for (int i = 0; i < nc; i++) {
+                wp[i] = wp[i] + slope*pos[i];
+            }
+        }
+
 #ifndef NDEBUG
         for (int i = 0; i < nc; ++i) {
             //printf("p[%d] = %f\n", i, p[i]);
@@ -11582,11 +11624,12 @@ static void ggml_compute_forward_soft_max(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
         const struct ggml_tensor * src1,
+        const struct ggml_tensor * src2,
               struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_soft_max_f32(params, src0, src1, dst);
+                ggml_compute_forward_soft_max_f32(params, src0, src1, src2, dst);
             } break;
         default:
             {
@@ -11730,22 +11773,20 @@ static void ggml_compute_forward_alibi_f32(
     const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
 
-    for (int64_t i = 0; i < ne0; i++) {
-        for (int64_t j = 0; j < ne1; j++) {
-            for (int64_t k = 0; k < ne2_ne3; k++) {
-                float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
-                float *      pdst = (float *)((char *)  dst->data + i*nb0 + j*nb1 + k*nb2);
-
-                // TODO: k*nb2 or k*nb3
+    for (int64_t k = 0; k < ne2_ne3; k++) {
+        // TODO: k*nb2 or k*nb3
+        float m_k;
 
-                float m_k;
-
-                if (k < n_heads_log2_floor) {
-                    m_k = powf(m0, k + 1);
-                } else {
-                    m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
-                }
+        if (k < n_heads_log2_floor) {
+            m_k = powf(m0, k + 1);
+        } else {
+            m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
+        }
 
+        for (int64_t i = 0; i < ne0; i++) {
+            for (int64_t j = 0; j < ne1; j++) {
+                float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
+                float *      pdst = (float *)((char *)  dst->data + i*nb0 + j*nb1 + k*nb2);
                 pdst[0] = i * m_k + src[0];
             }
         }
@@ -11790,21 +11831,20 @@ static void ggml_compute_forward_alibi_f16(
     const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
 
-    for (int i = 0; i < ne0; i++) {
-        for (int j = 0; j < ne1; j++) {
-            for (int k = 0; k < ne2_ne3; k++) {
-                ggml_fp16_t * const src  = (ggml_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
-                      float *      pdst  =       (float *)((char *)  dst->data + i*nb0 + j*nb1 + k*nb2);
-
-                // TODO: k*nb2 or k*nb3
+    for (int k = 0; k < ne2_ne3; k++) {
+        // TODO: k*nb2 or k*nb3
+        float m_k;
 
-                float m_k;
+        if (k < n_heads_log2_floor) {
+            m_k = powf(m0, k + 1);
+        } else {
+            m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
+        }
 
-                if (k < n_heads_log2_floor) {
-                    m_k = powf(m0, k + 1);
-                } else {
-                    m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
-                }
+        for (int i = 0; i < ne0; i++) {
+            for (int j = 0; j < ne1; j++) {
+                ggml_fp16_t * const src  = (ggml_fp16_t *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
+                float       *      pdst  =       (float *)((char *)  dst->data + i*nb0 + j*nb1 + k*nb2);
 
                 // we return F32
                 pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]);
@@ -15116,7 +15156,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             } break;
         case GGML_OP_SOFT_MAX:
             {
-                ggml_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor);
+                ggml_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
             } break;
         case GGML_OP_SOFT_MAX_BACK:
             {
diff --git a/ggml.h b/ggml.h
index 270018185f397c664cd9f5045b26198a405a9a63..6c1956772324c2daff38fcd24bba429a225256f3 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -1383,13 +1383,17 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
-    // fused soft_max(a*scale + mask)
+    // fused soft_max(a*scale + mask + pos[i]*(ALiBi slope))
     // mask is optional
+    // pos is required when max_bias > 0.0f
+    // max_bias = 0.0f for no ALiBi
     GGML_API struct ggml_tensor * ggml_soft_max_ext(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             struct ggml_tensor  * mask,
-            float                 scale);
+            struct ggml_tensor  * pos,
+            float                 scale,
+            float                 max_bias);
 
     GGML_API struct ggml_tensor * ggml_soft_max_back(
             struct ggml_context * ctx,
@@ -1491,12 +1495,13 @@ extern "C" {
 
     // alibi position embedding
     // in-place, returns view(a)
-    GGML_API struct ggml_tensor * ggml_alibi(
+    GGML_DEPRECATED(GGML_API struct ggml_tensor * ggml_alibi(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
             int                   n_past,
             int                   n_head,
-            float                 bias_max);
+            float                 bias_max),
+        "use ggml_soft_max_ext instead (will be removed in Mar 2024)");
 
     // clamp
     // in-place, returns view(a)
index 8966c3e66591699c0887273ef8810853b0e8cf79..6ac9caa957a05dd37f7b4ab2bd7e9a42116b2f92 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -1557,12 +1557,13 @@ struct llama_hparams {
     uint32_t n_yarn_orig_ctx;
     int32_t  rope_scaling_type_train;
 
-    float f_clamp_kqv;
-    float f_max_alibi_bias;
+    float f_clamp_kqv      = 0.0f;
+    float f_max_alibi_bias = 0.0f;
 
     bool causal_attn = true;
-    uint32_t pooling_type = LLAMA_POOLING_NONE;
+    bool need_kq_pos = false;
 
+    uint32_t pooling_type = LLAMA_POOLING_NONE;
 
     bool operator!=(const llama_hparams & other) const {
         if (this->vocab_only    != other.vocab_only)    return true;
@@ -1923,6 +1924,7 @@ struct llama_context {
     struct ggml_tensor * inp_embd;      // F32 [n_embd, n_batch]
     struct ggml_tensor * inp_pos;       // I32 [n_batch]
     struct ggml_tensor * inp_KQ_mask;   // F32 [n_ctx, n_batch]
+    struct ggml_tensor * inp_KQ_pos;    // F32 [n_ctx]
     struct ggml_tensor * inp_K_shift;   // I32 [n_ctx]
     struct ggml_tensor * inp_mean;      // F32 [n_batch, n_batch]
     struct ggml_tensor * inp_cls;       // I32 [n_batch]
@@ -3054,6 +3056,11 @@ static void llm_load_hparams(
                     case 40: model.type = e_model::MODEL_13B; break;
                     default: model.type = e_model::MODEL_UNKNOWN;
                 }
+
+                if (model.type == e_model::MODEL_13B) {
+                    // TODO: become GGUF KV parameter
+                    hparams.f_max_alibi_bias = 8.0f;
+                }
             } break;
         case LLM_ARCH_STARCODER:
             {
@@ -3081,6 +3088,9 @@ static void llm_load_hparams(
                     case 32: model.type = e_model::MODEL_1B; break;
                     default: model.type = e_model::MODEL_UNKNOWN;
                 }
+
+                // TODO: become GGUF KV parameter
+                hparams.f_max_alibi_bias = 8.0f;
             } break;
         case LLM_ARCH_BERT:
             {
@@ -3126,11 +3136,12 @@ static void llm_load_hparams(
                             case 4096: model.type = e_model::MODEL_7B; break;
                         } break;
                 }
+
+                // TODO: become GGUF KV parameter
+                hparams.f_max_alibi_bias = 8.0f;
             } break;
         case LLM_ARCH_MPT:
             {
-                hparams.f_clamp_kqv = 0.0f;
-
                 ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS,  hparams.f_norm_eps);
                 ml.get_key(LLM_KV_ATTENTION_CLAMP_KQV,      hparams.f_clamp_kqv, false);
                 ml.get_key(LLM_KV_ATTENTION_MAX_ALIBI_BIAS, hparams.f_max_alibi_bias);
@@ -3232,6 +3243,10 @@ static void llm_load_hparams(
     }
 
     model.ftype = ml.ftype;
+
+    if (hparams.f_max_alibi_bias > 0.0f) {
+        hparams.need_kq_pos = true;
+    }
 }
 
 // TODO: This should probably be in llama.h
@@ -4774,10 +4789,10 @@ static struct ggml_tensor * llm_build_kqv(
          struct ggml_tensor * wo_b,
          struct ggml_tensor * q_cur,
          struct ggml_tensor * kq_mask,
+         struct ggml_tensor * kq_pos,
                     int64_t   n_ctx,
                     int32_t   n_tokens,
                     int32_t   n_kv,
-                    float     max_alibi_bias,
                     float     kq_scale,
          const llm_build_cb & cb,
                     int       il) {
@@ -4807,26 +4822,26 @@ static struct ggml_tensor * llm_build_kqv(
         ggml_mul_mat_set_prec(kq, GGML_PREC_F32);
     }
 
-    if (max_alibi_bias > 0.0f) {
-        // temporary branch until we figure out how to handle ggml_alibi through ggml_add
+#if defined(GGML_USE_VULKAN) || defined(GGML_USE_KOMPUTE) || defined(GGML_USE_SYCL)
+#pragma message("TODO: ALiBi support in ggml_soft_max_ext is not implemented for Vulkan, Kompute, and SYCL")
+#pragma message("      Falling back to ggml_alibi(). Will become an error in Mar 2024")
+#pragma message("ref:  https://github.com/ggerganov/llama.cpp/pull/5488")
+    if (hparams.f_max_alibi_bias > 0.0f) {
         kq = ggml_scale(ctx, kq, kq_scale);
         cb(kq, "kq_scaled", il);
 
-        if (max_alibi_bias > 0.0f) {
-            // TODO: n_head or n_head_kv
-            // TODO: K-shift is likely not working
-            // TODO: change to ggml_add
-            kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias);
-            cb(kq, "kq_scaled_alibi", il);
-        }
+        kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, hparams.f_max_alibi_bias);
+        cb(kq, "kq_scaled_alibi", il);
 
         kq = ggml_add(ctx, kq, kq_mask);
         cb(kq, "kq_masked", il);
 
         kq = ggml_soft_max(ctx, kq);
         cb(kq, "kq_soft_max", il);
-    } else {
-        kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale);
+    } else
+#endif
+    {
+        kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_pos, kq_scale, hparams.f_max_alibi_bias);
         cb(kq, "kq_soft_max_ext", il);
     }
 
@@ -4874,11 +4889,11 @@ static struct ggml_tensor * llm_build_kv(
          struct ggml_tensor * v_cur,
          struct ggml_tensor * q_cur,
          struct ggml_tensor * kq_mask,
+         struct ggml_tensor * kq_pos,
                     int64_t   n_ctx,
                     int32_t   n_tokens,
                     int32_t   kv_head,
                     int32_t   n_kv,
-                    float     max_alibi_bias,
                     float     kq_scale,
          const llm_build_cb & cb,
                     int       il) {
@@ -4892,9 +4907,8 @@ static struct ggml_tensor * llm_build_kv(
     llm_build_kv_store(ctx, hparams, kv, graph, k_cur, v_cur, n_ctx, n_tokens, kv_head, cb, il);
 
     struct ggml_tensor * cur;
-    cur  = llm_build_kqv(ctx, model, hparams, kv, graph,
-            wo, wo_b,
-            q_cur, kq_mask, n_ctx, n_tokens, n_kv, max_alibi_bias, kq_scale, cb, il);
+    cur  = llm_build_kqv(ctx, model, hparams, kv, graph, wo, wo_b,
+            q_cur, kq_mask, kq_pos, n_ctx, n_tokens, n_kv, kq_scale, cb, il);
     cb(cur, "kqv_out", il);
 
     return cur;
@@ -5062,7 +5076,7 @@ struct llm_build_context {
                 }
 
                 Qcur = ggml_rope_custom(
-                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head,    n_tokens), inp_pos,
+                    ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
                     hparams.n_rot, 0, 0, n_orig_ctx, freq_base, freq_scale,
                     ext_factor, attn_factor, beta_fast, beta_slow
                 );
@@ -5077,7 +5091,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -5207,6 +5221,10 @@ struct llm_build_context {
         struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
         cb(KQ_mask, "KQ_mask", -1);
 
+        // positions of the tokens in the KV cache
+        struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0);
+        cb(KQ_pos, "KQ_pos", -1);
+
         // shift the entire K-cache if needed
         if (do_rope_shift) {
             llm_build_k_shift(ctx0, hparams, cparams, kv_self, gf, lctx.inp_K_shift, LLM_ROPE, n_ctx, freq_base, freq_scale, cb);
@@ -5255,12 +5273,9 @@ struct llm_build_context {
                 cb(Kcur, "Kcur", il);
 
 
-                // apply ALiBi for 13B model
-                const float max_alibi_bias = model.type == MODEL_13B ? 8.0f : -1.0f;
-
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -5384,7 +5399,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -5483,7 +5498,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -5688,7 +5703,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Q, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Q, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -5750,6 +5765,10 @@ struct llm_build_context {
         struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
         cb(KQ_mask, "KQ_mask", -1);
 
+        // positions of the tokens in the KV cache
+        struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0);
+        cb(KQ_pos, "KQ_pos", -1);
+
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * inpSA = inpL;
 
@@ -5777,7 +5796,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -5878,7 +5897,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             } else {
                 // compute Q and K and RoPE them
@@ -5909,7 +5928,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -5985,6 +6004,10 @@ struct llm_build_context {
         struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
         cb(KQ_mask, "KQ_mask", -1);
 
+        // positions of the tokens in the KV cache
+        struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0);
+        cb(KQ_pos, "KQ_pos", -1);
+
         inpL = llm_build_norm(ctx0, inpL, hparams,
                 model.tok_norm,
                 model.tok_norm_b,
@@ -6018,7 +6041,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, 8.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -6078,6 +6101,10 @@ struct llm_build_context {
         struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
         cb(KQ_mask, "KQ_mask", -1);
 
+        // positions of the tokens in the KV cache
+        struct ggml_tensor * KQ_pos = ggml_view_1d(ctx0, lctx.inp_KQ_pos, n_kv, 0);
+        cb(KQ_pos, "KQ_pos", -1);
+
         for (int il = 0; il < n_layer; ++il) {
             struct ggml_tensor * attn_norm;
 
@@ -6111,7 +6138,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, hparams.f_max_alibi_bias, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, KQ_pos, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -6233,7 +6260,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -6348,7 +6375,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -6469,7 +6496,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -6596,7 +6623,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f, cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f, cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -6699,7 +6726,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
             struct ggml_tensor * sa_out = cur;
@@ -6798,7 +6825,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -6907,7 +6934,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -7025,7 +7052,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, NULL,
-                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -7144,7 +7171,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -7276,7 +7303,7 @@ struct llm_build_context {
 
                 cur = llm_build_kv(ctx0, model, hparams, kv_self, gf,
                         model.layers[il].wo, model.layers[il].bo,
-                        Kcur, Vcur, Qcur, KQ_mask, n_ctx, n_tokens, kv_head, n_kv, -1.0f, 1.0f/sqrtf(float(n_embd_head)), cb, il);
+                        Kcur, Vcur, Qcur, KQ_mask, nullptr, n_ctx, n_tokens, kv_head, n_kv, 1.0f/sqrtf(float(n_embd_head)), cb, il);
                 cb(cur, "kqv_out", il);
             }
 
@@ -7507,6 +7534,18 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {
         }
     }
 
+    if (hparams.need_kq_pos) {
+        const int64_t n_kv = kv_self.n;
+
+        assert(ggml_backend_buffer_is_host(lctx.inp_KQ_pos->buffer));
+
+        float * data = (float *) lctx.inp_KQ_pos->data;
+
+        for (int i = 0; i < n_kv; ++i) {
+            data[i] = float(lctx.kv_self.cells[i].pos);
+        }
+    }
+
     if (kv_self.has_shift) {
         const int64_t n_ctx = cparams.n_ctx;
 
@@ -11434,7 +11473,7 @@ struct llama_context * llama_new_context_with_model(
         // graph inputs
         {
             ggml_init_params init_params = {
-                /* .mem_size   */ ggml_tensor_overhead()*7,
+                /* .mem_size   */ ggml_tensor_overhead()*8,
                 /* .mem_buffer */ nullptr,
                 /* .no_alloc   */ true,
             };
@@ -11444,6 +11483,7 @@ struct llama_context * llama_new_context_with_model(
             ctx->inp_embd    = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, hparams.n_embd, cparams.n_batch);
             ctx->inp_pos     = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
             ctx->inp_KQ_mask = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx, cparams.n_batch);
+            ctx->inp_KQ_pos  = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_ctx);
             ctx->inp_K_shift = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_ctx);
             ctx->inp_mean    = ggml_new_tensor_2d(ctx->ctx_input, GGML_TYPE_F32, cparams.n_batch, cparams.n_batch);
             ctx->inp_cls     = ggml_new_tensor_1d(ctx->ctx_input, GGML_TYPE_I32, cparams.n_batch);
@@ -11452,6 +11492,7 @@ struct llama_context * llama_new_context_with_model(
             ggml_set_name(ctx->inp_embd,    "inp_embd");
             ggml_set_name(ctx->inp_pos,     "inp_pos");
             ggml_set_name(ctx->inp_KQ_mask, "inp_KQ_mask");
+            ggml_set_name(ctx->inp_KQ_pos,  "inp_KQ_pos");
             ggml_set_name(ctx->inp_K_shift, "inp_K_shift");
             ggml_set_name(ctx->inp_mean,    "inp_mean");
             ggml_set_name(ctx->inp_cls,     "inp_cls");
index 9af8517d950db9364c0550d51dfd1d6a22b6ffe1..30a7d1f5ab3e2ccc9187831c7e49cf92ce484b05 100644 (file)
@@ -1085,24 +1085,32 @@ struct test_diag_mask_inf : public test_case {
 struct test_soft_max : public test_case {
     const ggml_type type;
     const std::array<int64_t, 4> ne;
-    const float scale;
     const bool mask;
+    const float scale;
+    const float max_bias;
 
     std::string vars() override {
-        return VARS_TO_STR4(type, ne, scale, mask);
+        return VARS_TO_STR5(type, ne, mask, scale, max_bias);
     }
 
     test_soft_max(ggml_type type = GGML_TYPE_F32,
             std::array<int64_t, 4> ne = {10, 10, 10, 10},
+            bool mask = false,
             float scale = 1.0f,
-            bool mask = false)
-        : type(type), ne(ne), scale(scale), mask(mask) {}
+            float max_bias = 0.0f)
+        : type(type), ne(ne), mask(mask), scale(scale), max_bias(max_bias) {}
 
     ggml_tensor * build_graph(ggml_context * ctx) override {
         ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_tensor * b = nullptr;
-        if (mask) { b = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]); }
-        ggml_tensor * out = ggml_soft_max_ext(ctx, a, b, scale);
+        ggml_tensor * mask = nullptr;
+        if (this->mask) {
+            mask = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]);
+        }
+        ggml_tensor * pos = nullptr;
+        if (max_bias > 0.0f) {
+            pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[0]);
+        }
+        ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, pos, scale, max_bias);
         return out;
     }
 };
@@ -1147,30 +1155,6 @@ struct test_rope : public test_case {
     }
 };
 
-// GGML_OP_ALIBI
-struct test_alibi : public test_case {
-    const ggml_type type;
-    const std::array<int64_t, 4> ne;
-    int n_past;
-    int n_head;
-    float bias_max;
-
-    std::string vars() override {
-        return VARS_TO_STR5(type, ne, n_past, n_head, bias_max);
-    }
-
-    test_alibi(ggml_type type = GGML_TYPE_F32,
-            std::array<int64_t, 4> ne = {10, 10, 10, 10},
-            int n_past = 512, int n_head = 10, float bias_max = 0.5f)
-        : type(type), ne(ne), n_past(n_past), n_head(n_head), bias_max(bias_max) {}
-
-    ggml_tensor * build_graph(ggml_context * ctx) override {
-        ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
-        ggml_tensor * out = ggml_alibi(ctx, a, n_past, n_head, bias_max);
-        return out;
-    }
-};
-
 // GGML_OP_POOL2D
 struct test_pool2d : public test_case {
     enum ggml_op_pool pool_type;
@@ -1488,7 +1472,7 @@ struct test_moe : public test_case {
         ggml_tensor * cur = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_embd, n_tokens);
 
         ggml_tensor * logits = ggml_mul_mat(ctx, ffn_gate_inp, cur);
-        ggml_tensor * probs = ggml_soft_max_ext(ctx, logits, nullptr, 1.0f/sqrtf(n_embd));
+        ggml_tensor * probs = ggml_soft_max_ext(ctx, logits, nullptr, nullptr, 1.0f/sqrtf(n_embd), 0.0f);
 
         // select experts
         ggml_tensor * selected_experts = ggml_top_k(ctx, probs, n_experts_per_tok);
@@ -1617,7 +1601,6 @@ public:
         ggml_cpy(ctx, v_cur_t, v_cache_view);
     }
 
-    // if max_alibi_bias > 0 then apply ALiBi
     struct ggml_tensor * llm_build_kqv(
             struct ggml_context * ctx,
              struct ggml_tensor * k_l,
@@ -1636,7 +1619,7 @@ public:
 
         struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
 
-        kq = ggml_soft_max_ext(ctx, kq, kq_mask, kq_scale);
+        kq = ggml_soft_max_ext(ctx, kq, kq_mask, nullptr, kq_scale, 0.0f);
 
         // split cached v into n_head heads
         struct ggml_tensor * v =
@@ -2083,6 +2066,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
     test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10,  1}, 5));
     test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 10, 10}, 5));
 
+#if 0
     std::uniform_int_distribution<> dist_ne1(1, 50);
     int exponent = 1;
     while (exponent < (1 << 17)) {
@@ -2091,14 +2075,29 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         for (int n = 0; n < 10; ++n) {
             int64_t ne0 = dist_ne0(rng);
             int64_t ne1 = dist_ne1(rng);
-            test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}));
+            test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, n/2 == 0, 0.1f, ne0 < 1000 ? 4.0f : 0.0f));
         }
 
         exponent <<= 1;
     }
+#endif
+    for (bool mask : {false, true}) {
+        for (float max_bias : {0.0f, 8.0f}) {
+            for (float scale : {1.0f, 0.1f}) {
+                for (int64_t ne0 : {16, 1024}) {
+                    for (int64_t ne1 : {16, 1024}) {
+                        test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0, ne1, 1, 1}, mask, scale, max_bias));
+                        test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {ne0-1, ne1-1, 1, 1}, mask, scale, max_bias));
+                    }
+                }
+            }
+        }
+    }
 
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, 0.1f));
-    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, 0.1f, true));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  0.1f, 0.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {16, 2, 32, 1}, false, 0.1f, 8.0f));
+    test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true,  0.1f, 8.0f));
 
     for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) {
         test_cases.emplace_back(new test_rope(type, {128,  32, 10, 1}, 128, 0, 512)); // llama 7B
@@ -2113,7 +2112,6 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
         test_cases.emplace_back(new test_rope(type, { 80,  32, 10, 1},  32, 2, 512)); // neox (phi-2)
     }
 
-    test_cases.emplace_back(new test_alibi());
     test_cases.emplace_back(new test_concat(GGML_TYPE_F32));
     test_cases.emplace_back(new test_concat(GGML_TYPE_I32));