]> git.djapps.eu Git - pkg/ggml/sources/llama.cpp/commitdiff
ggml : add ggml_soft_max_ext (#4256)
authorGeorgi Gerganov <redacted>
Fri, 1 Dec 2023 08:51:24 +0000 (10:51 +0200)
committerGitHub <redacted>
Fri, 1 Dec 2023 08:51:24 +0000 (10:51 +0200)
* metal : implement soft_max_ext

* cuda : implement soft_max_ext

* ggml : implement soft_max_ext (CPU)

* batched-bench : print threads

ggml-ci

* metal : simplify soft_max encoding

ggml-ci

* cuda : use 512 threads for soft_max instead of 32

* ggml : update soft max cpu

* cuda : do warp-based block reduce

* cuda : increase max block size to 1024

* cuda : fix warp reduction initialization of shared mem

* metal : warp-based reduction for soft max kernel

* metal : warp-based reduce for rms_norm

* metal : simplify soft max kernel

ggml-ci

* alloc : fix build with debug

examples/batched-bench/batched-bench.cpp
ggml-alloc.c
ggml-cuda.cu
ggml-metal.m
ggml-metal.metal
ggml.c
ggml.h
llama.cpp

index 533c55c17aad17da651014866a30b6d63a07ec3a..57596ed9860507978334417f18565a2d675a3dce 100644 (file)
@@ -155,7 +155,7 @@ int main(int argc, char ** argv) {
     }
 
     LOG_TEE("\n");
-    LOG_TEE("%s: n_kv_max = %d, is_pp_shared = %d, n_gpu_layers = %d, mmq = %d\n", __func__, n_kv_max, is_pp_shared, n_gpu_layers, mmq);
+    LOG_TEE("%s: n_kv_max = %d, is_pp_shared = %d, n_gpu_layers = %d, mmq = %d, n_threads = %d, n_threads_batch = %d\n", __func__, n_kv_max, is_pp_shared, n_gpu_layers, mmq, ctx_params.n_threads, ctx_params.n_threads_batch);
     LOG_TEE("\n");
 
     LOG_TEE("|%6s | %6s | %4s | %6s | %8s | %8s | %8s | %8s | %8s | %8s |\n", "PP",     "TG",     "B",    "N_KV",     "T_PP s",   "S_PP t/s", "T_TG s",   "S_TG t/s", "T s",      "S t/s");
index cdfe4caf69613d6778f8a80af9b0d6ffc8e67652..0d4e12ae99d3d380aae7ec3cf4b0c8f8ec473026 100644 (file)
@@ -137,7 +137,7 @@ void ggml_tallocr_alloc(ggml_tallocr_t alloc, struct ggml_tensor * tensor) {
 
 #ifdef GGML_ALLOCATOR_DEBUG
     add_allocated_tensor(alloc, tensor);
-    size_t cur_max = (char*)addr - (char*)alloc->data + size;
+    size_t cur_max = (char*)addr - (char*)alloc->base + size;
     if (cur_max > alloc->max_size) {
         printf("max_size = %.2f MB: tensors: ", cur_max / 1024.0 / 1024.0);
         for (int i = 0; i < 1024; i++) {
index 5b80e4ae313293bdc630a45e57567a2efea2dcc3..9019a849f0bff14945251c63145e52f3d05d6455 100644 (file)
@@ -443,6 +443,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
 #define CUDA_SCALE_BLOCK_SIZE 256
 #define CUDA_CLAMP_BLOCK_SIZE 256
 #define CUDA_ROPE_BLOCK_SIZE 256
+#define CUDA_SOFT_MAX_BLOCK_SIZE 1024
 #define CUDA_ALIBI_BLOCK_SIZE 32
 #define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
 #define CUDA_QUANTIZE_BLOCK_SIZE 256
@@ -501,6 +502,31 @@ static size_t g_scratch_offset = 0;
 
 static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
 
+static __device__ __forceinline__ float warp_reduce_sum(float x) {
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        x += __shfl_xor_sync(0xffffffff, x, mask, 32);
+    }
+    return x;
+}
+
+static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
+        a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
+    }
+    return a;
+}
+
+static __device__ __forceinline__ float warp_reduce_max(float x) {
+#pragma unroll
+    for (int mask = 16; mask > 0; mask >>= 1) {
+        x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
+    }
+    return x;
+}
+
 static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
     const int i = blockDim.x*blockIdx.x + threadIdx.x;
 
@@ -577,15 +603,6 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) {
     dst[i] = x[i] * x[i];
 }
 
-static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
-        a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
-    }
-    return a;
-}
-
 template <int block_size>
 static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
     const int row = blockIdx.x*blockDim.y + threadIdx.y;
@@ -624,14 +641,6 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
     }
 }
 
-static __device__ __forceinline__ float warp_reduce_sum(float x) {
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        x += __shfl_xor_sync(0xffffffff, x, mask, 32);
-    }
-    return x;
-}
-
 template <int block_size>
 static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
     const int row = blockIdx.x*blockDim.y + threadIdx.y;
@@ -4717,45 +4726,74 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
     dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
 }
 
-// the CUDA soft max implementation differs from the CPU implementation
-// instead of doubles floats are used
-static __global__ void soft_max_f32(const float * x, float * dst, const int ncols) {
-    const int row = blockDim.x*blockIdx.x + threadIdx.x;
-    const int block_size = blockDim.y;
-    const int tid = threadIdx.y;
+static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols, const int nrows_y, const float scale) {
+    const int tid  = threadIdx.x;
+    const int rowx = blockIdx.x;
+    const int rowy = rowx % nrows_y; // broadcast the mask (y) in the row dimension
+
+    const int block_size = blockDim.x;
+
+    const int warp_id = threadIdx.x / WARP_SIZE;
+    const int lane_id = threadIdx.x % WARP_SIZE;
+
+    __shared__ float buf[CUDA_SOFT_MAX_BLOCK_SIZE/WARP_SIZE];
 
     float max_val = -INFINITY;
 
     for (int col = tid; col < ncols; col += block_size) {
-        const int i = row*ncols + col;
-        max_val = max(max_val, x[i]);
+        const int ix = rowx*ncols + col;
+        const int iy = rowy*ncols + col;
+        max_val = max(max_val, x[ix]*scale + (y ? y[iy] : 0.0f));
     }
 
     // find the max value in the block
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        max_val = max(max_val, __shfl_xor_sync(0xffffffff, max_val, mask, 32));
+    max_val = warp_reduce_max(max_val);
+    if (block_size > WARP_SIZE) {
+        if (warp_id == 0) {
+            buf[lane_id] = -INFINITY;
+        }
+        __syncthreads();
+
+        if (lane_id == 0) {
+            buf[warp_id] = max_val;
+        }
+        __syncthreads();
+
+        max_val = buf[lane_id];
+        max_val = warp_reduce_max(max_val);
     }
 
     float tmp = 0.f;
 
     for (int col = tid; col < ncols; col += block_size) {
-        const int i = row*ncols + col;
-        const float val = expf(x[i] - max_val);
+        const int ix = rowx*ncols + col;
+        const int iy = rowy*ncols + col;
+        const float val = expf((x[ix]*scale + (y ? y[iy] : 0.0f)) - max_val);
         tmp += val;
-        dst[i] = val;
+        dst[ix] = val;
     }
 
-    // sum up partial sums
-#pragma unroll
-    for (int mask = 16; mask > 0; mask >>= 1) {
-        tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
+    // find the sum of exps in the block
+    tmp = warp_reduce_sum(tmp);
+    if (block_size > WARP_SIZE) {
+        if (warp_id == 0) {
+            buf[lane_id] = 0.f;
+        }
+        __syncthreads();
+
+        if (lane_id == 0) {
+            buf[warp_id] = tmp;
+        }
+        __syncthreads();
+
+        tmp = buf[lane_id];
+        tmp = warp_reduce_sum(tmp);
     }
 
     const float inv_tmp = 1.f / tmp;
 
     for (int col = tid; col < ncols; col += block_size) {
-        const int i = row*ncols + col;
+        const int i = rowx*ncols + col;
         dst[i] *= inv_tmp;
     }
 }
@@ -5792,10 +5830,12 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
     diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
 }
 
-static void soft_max_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, cudaStream_t stream) {
-    const dim3 block_dims(1, WARP_SIZE, 1);
+static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
+    int nth = WARP_SIZE;
+    while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
+    const dim3 block_dims(nth,     1, 1);
     const dim3 block_nums(nrows_x, 1, 1);
-    soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x);
+    soft_max_f32<<<block_nums, block_dims, 0, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
 }
 
 static void im2col_f32_f16_cuda(const float * x, half * dst,
@@ -6846,14 +6886,18 @@ inline void ggml_cuda_op_soft_max(
     GGML_ASSERT(src0->type == GGML_TYPE_F32);
     GGML_ASSERT( dst->type == GGML_TYPE_F32);
 
+    GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
+
     const int64_t ne00 = src0->ne[0];
-    const int64_t nrows = ggml_nrows(src0);
+    const int64_t nrows_x = ggml_nrows(src0);
+    const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1;
 
-    soft_max_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream);
+    float scale = 1.0f;
+    memcpy(&scale, dst->op_params, sizeof(float));
+
+    soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
 
-    (void) src1;
     (void) dst;
-    (void) src1_dd;
 }
 
 inline void ggml_cuda_op_scale(
index d52a1c3c48210b1f2805edde477b45001796c7e6..6cfacf64fcd9fc0b61a6290b8850fe40a29fb7e3 100644 (file)
@@ -1028,20 +1028,27 @@ void ggml_metal_graph_compute(
                             int nth = 32; // SIMD width
 
                             if (ne00%4 == 0) {
+                                while (nth < ne00/4 && nth < 256) {
+                                    nth *= 2;
+                                }
                                 [encoder setComputePipelineState:ctx->pipeline_soft_max_4];
                             } else {
-                                do {
+                                while (nth < ne00 && nth < 1024) {
                                     nth *= 2;
-                                } while (nth <= ne00 && nth <= 1024);
-                                nth /= 2;
+                                }
                                 [encoder setComputePipelineState:ctx->pipeline_soft_max];
                             }
-                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                            [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
-                            [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
-                            [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
-                            [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
+
+                            const float scale = ((float *) dst->op_params)[0];
+
+                            [encoder setBuffer:id_src0 offset:offs_src0   atIndex:0];
+                            [encoder setBuffer:id_src1 offset:offs_src1   atIndex:1];
+                            [encoder setBuffer:id_dst  offset:offs_dst    atIndex:2];
+                            [encoder setBytes:&ne00  length:sizeof(ne00)  atIndex:3];
+                            [encoder setBytes:&ne01  length:sizeof(ne01)  atIndex:4];
+                            [encoder setBytes:&ne02  length:sizeof(ne02)  atIndex:5];
+                            [encoder setBytes:&scale length:sizeof(scale) atIndex:6];
+                            [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
 
                             [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
@@ -1351,15 +1358,19 @@ void ggml_metal_graph_compute(
                             float eps;
                             memcpy(&eps, dst->op_params, sizeof(float));
 
-                            const int nth = MIN(512, ne00);
+                            int nth = 32; // SIMD width
+
+                            while (nth < ne00/4 && nth < 1024) {
+                                nth *= 2;
+                            }
 
                             [encoder setComputePipelineState:ctx->pipeline_rms_norm];
-                            [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
-                            [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
-                            [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
-                            [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
-                            [encoder setBytes:&eps  length:sizeof(   float) atIndex:4];
-                            [encoder setThreadgroupMemoryLength:GGML_PAD(nth/32*sizeof(float), 16) atIndex:0];
+                            [encoder setBuffer:id_src0 offset:offs_src0        atIndex:0];
+                            [encoder setBuffer:id_dst  offset:offs_dst         atIndex:1];
+                            [encoder setBytes:&ne00    length:sizeof( int64_t) atIndex:2];
+                            [encoder setBytes:&nb01    length:sizeof(uint64_t) atIndex:3];
+                            [encoder setBytes:&eps     length:sizeof(   float) atIndex:4];
+                            [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
 
                             const int64_t nrows = ggml_nrows(src0);
 
index 5d1357cd72d4592782802a60222e81d2cacb8d8f..9a79f815f3a72088665c711405966c9be26cce21 100644 (file)
@@ -39,6 +39,8 @@ typedef struct {
     int8_t  qs[QK8_0]; // quants
 } block_q8_0;
 
+#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
+
 // general-purpose kernel for addition of two tensors
 // pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3
 // cons: not very efficient
@@ -180,10 +182,12 @@ kernel void kernel_gelu(
 
 kernel void kernel_soft_max(
         device const float * src0,
+        device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
         constant   int64_t & ne01,
         constant   int64_t & ne02,
+        constant     float & scale,
         threadgroup float  * buf [[threadgroup(0)]],
         uint  tgpig[[threadgroup_position_in_grid]],
         uint  tpitg[[thread_position_in_threadgroup]],
@@ -194,73 +198,77 @@ kernel void kernel_soft_max(
     const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
     const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
 
-    device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
-    device       float * pdst  = dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+    device const float * psrc0 =        src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
+    device const float * pmask = src1 ? src1                                      + i01*ne00 : nullptr;
+    device       float * pdst  =        dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
 
     // parallel max
-    float lmax = tpitg < ne00 ? psrc0[tpitg] : -INFINITY;
+    float lmax = -INFINITY;
 
-    for (int i00 = tpitg + ntg; i00 < ne00; i00 += ntg) {
-        lmax = MAX(lmax, psrc0[i00]);
+    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
+        lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f));
     }
 
-    float max = simd_max(lmax);
-    if (tiisg == 0) {
-        buf[sgitg] = max;
-    }
+    // find the max value in the block
+    float max_val = simd_max(lmax);
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = -INFINITY;
+        }
 
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+        threadgroup_barrier(mem_flags::mem_threadgroup);
 
-    // broadcast, simd group number is ntg / 32
-    for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
-       if (tpitg < i) {
-           buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
-       }
-    }
+        if (tiisg == 0) {
+            buf[sgitg] = max_val;
+        }
 
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+        threadgroup_barrier(mem_flags::mem_threadgroup);
 
-    max = buf[0];
+        max_val = buf[tiisg];
+        max_val = simd_max(max_val);
+    }
 
     // parallel sum
     float lsum = 0.0f;
     for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
-        const float exp_psrc0 = exp(psrc0[i00] - max);
+        const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
         lsum += exp_psrc0;
-        // Remember the result of exp here. exp is expensive, so we really do not
-        // wish to compute it twice.
         pdst[i00] = exp_psrc0;
     }
 
     float sum = simd_sum(lsum);
-    if (tiisg == 0) {
-        buf[sgitg] = sum;
-    }
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = 0.0f;
+        }
 
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+        threadgroup_barrier(mem_flags::mem_threadgroup);
 
-    // broadcast, simd group number is ntg / 32
-    for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
-       if (tpitg < i) {
-           buf[tpitg] += buf[tpitg + i];
-       }
-    }
+        if (tiisg == 0) {
+            buf[sgitg] = sum;
+        }
 
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        sum = buf[tiisg];
+        sum = simd_sum(sum);
+    }
 
-    sum = buf[0];
+    const float inv_sum = 1.0f/sum;
 
     for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
-        pdst[i00] /= sum;
+        pdst[i00] *= inv_sum;
     }
 }
 
 kernel void kernel_soft_max_4(
         device const float * src0,
+        device const float * src1,
         device       float * dst,
         constant   int64_t & ne00,
         constant   int64_t & ne01,
         constant   int64_t & ne02,
+        constant     float & scale,
         threadgroup float  * buf [[threadgroup(0)]],
         uint  tgpig[[threadgroup_position_in_grid]],
         uint  tpitg[[thread_position_in_threadgroup]],
@@ -271,64 +279,68 @@ kernel void kernel_soft_max_4(
     const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
     const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
 
-    device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
-    device       float4 * pdst4 = (device       float4 *)(dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+    device const float4 * psrc4 =        (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
+    device const float4 * pmask = src1 ? (device const float4 *)(src1 +                                      i01*ne00) : nullptr;
+    device       float4 * pdst4 =        (device       float4 *)(dst  + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
 
     // parallel max
-    float4 lmax4 = tpitg < ne00/4 ? psrc4[tpitg] : -INFINITY;
+    float4 lmax4 = -INFINITY;
 
-    for (int i00 = tpitg + ntg; i00 < ne00/4; i00 += ntg) {
-        lmax4 = fmax(lmax4, psrc4[i00]);
+    for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
+        lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
     }
 
     const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
-    float max = simd_max(lmax);
-    if (tiisg == 0) {
-        buf[sgitg] = max;
-    }
 
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+    float max_val = simd_max(lmax);
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = -INFINITY;
+        }
 
-    // broadcast, simd group number is ntg / 32
-    for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
-       if (tpitg < i) {
-           buf[tpitg] = MAX(buf[tpitg], buf[tpitg + i]);
-       }
-    }
+        threadgroup_barrier(mem_flags::mem_threadgroup);
 
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+        if (tiisg == 0) {
+            buf[sgitg] = max_val;
+        }
+
+        threadgroup_barrier(mem_flags::mem_threadgroup);
 
-    max = buf[0];
+        max_val = buf[tiisg];
+        max_val = simd_max(max_val);
+    }
 
     // parallel sum
     float4 lsum4 = 0.0f;
     for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
-        const float4 exp_psrc4 = exp(psrc4[i00] - max);
+        const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
         lsum4 += exp_psrc4;
         pdst4[i00] = exp_psrc4;
     }
 
     const float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
     float sum = simd_sum(lsum);
-    if (tiisg == 0) {
-        buf[sgitg] = sum;
-    }
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = 0.0f;
+        }
 
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+        threadgroup_barrier(mem_flags::mem_threadgroup);
 
-    // broadcast, simd group number is ntg / 32
-    for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
-       if (tpitg < i) {
-           buf[tpitg] += buf[tpitg + i];
-       }
-    }
+        if (tiisg == 0) {
+            buf[sgitg] = sum;
+        }
 
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+        threadgroup_barrier(mem_flags::mem_threadgroup);
+
+        sum = buf[tiisg];
+        sum = simd_sum(sum);
+    }
 
-    sum = buf[0];
+    const float inv_sum = 1.0f/sum;
 
     for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
-        pdst4[i00] /= sum;
+        pdst4[i00] *= inv_sum;
     }
 }
 
@@ -435,14 +447,13 @@ kernel void kernel_rms_norm(
         constant   int64_t & ne00,
         constant  uint64_t & nb01,
         constant     float & eps,
-        threadgroup float  * sum [[threadgroup(0)]],
+        threadgroup float  * buf [[threadgroup(0)]],
         uint tgpig[[threadgroup_position_in_grid]],
         uint tpitg[[thread_position_in_threadgroup]],
         uint sgitg[[simdgroup_index_in_threadgroup]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint   ntg[[threads_per_threadgroup]]) {
-    device const float4 * x        = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
-    device const float  * x_scalar = (device const float  *) x;
+    device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
 
     float4 sumf = 0;
     float all_sum = 0;
@@ -453,40 +464,30 @@ kernel void kernel_rms_norm(
     }
     all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3];
     all_sum = simd_sum(all_sum);
-    if (tiisg == 0) {
-        sum[sgitg] = all_sum;
-    }
+    if (ntg > N_SIMDWIDTH) {
+        if (sgitg == 0) {
+            buf[tiisg] = 0.0f;
+        }
 
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+        threadgroup_barrier(mem_flags::mem_threadgroup);
 
-    // broadcast, simd group number is ntg / 32
-    for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
-       if (tpitg < i) {
-           sum[tpitg] += sum[tpitg + i];
-       }
-    }
-    if (tpitg == 0) {
-        for (int i = 4 * (ne00 / 4); i < ne00; i++) {
-            sum[0] += x_scalar[i];
+        if (tiisg == 0) {
+            buf[sgitg] = all_sum;
         }
-        sum[0] /= ne00;
-    }
 
-    threadgroup_barrier(mem_flags::mem_threadgroup);
+        threadgroup_barrier(mem_flags::mem_threadgroup);
 
-    const float mean  = sum[0];
+        all_sum = buf[tiisg];
+        all_sum = simd_sum(all_sum);
+    }
+
+    const float mean  = all_sum/ne00;
     const float scale = 1.0f/sqrt(mean + eps);
 
     device float4 * y = (device float4 *) (dst + tgpig*ne00);
-    device float * y_scalar = (device float *) y;
     for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
         y[i00] = x[i00] * scale;
     }
-    if (tpitg == 0) {
-        for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
-            y_scalar[i00] = x_scalar[i00] * scale;
-        }
-    }
 }
 
 // function for calculate inner product between half a q4_0 block and 16 floats (yl), sumy is SUM(yl[i])
@@ -576,7 +577,6 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
 // putting them in the kernel cause a significant performance penalty
 #define N_DST 4        // each SIMD group works on 4 rows
 #define N_SIMDGROUP 2  // number of SIMD groups in a thread group
-#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
 //Note: This is a template, but strictly speaking it only applies to
 //      quantizations where the block size is 32. It also does not
 //      giard against the number of rows not being divisible by
diff --git a/ggml.c b/ggml.c
index c522a101f15526fe99fd353fd9900ce638e8a41e..e2687ef4f072f089f190b8d75f2e87910b45d0e2 100644 (file)
--- a/ggml.c
+++ b/ggml.c
@@ -4826,7 +4826,17 @@ struct ggml_tensor * ggml_diag_mask_zero_inplace(
 static struct ggml_tensor * ggml_soft_max_impl(
         struct ggml_context * ctx,
         struct ggml_tensor  * a,
+        struct ggml_tensor  * mask,
+        float                 scale,
         bool                  inplace) {
+    GGML_ASSERT(ggml_is_contiguous(a));
+    if (mask) {
+        GGML_ASSERT(ggml_is_contiguous(mask));
+        GGML_ASSERT(mask->ne[2] == 1);
+        GGML_ASSERT(mask->ne[3] == 1);
+        GGML_ASSERT(ggml_can_repeat_rows(mask, a));
+    }
+
     bool is_node = false;
 
     if (a->grad) {
@@ -4835,9 +4845,13 @@ static struct ggml_tensor * ggml_soft_max_impl(
 
     struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);
 
+    float params[] = { scale };
+    ggml_set_op_params(result, params, sizeof(params));
+
     result->op   = GGML_OP_SOFT_MAX;
     result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
     result->src[0] = a;
+    result->src[1] = mask;
 
     return result;
 }
@@ -4845,13 +4859,21 @@ static struct ggml_tensor * ggml_soft_max_impl(
 struct ggml_tensor * ggml_soft_max(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_soft_max_impl(ctx, a, false);
+    return ggml_soft_max_impl(ctx, a, NULL, 1.0f, false);
 }
 
 struct ggml_tensor * ggml_soft_max_inplace(
         struct ggml_context * ctx,
         struct ggml_tensor  * a) {
-    return ggml_soft_max_impl(ctx, a, true);
+    return ggml_soft_max_impl(ctx, a, NULL, 1.0f, true);
+}
+
+struct ggml_tensor * ggml_soft_max_ext(
+        struct ggml_context * ctx,
+        struct ggml_tensor  * a,
+        struct ggml_tensor  * mask,
+        float                 scale) {
+    return ggml_soft_max_impl(ctx, a, mask, scale, false);
 }
 
 // ggml_soft_max_back
@@ -10551,20 +10573,25 @@ static void ggml_compute_forward_diag_mask_zero(
 static void ggml_compute_forward_soft_max_f32(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        struct ggml_tensor * dst) {
-    GGML_ASSERT(ggml_is_contiguous(src0));
-    GGML_ASSERT(ggml_is_contiguous(dst));
-    GGML_ASSERT(ggml_are_same_shape(src0, dst));
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
+    assert(ggml_is_contiguous(dst));
+    assert(ggml_are_same_shape(src0, dst));
 
     if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
         return;
     }
 
+    float scale = 1.0f;
+    memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
+
     // TODO: handle transposed/permuted matrices
 
     const int ith = params->ith;
     const int nth = params->nth;
 
+    const int64_t ne11 = src1 ? src1->ne[1] : 1;
+
     const int nc = src0->ne[0];
     const int nr = ggml_nrows(src0);
 
@@ -10575,29 +10602,40 @@ static void ggml_compute_forward_soft_max_f32(
     const int ir0 = dr*ith;
     const int ir1 = MIN(ir0 + dr, nr);
 
+    float * wp = (float *) params->wdata + (nc + CACHE_LINE_SIZE_F32) * ith;
+
     for (int i1 = ir0; i1 < ir1; i1++) {
-        float *sp = (float *)((char *) src0->data + i1*src0->nb[1]);
-        float *dp = (float *)((char *)  dst->data +  i1*dst->nb[1]);
+        float * sp = (float *)((char *) src0->data + i1*src0->nb[1]);
+        float * dp = (float *)((char *)  dst->data +  i1*dst->nb[1]);
+
+        // broadcast the mask across rows
+        float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
+
+        ggml_vec_cpy_f32  (nc, wp, sp);
+        ggml_vec_scale_f32(nc, wp, scale);
+        if (mp) {
+            ggml_vec_acc_f32(nc, wp, mp);
+        }
 
 #ifndef NDEBUG
         for (int i = 0; i < nc; ++i) {
             //printf("p[%d] = %f\n", i, p[i]);
-            assert(!isnan(sp[i]));
+            assert(!isnan(wp[i]));
         }
 #endif
 
         float max = -INFINITY;
-        ggml_vec_max_f32(nc, &max, sp);
+        ggml_vec_max_f32(nc, &max, wp);
 
         ggml_float sum = 0.0;
 
         uint16_t scvt;
         for (int i = 0; i < nc; i++) {
-            if (sp[i] == -INFINITY) {
+            if (wp[i] == -INFINITY) {
                 dp[i] = 0.0f;
             } else {
-                // const float val = (sp[i] == -INFINITY) ? 0.0 : exp(sp[i] - max);
-                ggml_fp16_t s = GGML_FP32_TO_FP16(sp[i] - max);
+                // const float val = (wp[i] == -INFINITY) ? 0.0 : exp(wp[i] - max);
+                ggml_fp16_t s = GGML_FP32_TO_FP16(wp[i] - max);
                 memcpy(&scvt, &s, sizeof(scvt));
                 const float val = GGML_FP16_TO_FP32(ggml_table_exp_f16[scvt]);
                 sum += (ggml_float)val;
@@ -10622,11 +10660,12 @@ static void ggml_compute_forward_soft_max_f32(
 static void ggml_compute_forward_soft_max(
         const struct ggml_compute_params * params,
         const struct ggml_tensor * src0,
-        struct ggml_tensor * dst) {
+        const struct ggml_tensor * src1,
+              struct ggml_tensor * dst) {
     switch (src0->type) {
         case GGML_TYPE_F32:
             {
-                ggml_compute_forward_soft_max_f32(params, src0, dst);
+                ggml_compute_forward_soft_max_f32(params, src0, src1, dst);
             } break;
         default:
             {
@@ -13863,7 +13902,7 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
             } break;
         case GGML_OP_SOFT_MAX:
             {
-                ggml_compute_forward_soft_max(params, tensor->src[0], tensor);
+                ggml_compute_forward_soft_max(params, tensor->src[0], tensor->src[1], tensor);
             } break;
         case GGML_OP_SOFT_MAX_BACK:
             {
@@ -15899,6 +15938,12 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
                         cur = ggml_type_size(GGML_TYPE_F32) * node->src[0]->ne[0] * n_tasks;
                     }
                 } break;
+            case GGML_OP_SOFT_MAX:
+                {
+                    n_tasks = MIN(MIN(4, n_threads), ggml_nrows(node->src[0]));
+
+                    cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
+                } break;
             case GGML_OP_CONV_TRANSPOSE_1D:
                 {
                     GGML_ASSERT(node->src[0]->ne[3] == 1);
diff --git a/ggml.h b/ggml.h
index 4d6d4edfd933c5ae02890751deb33a2616c43429..2f6787d4e4219b6e32621918ccaedd2efe2158a8 100644 (file)
--- a/ggml.h
+++ b/ggml.h
@@ -1282,6 +1282,14 @@ extern "C" {
             struct ggml_context * ctx,
             struct ggml_tensor  * a);
 
+    // fused soft_max(a*scale + mask)
+    // mask is optional
+    GGML_API struct ggml_tensor * ggml_soft_max_ext(
+            struct ggml_context * ctx,
+            struct ggml_tensor  * a,
+            struct ggml_tensor  * mask,
+            float                 scale);
+
     GGML_API struct ggml_tensor * ggml_soft_max_back(
             struct ggml_context * ctx,
             struct ggml_tensor  * a,
index 1e00ea4a97870dfd75553e372cd8e303c90ee8c0..e74fd7234eb8481134d13bccbbc9a1865806b7fb 100644 (file)
--- a/llama.cpp
+++ b/llama.cpp
@@ -3704,22 +3704,28 @@ static struct ggml_tensor * llm_build_kqv(
     struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
     cb(kq, "kq", il);
 
-    kq = ggml_scale(ctx, kq, kq_scale);
-    cb(kq, "kq_scaled", il);
-
     if (max_alibi_bias > 0.0f) {
-        // TODO: n_head or n_head_kv
-        // TODO: K-shift is likely not working
-        // TODO: change to ggml_add
-        kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias);
-        cb(kq, "kq_scaled_alibi", il);
-    }
+        // temporary branch until we figure out how to handle ggml_alibi through ggml_add
+        kq = ggml_scale(ctx, kq, kq_scale);
+        cb(kq, "kq_scaled", il);
+
+        if (max_alibi_bias > 0.0f) {
+            // TODO: n_head or n_head_kv
+            // TODO: K-shift is likely not working
+            // TODO: change to ggml_add
+            kq = ggml_alibi(ctx, kq, /*n_past*/ 0, n_head, max_alibi_bias);
+            cb(kq, "kq_scaled_alibi", il);
+        }
 
-    kq = ggml_add(ctx, kq, kq_mask);
-    cb(kq, "kq_masked", il);
+        kq = ggml_add(ctx, kq, kq_mask);
+        cb(kq, "kq_masked", il);
 
-    kq = ggml_soft_max(ctx, kq);
-    cb(kq, "kq_soft_max", il);
+        kq = ggml_soft_max(ctx, kq);
+        cb(kq, "kq_soft_max", il);
+    } else {
+        kq = ggml_soft_max_ext(ctx, kq, kq_mask, 1.0f/sqrtf(float(n_embd_head)));
+        cb(kq, "kq_soft_max_ext", il);
+    }
 
     // split cached v into n_head heads
     struct ggml_tensor * v =
@@ -5041,6 +5047,7 @@ static const std::unordered_map<const char *, llm_offload_func_e> k_offload_map
     { "kq_scaled_alibi",            OFFLOAD_FUNC_KQ  },
     { "kq_masked",                  OFFLOAD_FUNC_KQ  },
     { "kq_soft_max",                OFFLOAD_FUNC_V   },
+    { "kq_soft_max_ext",            OFFLOAD_FUNC_V   },
     { "v",                          OFFLOAD_FUNC_V   },
     { "kqv",                        OFFLOAD_FUNC_V   },
     { "kqv_merged",                 OFFLOAD_FUNC_V   },