]> git.djapps.eu Git - pkg/ggml/sources/ggml/commitdiff
sync : llama.cpp (CUDA, Metal, OpenCL, gguf magic, ggml iter) (#592)
authorGeorgi Gerganov <redacted>
Tue, 24 Oct 2023 18:51:12 +0000 (21:51 +0300)
committerGitHub <redacted>
Tue, 24 Oct 2023 18:51:12 +0000 (21:51 +0300)
ggml-ci

include/ggml/ggml.h
src/ggml-cuda.cu
src/ggml-metal.m
src/ggml-metal.metal
src/ggml-opencl.cpp
src/ggml.c

index 4b16032f02a8160e10608bbfdd174db195fb98d7..08bff5511c2254d942046b332d91984b04497ebc 100644 (file)
 #define GGML_EXIT_SUCCESS 0
 #define GGML_EXIT_ABORTED 1
 
-#define GGUF_MAGIC   0x46554747 // "GGUF"
-#define GGUF_VERSION 2
+#define GGUF_MAGIC "GGUF"
+
+#define GGUF_VERSION 3
 
 #define GGUF_DEFAULT_ALIGNMENT 32
 
@@ -705,6 +706,9 @@ extern "C" {
     GGML_API struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src);
     GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, struct ggml_tensor * src);
 
+    // Context tensor enumeration and lookup
+    GGML_API struct ggml_tensor * ggml_get_first_tensor(struct ggml_context * ctx);
+    GGML_API struct ggml_tensor * ggml_get_next_tensor (struct ggml_context * ctx, struct ggml_tensor * tensor);
     GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name);
 
     GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor);
index 5bd83bb5c056a42ff32497b8daa1cfe050acce00..d1e874b6c778af32d887cc5cf5e8970ceb92310f 100644 (file)
@@ -29,6 +29,8 @@
 #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
 #define cublasCreate hipblasCreate
 #define cublasGemmEx hipblasGemmEx
+#define cublasGemmBatchedEx hipblasGemmBatchedEx
+#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
 #define cublasHandle_t hipblasHandle_t
 #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
 #define cublasSetStream hipblasSetStream
@@ -415,6 +417,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
 #define CUDA_SILU_BLOCK_SIZE 256
 #define CUDA_CPY_BLOCK_SIZE 32
 #define CUDA_SCALE_BLOCK_SIZE 256
+#define CUDA_CLAMP_BLOCK_SIZE 256
 #define CUDA_ROPE_BLOCK_SIZE 256
 #define CUDA_ALIBI_BLOCK_SIZE 32
 #define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
@@ -4325,13 +4328,13 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
 
     const half * x = (const half *) vx;
 
-    const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
-    const int channel = blockDim.z*blockIdx.z + threadIdx.z;
+    const int row_x     = blockDim.y*blockIdx.y + threadIdx.y;
+    const int channel   = blockDim.z*blockIdx.z + threadIdx.z;
     const int channel_x = channel / channel_x_divisor;
 
-    const int nrows_y = ncols_x;
+    const int nrows_y   = ncols_x;
     const int nrows_dst = nrows_x;
-    const int row_dst = row_x;
+    const int row_dst   = row_x;
 
     const int idst = channel*nrows_dst + row_dst;
 
@@ -4344,13 +4347,13 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous
             break;
         }
 
-        const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
-        const float xi = __half2float(x[ix]);
-
         const int row_y = col_x;
 
+        const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x;
         const int iy = channel*nrows_y + row_y;
 
+        const float xi = __half2float(x[ix]);
+
         tmp += xi * y[iy];
     }
 
@@ -4585,6 +4588,15 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale
     dst[i] = scale * x[i];
 }
 
+static __global__ void clamp_f32(const float * x, float * dst, const float min, const float max, const int k) {
+    const int i = blockDim.x*blockIdx.x + threadIdx.x;
+
+    if (i >= k) {
+        return;
+    }
+
+    dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
+}
 
 template<int qk, int qr, dequantize_kernel_t dq>
 static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) {
@@ -5475,6 +5487,11 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
     scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
 }
 
+static void clamp_f32_cuda(const float * x, float * dst, const float min, const float max, const int k, cudaStream_t stream) {
+    const int num_blocks = (k + CUDA_CLAMP_BLOCK_SIZE - 1) / CUDA_CLAMP_BLOCK_SIZE;
+    clamp_f32<<<num_blocks, CUDA_CLAMP_BLOCK_SIZE, 0, stream>>>(x, dst, min, max, k);
+}
+
 template<typename T>
 static void rope_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale,
                           const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
@@ -6419,12 +6436,12 @@ inline void ggml_cuda_op_alibi(
     const int64_t ne02 = src0->ne[2];
     const int64_t nrows = ggml_nrows(src0);
 
-    const int n_past = ((int32_t *) dst->op_params)[0];
+    //const int n_past = ((int32_t *) dst->op_params)[0];
     const int n_head = ((int32_t *) dst->op_params)[1];
     float max_bias;
     memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
 
-    GGML_ASSERT(ne01 + n_past == ne00);
+    //GGML_ASSERT(ne01 + n_past == ne00);
     GGML_ASSERT(n_head == ne02);
 
     const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
@@ -6500,6 +6517,24 @@ inline void ggml_cuda_op_scale(
     (void) src1_dd;
 }
 
+inline void ggml_cuda_op_clamp(
+    const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
+    const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) {
+
+    GGML_ASSERT(src0->type == GGML_TYPE_F32);
+    GGML_ASSERT( dst->type == GGML_TYPE_F32);
+
+    const float min = ((float *) dst->op_params)[0];
+    const float max = ((float *) dst->op_params)[1];
+
+    clamp_f32_cuda(src0_dd, dst_dd, min, max, ggml_nelements(src0), main_stream);
+    CUDA_CHECK(cudaGetLastError());
+
+    (void) src1;
+    (void) dst;
+    (void) src1_dd;
+}
+
 static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const ggml_cuda_op_flatten_t op) {
     const int64_t nrows0 = ggml_nrows(src0);
 
@@ -6980,7 +7015,8 @@ static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tens
 }
 
 static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
-    GGML_ASSERT(!ggml_is_contiguous(src0) && ggml_is_contiguous(src1));
+    GGML_ASSERT(!ggml_is_transposed(src0));
+    GGML_ASSERT(!ggml_is_transposed(src1));
     GGML_ASSERT(!ggml_is_permuted(src0));
     GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
     GGML_ASSERT(src0->type == GGML_TYPE_F16);
@@ -6990,11 +7026,11 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
     const int64_t ne01 = src0->ne[1];
     const int64_t ne02 = src0->ne[2];
 
-    const int64_t ne12 = src1->ne[2];
-
     const int64_t nb01 = src0->nb[1];
     const int64_t nb02 = src0->nb[2];
 
+    const int64_t ne12 = src1->ne[2];
+
     CUDA_CHECK(ggml_cuda_set_device(g_main_device));
     cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
 
@@ -7013,6 +7049,159 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor
     ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream);
 }
 
+static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){
+    GGML_ASSERT(!ggml_is_transposed(src0));
+    GGML_ASSERT(!ggml_is_transposed(src1));
+    GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT);
+    GGML_ASSERT(src0->type == GGML_TYPE_F16);
+    GGML_ASSERT(src1->type == GGML_TYPE_F32);
+
+    const int64_t ne00 = src0->ne[0]; GGML_UNUSED(ne00);
+    const int64_t ne01 = src0->ne[1];
+    const int64_t ne02 = src0->ne[2];
+    const int64_t ne03 = src0->ne[3];
+
+    const int64_t nb01 = src0->nb[1];
+    const int64_t nb02 = src0->nb[2]; GGML_UNUSED(nb02);
+    const int64_t nb03 = src0->nb[3]; GGML_UNUSED(nb03);
+
+    const int64_t ne10 = src1->ne[0];
+    const int64_t ne11 = src1->ne[1];
+    const int64_t ne12 = src1->ne[2];
+    const int64_t ne13 = src1->ne[3];
+
+    const int64_t nb11 = src1->nb[1];
+    const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12);
+    const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13);
+
+    const int64_t ne1 = ggml_nelements(src1);
+    const int64_t ne  = ggml_nelements(dst);
+
+    CUDA_CHECK(ggml_cuda_set_device(g_main_device));
+    cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
+
+    int id;
+    CUDA_CHECK(cudaGetDevice(&id));
+    CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], main_stream));
+
+    ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra;
+    void * src0_ddq = src0_extra->data_device[g_main_device];
+    half * src0_as_f16 = (half *) src0_ddq;
+
+    ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra;
+    float * src1_ddf = (float *) src1_extra->data_device[g_main_device];
+
+    ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra;
+    float * dst_ddf = (float *) dst_extra->data_device[g_main_device];
+
+    // convert src1 to fp16
+    const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
+    GGML_ASSERT(to_fp16_cuda != nullptr);
+
+    size_t src1_as = 0;
+    half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as);
+    to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream);
+
+    size_t dst_as = 0;
+    half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as);
+
+    GGML_ASSERT(ne12 % ne02 == 0);
+    GGML_ASSERT(ne13 % ne03 == 0);
+
+    // broadcast factors
+    const int64_t r2 = ne12/ne02;
+    const int64_t r3 = ne13/ne03;
+
+    const half alpha_f16 = 1.0f;
+    const half beta_f16  = 0.0f;
+
+#if 0
+    // use cublasGemmEx
+    {
+        for (int i13 = 0; i13 < ne13; ++i13) {
+            for (int i12 = 0; i12 < ne12; ++i12) {
+                int i03 = i13 / r3;
+                int i02 = i12 / r2;
+
+                CUBLAS_CHECK(
+                        cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
+                            ne01, ne11, ne10,
+                            &alpha_f16, (const char *) src0_as_f16 + i02*src0->nb[2]   + i03*src0->nb[3]  , CUDA_R_16F, nb01/sizeof(half),
+                                        (const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float),
+                            &beta_f16,  (      char *)     dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01,
+                            CUBLAS_COMPUTE_16F,
+                            CUBLAS_GEMM_DEFAULT_TENSOR_OP));
+            }
+        }
+    }
+#else
+    if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) {
+        // there is no broadcast and src0, src1 are contiguous across dims 2, 3
+        // use cublasGemmStridedBatchedEx
+        CUBLAS_CHECK(
+        cublasGemmStridedBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
+                ne01, ne11, ne10,
+                &alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half),  src0->nb[2]/sizeof(half),  // strideA
+                            (const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB
+                &beta_f16,  (      char *)     dst_f16, CUDA_R_16F, ne01,                dst->nb[2]/sizeof(float), // strideC
+                ne12*ne13,
+                CUBLAS_COMPUTE_16F,
+                CUBLAS_GEMM_DEFAULT_TENSOR_OP));
+    } else {
+        // use cublasGemmBatchedEx
+        // TODO: https://github.com/ggerganov/llama.cpp/pull/3749#discussion_r1369997000
+        const int ne23 = ne12*ne13;
+
+        // TODO: avoid this alloc
+        void ** ptrs = (void **) malloc(3*ne23*sizeof(void *));
+
+        for (int i13 = 0; i13 < ne13; ++i13) {
+            for (int i12 = 0; i12 < ne12; ++i12) {
+                int i03 = i13 / r3;
+                int i02 = i12 / r2;
+
+                ptrs[0*ne23 + i12 + i13*ne12] = (char *) src0_as_f16 + i02*src0->nb[2]   + i03*src0->nb[3];
+                ptrs[1*ne23 + i12 + i13*ne12] = (char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2;
+                ptrs[2*ne23 + i12 + i13*ne12] = (char *)     dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2;
+            }
+        }
+
+        // allocate device memory for pointers
+        void ** ptrs_as = nullptr;
+        CUDA_CHECK(cudaMalloc(&ptrs_as, 3*ne23*sizeof(void *)));
+
+        // TODO: this does not work for some reason -- not sure why?
+        //size_t ptrs_s = 0;
+        //ptrs_as = (void **) ggml_cuda_pool_malloc(3*ne23*sizeof(void *), &ptrs_s);
+
+        // copy pointers to device
+        CUDA_CHECK(cudaMemcpy(ptrs_as, ptrs, 3*ne23*sizeof(void *), cudaMemcpyHostToDevice));
+
+        free(ptrs);
+
+        CUBLAS_CHECK(
+        cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N,
+                ne01, ne11, ne10,
+                &alpha_f16, (const void **) (ptrs_as + 0*ne23), CUDA_R_16F, nb01/sizeof(half),
+                            (const void **) (ptrs_as + 1*ne23), CUDA_R_16F, nb11/sizeof(float),
+                &beta_f16,  (      void **) (ptrs_as + 2*ne23), CUDA_R_16F, ne01,
+                ne23,
+                CUBLAS_COMPUTE_16F,
+                CUBLAS_GEMM_DEFAULT_TENSOR_OP));
+
+        // free device memory for pointers
+        CUDA_CHECK(cudaFree(ptrs_as));
+        //ggml_cuda_pool_free(ptrs_as, ptrs_s);
+    }
+#endif
+
+    const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
+    to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream);
+
+    ggml_cuda_pool_free(src1_as_f16, src1_as);
+    ggml_cuda_pool_free(dst_f16, dst_as);
+}
+
 static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) &&
         src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU;
@@ -7025,10 +7214,22 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
         }
     }
 
+    // debug helpers
+    //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]);
+    //printf("      %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]);
+    //printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]);
+    //printf("      %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]);
+    //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
+    //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
+
     if (all_on_device && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) {
+        // KQ
         ggml_cuda_mul_mat_vec_p021(src0, src1, dst);
-    } else if (all_on_device && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->ne[1] == 1) {
+    } else if (all_on_device && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
+        // KQV
         ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
+    } else if (all_on_device && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) {
+        ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst);
     } else if (src0->type == GGML_TYPE_F32) {
         ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
     } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
@@ -7061,6 +7262,10 @@ static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1,
     ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale);
 }
 
+static void ggml_cuda_clamp(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
+    ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_clamp);
+}
+
 static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
     const int64_t ne = ggml_nelements(src0);
     GGML_ASSERT(ne == ggml_nelements(src1));
@@ -7470,6 +7675,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
         case GGML_OP_SCALE:
             func = ggml_cuda_scale;
             break;
+        case GGML_OP_CLAMP:
+            if (!any_on_device) {
+                return false;
+            }
+            func = ggml_cuda_clamp;
+            break;
         case GGML_OP_CPY:
             func = ggml_cuda_cpy;
             break;
index 29cb3c922daeba4e912f1b2defcca6fbc30e18f4..c1901dca75269dbc86971b5dccccaaae8e2c871c 100644 (file)
@@ -62,6 +62,7 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(mul);
     GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast
     GGML_METAL_DECL_KERNEL(scale);
+    GGML_METAL_DECL_KERNEL(scale_4);
     GGML_METAL_DECL_KERNEL(silu);
     GGML_METAL_DECL_KERNEL(relu);
     GGML_METAL_DECL_KERNEL(gelu);
@@ -73,6 +74,8 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(get_rows_f16);
     GGML_METAL_DECL_KERNEL(get_rows_q4_0);
     GGML_METAL_DECL_KERNEL(get_rows_q4_1);
+    GGML_METAL_DECL_KERNEL(get_rows_q5_0);
+    GGML_METAL_DECL_KERNEL(get_rows_q5_1);
     GGML_METAL_DECL_KERNEL(get_rows_q8_0);
     GGML_METAL_DECL_KERNEL(get_rows_q2_K);
     GGML_METAL_DECL_KERNEL(get_rows_q3_K);
@@ -87,6 +90,8 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4);
     GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32);
     GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_q5_0_f32);
+    GGML_METAL_DECL_KERNEL(mul_mv_q5_1_f32);
     GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32);
     GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32);
     GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32);
@@ -97,6 +102,8 @@ struct ggml_metal_context {
     GGML_METAL_DECL_KERNEL(mul_mm_f16_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_q5_0_f32);
+    GGML_METAL_DECL_KERNEL(mul_mm_q5_1_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32);
     GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32);
@@ -243,6 +250,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(mul);
         GGML_METAL_ADD_KERNEL(mul_row);
         GGML_METAL_ADD_KERNEL(scale);
+        GGML_METAL_ADD_KERNEL(scale_4);
         GGML_METAL_ADD_KERNEL(silu);
         GGML_METAL_ADD_KERNEL(relu);
         GGML_METAL_ADD_KERNEL(gelu);
@@ -254,6 +262,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(get_rows_f16);
         GGML_METAL_ADD_KERNEL(get_rows_q4_0);
         GGML_METAL_ADD_KERNEL(get_rows_q4_1);
+        GGML_METAL_ADD_KERNEL(get_rows_q5_0);
+        GGML_METAL_ADD_KERNEL(get_rows_q5_1);
         GGML_METAL_ADD_KERNEL(get_rows_q8_0);
         GGML_METAL_ADD_KERNEL(get_rows_q2_K);
         GGML_METAL_ADD_KERNEL(get_rows_q3_K);
@@ -268,6 +278,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
         GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4);
         GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32);
         GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_q5_0_f32);
+        GGML_METAL_ADD_KERNEL(mul_mv_q5_1_f32);
         GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32);
         GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32);
         GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32);
@@ -278,8 +290,10 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
             GGML_METAL_ADD_KERNEL(mul_mm_f32_f32);
             GGML_METAL_ADD_KERNEL(mul_mm_f16_f32);
             GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32);
-            GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
             GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_q5_0_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_q5_1_f32);
+            GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32);
             GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32);
             GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32);
             GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32);
@@ -335,6 +349,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(mul);
     GGML_METAL_DEL_KERNEL(mul_row);
     GGML_METAL_DEL_KERNEL(scale);
+    GGML_METAL_DEL_KERNEL(scale_4);
     GGML_METAL_DEL_KERNEL(silu);
     GGML_METAL_DEL_KERNEL(relu);
     GGML_METAL_DEL_KERNEL(gelu);
@@ -346,6 +361,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(get_rows_f16);
     GGML_METAL_DEL_KERNEL(get_rows_q4_0);
     GGML_METAL_DEL_KERNEL(get_rows_q4_1);
+    GGML_METAL_DEL_KERNEL(get_rows_q5_0);
+    GGML_METAL_DEL_KERNEL(get_rows_q5_1);
     GGML_METAL_DEL_KERNEL(get_rows_q8_0);
     GGML_METAL_DEL_KERNEL(get_rows_q2_K);
     GGML_METAL_DEL_KERNEL(get_rows_q3_K);
@@ -360,6 +377,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
     GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4);
     GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32);
     GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_q5_0_f32);
+    GGML_METAL_DEL_KERNEL(mul_mv_q5_1_f32);
     GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32);
     GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32);
     GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32);
@@ -370,8 +389,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
         GGML_METAL_DEL_KERNEL(mul_mm_f32_f32);
         GGML_METAL_DEL_KERNEL(mul_mm_f16_f32);
         GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32);
-        GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
         GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_q5_0_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_q5_1_f32);
+        GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32);
         GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32);
         GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32);
         GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32);
@@ -779,8 +800,8 @@ void ggml_metal_graph_compute(
                         } break;
                     case GGML_OP_CONCAT:
                         {
+                            const int64_t nb = ne00;
 
-                            int64_t nb = ne00;
                             [encoder setComputePipelineState:ctx->pipeline_concat];
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
@@ -812,6 +833,7 @@ void ggml_metal_graph_compute(
                             [encoder setBytes:&nb   length:sizeof(nb)   atIndex:27];
 
                             const int nth = MIN(1024, ne0);
+
                             [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
                         } break;
                     case GGML_OP_ADD:
@@ -904,13 +926,19 @@ void ggml_metal_graph_compute(
 
                             const float scale = *(const float *) src1->data;
 
-                            [encoder setComputePipelineState:ctx->pipeline_scale];
+                            int64_t n = ggml_nelements(dst);
+
+                            if (n % 4 == 0) {
+                                n /= 4;
+                                [encoder setComputePipelineState:ctx->pipeline_scale_4];
+                            } else {
+                                [encoder setComputePipelineState:ctx->pipeline_scale];
+                            }
+
                             [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                             [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
                             [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
 
-                            const int64_t n = ggml_nelements(dst)/4;
-
                             [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                         } break;
                     case GGML_OP_UNARY:
@@ -921,9 +949,10 @@ void ggml_metal_graph_compute(
                                     [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                                     [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
 
-                                    const int64_t n = ggml_nelements(dst)/4;
+                                    const int64_t n = ggml_nelements(dst);
+                                    GGML_ASSERT(n % 4 == 0);
 
-                                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                                 } break;
                             case GGML_UNARY_OP_RELU:
                                 {
@@ -941,9 +970,10 @@ void ggml_metal_graph_compute(
                                     [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
                                     [encoder setBuffer:id_dst  offset:offs_dst  atIndex:1];
 
-                                    const int64_t n = ggml_nelements(dst)/4;
+                                    const int64_t n = ggml_nelements(dst);
+                                    GGML_ASSERT(n % 4 == 0);
 
-                                    [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
+                                    [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
                                 } break;
                             default:
                                 {
@@ -1040,7 +1070,7 @@ void ggml_metal_graph_compute(
                                 !ggml_is_transposed(src0) &&
                                 !ggml_is_transposed(src1) &&
                                 src1t == GGML_TYPE_F32 &&
-                                ne00 % 32 == 0 &&
+                                ne00 % 32 == 0 && ne00 >= 64 &&
                                 ne11 > ne11_mm_min) {
                                 //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
                                 switch (src0->type) {
@@ -1048,6 +1078,8 @@ void ggml_metal_graph_compute(
                                     case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32];  break;
                                     case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break;
                                     case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break;
+                                    case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_0_f32]; break;
+                                    case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_1_f32]; break;
                                     case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break;
                                     case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break;
                                     case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break;
@@ -1117,6 +1149,24 @@ void ggml_metal_graph_compute(
                                             nth1 = 8;
                                             [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32];
                                         } break;
+                                    case GGML_TYPE_Q5_0:
+                                        {
+                                            GGML_ASSERT(ne02 == 1);
+                                            GGML_ASSERT(ne12 == 1);
+
+                                            nth0 = 8;
+                                            nth1 = 8;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32];
+                                        } break;
+                                    case GGML_TYPE_Q5_1:
+                                        {
+                                            GGML_ASSERT(ne02 == 1);
+                                            GGML_ASSERT(ne12 == 1);
+
+                                            nth0 = 8;
+                                            nth1 = 8;
+                                            [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32];
+                                        } break;
                                     case GGML_TYPE_Q8_0:
                                         {
                                             GGML_ASSERT(ne02 == 1);
@@ -1197,7 +1247,8 @@ void ggml_metal_graph_compute(
                                 [encoder setBytes:&ne1  length:sizeof(ne1)  atIndex:16];
                                 [encoder setBytes:&gqa  length:sizeof(gqa)  atIndex:17];
 
-                                if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 ||
+                                if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
+                                    src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
                                     src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
                                     [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
                                 }
@@ -1229,6 +1280,8 @@ void ggml_metal_graph_compute(
                                 case GGML_TYPE_F16:  [encoder setComputePipelineState:ctx->pipeline_get_rows_f16];  break;
                                 case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
                                 case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
+                                case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_0]; break;
+                                case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_1]; break;
                                 case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break;
                                 case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break;
                                 case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break;
@@ -1251,6 +1304,8 @@ void ggml_metal_graph_compute(
                         } break;
                     case GGML_OP_RMS_NORM:
                         {
+                            GGML_ASSERT(ne00 % 4 == 0);
+
                             float eps;
                             memcpy(&eps, dst->op_params, sizeof(float));
 
@@ -1293,7 +1348,7 @@ void ggml_metal_graph_compute(
 
                             const int nth = MIN(1024, ne00);
 
-                            const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
+                            //const int n_past = ((int32_t *) dst->op_params)[0];
                             const int n_head = ((int32_t *) dst->op_params)[1];
                             float max_bias;
                             memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
index b6288db28660dc9e8902cb7a27369e9af30db695..f4b460564453c5b3cfb1e0b545fea0d5e2c6182e 100644 (file)
@@ -18,6 +18,21 @@ typedef struct {
     uint8_t qs[QK4_1 / 2];  // nibbles / quants
 } block_q4_1;
 
+#define QK5_0 32
+typedef struct {
+    half d;                // delta
+    uint8_t qh[4];         // 5-th bit of quants
+    uint8_t qs[QK5_0 / 2]; // nibbles / quants
+} block_q5_0;
+
+#define QK5_1 32
+typedef struct {
+    half d;                 // delta
+    half m;                 // min
+    uint8_t qh[4];          // 5-th bit of quants
+    uint8_t qs[QK5_1 / 2];  // nibbles / quants
+} block_q5_1;
+
 #define QK8_0 32
 typedef struct {
     half    d;         // delta
@@ -110,9 +125,17 @@ kernel void kernel_mul_row(
 }
 
 kernel void kernel_scale(
+        device const float * src0,
+        device       float * dst,
+        constant     float & scale,
+        uint tpig[[thread_position_in_grid]]) {
+    dst[tpig] = src0[tpig] * scale;
+}
+
+kernel void kernel_scale_4(
         device const float4 * src0,
         device       float4 * dst,
-        constant     float & scale,
+        constant     float  & scale,
         uint tpig[[thread_position_in_grid]]) {
     dst[tpig] = src0[tpig] * scale;
 }
@@ -345,10 +368,11 @@ kernel void kernel_rms_norm(
         uint sgitg[[simdgroup_index_in_threadgroup]],
         uint tiisg[[thread_index_in_simdgroup]],
         uint   ntg[[threads_per_threadgroup]]) {
-    device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
-    device const float * x_scalar = (device const float *) x;
-    float4 sumf=0;
-    float all_sum=0;
+    device const float4 * x        = (device const float4 *) ((device const char *) src0 + tgpig*nb01);
+    device const float  * x_scalar = (device const float  *) x;
+
+    float4 sumf = 0;
+    float all_sum = 0;
 
     // parallel sum
     for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
@@ -361,6 +385,7 @@ kernel void kernel_rms_norm(
     }
 
     threadgroup_barrier(mem_flags::mem_threadgroup);
+
     // broadcast, simd group number is ntg / 32
     for (uint i = ntg / 32 / 2; i > 0; i /= 2) {
        if (tpitg < i) {
@@ -368,7 +393,9 @@ kernel void kernel_rms_norm(
        }
     }
     if (tpitg == 0) {
-        for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];}
+        for (int i = 4 * (ne00 / 4); i < ne00; i++) {
+            sum[0] += x_scalar[i];
+        }
         sum[0] /= ne00;
     }
 
@@ -383,7 +410,9 @@ kernel void kernel_rms_norm(
         y[i00] = x[i00] * scale;
     }
     if (tpitg == 0) {
-        for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;}
+        for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {
+            y_scalar[i00] = x_scalar[i00] * scale;
+        }
     }
 }
 
@@ -393,8 +422,11 @@ kernel void kernel_rms_norm(
 // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
 inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
     float d = qb_curr->d;
+
     float2 acc = 0.f;
+
     device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
+
     for (int i = 0; i < 8; i+=2) {
         acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
                 + yl[i + 1] * (qs[i / 2] & 0x0F00);
@@ -411,8 +443,11 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre
 inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
     float d = qb_curr->d;
     float m = qb_curr->m;
-    device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
+
     float2 acc = 0.f;
+
+    device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
+
     for (int i = 0; i < 8; i+=2) {
         acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
                 + yl[i + 1] * (qs[i / 2] & 0x0F00);
@@ -422,6 +457,49 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
     return d * (acc[0] + acc[1]) + sumy * m;
 }
 
+// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q5 quants begin (0 or QK5_0/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
+    float d = qb_curr->d;
+
+    float2 acc = 0.f;
+
+    device const uint16_t * qs =  ((device const uint16_t *)qb_curr + 3 + il/2);
+           const uint32_t   qh = *((device const uint32_t *)qb_curr->qh);
+
+    for (int i = 0; i < 8; i+=2) {
+        acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il        ) << 4 ) & 0x00010))
+                + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il        ) << 12) & 0x01000));
+        acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
+                + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
+    }
+    return d * (sumy * -16.f + acc[0] + acc[1]);
+}
+
+// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q5 quants begin (0 or QK5_1/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
+    float d = qb_curr->d;
+    float m = qb_curr->m;
+
+    float2 acc = 0.f;
+
+    device const uint16_t * qs =  ((device const uint16_t *)qb_curr + 4 + il/2);
+           const uint32_t   qh = *((device const uint32_t *)qb_curr->qh);
+
+    for (int i = 0; i < 8; i+=2) {
+        acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il        ) << 4 ) & 0x00010))
+                + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il        ) << 12) & 0x01000));
+        acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
+                + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
+    }
+    return d * (acc[0] + acc[1]) + sumy * m;
+}
+
 // putting them in the kernel cause a significant performance penalty
 #define N_DST 4        // each SIMD group works on 4 rows
 #define N_SIMDGROUP 2  // number of SIMD groups in a thread group
@@ -519,6 +597,43 @@ kernel void kernel_mul_mv_q4_1_f32(
      mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
 }
 
+kernel void kernel_mul_mv_q5_0_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01[[buffer(4)]],
+        constant   int64_t & ne02[[buffer(5)]],
+        constant   int64_t & ne10[[buffer(9)]],
+        constant   int64_t & ne12[[buffer(11)]],
+        constant   int64_t & ne0[[buffer(15)]],
+        constant   int64_t & ne1[[buffer(16)]],
+        constant   uint    & gqa[[buffer(17)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+    mul_vec_q_n_f32<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
+}
+
+kernel void kernel_mul_mv_q5_1_f32(
+        device const  void * src0,
+        device const float * src1,
+        device       float * dst,
+        constant   int64_t & ne00,
+        constant   int64_t & ne01[[buffer(4)]],
+        constant   int64_t & ne02[[buffer(5)]],
+        constant   int64_t & ne10[[buffer(9)]],
+        constant   int64_t & ne12[[buffer(11)]],
+        constant   int64_t & ne0[[buffer(15)]],
+        constant   int64_t & ne1[[buffer(16)]],
+        constant   uint    & gqa[[buffer(17)]],
+        uint3 tgpig[[threadgroup_position_in_grid]],
+        uint  tiisg[[thread_index_in_simdgroup]],
+        uint  sgitg[[simdgroup_index_in_threadgroup]]) {
+    mul_vec_q_n_f32<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
+}
+
+
 #define NB_Q8_0 8
 
 kernel void kernel_mul_mv_q8_0_f32(
@@ -2143,6 +2258,62 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
     }
 }
 
+template <typename type4x4>
+void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
+    device const uint16_t * qs = ((device const uint16_t *)xb + 3);
+    const float d = xb->d;
+    const float md = -16.h * xb->d;
+    const ushort mask = il ? 0x00F0 : 0x000F;
+
+    const uint32_t qh = *((device const uint32_t *)xb->qh);
+
+    const int x_mv = il ? 4 : 0;
+
+    const int gh_mv = il ? 12 : 0;
+    const int gh_bk = il ?  0 : 4;
+
+    for (int i = 0; i < 8; i++) {
+        // extract the 5-th bits for x0 and x1
+        const uint8_t xh_0 = ((qh >> (gh_mv + 2*i  )) << gh_bk) & 0x10;
+        const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
+
+        // combine the 4-bits from qs with the 5th bit
+        const int32_t x0 = ((((qs[i]     ) & mask) >> x_mv) | xh_0);
+        const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
+
+        reg[i/2][2*(i%2)+0] = d * x0 + md;
+        reg[i/2][2*(i%2)+1] = d * x1 + md;
+    }
+}
+
+template <typename type4x4>
+void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
+    device const uint16_t * qs = ((device const uint16_t *)xb + 4);
+    const float d = xb->d;
+    const float m = xb->m;
+    const ushort mask = il ? 0x00F0 : 0x000F;
+
+    const uint32_t qh = *((device const uint32_t *)xb->qh);
+
+    const int x_mv = il ? 4 : 0;
+
+    const int gh_mv = il ? 12 : 0;
+    const int gh_bk = il ?  0 : 4;
+
+    for (int i = 0; i < 8; i++) {
+        // extract the 5-th bits for x0 and x1
+        const uint8_t xh_0 = ((qh >> (gh_mv + 2*i  )) << gh_bk) & 0x10;
+        const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
+
+        // combine the 4-bits from qs with the 5th bit
+        const int32_t x0 = ((((qs[i]     ) & mask) >> x_mv) | xh_0);
+        const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
+
+        reg[i/2][2*(i%2)+0] = d * x0 + m;
+        reg[i/2][2*(i%2)+1] = d * x1 + m;
+    }
+}
+
 template <typename type4x4>
 void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
     device const int8_t * qs = ((device const int8_t *)xb->qs);
@@ -2484,6 +2655,8 @@ template [[host_name("kernel_get_rows_f32")]]  kernel get_rows_t kernel_get_rows
 template [[host_name("kernel_get_rows_f16")]]  kernel get_rows_t kernel_get_rows<half4x4,    1, dequantize_f16>;
 template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
 template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
+template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows<block_q5_0, 2, dequantize_q5_0>;
+template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows<block_q5_1, 2, dequantize_q5_1>;
 template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows<block_q8_0, 2, dequantize_q8_0>;
 template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows<block_q2_K, QK_NL, dequantize_q2_K>;
 template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows<block_q3_K, QK_NL, dequantize_q3_K>;
@@ -2512,6 +2685,8 @@ template [[host_name("kernel_mul_mm_f32_f32")]]  kernel mat_mm_t kernel_mul_mm<f
 template [[host_name("kernel_mul_mm_f16_f32")]]  kernel mat_mm_t kernel_mul_mm<half4x4,    1,     dequantize_f16>;
 template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2,     dequantize_q4_0>;
 template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2,     dequantize_q4_1>;
+template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_0, 2,     dequantize_q5_0>;
+template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q5_1, 2,     dequantize_q5_1>;
 template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2,     dequantize_q8_0>;
 template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
 template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
index 4a331f24a92ae340f7c910f219344d8d0467d7af..202bcb4853893c7332906418059c9111ddedda8d 100644 (file)
@@ -19,7 +19,7 @@
 #pragma warning(disable: 4244 4267) // possible loss of data
 #endif
 
-#define CL_DMMV_BLOCK_SIZE 32
+#define CL_DMMV_LOCAL_SIZE 32
 
 #ifndef K_QUANTS_PER_ITERATION
 #define K_QUANTS_PER_ITERATION 1
@@ -338,7 +338,7 @@ __kernel void dequantize_mul_mat_vec_q2_K(__global const struct block_q2_K * xx,
     const int row = get_group_id(0);
 
     const int num_blocks_per_row = ncols / QK_K;
-    const int ib0 = row*num_blocks_per_row;
+    const int ib0 = row*num_blocks_per_row + get_global_offset(0);
 
     __global const struct block_q2_K * x = xx + ib0;
 
@@ -413,7 +413,7 @@ __kernel void dequantize_mul_mat_vec_q3_K(__global const struct block_q3_K * xx,
     const int row = get_group_id(0);
 
     const int num_blocks_per_row = ncols / QK_K;
-    const int ib0 = row*num_blocks_per_row;
+    const int ib0 = row*num_blocks_per_row + get_global_offset(0);
 
     __global const struct block_q3_K * x = xx + ib0;
 
@@ -489,7 +489,7 @@ __kernel void dequantize_mul_mat_vec_q4_K(__global const struct block_q4_K * xx,
 
     const int row = get_group_id(0);
     const int num_blocks_per_row = ncols / QK_K;
-    const int ib0 = row*num_blocks_per_row;
+    const int ib0 = row*num_blocks_per_row + get_global_offset(0);
 
     const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION;  // 0...15
     const int ix  = get_local_id(0)%K_QUANTS_PER_ITERATION;
@@ -562,7 +562,7 @@ __kernel void dequantize_mul_mat_vec_q5_K(__global const struct block_q5_K * xx,
 
     const int row = get_group_id(0);
     const int num_blocks_per_row = ncols / QK_K;
-    const int ib0 = row*num_blocks_per_row;
+    const int ib0 = row*num_blocks_per_row + get_global_offset(0);
 
     const int tid = get_local_id(0)/2;  // 0...15
     const int ix  = get_local_id(0)%2;
@@ -641,7 +641,7 @@ __kernel void dequantize_mul_mat_vec_q6_K(__global const struct block_q6_K * xx,
     const int row = get_group_id(0);
 
     const int num_blocks_per_row = ncols / QK_K;
-    const int ib0 = row*num_blocks_per_row;
+    const int ib0 = row*num_blocks_per_row + get_global_offset(0);
 
     __global const struct block_q6_K * x = xx + ib0;
 
@@ -745,19 +745,21 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __global float* y) {
 
 std::string dequant_mul_mat_vec_template = MULTILINE_QUOTE(
 __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) {
-    const int block_size = get_local_size(0);
+    const int local_size = get_local_size(0);
     const int row = get_group_id(0);
     const int tid = get_local_id(0);
 
     const uint qk = QUANT_K;
     const uint qr = QUANT_R;
 
+    const int col_step = local_size * 2;
     const int y_offset = qr == 1 ? 1 : qk/2;
 
+    x += get_global_offset(0);
+
     tmp[tid] = 0;
 
-    for (int i = 0; i < ncols/block_size; i += 2) {
-        const int col = i*block_size + 2*tid;
+    for (int col = tid*2; col < ncols; col += col_step) {
         const int ib = (row*ncols + col)/qk; // block index
         const int iqs = (col%qk)/qr; // quant index
         const int iybs = col - col%qk; // y block start index
@@ -773,7 +775,7 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float
 
     // sum up partial sums and write back result
     barrier(CLK_LOCAL_MEM_FENCE);
-    for (int s=block_size/2; s>0; s>>=1) {
+    for (int s=local_size/2; s>0; s>>=1) {
         if (tid < s) {
             tmp[tid] += tmp[tid + s];
         }
@@ -1393,75 +1395,46 @@ static void ggml_cl_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1,
     const int64_t ne01 = src0->ne[1];
     const int64_t ne02 = src0->ne[2];
     const int64_t ne03 = src0->ne[3];
-    const int64_t ne0 = ne00 * ne01 * ne02 * ne03;
     const int64_t ne10 = src1->ne[0];
     const int64_t ne11 = src1->ne[1];
     const int64_t ne12 = src1->ne[2];
     const int64_t ne13 = src1->ne[3];
-    const int64_t nb10 = src1->nb[0];
     const int nb2  = dst->nb[2];
     const int nb3  = dst->nb[3];
     size_t x_size;
     size_t d_size;
 
-    cl_mem d_X = ggml_cl_pool_malloc(ne0 * sizeof(float), &x_size); // src0
+    cl_mem d_X = ggml_cl_pool_malloc(ne00 * ne01 * sizeof(float), &x_size); // src0
     cl_mem d_Y = (cl_mem) src1->extra; // src1 is already on device, broadcasted.
-    cl_mem d_D = ggml_cl_pool_malloc(ne0 * sizeof(float), &d_size); // dst
+    cl_mem d_D = ggml_cl_pool_malloc(ne00 * ne01 * sizeof(float), &d_size); // dst
 
 
     for (int64_t i03 = 0; i03 < ne03; i03++) {
         for (int64_t i02 = 0; i02 < ne02; i02++) {
-            const int i0 = i03*ne02 + i02;
-
             cl_event ev;
 
             // copy src0 to device
-            CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, i0, src0, i03, i02, &ev));
-
-            if (nb10 == sizeof(float)) {
-                // Contiguous, avoid overhead from queueing many kernel runs
-                const int64_t i13 = i03%ne13;
-                const int64_t i12 = i02%ne12;
-                const int i1 = i13*ne12*ne11 + i12*ne11;
-
-                cl_int x_offset = 0;
-                cl_int y_offset = i1*ne10;
-                cl_int d_offset = 0;
-
-                size_t global = ne00 * ne01;
-                cl_int ky = ne10;
-                CL_CHECK(clSetKernelArg(mul_f32_cl, 0, sizeof(cl_mem), &d_X));
-                CL_CHECK(clSetKernelArg(mul_f32_cl, 1, sizeof(cl_int), &x_offset));
-                CL_CHECK(clSetKernelArg(mul_f32_cl, 2, sizeof(cl_mem), &d_Y));
-                CL_CHECK(clSetKernelArg(mul_f32_cl, 3, sizeof(cl_int), &y_offset));
-                CL_CHECK(clSetKernelArg(mul_f32_cl, 4, sizeof(cl_mem), &d_D));
-                CL_CHECK(clSetKernelArg(mul_f32_cl, 5, sizeof(cl_int), &d_offset));
-                CL_CHECK(clSetKernelArg(mul_f32_cl, 6, sizeof(cl_int), &ky));
-                CL_CHECK(clEnqueueNDRangeKernel(queue, mul_f32_cl, 1, NULL, &global, NULL, 1, &ev, NULL));
-            } else {
-                for (int64_t i01 = 0; i01 < ne01; i01++) {
-                    const int64_t i13 = i03%ne13;
-                    const int64_t i12 = i02%ne12;
-                    const int64_t i11 = i01%ne11;
-                    const int i1 = i13*ne12*ne11 + i12*ne11 + i11;
-
-                    cl_int x_offset = i01*ne00;
-                    cl_int y_offset = i1*ne10;
-                    cl_int d_offset = i01*ne00;
+            CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, &ev));
 
-                    // compute
-                    size_t global = ne00;
-                    cl_int ky = ne10;
-                    CL_CHECK(clSetKernelArg(mul_f32_cl, 0, sizeof(cl_mem), &d_X));
-                    CL_CHECK(clSetKernelArg(mul_f32_cl, 1, sizeof(cl_int), &x_offset));
-                    CL_CHECK(clSetKernelArg(mul_f32_cl, 2, sizeof(cl_mem), &d_Y));
-                    CL_CHECK(clSetKernelArg(mul_f32_cl, 3, sizeof(cl_int), &y_offset));
-                    CL_CHECK(clSetKernelArg(mul_f32_cl, 4, sizeof(cl_mem), &d_D));
-                    CL_CHECK(clSetKernelArg(mul_f32_cl, 5, sizeof(cl_int), &d_offset));
-                    CL_CHECK(clSetKernelArg(mul_f32_cl, 6, sizeof(cl_int), &ky));
-                    CL_CHECK(clEnqueueNDRangeKernel(queue, mul_f32_cl, 1, NULL, &global, NULL, 1, &ev, NULL));
-                }
-            }
+            const int64_t i13 = i03%ne13;
+            const int64_t i12 = i02%ne12;
+            const int i1 = i13*ne12*ne11 + i12*ne11;
+
+            cl_int x_offset = 0;
+            cl_int y_offset = i1*ne10;
+            cl_int d_offset = 0;
+
+            size_t global = ne00 * ne01;
+            cl_int ky = ne10 * ne11;
+
+            CL_CHECK(clSetKernelArg(mul_f32_cl, 0, sizeof(cl_mem), &d_X));
+            CL_CHECK(clSetKernelArg(mul_f32_cl, 1, sizeof(cl_int), &x_offset));
+            CL_CHECK(clSetKernelArg(mul_f32_cl, 2, sizeof(cl_mem), &d_Y));
+            CL_CHECK(clSetKernelArg(mul_f32_cl, 3, sizeof(cl_int), &y_offset));
+            CL_CHECK(clSetKernelArg(mul_f32_cl, 4, sizeof(cl_mem), &d_D));
+            CL_CHECK(clSetKernelArg(mul_f32_cl, 5, sizeof(cl_int), &d_offset));
+            CL_CHECK(clSetKernelArg(mul_f32_cl, 6, sizeof(cl_int), &ky));
+            CL_CHECK(clEnqueueNDRangeKernel(queue, mul_f32_cl, 1, NULL, &global, NULL, 1, &ev, NULL));
 
             CL_CHECK(clReleaseEvent(ev));
             CL_CHECK(clFinish(queue));
@@ -1516,46 +1489,45 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
     cl_mem d_D = ggml_cl_pool_malloc(sizeof(float) * d_ne, &d_size);
 
     size_t x_offset = 0;
-    int64_t pi02 = -1;
-    int64_t pi03 = -1;
-
-    for (int64_t i13 = 0; i13 < ne13; i13++) {
-        int64_t i03 = i13 / r3;
-
-        for (int64_t i12 = 0; i12 < ne12; i12++) {
-            int64_t i02 = i12 / r2;
-
-            // copy data to device
-            if (src0->backend == GGML_BACKEND_GPU) {
-                x_offset = (i03 * ne02 + i02) * x_ne;
-            } else if (i02 != pi02 || i03 != pi03) {
-                CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, NULL));
-                pi02 = i02;
-                pi03 = i03;
-            }
-            CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i13, i12, NULL));
 
-            CL_CHECK(clFinish(queue));
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        // TODO: copy src0 here when r3>1
+        for (int64_t i13 = i03 * r3, e13 = i13 + r3; i13 < e13; i13++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                if (src0->backend == GGML_BACKEND_GPU) {
+                    x_offset = (i03 * ne02 + i02) * x_ne;
+                } else {
+                    // copy src0 to device
+                    CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, NULL));
+                }
 
-            // compute
-            cl_event ev_sgemm;
-            clblast::StatusCode status = clblast::Gemm<cl_float>(clblast::Layout::kColMajor,
-                                                       clblast::Transpose::kYes, clblast::Transpose::kNo,
-                                                       ne01, ne11, ne10,
-                                                       alpha,
-                                                       d_X, x_offset, ne00,
-                                                       d_Y, 0, ne10,
-                                                       beta,
-                                                       d_D, 0, ne01,
-                                                       &queue, &ev_sgemm);
-
-            if (status != clblast::StatusCode::kSuccess) {
-                GGML_ASSERT(false);
-            }
+                for (int64_t i12 = i02 * r2, e12 = i12 + r2; i12 < e12; i12++) {
+                    // copy src1 to device
+                    CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i13, i12, NULL));
 
-            // copy dst to host
-            float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
-            CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(float) * d_ne, d, 1, &ev_sgemm, NULL));
+                    CL_CHECK(clFinish(queue));
+
+                    // compute
+                    cl_event ev_sgemm;
+                    clblast::StatusCode status = clblast::Gemm<cl_float>(clblast::Layout::kColMajor,
+                                                               clblast::Transpose::kYes, clblast::Transpose::kNo,
+                                                               ne01, ne11, ne10,
+                                                               alpha,
+                                                               d_X, x_offset, ne00,
+                                                               d_Y, 0, ne10,
+                                                               beta,
+                                                               d_D, 0, ne01,
+                                                               &queue, &ev_sgemm);
+
+                    if (status != clblast::StatusCode::kSuccess) {
+                        GGML_ASSERT(false);
+                    }
+
+                    // copy dst to host
+                    float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
+                    CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(float) * d_ne, d, 1, &ev_sgemm, NULL));
+                }
+            }
         }
     }
 
@@ -1566,7 +1538,7 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr
     ggml_cl_pool_free(d_D, d_size);
 }
 
-static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t /* wsize */) {
+static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t wsize) {
     GGML_ASSERT(fp16_support);
 
     const int64_t ne00 = src0->ne[0];
@@ -1596,6 +1568,10 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
     const int y_ne = ne11 * ne10;
     const int d_ne = ne11 * ne01;
 
+    GGML_ASSERT(wsize >= sizeof(ggml_fp16_t) * y_ne);
+    GGML_ASSERT(wsize >= sizeof(ggml_fp16_t) * d_ne);
+    ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata;
+
     size_t x_size;
     size_t y_size;
     size_t d_size;
@@ -1612,74 +1588,70 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr
     bool src1_cont_cols = (size_t)nb11 == ne11*sizeof(float);
 
     size_t x_offset = 0;
-    int64_t pi02 = -1;
-    int64_t pi03 = -1;
 
-    for (int64_t i13 = 0; i13 < ne13; i13++) {
-        int64_t i03 = i13 / r3;
-
-        for (int64_t i12 = 0; i12 < ne12; i12++) {
-            int64_t i02 = i12 / r2;
-
-            // copy src0 to device
-            if (src0->backend == GGML_BACKEND_GPU) {
-                x_offset = (i03 * ne02 + i02) * x_ne;
-            } else if (i02 != pi02 || i03 != pi03) {
-                CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, NULL));
-                pi02 = i02;
-                pi03 = i03;
-            }
-
-            // convert src1 to fp16
-            // TODO: use multiple threads
-            ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata + (ne11 * ne10) * (i13 * ne12 + i12);
-            char * src1i = (char *) src1->data + i13*nb13 + i12*nb12;
-            if (src1_cont_rows) {
-                if (src1_cont_cols) {
-                    ggml_fp32_to_fp16_row((float *) src1i, tmp, ne10*ne11);
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        // TODO: copy src0 here when r3>1
+        for (int64_t i13 = i03 * r3, e13 = i13 + r3; i13 < e13; i13++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                if (src0->backend == GGML_BACKEND_GPU) {
+                    x_offset = (i03 * ne02 + i02) * x_ne;
+                } else {
+                    // copy src0 to device
+                    CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, NULL));
                 }
-                else {
-                    for (int64_t i11 = 0; i11 < ne11; i11++) {
-                        ggml_fp32_to_fp16_row((float *) (src1i + i11*nb11), tmp + i11*ne10, ne10);
+
+                for (int64_t i12 = i02 * r2, e12 = i12 + r2; i12 < e12; i12++) {
+                    // convert src1 to fp16
+                    // TODO: use multiple threads
+                    char * src1i = (char *) src1->data + i13*nb13 + i12*nb12;
+                    if (src1_cont_rows) {
+                        if (src1_cont_cols) {
+                            ggml_fp32_to_fp16_row((float *) src1i, tmp, ne10*ne11);
+                        }
+                        else {
+                            for (int64_t i11 = 0; i11 < ne11; i11++) {
+                                ggml_fp32_to_fp16_row((float *) (src1i + i11*nb11), tmp + i11*ne10, ne10);
+                            }
+                        }
                     }
-                }
-            }
-            else {
-                for (int64_t i11 = 0; i11 < ne11; i11++) {
-                    for (int64_t i10 = 0; i10 < ne10; i10++) {
-                        // very slow due to no inlining
-                        tmp[i11*ne10 + i10] = ggml_fp32_to_fp16(*(float *) (src1i + i11*nb11 + i10*nb10));
+                    else {
+                        for (int64_t i11 = 0; i11 < ne11; i11++) {
+                            for (int64_t i10 = 0; i10 < ne10; i10++) {
+                                // very slow due to no inlining
+                                tmp[i11*ne10 + i10] = ggml_fp32_to_fp16(*(float *) (src1i + i11*nb11 + i10*nb10));
+                            }
+                        }
                     }
-                }
-            }
 
-            // copy src1 to device
-            CL_CHECK(clEnqueueWriteBuffer(queue, d_Y, false, 0, sizeof(ggml_fp16_t) * y_ne, tmp, 0, NULL, NULL));
+                    // copy src1 to device
+                    CL_CHECK(clEnqueueWriteBuffer(queue, d_Y, false, 0, sizeof(ggml_fp16_t) * y_ne, tmp, 0, NULL, NULL));
 
-            CL_CHECK(clFinish(queue));
+                    CL_CHECK(clFinish(queue));
 
-            // compute
-            cl_event ev_sgemm;
-            clblast::StatusCode status = clblast::Gemm<cl_half>(clblast::Layout::kColMajor,
-                                                       clblast::Transpose::kYes, clblast::Transpose::kNo,
-                                                       ne01, ne11, ne10,
-                                                       alpha,
-                                                       d_X, x_offset, ne00,
-                                                       d_Y, 0, ne10,
-                                                       beta,
-                                                       d_D, 0, ne01,
-                                                       &queue, &ev_sgemm);
-
-            if (status != clblast::StatusCode::kSuccess) {
-                GGML_ASSERT(false);
-            }
+                    // compute
+                    cl_event ev_sgemm;
+                    clblast::StatusCode status = clblast::Gemm<cl_half>(clblast::Layout::kColMajor,
+                                                               clblast::Transpose::kYes, clblast::Transpose::kNo,
+                                                               ne01, ne11, ne10,
+                                                               alpha,
+                                                               d_X, x_offset, ne00,
+                                                               d_Y, 0, ne10,
+                                                               beta,
+                                                               d_D, 0, ne01,
+                                                               &queue, &ev_sgemm);
+
+                    if (status != clblast::StatusCode::kSuccess) {
+                        GGML_ASSERT(false);
+                    }
 
-            // copy dst to host, then convert to float
-            CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(ggml_fp16_t) * d_ne, tmp, 1, &ev_sgemm, NULL));
+                    // copy dst to host, then convert to float
+                    CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(ggml_fp16_t) * d_ne, tmp, 1, &ev_sgemm, NULL));
 
-            float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
+                    float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
 
-            ggml_fp16_to_fp32_row(tmp, d, d_ne);
+                    ggml_fp16_to_fp32_row(tmp, d, d_ne);
+                }
+            }
         }
     }
 
@@ -1704,7 +1676,7 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
     const int nb2  = dst->nb[2];
     const int nb3  = dst->nb[3];
     const ggml_type type = src0->type;
-    const bool mul_mat_vec = ne11 == 1;
+    const bool mul_mat_vec = ne11 == 1 && ne00%2 == 0;
 
     const int64_t r2 = ne12 / ne02;
     const int64_t r3 = ne13 / ne03;
@@ -1737,90 +1709,86 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
     GGML_ASSERT(to_fp32_cl != nullptr);
 
     const size_t global_denom = ggml_cl_global_denom(type);
-    const size_t local = ggml_cl_local_size(type);
+    const size_t local = mul_mat_vec ? CL_DMMV_LOCAL_SIZE : ggml_cl_local_size(type);
 
     size_t ev_idx = 0;
     std::vector<cl_event> events;
 
-    int64_t pi02 = -1;
-    int64_t pi03 = -1;
-
-    for (int64_t i13 = 0; i13 < ne13; i13++) {
-        int64_t i03 = i13 / r3;
-
-        for (int64_t i12 = 0; i12 < ne12; i12++) {
-            int64_t i02 = i12 / r2;
-
-            // copy src0 to device if necessary
-            if (src0->backend == GGML_BACKEND_CPU) {
-                if (i02 != pi02 || i03 != pi03) {
+    for (int64_t i03 = 0; i03 < ne03; i03++) {
+        // TODO: copy and dequantize src0 here when r3>1
+        for (int64_t i13 = i03 * r3, e13 = i13 + r3; i13 < e13; i13++) {
+            for (int64_t i02 = 0; i02 < ne02; i02++) {
+                // copy src0 to device if necessary
+                if (src0->backend == GGML_BACKEND_CPU) {
                     events.emplace_back();
                     CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Q, 0, src0, i03, i02, events.data() + ev_idx++));
-                    pi02 = i02;
-                    pi03 = i03;
-                }
-            } else if (src0->backend == GGML_BACKEND_GPU) {
-                d_Q = (cl_mem) src0->extra;
-            } else {
-                GGML_ASSERT(false);
-            }
-            if (mul_mat_vec) { // specialized dequantize_mul_mat_vec kernel
-                // copy src1 to device
-                events.emplace_back();
-                CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i13, i12, events.data() + ev_idx++));
-
-                // compute
-                const size_t global = ne01 * CL_DMMV_BLOCK_SIZE;
-                const size_t local = CL_DMMV_BLOCK_SIZE;
-                const cl_int ncols = ne00;
-                events.emplace_back();
-                CL_CHECK(clSetKernelArg(*dmmv, 0, sizeof(cl_mem), &d_Q));
-                CL_CHECK(clSetKernelArg(*dmmv, 1, sizeof(float) * local, NULL));
-                CL_CHECK(clSetKernelArg(*dmmv, 2, sizeof(cl_mem), &d_Y));
-                CL_CHECK(clSetKernelArg(*dmmv, 3, sizeof(cl_mem), &d_D));
-                CL_CHECK(clSetKernelArg(*dmmv, 4, sizeof(cl_int), &ncols));
-                CL_CHECK(clEnqueueNDRangeKernel(queue, *dmmv, 1, NULL, &global, &local, events.size() - 1, events.data(), events.data() + ev_idx++));
-            } else { // general dequantization kernel + CLBlast matrix matrix multiplication
-                // convert src0 to fp32 on device
-                const size_t global = x_ne / global_denom;
-                const size_t offset = src0->backend == GGML_BACKEND_GPU ? (i03 * ne02 + i02) * x_bps : 0;
-                CL_CHECK(clSetKernelArg(*to_fp32_cl, 0, sizeof(cl_mem), &d_Q));
-                CL_CHECK(clSetKernelArg(*to_fp32_cl, 1, sizeof(cl_mem), &d_X));
-                CL_CHECK(clEnqueueNDRangeKernel(queue, *to_fp32_cl, 1, offset > 0 ? &offset : NULL, &global, local > 0 ? &local : NULL, events.size(), !events.empty() ? events.data() : NULL, NULL));
-
-                // copy src1 to device
-                CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i13, i12, NULL));
-
-                events.emplace_back();
-
-                // wait for conversion
-                CL_CHECK(clFinish(queue));
-
-                // compute
-                clblast::StatusCode status = clblast::Gemm<cl_float>(clblast::Layout::kColMajor,
-                                                           clblast::Transpose::kYes, clblast::Transpose::kNo,
-                                                           ne01, ne11, ne10,
-                                                           alpha,
-                                                           d_X, 0, ne00,
-                                                           d_Y, 0, ne10,
-                                                           beta,
-                                                           d_D, 0, ne01,
-                                                           &queue, events.data() + ev_idx++);
-
-                if (status != clblast::StatusCode::kSuccess) {
+                } else if (src0->backend == GGML_BACKEND_GPU) {
+                    d_Q = (cl_mem) src0->extra;
+                } else {
                     GGML_ASSERT(false);
                 }
-            }
 
-            // copy dst to host
-            float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
-            CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(float) * d_ne, d, 1, &events[events.size() - 1], NULL));
-            for (auto *event : events) {
-                clReleaseEvent(event);
-            }
+                if (!mul_mat_vec) {
+                    // convert src0 to fp32 on device
+                    const size_t global = x_ne / global_denom;
+                    const size_t offset = src0->backend == GGML_BACKEND_GPU ? (i03 * ne02 + i02) * x_bps : 0;
+                    CL_CHECK(clSetKernelArg(*to_fp32_cl, 0, sizeof(cl_mem), &d_Q));
+                    CL_CHECK(clSetKernelArg(*to_fp32_cl, 1, sizeof(cl_mem), &d_X));
+                    CL_CHECK(clEnqueueNDRangeKernel(queue, *to_fp32_cl, 1, &offset, &global, local > 0 ? &local : NULL, events.size(), !events.empty() ? events.data() : NULL, NULL));
+                }
 
-            ev_idx = 0;
-            events.clear();
+                for (int64_t i12 = i02 * r2, e12 = i12 + r2; i12 < e12; i12++) {
+                    if (mul_mat_vec) { // specialized dequantize_mul_mat_vec kernel
+                        // copy src1 to device
+                        events.emplace_back();
+                        CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i13, i12, events.data() + ev_idx++));
+
+                        // compute
+                        const size_t global = ne01 * local;
+                        const size_t offset = src0->backend == GGML_BACKEND_GPU ? (i03 * ne02 + i02) * x_bps : 0;
+                        const cl_int ncols = ne00;
+                        events.emplace_back();
+                        CL_CHECK(clSetKernelArg(*dmmv, 0, sizeof(cl_mem), &d_Q));
+                        CL_CHECK(clSetKernelArg(*dmmv, 1, sizeof(float) * local, NULL));
+                        CL_CHECK(clSetKernelArg(*dmmv, 2, sizeof(cl_mem), &d_Y));
+                        CL_CHECK(clSetKernelArg(*dmmv, 3, sizeof(cl_mem), &d_D));
+                        CL_CHECK(clSetKernelArg(*dmmv, 4, sizeof(cl_int), &ncols));
+                        CL_CHECK(clEnqueueNDRangeKernel(queue, *dmmv, 1, &offset, &global, &local, events.size() - 1, events.data(), events.data() + ev_idx++));
+                    } else { // CLBlast matrix matrix multiplication
+                        // copy src1 to device
+                        CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i13, i12, NULL));
+
+                        // wait for conversion
+                        CL_CHECK(clFinish(queue));
+
+                        // compute
+                        events.emplace_back();
+                        clblast::StatusCode status = clblast::Gemm<cl_float>(clblast::Layout::kColMajor,
+                                                                   clblast::Transpose::kYes, clblast::Transpose::kNo,
+                                                                   ne01, ne11, ne10,
+                                                                   alpha,
+                                                                   d_X, 0, ne00,
+                                                                   d_Y, 0, ne10,
+                                                                   beta,
+                                                                   d_D, 0, ne01,
+                                                                   &queue, events.data() + ev_idx++);
+
+                        if (status != clblast::StatusCode::kSuccess) {
+                            GGML_ASSERT(false);
+                        }
+                    }
+
+                    // copy dst to host
+                    float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3);
+                    CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(float) * d_ne, d, 1, &events[events.size() - 1], NULL));
+                    for (auto *event : events) {
+                        clReleaseEvent(event);
+                    }
+
+                    ev_idx = 0;
+                    events.clear();
+                }
+            }
         }
     }
 
@@ -1895,8 +1863,8 @@ void ggml_cl_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor *
 }
 
 size_t ggml_cl_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) {
-    if (ggml_cl_mul_mat_use_f16(src0, src1, dst)) {
-        return ggml_nelements(src1) * sizeof(ggml_fp16_t);
+    if (src0->type == GGML_TYPE_F16 && ggml_cl_mul_mat_use_f16(src0, src1, dst)) {
+        return sizeof(ggml_fp16_t) * std::max(src1->ne[0] * src1->ne[1], dst->ne[0] * dst->ne[1]);
     }
     return 0;
 }
index f71cfd8da465ccd520bd9df124c16a9fcd7e3b60..6f66bab051cea482d07ce52fa564e57b2df45077 100644 (file)
@@ -5496,6 +5496,39 @@ struct ggml_tensor * ggml_view_tensor(
     return result;
 }
 
+struct ggml_tensor * ggml_get_first_tensor(struct ggml_context * ctx) {
+    struct ggml_object * obj = ctx->objects_begin;
+
+    char * const mem_buffer = ctx->mem_buffer;
+
+    while (obj != NULL) {
+        if (obj->type == GGML_OBJECT_TENSOR) {
+            return (struct ggml_tensor *)(mem_buffer + obj->offs);
+        }
+
+        obj = obj->next;
+    }
+
+    return NULL;
+}
+
+struct ggml_tensor * ggml_get_next_tensor(struct ggml_context * ctx, struct ggml_tensor * tensor) {
+    struct ggml_object * obj = (struct ggml_object *) ((char *)tensor - GGML_OBJECT_SIZE);
+    obj = obj->next;
+
+    char * const mem_buffer = ctx->mem_buffer;
+
+    while (obj != NULL) {
+        if (obj->type == GGML_OBJECT_TENSOR) {
+            return (struct ggml_tensor *)(mem_buffer + obj->offs);
+        }
+
+        obj = obj->next;
+    }
+
+    return NULL;
+}
+
 struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name) {
     struct ggml_object * obj = ctx->objects_begin;
 
@@ -8701,6 +8734,7 @@ void ggml_set_param(
 
     GGML_ASSERT(tensor->grad == NULL);
     tensor->grad = ggml_dup_tensor(ctx, tensor);
+    ggml_format_name(tensor->grad, "%s (grad)", tensor->name);
 }
 
 // ggml_compute_forward_dup
@@ -11283,7 +11317,7 @@ static void ggml_compute_forward_silu_f32(
 
 #ifndef NDEBUG
         for (int k = 0; k < nc; k++) {
-            const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k];
+            const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k];
             UNUSED(x);
             assert(!isnan(x));
             assert(!isinf(x));
@@ -13106,24 +13140,22 @@ static void ggml_compute_forward_alibi_f32(
         return;
     }
 
-    const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past);
+    //const int n_past = ((int32_t *) dst->op_params)[0];
     const int n_head = ((int32_t *) dst->op_params)[1];
     float max_bias;
     memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
 
-    assert(n_past >= 0);
+    const int64_t ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
+    const int64_t ne1 = src0->ne[1]; // seq_len_without_past
+    const int64_t ne2 = src0->ne[2]; // n_head -> this is k
+    //const int64_t ne3 = src0->ne[3]; // 1 -> bsz
 
-    const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
-    const int ne1 = src0->ne[1]; // seq_len_without_past
-    const int ne2 = src0->ne[2]; // n_head -> this is k
-    //const int ne3 = src0->ne[3]; // 1 -> bsz
+    const int64_t n  = ggml_nrows(src0);
+    const int64_t ne2_ne3 = n/ne1; // ne2*ne3
 
-    const int n  = ggml_nrows(src0);
-    const int ne2_ne3 = n/ne1; // ne2*ne3
-
-    const int nb0 = src0->nb[0];
-    const int nb1 = src0->nb[1];
-    const int nb2 = src0->nb[2];
+    const size_t nb0 = src0->nb[0];
+    const size_t nb1 = src0->nb[1];
+    const size_t nb2 = src0->nb[2];
     //const int nb3 = src0->nb[3];
 
     GGML_ASSERT(nb0 == sizeof(float));
@@ -13135,9 +13167,9 @@ static void ggml_compute_forward_alibi_f32(
     const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
 
-    for (int i = 0; i < ne0; i++) {
-        for (int j = 0; j < ne1; j++) {
-            for (int k = 0; k < ne2_ne3; k++) {
+    for (int64_t i = 0; i < ne0; i++) {
+        for (int64_t j = 0; j < ne1; j++) {
+            for (int64_t k = 0; k < ne2_ne3; k++) {
                 float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2);
                 float *      pdst = (float *)((char *)  dst->data + i*nb0 + j*nb1 + k*nb2);
 
@@ -13152,7 +13184,6 @@ static void ggml_compute_forward_alibi_f32(
                 }
 
                 pdst[0] = i * m_k + src[0];
-
             }
         }
     }
@@ -13553,7 +13584,7 @@ static void ggml_compute_forward_rope_f16(
                         dst_data[n_dims]     = GGML_FP32_TO_FP16(x2*cos_block_theta - x3*sin_block_theta);
                         dst_data[n_dims/2*3] = GGML_FP32_TO_FP16(x2*sin_block_theta + x3*cos_block_theta);
                     }
-                } if (!is_neox) {
+                } else if (!is_neox) {
                     for (int64_t i0 = 0; i0 < ne0; i0 += 2) {
                         const float cos_theta = cosf(theta);
                         const float sin_theta = sinf(theta);
@@ -16820,6 +16851,10 @@ static void ggml_compute_forward_cross_entropy_loss_back(
 static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
     GGML_ASSERT(params);
 
+    if (tensor->op == GGML_OP_NONE) {
+        return;
+    }
+
 #ifdef GGML_USE_CUBLAS
     bool skip_cpu = ggml_cuda_compute_forward(params, tensor);
     if (skip_cpu) {
@@ -19414,6 +19449,7 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) {
 
                             if (idx == -1) {
                                 fprintf(stderr, "%s: failed to find tensor, arg = %d, node = %d\n", __func__, j, i);
+                                fclose(fout);
                                 return;
                             }
 
@@ -21087,7 +21123,7 @@ struct gguf_kv {
 };
 
 struct gguf_header {
-    uint32_t magic;
+    char magic[4];
     uint32_t version;
     uint64_t n_tensors; // GGUFv2
     uint64_t n_kv;      // GGUFv2
@@ -21157,7 +21193,7 @@ static bool gguf_fread_str_v1(FILE * file, struct gguf_str * p, size_t * offset)
 struct gguf_context * gguf_init_empty(void) {
     struct gguf_context * ctx = GGML_ALIGNED_MALLOC(sizeof(struct gguf_context));
 
-    ctx->header.magic     = GGUF_MAGIC;
+    memcpy(ctx->header.magic, GGUF_MAGIC, sizeof(ctx->header.magic));
     ctx->header.version   = GGUF_VERSION;
     ctx->header.n_tensors = 0;
     ctx->header.n_kv      = 0;
@@ -21183,16 +21219,18 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
     // offset from start of file
     size_t offset = 0;
 
-    uint32_t magic = 0;
+    char magic[4];
 
     // check the magic before making allocations
     {
         gguf_fread_el(file, &magic, sizeof(magic), &offset);
 
-        if (magic != GGUF_MAGIC) {
-            fprintf(stderr, "%s: invalid magic number %08x\n", __func__, magic);
-            fclose(file);
-            return NULL;
+        for (uint32_t i = 0; i < sizeof(magic); i++) {
+            if (magic[i] != GGUF_MAGIC[i]) {
+                fprintf(stderr, "%s: invalid magic characters %s.\n", __func__, magic);
+                fclose(file);
+                return NULL;
+            }
         }
     }
 
@@ -21202,7 +21240,8 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p
 
     // read the header
     {
-        ctx->header.magic = magic;
+        strncpy(ctx->header.magic, magic, 4);
+
 
         ctx->kv    = NULL;
         ctx->infos = NULL;