From: Georgi Gerganov Date: Tue, 24 Oct 2023 18:51:12 +0000 (+0300) Subject: sync : llama.cpp (CUDA, Metal, OpenCL, gguf magic, ggml iter) (#592) X-Git-Tag: upstream/0.0.1642~1210 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=dd92cfd4b188e9202dddb7b85eb8bc1e51cf8288;p=pkg%2Fggml%2Fsources%2Fggml sync : llama.cpp (CUDA, Metal, OpenCL, gguf magic, ggml iter) (#592) ggml-ci --- diff --git a/include/ggml/ggml.h b/include/ggml/ggml.h index 4b16032f..08bff551 100644 --- a/include/ggml/ggml.h +++ b/include/ggml/ggml.h @@ -231,8 +231,9 @@ #define GGML_EXIT_SUCCESS 0 #define GGML_EXIT_ABORTED 1 -#define GGUF_MAGIC 0x46554747 // "GGUF" -#define GGUF_VERSION 2 +#define GGUF_MAGIC "GGUF" + +#define GGUF_VERSION 3 #define GGUF_DEFAULT_ALIGNMENT 32 @@ -705,6 +706,9 @@ extern "C" { GGML_API struct ggml_tensor * ggml_dup_tensor (struct ggml_context * ctx, const struct ggml_tensor * src); GGML_API struct ggml_tensor * ggml_view_tensor(struct ggml_context * ctx, struct ggml_tensor * src); + // Context tensor enumeration and lookup + GGML_API struct ggml_tensor * ggml_get_first_tensor(struct ggml_context * ctx); + GGML_API struct ggml_tensor * ggml_get_next_tensor (struct ggml_context * ctx, struct ggml_tensor * tensor); GGML_API struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name); GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor); diff --git a/src/ggml-cuda.cu b/src/ggml-cuda.cu index 5bd83bb5..d1e874b6 100644 --- a/src/ggml-cuda.cu +++ b/src/ggml-cuda.cu @@ -29,6 +29,8 @@ #define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) #define cublasCreate hipblasCreate #define cublasGemmEx hipblasGemmEx +#define cublasGemmBatchedEx hipblasGemmBatchedEx +#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx #define cublasHandle_t hipblasHandle_t #define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS #define cublasSetStream hipblasSetStream @@ -415,6 +417,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_ #define CUDA_SILU_BLOCK_SIZE 256 #define CUDA_CPY_BLOCK_SIZE 32 #define CUDA_SCALE_BLOCK_SIZE 256 +#define CUDA_CLAMP_BLOCK_SIZE 256 #define CUDA_ROPE_BLOCK_SIZE 256 #define CUDA_ALIBI_BLOCK_SIZE 32 #define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32 @@ -4325,13 +4328,13 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous const half * x = (const half *) vx; - const int row_x = blockDim.y*blockIdx.y + threadIdx.y; - const int channel = blockDim.z*blockIdx.z + threadIdx.z; + const int row_x = blockDim.y*blockIdx.y + threadIdx.y; + const int channel = blockDim.z*blockIdx.z + threadIdx.z; const int channel_x = channel / channel_x_divisor; - const int nrows_y = ncols_x; + const int nrows_y = ncols_x; const int nrows_dst = nrows_x; - const int row_dst = row_x; + const int row_dst = row_x; const int idst = channel*nrows_dst + row_dst; @@ -4344,13 +4347,13 @@ static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous break; } - const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x; - const float xi = __half2float(x[ix]); - const int row_y = col_x; + const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x; const int iy = channel*nrows_y + row_y; + const float xi = __half2float(x[ix]); + tmp += xi * y[iy]; } @@ -4585,6 +4588,15 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale dst[i] = scale * x[i]; } +static __global__ void clamp_f32(const float * x, float * dst, const float min, const float max, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]); +} template static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const int nrows, const int ncols, cudaStream_t stream) { @@ -5475,6 +5487,11 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons scale_f32<<>>(x, dst, scale, k); } +static void clamp_f32_cuda(const float * x, float * dst, const float min, const float max, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_CLAMP_BLOCK_SIZE - 1) / CUDA_CLAMP_BLOCK_SIZE; + clamp_f32<<>>(x, dst, min, max, k); +} + template static void rope_cuda(const T * x, T * dst, const int ncols, const int nrows, const int32_t * pos, const float freq_scale, const int p_delta_rows, const float theta_scale, cudaStream_t stream) { @@ -6419,12 +6436,12 @@ inline void ggml_cuda_op_alibi( const int64_t ne02 = src0->ne[2]; const int64_t nrows = ggml_nrows(src0); - const int n_past = ((int32_t *) dst->op_params)[0]; + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_head = ((int32_t *) dst->op_params)[1]; float max_bias; memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - GGML_ASSERT(ne01 + n_past == ne00); + //GGML_ASSERT(ne01 + n_past == ne00); GGML_ASSERT(n_head == ne02); const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); @@ -6500,6 +6517,24 @@ inline void ggml_cuda_op_scale( (void) src1_dd; } +inline void ggml_cuda_op_clamp( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const float min = ((float *) dst->op_params)[0]; + const float max = ((float *) dst->op_params)[1]; + + clamp_f32_cuda(src0_dd, dst_dd, min, max, ggml_nelements(src0), main_stream); + CUDA_CHECK(cudaGetLastError()); + + (void) src1; + (void) dst; + (void) src1_dd; +} + static void ggml_cuda_op_flatten(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const ggml_cuda_op_flatten_t op) { const int64_t nrows0 = ggml_nrows(src0); @@ -6980,7 +7015,8 @@ static void ggml_cuda_mul_mat_vec_p021(const ggml_tensor * src0, const ggml_tens } static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ - GGML_ASSERT(!ggml_is_contiguous(src0) && ggml_is_contiguous(src1)); + GGML_ASSERT(!ggml_is_transposed(src0)); + GGML_ASSERT(!ggml_is_transposed(src1)); GGML_ASSERT(!ggml_is_permuted(src0)); GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); GGML_ASSERT(src0->type == GGML_TYPE_F16); @@ -6990,11 +7026,11 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor const int64_t ne01 = src0->ne[1]; const int64_t ne02 = src0->ne[2]; - const int64_t ne12 = src1->ne[2]; - const int64_t nb01 = src0->nb[1]; const int64_t nb02 = src0->nb[2]; + const int64_t ne12 = src1->ne[2]; + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; @@ -7013,6 +7049,159 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream); } +static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst){ + GGML_ASSERT(!ggml_is_transposed(src0)); + GGML_ASSERT(!ggml_is_transposed(src1)); + GGML_ASSERT(src0->backend != GGML_BACKEND_GPU_SPLIT); + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; GGML_UNUSED(ne00); + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + const int64_t nb01 = src0->nb[1]; + const int64_t nb02 = src0->nb[2]; GGML_UNUSED(nb02); + const int64_t nb03 = src0->nb[3]; GGML_UNUSED(nb03); + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; + const int64_t ne13 = src1->ne[3]; + + const int64_t nb11 = src1->nb[1]; + const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12); + const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13); + + const int64_t ne1 = ggml_nelements(src1); + const int64_t ne = ggml_nelements(dst); + + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; + + int id; + CUDA_CHECK(cudaGetDevice(&id)); + CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], main_stream)); + + ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; + void * src0_ddq = src0_extra->data_device[g_main_device]; + half * src0_as_f16 = (half *) src0_ddq; + + ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; + float * src1_ddf = (float *) src1_extra->data_device[g_main_device]; + + ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; + float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; + + // convert src1 to fp16 + const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); + GGML_ASSERT(to_fp16_cuda != nullptr); + + size_t src1_as = 0; + half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as); + to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream); + + size_t dst_as = 0; + half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as); + + GGML_ASSERT(ne12 % ne02 == 0); + GGML_ASSERT(ne13 % ne03 == 0); + + // broadcast factors + const int64_t r2 = ne12/ne02; + const int64_t r3 = ne13/ne03; + + const half alpha_f16 = 1.0f; + const half beta_f16 = 0.0f; + +#if 0 + // use cublasGemmEx + { + for (int i13 = 0; i13 < ne13; ++i13) { + for (int i12 = 0; i12 < ne12; ++i12) { + int i03 = i13 / r3; + int i02 = i12 / r2; + + CUBLAS_CHECK( + cublasGemmEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + &alpha_f16, (const char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3] , CUDA_R_16F, nb01/sizeof(half), + (const char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2, CUDA_R_16F, nb11/sizeof(float), + &beta_f16, ( char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2, CUDA_R_16F, ne01, + CUBLAS_COMPUTE_16F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + } + } + } +#else + if (r2 == 1 && r3 == 1 && src0->nb[2]*src0->ne[2] == src0->nb[3] && src1->nb[2]*src1->ne[2] == src1->nb[3]) { + // there is no broadcast and src0, src1 are contiguous across dims 2, 3 + // use cublasGemmStridedBatchedEx + CUBLAS_CHECK( + cublasGemmStridedBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + &alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA + (const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB + &beta_f16, ( char *) dst_f16, CUDA_R_16F, ne01, dst->nb[2]/sizeof(float), // strideC + ne12*ne13, + CUBLAS_COMPUTE_16F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + } else { + // use cublasGemmBatchedEx + // TODO: https://github.com/ggerganov/llama.cpp/pull/3749#discussion_r1369997000 + const int ne23 = ne12*ne13; + + // TODO: avoid this alloc + void ** ptrs = (void **) malloc(3*ne23*sizeof(void *)); + + for (int i13 = 0; i13 < ne13; ++i13) { + for (int i12 = 0; i12 < ne12; ++i12) { + int i03 = i13 / r3; + int i02 = i12 / r2; + + ptrs[0*ne23 + i12 + i13*ne12] = (char *) src0_as_f16 + i02*src0->nb[2] + i03*src0->nb[3]; + ptrs[1*ne23 + i12 + i13*ne12] = (char *) src1_as_f16 + i12*src1->nb[2]/2 + i13*src1->nb[3]/2; + ptrs[2*ne23 + i12 + i13*ne12] = (char *) dst_f16 + i12* dst->nb[2]/2 + i13* dst->nb[3]/2; + } + } + + // allocate device memory for pointers + void ** ptrs_as = nullptr; + CUDA_CHECK(cudaMalloc(&ptrs_as, 3*ne23*sizeof(void *))); + + // TODO: this does not work for some reason -- not sure why? + //size_t ptrs_s = 0; + //ptrs_as = (void **) ggml_cuda_pool_malloc(3*ne23*sizeof(void *), &ptrs_s); + + // copy pointers to device + CUDA_CHECK(cudaMemcpy(ptrs_as, ptrs, 3*ne23*sizeof(void *), cudaMemcpyHostToDevice)); + + free(ptrs); + + CUBLAS_CHECK( + cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + &alpha_f16, (const void **) (ptrs_as + 0*ne23), CUDA_R_16F, nb01/sizeof(half), + (const void **) (ptrs_as + 1*ne23), CUDA_R_16F, nb11/sizeof(float), + &beta_f16, ( void **) (ptrs_as + 2*ne23), CUDA_R_16F, ne01, + ne23, + CUBLAS_COMPUTE_16F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + // free device memory for pointers + CUDA_CHECK(cudaFree(ptrs_as)); + //ggml_cuda_pool_free(ptrs_as, ptrs_s); + } +#endif + + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); + to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream); + + ggml_cuda_pool_free(src1_as_f16, src1_as); + ggml_cuda_pool_free(dst_f16, dst_as); +} + static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { bool all_on_device = (src0->backend == GGML_BACKEND_GPU || src0->backend == GGML_BACKEND_GPU_SPLIT) && src1->backend == GGML_BACKEND_GPU && dst->backend == GGML_BACKEND_GPU; @@ -7025,10 +7214,22 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 } } + // debug helpers + //printf("src0: %8d %8d %8d %8d\n", src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3]); + //printf(" %8d %8d %8d %8d\n", src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3]); + //printf("src1: %8d %8d %8d %8d\n", src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3]); + //printf(" %8d %8d %8d %8d\n", src1->nb[0], src1->nb[1], src1->nb[2], src1->nb[3]); + //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); + //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); + if (all_on_device && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { + // KQ ggml_cuda_mul_mat_vec_p021(src0, src1, dst); - } else if (all_on_device && !ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && src1->ne[1] == 1) { + } else if (all_on_device && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { + // KQV ggml_cuda_mul_mat_vec_nc(src0, src1, dst); + } else if (all_on_device && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { + ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst); } else if (src0->type == GGML_TYPE_F32) { ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false); } else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) { @@ -7061,6 +7262,10 @@ static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale); } +static void ggml_cuda_clamp(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_clamp); +} + static void ggml_cuda_cpy(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const int64_t ne = ggml_nelements(src0); GGML_ASSERT(ne == ggml_nelements(src1)); @@ -7470,6 +7675,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ case GGML_OP_SCALE: func = ggml_cuda_scale; break; + case GGML_OP_CLAMP: + if (!any_on_device) { + return false; + } + func = ggml_cuda_clamp; + break; case GGML_OP_CPY: func = ggml_cuda_cpy; break; diff --git a/src/ggml-metal.m b/src/ggml-metal.m index 29cb3c92..c1901dca 100644 --- a/src/ggml-metal.m +++ b/src/ggml-metal.m @@ -62,6 +62,7 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(mul); GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast GGML_METAL_DECL_KERNEL(scale); + GGML_METAL_DECL_KERNEL(scale_4); GGML_METAL_DECL_KERNEL(silu); GGML_METAL_DECL_KERNEL(relu); GGML_METAL_DECL_KERNEL(gelu); @@ -73,6 +74,8 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(get_rows_f16); GGML_METAL_DECL_KERNEL(get_rows_q4_0); GGML_METAL_DECL_KERNEL(get_rows_q4_1); + GGML_METAL_DECL_KERNEL(get_rows_q5_0); + GGML_METAL_DECL_KERNEL(get_rows_q5_1); GGML_METAL_DECL_KERNEL(get_rows_q8_0); GGML_METAL_DECL_KERNEL(get_rows_q2_K); GGML_METAL_DECL_KERNEL(get_rows_q3_K); @@ -87,6 +90,8 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(mul_mv_f16_f32_l4); GGML_METAL_DECL_KERNEL(mul_mv_q4_0_f32); GGML_METAL_DECL_KERNEL(mul_mv_q4_1_f32); + GGML_METAL_DECL_KERNEL(mul_mv_q5_0_f32); + GGML_METAL_DECL_KERNEL(mul_mv_q5_1_f32); GGML_METAL_DECL_KERNEL(mul_mv_q8_0_f32); GGML_METAL_DECL_KERNEL(mul_mv_q2_K_f32); GGML_METAL_DECL_KERNEL(mul_mv_q3_K_f32); @@ -97,6 +102,8 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(mul_mm_f16_f32); GGML_METAL_DECL_KERNEL(mul_mm_q4_0_f32); GGML_METAL_DECL_KERNEL(mul_mm_q4_1_f32); + GGML_METAL_DECL_KERNEL(mul_mm_q5_0_f32); + GGML_METAL_DECL_KERNEL(mul_mm_q5_1_f32); GGML_METAL_DECL_KERNEL(mul_mm_q8_0_f32); GGML_METAL_DECL_KERNEL(mul_mm_q2_K_f32); GGML_METAL_DECL_KERNEL(mul_mm_q3_K_f32); @@ -243,6 +250,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(mul); GGML_METAL_ADD_KERNEL(mul_row); GGML_METAL_ADD_KERNEL(scale); + GGML_METAL_ADD_KERNEL(scale_4); GGML_METAL_ADD_KERNEL(silu); GGML_METAL_ADD_KERNEL(relu); GGML_METAL_ADD_KERNEL(gelu); @@ -254,6 +262,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(get_rows_f16); GGML_METAL_ADD_KERNEL(get_rows_q4_0); GGML_METAL_ADD_KERNEL(get_rows_q4_1); + GGML_METAL_ADD_KERNEL(get_rows_q5_0); + GGML_METAL_ADD_KERNEL(get_rows_q5_1); GGML_METAL_ADD_KERNEL(get_rows_q8_0); GGML_METAL_ADD_KERNEL(get_rows_q2_K); GGML_METAL_ADD_KERNEL(get_rows_q3_K); @@ -268,6 +278,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(mul_mv_f16_f32_l4); GGML_METAL_ADD_KERNEL(mul_mv_q4_0_f32); GGML_METAL_ADD_KERNEL(mul_mv_q4_1_f32); + GGML_METAL_ADD_KERNEL(mul_mv_q5_0_f32); + GGML_METAL_ADD_KERNEL(mul_mv_q5_1_f32); GGML_METAL_ADD_KERNEL(mul_mv_q8_0_f32); GGML_METAL_ADD_KERNEL(mul_mv_q2_K_f32); GGML_METAL_ADD_KERNEL(mul_mv_q3_K_f32); @@ -278,8 +290,10 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(mul_mm_f32_f32); GGML_METAL_ADD_KERNEL(mul_mm_f16_f32); GGML_METAL_ADD_KERNEL(mul_mm_q4_0_f32); - GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32); GGML_METAL_ADD_KERNEL(mul_mm_q4_1_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q5_0_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q5_1_f32); + GGML_METAL_ADD_KERNEL(mul_mm_q8_0_f32); GGML_METAL_ADD_KERNEL(mul_mm_q2_K_f32); GGML_METAL_ADD_KERNEL(mul_mm_q3_K_f32); GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32); @@ -335,6 +349,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(mul); GGML_METAL_DEL_KERNEL(mul_row); GGML_METAL_DEL_KERNEL(scale); + GGML_METAL_DEL_KERNEL(scale_4); GGML_METAL_DEL_KERNEL(silu); GGML_METAL_DEL_KERNEL(relu); GGML_METAL_DEL_KERNEL(gelu); @@ -346,6 +361,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(get_rows_f16); GGML_METAL_DEL_KERNEL(get_rows_q4_0); GGML_METAL_DEL_KERNEL(get_rows_q4_1); + GGML_METAL_DEL_KERNEL(get_rows_q5_0); + GGML_METAL_DEL_KERNEL(get_rows_q5_1); GGML_METAL_DEL_KERNEL(get_rows_q8_0); GGML_METAL_DEL_KERNEL(get_rows_q2_K); GGML_METAL_DEL_KERNEL(get_rows_q3_K); @@ -360,6 +377,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(mul_mv_f16_f32_l4); GGML_METAL_DEL_KERNEL(mul_mv_q4_0_f32); GGML_METAL_DEL_KERNEL(mul_mv_q4_1_f32); + GGML_METAL_DEL_KERNEL(mul_mv_q5_0_f32); + GGML_METAL_DEL_KERNEL(mul_mv_q5_1_f32); GGML_METAL_DEL_KERNEL(mul_mv_q8_0_f32); GGML_METAL_DEL_KERNEL(mul_mv_q2_K_f32); GGML_METAL_DEL_KERNEL(mul_mv_q3_K_f32); @@ -370,8 +389,10 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(mul_mm_f32_f32); GGML_METAL_DEL_KERNEL(mul_mm_f16_f32); GGML_METAL_DEL_KERNEL(mul_mm_q4_0_f32); - GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32); GGML_METAL_DEL_KERNEL(mul_mm_q4_1_f32); + GGML_METAL_DEL_KERNEL(mul_mm_q5_0_f32); + GGML_METAL_DEL_KERNEL(mul_mm_q5_1_f32); + GGML_METAL_DEL_KERNEL(mul_mm_q8_0_f32); GGML_METAL_DEL_KERNEL(mul_mm_q2_K_f32); GGML_METAL_DEL_KERNEL(mul_mm_q3_K_f32); GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32); @@ -779,8 +800,8 @@ void ggml_metal_graph_compute( } break; case GGML_OP_CONCAT: { + const int64_t nb = ne00; - int64_t nb = ne00; [encoder setComputePipelineState:ctx->pipeline_concat]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; @@ -812,6 +833,7 @@ void ggml_metal_graph_compute( [encoder setBytes:&nb length:sizeof(nb) atIndex:27]; const int nth = MIN(1024, ne0); + [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; case GGML_OP_ADD: @@ -904,13 +926,19 @@ void ggml_metal_graph_compute( const float scale = *(const float *) src1->data; - [encoder setComputePipelineState:ctx->pipeline_scale]; + int64_t n = ggml_nelements(dst); + + if (n % 4 == 0) { + n /= 4; + [encoder setComputePipelineState:ctx->pipeline_scale_4]; + } else { + [encoder setComputePipelineState:ctx->pipeline_scale]; + } + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&scale length:sizeof(scale) atIndex:2]; - const int64_t n = ggml_nelements(dst)/4; - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; case GGML_OP_UNARY: @@ -921,9 +949,10 @@ void ggml_metal_graph_compute( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - const int64_t n = ggml_nelements(dst)/4; + const int64_t n = ggml_nelements(dst); + GGML_ASSERT(n % 4 == 0); - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; case GGML_UNARY_OP_RELU: { @@ -941,9 +970,10 @@ void ggml_metal_graph_compute( [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - const int64_t n = ggml_nelements(dst)/4; + const int64_t n = ggml_nelements(dst); + GGML_ASSERT(n % 4 == 0); - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(n/4, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; default: { @@ -1040,7 +1070,7 @@ void ggml_metal_graph_compute( !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1t == GGML_TYPE_F32 && - ne00 % 32 == 0 && + ne00 % 32 == 0 && ne00 >= 64 && ne11 > ne11_mm_min) { //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); switch (src0->type) { @@ -1048,6 +1078,8 @@ void ggml_metal_graph_compute( case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_f16_f32]; break; case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_0_f32]; break; case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q4_1_f32]; break; + case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_0_f32]; break; + case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q5_1_f32]; break; case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q8_0_f32]; break; case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q2_K_f32]; break; case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_q3_K_f32]; break; @@ -1117,6 +1149,24 @@ void ggml_metal_graph_compute( nth1 = 8; [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32]; } break; + case GGML_TYPE_Q5_0: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 8; + nth1 = 8; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32]; + } break; + case GGML_TYPE_Q5_1: + { + GGML_ASSERT(ne02 == 1); + GGML_ASSERT(ne12 == 1); + + nth0 = 8; + nth1 = 8; + [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32]; + } break; case GGML_TYPE_Q8_0: { GGML_ASSERT(ne02 == 1); @@ -1197,7 +1247,8 @@ void ggml_metal_graph_compute( [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16]; [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17]; - if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q8_0 || + if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || + src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } @@ -1229,6 +1280,8 @@ void ggml_metal_graph_compute( case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break; case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break; + case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_0]; break; + case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_1]; break; case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q8_0]; break; case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_K]; break; case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_K]; break; @@ -1251,6 +1304,8 @@ void ggml_metal_graph_compute( } break; case GGML_OP_RMS_NORM: { + GGML_ASSERT(ne00 % 4 == 0); + float eps; memcpy(&eps, dst->op_params, sizeof(float)); @@ -1293,7 +1348,7 @@ void ggml_metal_graph_compute( const int nth = MIN(1024, ne00); - const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past); + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_head = ((int32_t *) dst->op_params)[1]; float max_bias; memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); diff --git a/src/ggml-metal.metal b/src/ggml-metal.metal index b6288db2..f4b46056 100644 --- a/src/ggml-metal.metal +++ b/src/ggml-metal.metal @@ -18,6 +18,21 @@ typedef struct { uint8_t qs[QK4_1 / 2]; // nibbles / quants } block_q4_1; +#define QK5_0 32 +typedef struct { + half d; // delta + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_0 / 2]; // nibbles / quants +} block_q5_0; + +#define QK5_1 32 +typedef struct { + half d; // delta + half m; // min + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_1 / 2]; // nibbles / quants +} block_q5_1; + #define QK8_0 32 typedef struct { half d; // delta @@ -110,9 +125,17 @@ kernel void kernel_mul_row( } kernel void kernel_scale( + device const float * src0, + device float * dst, + constant float & scale, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * scale; +} + +kernel void kernel_scale_4( device const float4 * src0, device float4 * dst, - constant float & scale, + constant float & scale, uint tpig[[thread_position_in_grid]]) { dst[tpig] = src0[tpig] * scale; } @@ -345,10 +368,11 @@ kernel void kernel_rms_norm( uint sgitg[[simdgroup_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint ntg[[threads_per_threadgroup]]) { - device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01); - device const float * x_scalar = (device const float *) x; - float4 sumf=0; - float all_sum=0; + device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01); + device const float * x_scalar = (device const float *) x; + + float4 sumf = 0; + float all_sum = 0; // parallel sum for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { @@ -361,6 +385,7 @@ kernel void kernel_rms_norm( } threadgroup_barrier(mem_flags::mem_threadgroup); + // broadcast, simd group number is ntg / 32 for (uint i = ntg / 32 / 2; i > 0; i /= 2) { if (tpitg < i) { @@ -368,7 +393,9 @@ kernel void kernel_rms_norm( } } if (tpitg == 0) { - for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];} + for (int i = 4 * (ne00 / 4); i < ne00; i++) { + sum[0] += x_scalar[i]; + } sum[0] /= ne00; } @@ -383,7 +410,9 @@ kernel void kernel_rms_norm( y[i00] = x[i00] * scale; } if (tpitg == 0) { - for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;} + for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) { + y_scalar[i00] = x_scalar[i00] * scale; + } } } @@ -393,8 +422,11 @@ kernel void kernel_rms_norm( // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { float d = qb_curr->d; + float2 acc = 0.f; + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2); + for (int i = 0; i < 8; i+=2) { acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) + yl[i + 1] * (qs[i / 2] & 0x0F00); @@ -411,8 +443,11 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) { float d = qb_curr->d; float m = qb_curr->m; - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); + float2 acc = 0.f; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); + for (int i = 0; i < 8; i+=2) { acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) + yl[i + 1] * (qs[i / 2] & 0x0F00); @@ -422,6 +457,49 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre return d * (acc[0] + acc[1]) + sumy * m; } +// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q5 quants begin (0 or QK5_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + + float2 acc = 0.f; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2); + const uint32_t qh = *((device const uint32_t *)qb_curr->qh); + + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) + + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); + acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) + + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); + } + return d * (sumy * -16.f + acc[0] + acc[1]); +} + +// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q5 quants begin (0 or QK5_1/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + float m = qb_curr->m; + + float2 acc = 0.f; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2); + const uint32_t qh = *((device const uint32_t *)qb_curr->qh); + + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) + + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); + acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) + + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); + } + return d * (acc[0] + acc[1]) + sumy * m; +} + // putting them in the kernel cause a significant performance penalty #define N_DST 4 // each SIMD group works on 4 rows #define N_SIMDGROUP 2 // number of SIMD groups in a thread group @@ -519,6 +597,43 @@ kernel void kernel_mul_mv_q4_1_f32( mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); } +kernel void kernel_mul_mv_q5_0_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01[[buffer(4)]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); +} + +kernel void kernel_mul_mv_q5_1_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01[[buffer(4)]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); +} + + #define NB_Q8_0 8 kernel void kernel_mul_mv_q8_0_f32( @@ -2143,6 +2258,62 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg } } +template +void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 3); + const float d = xb->d; + const float md = -16.h * xb->d; + const ushort mask = il ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = il ? 4 : 0; + + const int gh_mv = il ? 12 : 0; + const int gh_bk = il ? 0 : 4; + + for (int i = 0; i < 8; i++) { + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg[i/2][2*(i%2)+0] = d * x0 + md; + reg[i/2][2*(i%2)+1] = d * x1 + md; + } +} + +template +void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 4); + const float d = xb->d; + const float m = xb->m; + const ushort mask = il ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = il ? 4 : 0; + + const int gh_mv = il ? 12 : 0; + const int gh_bk = il ? 0 : 4; + + for (int i = 0; i < 8; i++) { + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg[i/2][2*(i%2)+0] = d * x0 + m; + reg[i/2][2*(i%2)+1] = d * x1 + m; + } +} + template void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { device const int8_t * qs = ((device const int8_t *)xb->qs); @@ -2484,6 +2655,8 @@ template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows; @@ -2512,6 +2685,8 @@ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; diff --git a/src/ggml-opencl.cpp b/src/ggml-opencl.cpp index 4a331f24..202bcb48 100644 --- a/src/ggml-opencl.cpp +++ b/src/ggml-opencl.cpp @@ -19,7 +19,7 @@ #pragma warning(disable: 4244 4267) // possible loss of data #endif -#define CL_DMMV_BLOCK_SIZE 32 +#define CL_DMMV_LOCAL_SIZE 32 #ifndef K_QUANTS_PER_ITERATION #define K_QUANTS_PER_ITERATION 1 @@ -338,7 +338,7 @@ __kernel void dequantize_mul_mat_vec_q2_K(__global const struct block_q2_K * xx, const int row = get_group_id(0); const int num_blocks_per_row = ncols / QK_K; - const int ib0 = row*num_blocks_per_row; + const int ib0 = row*num_blocks_per_row + get_global_offset(0); __global const struct block_q2_K * x = xx + ib0; @@ -413,7 +413,7 @@ __kernel void dequantize_mul_mat_vec_q3_K(__global const struct block_q3_K * xx, const int row = get_group_id(0); const int num_blocks_per_row = ncols / QK_K; - const int ib0 = row*num_blocks_per_row; + const int ib0 = row*num_blocks_per_row + get_global_offset(0); __global const struct block_q3_K * x = xx + ib0; @@ -489,7 +489,7 @@ __kernel void dequantize_mul_mat_vec_q4_K(__global const struct block_q4_K * xx, const int row = get_group_id(0); const int num_blocks_per_row = ncols / QK_K; - const int ib0 = row*num_blocks_per_row; + const int ib0 = row*num_blocks_per_row + get_global_offset(0); const int tid = get_local_id(0)/K_QUANTS_PER_ITERATION; // 0...15 const int ix = get_local_id(0)%K_QUANTS_PER_ITERATION; @@ -562,7 +562,7 @@ __kernel void dequantize_mul_mat_vec_q5_K(__global const struct block_q5_K * xx, const int row = get_group_id(0); const int num_blocks_per_row = ncols / QK_K; - const int ib0 = row*num_blocks_per_row; + const int ib0 = row*num_blocks_per_row + get_global_offset(0); const int tid = get_local_id(0)/2; // 0...15 const int ix = get_local_id(0)%2; @@ -641,7 +641,7 @@ __kernel void dequantize_mul_mat_vec_q6_K(__global const struct block_q6_K * xx, const int row = get_group_id(0); const int num_blocks_per_row = ncols / QK_K; - const int ib0 = row*num_blocks_per_row; + const int ib0 = row*num_blocks_per_row + get_global_offset(0); __global const struct block_q6_K * x = xx + ib0; @@ -745,19 +745,21 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __global float* y) { std::string dequant_mul_mat_vec_template = MULTILINE_QUOTE( __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) { - const int block_size = get_local_size(0); + const int local_size = get_local_size(0); const int row = get_group_id(0); const int tid = get_local_id(0); const uint qk = QUANT_K; const uint qr = QUANT_R; + const int col_step = local_size * 2; const int y_offset = qr == 1 ? 1 : qk/2; + x += get_global_offset(0); + tmp[tid] = 0; - for (int i = 0; i < ncols/block_size; i += 2) { - const int col = i*block_size + 2*tid; + for (int col = tid*2; col < ncols; col += col_step) { const int ib = (row*ncols + col)/qk; // block index const int iqs = (col%qk)/qr; // quant index const int iybs = col - col%qk; // y block start index @@ -773,7 +775,7 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float // sum up partial sums and write back result barrier(CLK_LOCAL_MEM_FENCE); - for (int s=block_size/2; s>0; s>>=1) { + for (int s=local_size/2; s>0; s>>=1) { if (tid < s) { tmp[tid] += tmp[tid + s]; } @@ -1393,75 +1395,46 @@ static void ggml_cl_mul_f32(const ggml_tensor * src0, const ggml_tensor * src1, const int64_t ne01 = src0->ne[1]; const int64_t ne02 = src0->ne[2]; const int64_t ne03 = src0->ne[3]; - const int64_t ne0 = ne00 * ne01 * ne02 * ne03; const int64_t ne10 = src1->ne[0]; const int64_t ne11 = src1->ne[1]; const int64_t ne12 = src1->ne[2]; const int64_t ne13 = src1->ne[3]; - const int64_t nb10 = src1->nb[0]; const int nb2 = dst->nb[2]; const int nb3 = dst->nb[3]; size_t x_size; size_t d_size; - cl_mem d_X = ggml_cl_pool_malloc(ne0 * sizeof(float), &x_size); // src0 + cl_mem d_X = ggml_cl_pool_malloc(ne00 * ne01 * sizeof(float), &x_size); // src0 cl_mem d_Y = (cl_mem) src1->extra; // src1 is already on device, broadcasted. - cl_mem d_D = ggml_cl_pool_malloc(ne0 * sizeof(float), &d_size); // dst + cl_mem d_D = ggml_cl_pool_malloc(ne00 * ne01 * sizeof(float), &d_size); // dst for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { - const int i0 = i03*ne02 + i02; - cl_event ev; // copy src0 to device - CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, i0, src0, i03, i02, &ev)); - - if (nb10 == sizeof(float)) { - // Contiguous, avoid overhead from queueing many kernel runs - const int64_t i13 = i03%ne13; - const int64_t i12 = i02%ne12; - const int i1 = i13*ne12*ne11 + i12*ne11; - - cl_int x_offset = 0; - cl_int y_offset = i1*ne10; - cl_int d_offset = 0; - - size_t global = ne00 * ne01; - cl_int ky = ne10; - CL_CHECK(clSetKernelArg(mul_f32_cl, 0, sizeof(cl_mem), &d_X)); - CL_CHECK(clSetKernelArg(mul_f32_cl, 1, sizeof(cl_int), &x_offset)); - CL_CHECK(clSetKernelArg(mul_f32_cl, 2, sizeof(cl_mem), &d_Y)); - CL_CHECK(clSetKernelArg(mul_f32_cl, 3, sizeof(cl_int), &y_offset)); - CL_CHECK(clSetKernelArg(mul_f32_cl, 4, sizeof(cl_mem), &d_D)); - CL_CHECK(clSetKernelArg(mul_f32_cl, 5, sizeof(cl_int), &d_offset)); - CL_CHECK(clSetKernelArg(mul_f32_cl, 6, sizeof(cl_int), &ky)); - CL_CHECK(clEnqueueNDRangeKernel(queue, mul_f32_cl, 1, NULL, &global, NULL, 1, &ev, NULL)); - } else { - for (int64_t i01 = 0; i01 < ne01; i01++) { - const int64_t i13 = i03%ne13; - const int64_t i12 = i02%ne12; - const int64_t i11 = i01%ne11; - const int i1 = i13*ne12*ne11 + i12*ne11 + i11; - - cl_int x_offset = i01*ne00; - cl_int y_offset = i1*ne10; - cl_int d_offset = i01*ne00; + CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, &ev)); - // compute - size_t global = ne00; - cl_int ky = ne10; - CL_CHECK(clSetKernelArg(mul_f32_cl, 0, sizeof(cl_mem), &d_X)); - CL_CHECK(clSetKernelArg(mul_f32_cl, 1, sizeof(cl_int), &x_offset)); - CL_CHECK(clSetKernelArg(mul_f32_cl, 2, sizeof(cl_mem), &d_Y)); - CL_CHECK(clSetKernelArg(mul_f32_cl, 3, sizeof(cl_int), &y_offset)); - CL_CHECK(clSetKernelArg(mul_f32_cl, 4, sizeof(cl_mem), &d_D)); - CL_CHECK(clSetKernelArg(mul_f32_cl, 5, sizeof(cl_int), &d_offset)); - CL_CHECK(clSetKernelArg(mul_f32_cl, 6, sizeof(cl_int), &ky)); - CL_CHECK(clEnqueueNDRangeKernel(queue, mul_f32_cl, 1, NULL, &global, NULL, 1, &ev, NULL)); - } - } + const int64_t i13 = i03%ne13; + const int64_t i12 = i02%ne12; + const int i1 = i13*ne12*ne11 + i12*ne11; + + cl_int x_offset = 0; + cl_int y_offset = i1*ne10; + cl_int d_offset = 0; + + size_t global = ne00 * ne01; + cl_int ky = ne10 * ne11; + + CL_CHECK(clSetKernelArg(mul_f32_cl, 0, sizeof(cl_mem), &d_X)); + CL_CHECK(clSetKernelArg(mul_f32_cl, 1, sizeof(cl_int), &x_offset)); + CL_CHECK(clSetKernelArg(mul_f32_cl, 2, sizeof(cl_mem), &d_Y)); + CL_CHECK(clSetKernelArg(mul_f32_cl, 3, sizeof(cl_int), &y_offset)); + CL_CHECK(clSetKernelArg(mul_f32_cl, 4, sizeof(cl_mem), &d_D)); + CL_CHECK(clSetKernelArg(mul_f32_cl, 5, sizeof(cl_int), &d_offset)); + CL_CHECK(clSetKernelArg(mul_f32_cl, 6, sizeof(cl_int), &ky)); + CL_CHECK(clEnqueueNDRangeKernel(queue, mul_f32_cl, 1, NULL, &global, NULL, 1, &ev, NULL)); CL_CHECK(clReleaseEvent(ev)); CL_CHECK(clFinish(queue)); @@ -1516,46 +1489,45 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr cl_mem d_D = ggml_cl_pool_malloc(sizeof(float) * d_ne, &d_size); size_t x_offset = 0; - int64_t pi02 = -1; - int64_t pi03 = -1; - - for (int64_t i13 = 0; i13 < ne13; i13++) { - int64_t i03 = i13 / r3; - - for (int64_t i12 = 0; i12 < ne12; i12++) { - int64_t i02 = i12 / r2; - - // copy data to device - if (src0->backend == GGML_BACKEND_GPU) { - x_offset = (i03 * ne02 + i02) * x_ne; - } else if (i02 != pi02 || i03 != pi03) { - CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, NULL)); - pi02 = i02; - pi03 = i03; - } - CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i13, i12, NULL)); - CL_CHECK(clFinish(queue)); + for (int64_t i03 = 0; i03 < ne03; i03++) { + // TODO: copy src0 here when r3>1 + for (int64_t i13 = i03 * r3, e13 = i13 + r3; i13 < e13; i13++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + if (src0->backend == GGML_BACKEND_GPU) { + x_offset = (i03 * ne02 + i02) * x_ne; + } else { + // copy src0 to device + CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, NULL)); + } - // compute - cl_event ev_sgemm; - clblast::StatusCode status = clblast::Gemm(clblast::Layout::kColMajor, - clblast::Transpose::kYes, clblast::Transpose::kNo, - ne01, ne11, ne10, - alpha, - d_X, x_offset, ne00, - d_Y, 0, ne10, - beta, - d_D, 0, ne01, - &queue, &ev_sgemm); - - if (status != clblast::StatusCode::kSuccess) { - GGML_ASSERT(false); - } + for (int64_t i12 = i02 * r2, e12 = i12 + r2; i12 < e12; i12++) { + // copy src1 to device + CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i13, i12, NULL)); - // copy dst to host - float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3); - CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(float) * d_ne, d, 1, &ev_sgemm, NULL)); + CL_CHECK(clFinish(queue)); + + // compute + cl_event ev_sgemm; + clblast::StatusCode status = clblast::Gemm(clblast::Layout::kColMajor, + clblast::Transpose::kYes, clblast::Transpose::kNo, + ne01, ne11, ne10, + alpha, + d_X, x_offset, ne00, + d_Y, 0, ne10, + beta, + d_D, 0, ne01, + &queue, &ev_sgemm); + + if (status != clblast::StatusCode::kSuccess) { + GGML_ASSERT(false); + } + + // copy dst to host + float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3); + CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(float) * d_ne, d, 1, &ev_sgemm, NULL)); + } + } } } @@ -1566,7 +1538,7 @@ static void ggml_cl_mul_mat_f32(const ggml_tensor * src0, const ggml_tensor * sr ggml_cl_pool_free(d_D, d_size); } -static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t /* wsize */) { +static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata, size_t wsize) { GGML_ASSERT(fp16_support); const int64_t ne00 = src0->ne[0]; @@ -1596,6 +1568,10 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr const int y_ne = ne11 * ne10; const int d_ne = ne11 * ne01; + GGML_ASSERT(wsize >= sizeof(ggml_fp16_t) * y_ne); + GGML_ASSERT(wsize >= sizeof(ggml_fp16_t) * d_ne); + ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata; + size_t x_size; size_t y_size; size_t d_size; @@ -1612,74 +1588,70 @@ static void ggml_cl_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor * sr bool src1_cont_cols = (size_t)nb11 == ne11*sizeof(float); size_t x_offset = 0; - int64_t pi02 = -1; - int64_t pi03 = -1; - for (int64_t i13 = 0; i13 < ne13; i13++) { - int64_t i03 = i13 / r3; - - for (int64_t i12 = 0; i12 < ne12; i12++) { - int64_t i02 = i12 / r2; - - // copy src0 to device - if (src0->backend == GGML_BACKEND_GPU) { - x_offset = (i03 * ne02 + i02) * x_ne; - } else if (i02 != pi02 || i03 != pi03) { - CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, NULL)); - pi02 = i02; - pi03 = i03; - } - - // convert src1 to fp16 - // TODO: use multiple threads - ggml_fp16_t * const tmp = (ggml_fp16_t *) wdata + (ne11 * ne10) * (i13 * ne12 + i12); - char * src1i = (char *) src1->data + i13*nb13 + i12*nb12; - if (src1_cont_rows) { - if (src1_cont_cols) { - ggml_fp32_to_fp16_row((float *) src1i, tmp, ne10*ne11); + for (int64_t i03 = 0; i03 < ne03; i03++) { + // TODO: copy src0 here when r3>1 + for (int64_t i13 = i03 * r3, e13 = i13 + r3; i13 < e13; i13++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + if (src0->backend == GGML_BACKEND_GPU) { + x_offset = (i03 * ne02 + i02) * x_ne; + } else { + // copy src0 to device + CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_X, 0, src0, i03, i02, NULL)); } - else { - for (int64_t i11 = 0; i11 < ne11; i11++) { - ggml_fp32_to_fp16_row((float *) (src1i + i11*nb11), tmp + i11*ne10, ne10); + + for (int64_t i12 = i02 * r2, e12 = i12 + r2; i12 < e12; i12++) { + // convert src1 to fp16 + // TODO: use multiple threads + char * src1i = (char *) src1->data + i13*nb13 + i12*nb12; + if (src1_cont_rows) { + if (src1_cont_cols) { + ggml_fp32_to_fp16_row((float *) src1i, tmp, ne10*ne11); + } + else { + for (int64_t i11 = 0; i11 < ne11; i11++) { + ggml_fp32_to_fp16_row((float *) (src1i + i11*nb11), tmp + i11*ne10, ne10); + } + } } - } - } - else { - for (int64_t i11 = 0; i11 < ne11; i11++) { - for (int64_t i10 = 0; i10 < ne10; i10++) { - // very slow due to no inlining - tmp[i11*ne10 + i10] = ggml_fp32_to_fp16(*(float *) (src1i + i11*nb11 + i10*nb10)); + else { + for (int64_t i11 = 0; i11 < ne11; i11++) { + for (int64_t i10 = 0; i10 < ne10; i10++) { + // very slow due to no inlining + tmp[i11*ne10 + i10] = ggml_fp32_to_fp16(*(float *) (src1i + i11*nb11 + i10*nb10)); + } + } } - } - } - // copy src1 to device - CL_CHECK(clEnqueueWriteBuffer(queue, d_Y, false, 0, sizeof(ggml_fp16_t) * y_ne, tmp, 0, NULL, NULL)); + // copy src1 to device + CL_CHECK(clEnqueueWriteBuffer(queue, d_Y, false, 0, sizeof(ggml_fp16_t) * y_ne, tmp, 0, NULL, NULL)); - CL_CHECK(clFinish(queue)); + CL_CHECK(clFinish(queue)); - // compute - cl_event ev_sgemm; - clblast::StatusCode status = clblast::Gemm(clblast::Layout::kColMajor, - clblast::Transpose::kYes, clblast::Transpose::kNo, - ne01, ne11, ne10, - alpha, - d_X, x_offset, ne00, - d_Y, 0, ne10, - beta, - d_D, 0, ne01, - &queue, &ev_sgemm); - - if (status != clblast::StatusCode::kSuccess) { - GGML_ASSERT(false); - } + // compute + cl_event ev_sgemm; + clblast::StatusCode status = clblast::Gemm(clblast::Layout::kColMajor, + clblast::Transpose::kYes, clblast::Transpose::kNo, + ne01, ne11, ne10, + alpha, + d_X, x_offset, ne00, + d_Y, 0, ne10, + beta, + d_D, 0, ne01, + &queue, &ev_sgemm); + + if (status != clblast::StatusCode::kSuccess) { + GGML_ASSERT(false); + } - // copy dst to host, then convert to float - CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(ggml_fp16_t) * d_ne, tmp, 1, &ev_sgemm, NULL)); + // copy dst to host, then convert to float + CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(ggml_fp16_t) * d_ne, tmp, 1, &ev_sgemm, NULL)); - float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3); + float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3); - ggml_fp16_to_fp32_row(tmp, d, d_ne); + ggml_fp16_to_fp32_row(tmp, d, d_ne); + } + } } } @@ -1704,7 +1676,7 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * const int nb2 = dst->nb[2]; const int nb3 = dst->nb[3]; const ggml_type type = src0->type; - const bool mul_mat_vec = ne11 == 1; + const bool mul_mat_vec = ne11 == 1 && ne00%2 == 0; const int64_t r2 = ne12 / ne02; const int64_t r3 = ne13 / ne03; @@ -1737,90 +1709,86 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * GGML_ASSERT(to_fp32_cl != nullptr); const size_t global_denom = ggml_cl_global_denom(type); - const size_t local = ggml_cl_local_size(type); + const size_t local = mul_mat_vec ? CL_DMMV_LOCAL_SIZE : ggml_cl_local_size(type); size_t ev_idx = 0; std::vector events; - int64_t pi02 = -1; - int64_t pi03 = -1; - - for (int64_t i13 = 0; i13 < ne13; i13++) { - int64_t i03 = i13 / r3; - - for (int64_t i12 = 0; i12 < ne12; i12++) { - int64_t i02 = i12 / r2; - - // copy src0 to device if necessary - if (src0->backend == GGML_BACKEND_CPU) { - if (i02 != pi02 || i03 != pi03) { + for (int64_t i03 = 0; i03 < ne03; i03++) { + // TODO: copy and dequantize src0 here when r3>1 + for (int64_t i13 = i03 * r3, e13 = i13 + r3; i13 < e13; i13++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + // copy src0 to device if necessary + if (src0->backend == GGML_BACKEND_CPU) { events.emplace_back(); CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Q, 0, src0, i03, i02, events.data() + ev_idx++)); - pi02 = i02; - pi03 = i03; - } - } else if (src0->backend == GGML_BACKEND_GPU) { - d_Q = (cl_mem) src0->extra; - } else { - GGML_ASSERT(false); - } - if (mul_mat_vec) { // specialized dequantize_mul_mat_vec kernel - // copy src1 to device - events.emplace_back(); - CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i13, i12, events.data() + ev_idx++)); - - // compute - const size_t global = ne01 * CL_DMMV_BLOCK_SIZE; - const size_t local = CL_DMMV_BLOCK_SIZE; - const cl_int ncols = ne00; - events.emplace_back(); - CL_CHECK(clSetKernelArg(*dmmv, 0, sizeof(cl_mem), &d_Q)); - CL_CHECK(clSetKernelArg(*dmmv, 1, sizeof(float) * local, NULL)); - CL_CHECK(clSetKernelArg(*dmmv, 2, sizeof(cl_mem), &d_Y)); - CL_CHECK(clSetKernelArg(*dmmv, 3, sizeof(cl_mem), &d_D)); - CL_CHECK(clSetKernelArg(*dmmv, 4, sizeof(cl_int), &ncols)); - CL_CHECK(clEnqueueNDRangeKernel(queue, *dmmv, 1, NULL, &global, &local, events.size() - 1, events.data(), events.data() + ev_idx++)); - } else { // general dequantization kernel + CLBlast matrix matrix multiplication - // convert src0 to fp32 on device - const size_t global = x_ne / global_denom; - const size_t offset = src0->backend == GGML_BACKEND_GPU ? (i03 * ne02 + i02) * x_bps : 0; - CL_CHECK(clSetKernelArg(*to_fp32_cl, 0, sizeof(cl_mem), &d_Q)); - CL_CHECK(clSetKernelArg(*to_fp32_cl, 1, sizeof(cl_mem), &d_X)); - CL_CHECK(clEnqueueNDRangeKernel(queue, *to_fp32_cl, 1, offset > 0 ? &offset : NULL, &global, local > 0 ? &local : NULL, events.size(), !events.empty() ? events.data() : NULL, NULL)); - - // copy src1 to device - CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i13, i12, NULL)); - - events.emplace_back(); - - // wait for conversion - CL_CHECK(clFinish(queue)); - - // compute - clblast::StatusCode status = clblast::Gemm(clblast::Layout::kColMajor, - clblast::Transpose::kYes, clblast::Transpose::kNo, - ne01, ne11, ne10, - alpha, - d_X, 0, ne00, - d_Y, 0, ne10, - beta, - d_D, 0, ne01, - &queue, events.data() + ev_idx++); - - if (status != clblast::StatusCode::kSuccess) { + } else if (src0->backend == GGML_BACKEND_GPU) { + d_Q = (cl_mem) src0->extra; + } else { GGML_ASSERT(false); } - } - // copy dst to host - float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3); - CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(float) * d_ne, d, 1, &events[events.size() - 1], NULL)); - for (auto *event : events) { - clReleaseEvent(event); - } + if (!mul_mat_vec) { + // convert src0 to fp32 on device + const size_t global = x_ne / global_denom; + const size_t offset = src0->backend == GGML_BACKEND_GPU ? (i03 * ne02 + i02) * x_bps : 0; + CL_CHECK(clSetKernelArg(*to_fp32_cl, 0, sizeof(cl_mem), &d_Q)); + CL_CHECK(clSetKernelArg(*to_fp32_cl, 1, sizeof(cl_mem), &d_X)); + CL_CHECK(clEnqueueNDRangeKernel(queue, *to_fp32_cl, 1, &offset, &global, local > 0 ? &local : NULL, events.size(), !events.empty() ? events.data() : NULL, NULL)); + } - ev_idx = 0; - events.clear(); + for (int64_t i12 = i02 * r2, e12 = i12 + r2; i12 < e12; i12++) { + if (mul_mat_vec) { // specialized dequantize_mul_mat_vec kernel + // copy src1 to device + events.emplace_back(); + CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i13, i12, events.data() + ev_idx++)); + + // compute + const size_t global = ne01 * local; + const size_t offset = src0->backend == GGML_BACKEND_GPU ? (i03 * ne02 + i02) * x_bps : 0; + const cl_int ncols = ne00; + events.emplace_back(); + CL_CHECK(clSetKernelArg(*dmmv, 0, sizeof(cl_mem), &d_Q)); + CL_CHECK(clSetKernelArg(*dmmv, 1, sizeof(float) * local, NULL)); + CL_CHECK(clSetKernelArg(*dmmv, 2, sizeof(cl_mem), &d_Y)); + CL_CHECK(clSetKernelArg(*dmmv, 3, sizeof(cl_mem), &d_D)); + CL_CHECK(clSetKernelArg(*dmmv, 4, sizeof(cl_int), &ncols)); + CL_CHECK(clEnqueueNDRangeKernel(queue, *dmmv, 1, &offset, &global, &local, events.size() - 1, events.data(), events.data() + ev_idx++)); + } else { // CLBlast matrix matrix multiplication + // copy src1 to device + CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i13, i12, NULL)); + + // wait for conversion + CL_CHECK(clFinish(queue)); + + // compute + events.emplace_back(); + clblast::StatusCode status = clblast::Gemm(clblast::Layout::kColMajor, + clblast::Transpose::kYes, clblast::Transpose::kNo, + ne01, ne11, ne10, + alpha, + d_X, 0, ne00, + d_Y, 0, ne10, + beta, + d_D, 0, ne01, + &queue, events.data() + ev_idx++); + + if (status != clblast::StatusCode::kSuccess) { + GGML_ASSERT(false); + } + } + + // copy dst to host + float * d = (float *) ((char *) dst->data + i12*nb2 + i13*nb3); + CL_CHECK(clEnqueueReadBuffer(queue, d_D, true, 0, sizeof(float) * d_ne, d, 1, &events[events.size() - 1], NULL)); + for (auto *event : events) { + clReleaseEvent(event); + } + + ev_idx = 0; + events.clear(); + } + } } } @@ -1895,8 +1863,8 @@ void ggml_cl_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * } size_t ggml_cl_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { - if (ggml_cl_mul_mat_use_f16(src0, src1, dst)) { - return ggml_nelements(src1) * sizeof(ggml_fp16_t); + if (src0->type == GGML_TYPE_F16 && ggml_cl_mul_mat_use_f16(src0, src1, dst)) { + return sizeof(ggml_fp16_t) * std::max(src1->ne[0] * src1->ne[1], dst->ne[0] * dst->ne[1]); } return 0; } diff --git a/src/ggml.c b/src/ggml.c index f71cfd8d..6f66bab0 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -5496,6 +5496,39 @@ struct ggml_tensor * ggml_view_tensor( return result; } +struct ggml_tensor * ggml_get_first_tensor(struct ggml_context * ctx) { + struct ggml_object * obj = ctx->objects_begin; + + char * const mem_buffer = ctx->mem_buffer; + + while (obj != NULL) { + if (obj->type == GGML_OBJECT_TENSOR) { + return (struct ggml_tensor *)(mem_buffer + obj->offs); + } + + obj = obj->next; + } + + return NULL; +} + +struct ggml_tensor * ggml_get_next_tensor(struct ggml_context * ctx, struct ggml_tensor * tensor) { + struct ggml_object * obj = (struct ggml_object *) ((char *)tensor - GGML_OBJECT_SIZE); + obj = obj->next; + + char * const mem_buffer = ctx->mem_buffer; + + while (obj != NULL) { + if (obj->type == GGML_OBJECT_TENSOR) { + return (struct ggml_tensor *)(mem_buffer + obj->offs); + } + + obj = obj->next; + } + + return NULL; +} + struct ggml_tensor * ggml_get_tensor(struct ggml_context * ctx, const char * name) { struct ggml_object * obj = ctx->objects_begin; @@ -8701,6 +8734,7 @@ void ggml_set_param( GGML_ASSERT(tensor->grad == NULL); tensor->grad = ggml_dup_tensor(ctx, tensor); + ggml_format_name(tensor->grad, "%s (grad)", tensor->name); } // ggml_compute_forward_dup @@ -11283,7 +11317,7 @@ static void ggml_compute_forward_silu_f32( #ifndef NDEBUG for (int k = 0; k < nc; k++) { - const float x = ((float *) ((char *) dst->data + i1*( dst->nb[1])))[k]; + const float x = ((float *) ((char *) dst->data + i1*(dst->nb[1])))[k]; UNUSED(x); assert(!isnan(x)); assert(!isinf(x)); @@ -13106,24 +13140,22 @@ static void ggml_compute_forward_alibi_f32( return; } - const int n_past = ((int32_t *) dst->op_params)[0]; UNUSED(n_past); + //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_head = ((int32_t *) dst->op_params)[1]; float max_bias; memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float)); - assert(n_past >= 0); + const int64_t ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 + const int64_t ne1 = src0->ne[1]; // seq_len_without_past + const int64_t ne2 = src0->ne[2]; // n_head -> this is k + //const int64_t ne3 = src0->ne[3]; // 1 -> bsz - const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 - const int ne1 = src0->ne[1]; // seq_len_without_past - const int ne2 = src0->ne[2]; // n_head -> this is k - //const int ne3 = src0->ne[3]; // 1 -> bsz + const int64_t n = ggml_nrows(src0); + const int64_t ne2_ne3 = n/ne1; // ne2*ne3 - const int n = ggml_nrows(src0); - const int ne2_ne3 = n/ne1; // ne2*ne3 - - const int nb0 = src0->nb[0]; - const int nb1 = src0->nb[1]; - const int nb2 = src0->nb[2]; + const size_t nb0 = src0->nb[0]; + const size_t nb1 = src0->nb[1]; + const size_t nb2 = src0->nb[2]; //const int nb3 = src0->nb[3]; GGML_ASSERT(nb0 == sizeof(float)); @@ -13135,9 +13167,9 @@ static void ggml_compute_forward_alibi_f32( const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor); - for (int i = 0; i < ne0; i++) { - for (int j = 0; j < ne1; j++) { - for (int k = 0; k < ne2_ne3; k++) { + for (int64_t i = 0; i < ne0; i++) { + for (int64_t j = 0; j < ne1; j++) { + for (int64_t k = 0; k < ne2_ne3; k++) { float * const src = (float *)((char *) src0->data + i*nb0 + j*nb1 + k*nb2); float * pdst = (float *)((char *) dst->data + i*nb0 + j*nb1 + k*nb2); @@ -13152,7 +13184,6 @@ static void ggml_compute_forward_alibi_f32( } pdst[0] = i * m_k + src[0]; - } } } @@ -13553,7 +13584,7 @@ static void ggml_compute_forward_rope_f16( dst_data[n_dims] = GGML_FP32_TO_FP16(x2*cos_block_theta - x3*sin_block_theta); dst_data[n_dims/2*3] = GGML_FP32_TO_FP16(x2*sin_block_theta + x3*cos_block_theta); } - } if (!is_neox) { + } else if (!is_neox) { for (int64_t i0 = 0; i0 < ne0; i0 += 2) { const float cos_theta = cosf(theta); const float sin_theta = sinf(theta); @@ -16820,6 +16851,10 @@ static void ggml_compute_forward_cross_entropy_loss_back( static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { GGML_ASSERT(params); + if (tensor->op == GGML_OP_NONE) { + return; + } + #ifdef GGML_USE_CUBLAS bool skip_cpu = ggml_cuda_compute_forward(params, tensor); if (skip_cpu) { @@ -19414,6 +19449,7 @@ void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname) { if (idx == -1) { fprintf(stderr, "%s: failed to find tensor, arg = %d, node = %d\n", __func__, j, i); + fclose(fout); return; } @@ -21087,7 +21123,7 @@ struct gguf_kv { }; struct gguf_header { - uint32_t magic; + char magic[4]; uint32_t version; uint64_t n_tensors; // GGUFv2 uint64_t n_kv; // GGUFv2 @@ -21157,7 +21193,7 @@ static bool gguf_fread_str_v1(FILE * file, struct gguf_str * p, size_t * offset) struct gguf_context * gguf_init_empty(void) { struct gguf_context * ctx = GGML_ALIGNED_MALLOC(sizeof(struct gguf_context)); - ctx->header.magic = GGUF_MAGIC; + memcpy(ctx->header.magic, GGUF_MAGIC, sizeof(ctx->header.magic)); ctx->header.version = GGUF_VERSION; ctx->header.n_tensors = 0; ctx->header.n_kv = 0; @@ -21183,16 +21219,18 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p // offset from start of file size_t offset = 0; - uint32_t magic = 0; + char magic[4]; // check the magic before making allocations { gguf_fread_el(file, &magic, sizeof(magic), &offset); - if (magic != GGUF_MAGIC) { - fprintf(stderr, "%s: invalid magic number %08x\n", __func__, magic); - fclose(file); - return NULL; + for (uint32_t i = 0; i < sizeof(magic); i++) { + if (magic[i] != GGUF_MAGIC[i]) { + fprintf(stderr, "%s: invalid magic characters %s.\n", __func__, magic); + fclose(file); + return NULL; + } } } @@ -21202,7 +21240,8 @@ struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_p // read the header { - ctx->header.magic = magic; + strncpy(ctx->header.magic, magic, 4); + ctx->kv = NULL; ctx->infos = NULL;