From: slaren Date: Thu, 7 Dec 2023 08:51:46 +0000 (+0100) Subject: test-backend-ops : add performance eval mode + improve CUDA repeat and binary broadca... X-Git-Tag: upstream/0.0.1642~1180 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=990f931f674ab8e9735ed2693faf20495be58604;p=pkg%2Fggml%2Fsources%2Fggml test-backend-ops : add performance eval mode + improve CUDA repeat and binary broadcast ops performance (#636) * ggml-cuda : implement repeat with bin_bcast * ggml-cuda : change supports_op for mul_mat to match compute_forward * test-backend-ops : add performance eval mode * improve formatting * add sd test cases * fix test case * ggml-cuda : bin_bcast: better block sizes, two elements per thread * metal : add dim3 broadcast support for mul mat * cleanup * typo * metal : enable mul mat-vec for dim2 > 1 * metal : mul mat-vec support dim3 broadcasts ggml-ci * ggml-cuda : fix bin_bcast for ne0=1 ggml-ci * ggml-cuda : limit block size z dim to 64 * test-backend-ops : add test cases * test-backend-ops : add warmup run, print test type before trying to compute * ggml-cuda : bin_bcast: collapse dimensions when possible, add fallback kernel for large tensors ggml-ci * test-backend-ops : avoid division by zero --------- Co-authored-by: Georgi Gerganov --- diff --git a/src/ggml-cuda.cu b/src/ggml-cuda.cu index dbe92d97..8fbd5c7e 100644 --- a/src/ggml-cuda.cu +++ b/src/ggml-cuda.cu @@ -434,7 +434,6 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_ #define WARP_SIZE 32 #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses -#define CUDA_ADDMUL_BLOCK_SIZE 256 #define CUDA_GELU_BLOCK_SIZE 256 #define CUDA_SILU_BLOCK_SIZE 256 #define CUDA_RELU_BLOCK_SIZE 256 @@ -501,6 +500,10 @@ static size_t g_scratch_offset = 0; static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; +static __device__ __forceinline__ float op_repeat(const float a, const float b) { + return b; +} + static __device__ __forceinline__ float op_add(const float a, const float b) { return a + b; } @@ -515,29 +518,69 @@ static __device__ __forceinline__ float op_div(const float a, const float b) { template static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst, - int ne0,/* int ne1, int ne2, */int ne3, + int ne0, int ne1, int ne2, int ne3, int ne10, int ne11, int ne12, int ne13, /*int s0, */ int s1, int s2, int s3, /*int s10,*/ int s11, int s12, int s13) { - const int i0 = blockDim.x*blockIdx.x + threadIdx.x; - const int i1 = blockIdx.y; - const int i2 = blockIdx.z / ne3; - const int i3 = blockIdx.z % ne3; + const int i0s = blockDim.x*blockIdx.x + threadIdx.x; + const int i1 = (blockDim.y*blockIdx.y + threadIdx.y); + const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3; + const int i3 = (blockDim.z*blockIdx.z + threadIdx.z) % ne3; - if (i0 >= ne0) { + if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { + return; + } + + const int i11 = i1 % ne11; + const int i12 = i2 % ne12; + const int i13 = i3 % ne13; + + const size_t i_src0 = i3*s3 + i2*s2 + i1*s1; + const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; + const size_t i_dst = i_src0; + + const src0_t * src0_row = src0 + i_src0; + const src1_t * src1_row = src1 + i_src1; + dst_t * dst_row = dst + i_dst; + + for (int i0 = i0s; i0 < ne0; i0 += blockDim.x*gridDim.x) { + const int i10 = i0 % ne10; + dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); + } +} + +template +static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst, + int ne0, int ne1, int ne2, int ne3, + int ne10, int ne11, int ne12, int ne13, + /*int s0, */ int s1, int s2, int s3, + /*int s10,*/ int s11, int s12, int s13) { + + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + const int i3 = i/(ne2*ne1*ne0); + const int i2 = (i/(ne1*ne0)) % ne2; + const int i1 = (i/ne0) % ne1; + const int i0 = i % ne0; + + if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { return; } - const int i10 = i0 % ne10; const int i11 = i1 % ne11; const int i12 = i2 % ne12; const int i13 = i3 % ne13; - const size_t i_dst = i3*s3 + i2*s2 + i1*s1 + i0; - const size_t i_src0 = i_dst; - const size_t i_src1 = i13*s13 + i12*s12 + i11*s11 + i10; + const size_t i_src0 = i3*s3 + i2*s2 + i1*s1; + const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; + const size_t i_dst = i_src0; + + const src0_t * src0_row = src0 + i_src0; + const src1_t * src1_row = src1 + i_src1; + dst_t * dst_row = dst + i_dst; - dst[i_dst] = (dst_t)bin_op((float)src0[i_src0], (float)src1[i_src1]); + const int i10 = i0 % ne10; + dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); } static __global__ void gelu_f32(const float * x, float * dst, const int k) { @@ -4849,24 +4892,108 @@ struct bin_bcast_cuda { GGML_TENSOR_BINARY_OP_LOCALS - //size_t s0 = nb0 / sizeof(src1_t); - size_t s1 = nb1 / sizeof(src1_t); - size_t s2 = nb2 / sizeof(src1_t); - size_t s3 = nb3 / sizeof(src1_t); - - //size_t s10 = nb10 / sizeof(src1_t); - size_t s11 = nb11 / sizeof(src1_t); - size_t s12 = nb12 / sizeof(src1_t); - size_t s13 = nb13 / sizeof(src1_t); - const int num_blocks_x = (ne0 + CUDA_ADDMUL_BLOCK_SIZE - 1) / CUDA_ADDMUL_BLOCK_SIZE; - dim3 num_blocks(num_blocks_x, ne1, ne2*ne3); + int nr0 = ne10/ne0; + int nr1 = ne11/ne1; + int nr2 = ne12/ne2; + int nr3 = ne13/ne3; + + int nr[4] = { nr0, nr1, nr2, nr3 }; + + // collapse dimensions until first broadcast dimension + int64_t cne0[] = {ne0, ne1, ne2, ne3}; + int64_t cne1[] = {ne10, ne11, ne12, ne13}; + size_t cnb0[] = {nb0, nb1, nb2, nb3}; + size_t cnb1[] = {nb10, nb11, nb12, nb13}; + auto collapse = [](int64_t cne[]) { + cne[0] *= cne[1]; + cne[1] = cne[2]; + cne[2] = cne[3]; + cne[3] = 1; + }; + + auto collapse_nb = [](size_t cnb[], int64_t cne[]) { + cnb[1] *= cne[1]; + cnb[2] *= cne[2]; + cnb[3] *= cne[3]; + }; + + for (int i = 0; i < 4; i++) { + if (nr[i] != 1) { + break; + } + if (i > 0) { + collapse_nb(cnb0, cne0); + collapse_nb(cnb1, cne1); + collapse(cne0); + collapse(cne1); + } + } + { + int64_t ne0 = cne0[0]; + int64_t ne1 = cne0[1]; + int64_t ne2 = cne0[2]; + int64_t ne3 = cne0[3]; + + int64_t ne10 = cne1[0]; + int64_t ne11 = cne1[1]; + int64_t ne12 = cne1[2]; + int64_t ne13 = cne1[3]; + + //size_t nb0 = cnb0[0]; + size_t nb1 = cnb0[1]; + size_t nb2 = cnb0[2]; + size_t nb3 = cnb0[3]; + + //size_t nb10 = cnb1[0]; + size_t nb11 = cnb1[1]; + size_t nb12 = cnb1[2]; + size_t nb13 = cnb1[3]; + + //size_t s0 = nb0 / sizeof(src1_t); + size_t s1 = nb1 / sizeof(src1_t); + size_t s2 = nb2 / sizeof(src1_t); + size_t s3 = nb3 / sizeof(src1_t); + + //size_t s10 = nb10 / sizeof(src1_t); + size_t s11 = nb11 / sizeof(src1_t); + size_t s12 = nb12 / sizeof(src1_t); + size_t s13 = nb13 / sizeof(src1_t); + + + const int block_size = 128; + + int64_t hne0 = std::max(ne0/2LL, 1LL); + + dim3 block_dims; + block_dims.x = std::min(hne0, block_size); + block_dims.y = std::min(ne1, block_size / block_dims.x); + block_dims.z = std::min(std::min(ne2*ne3, block_size / block_dims.x / block_dims.y), 64U); + + dim3 block_nums( + (hne0 + block_dims.x - 1) / block_dims.x, + (ne1 + block_dims.y - 1) / block_dims.y, + (ne2*ne3 + block_dims.z - 1) / block_dims.z + ); - k_bin_bcast<<>>(src0_dd, src1_dd, dst_dd, - ne0,/* ne1, ne2, */ne3, - ne10, ne11, ne12, ne13, - /* s0, */s1, s2, s3, - /* s10,*/ s11, s12, s13); + if (block_nums.z > 65535) { + // this is the maximum number of blocks in z direction, fallback to 1D grid kernel + int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size; + k_bin_bcast_unravel<<>>( + src0_dd, src1_dd, dst_dd, + ne0, ne1, ne2, ne3, + ne10, ne11, ne12, ne13, + /* s0, */ s1, s2, s3, + /* s10, */ s11, s12, s13); + } else { + k_bin_bcast<<>>( + src0_dd, src1_dd, dst_dd, + ne0, ne1, ne2, ne3, + ne10, ne11, ne12, ne13, + /* s0, */ s1, s2, s3, + /* s10, */ s11, s12, s13); + } + } } }; @@ -6096,63 +6223,6 @@ static cudaError_t ggml_cuda_cpy_tensor_2d( } } -static void ggml_cuda_op_repeat( - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, - const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & stream) { - // guaranteed to be an integer due to the check in ggml_can_repeat - const int64_t ne0 = dst->ne[0]; - const int64_t ne1 = dst->ne[1]; - const int64_t ne2 = dst->ne[2]; - const int64_t ne3 = dst->ne[3]; - - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - const int64_t ne03 = src0->ne[3]; - - const size_t nb0 = dst->nb[0]; - const size_t nb1 = dst->nb[1]; - const size_t nb2 = dst->nb[2]; - const size_t nb3 = dst->nb[3]; - - const size_t nb00 = src0->nb[0]; - const size_t nb01 = src0->nb[1]; - const size_t nb02 = src0->nb[2]; - const size_t nb03 = src0->nb[3]; - - const int nr0 = (int)(ne0/ne00); - const int nr1 = (int)(ne1/ne01); - const int nr2 = (int)(ne2/ne02); - const int nr3 = (int)(ne3/ne03); - - // TODO: support for transposed / permuted tensors - GGML_ASSERT(nb0 == sizeof(float)); - GGML_ASSERT(nb00 == sizeof(float)); - - // TODO: very inefficient, implement in a kernel, or fewer cudaMemcpyAsync calls for contiguous tensors - for (int i3 = 0; i3 < nr3; i3++) { - for (int k3 = 0; k3 < ne03; k3++) { - for (int i2 = 0; i2 < nr2; i2++) { - for (int k2 = 0; k2 < ne02; k2++) { - for (int i1 = 0; i1 < nr1; i1++) { - for (int k1 = 0; k1 < ne01; k1++) { - for (int i0 = 0; i0 < nr0; i0++) { - CUDA_CHECK(cudaMemcpyAsync( - (char *) dst_d + (i3*ne03 + k3)*nb3 + (i2*ne02 + k2)*nb2 + (i1*ne01 + k1)*nb1 + (i0*ne00)*nb0, - (const char *) src0_d + ( k3)*nb03 + ( k2)*nb02 + ( k1)*nb01, - ne00*nb0, cudaMemcpyDeviceToDevice, stream)); - } - } - } - } - } - } - } - - (void) src1; - (void) src1_d; -} - static void ggml_cuda_op_get_rows( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & stream) { @@ -6215,7 +6285,16 @@ inline void ggml_cuda_op_bin_bcast( ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type)); GGML_ASSERT(false); } +} +static void ggml_cuda_op_repeat( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_d, const float * src1_d, float * dst_d, const cudaStream_t & main_stream) { + + ggml_cuda_op_bin_bcast>(dst, src0, dst, nullptr, src0_d, dst_d, main_stream); + + (void) src1; + (void) src1_d; } inline void ggml_cuda_op_add( @@ -8393,7 +8472,8 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ break; default: return false; - } break; + } + break; case GGML_OP_NORM: func = ggml_cuda_norm; break; @@ -8842,10 +8922,10 @@ static void ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph UNUSED(backend); } -static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * tensor) { - switch (tensor->op) { +static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_tensor * op) { + switch (op->op) { case GGML_OP_UNARY: - switch (ggml_get_unary_op(tensor)) { + switch (ggml_get_unary_op(op)) { case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_RELU: @@ -8854,7 +8934,23 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten return false; } break; + case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: + { + struct ggml_tensor * a; + struct ggml_tensor * b; + if (op->op == GGML_OP_MUL_MAT) { + a = op->src[0]; + b = op->src[1]; + } else { + a = op->src[2]; + b = op->src[1]; + } + if (a->ne[3] != b->ne[3]) { + return false; + } + return true; + } break; case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_VIEW: @@ -8868,7 +8964,6 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_RMS_NORM: - case GGML_OP_MUL_MAT: case GGML_OP_SCALE: case GGML_OP_SQR: case GGML_OP_CLAMP: @@ -8913,6 +9008,9 @@ ggml_backend_t ggml_backend_cuda_init(int device) { return nullptr; } + // not strictly necessary, but it may reduce the overhead of the first graph_compute + ggml_cuda_set_main_device(device); + ggml_backend_context_cuda * ctx = new ggml_backend_context_cuda { /* .device = */ device }; diff --git a/src/ggml-metal.m b/src/ggml-metal.m index cff9d5bc..0c2f86ff 100644 --- a/src/ggml-metal.m +++ b/src/ggml-metal.m @@ -805,32 +805,14 @@ static bool ggml_metal_supports_op(const struct ggml_tensor * op) { case GGML_OP_DUP: case GGML_OP_CPY: case GGML_OP_CONT: + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: return true; case GGML_OP_DIAG_MASK_INF: case GGML_OP_GET_ROWS: { return op->ne[0] % 4 == 0; } break; - case GGML_OP_MUL_MAT: - case GGML_OP_MUL_MAT_ID: - { - struct ggml_tensor * a; - struct ggml_tensor * b; UNUSED(b); - if (op->op == GGML_OP_MUL_MAT) { - a = op->src[0]; - b = op->src[1]; - } else { - a = op->src[2]; - b = op->src[1]; - } - if (a->ne[3] != 1) { - return false; - } - if (ggml_is_quantized(a->type) && a->ne[2] != 1) { - return false; - } - return true; - } break; default: return false; } @@ -1222,9 +1204,13 @@ void ggml_metal_graph_compute( case GGML_OP_MUL_MAT: { GGML_ASSERT(ne00 == ne10); - GGML_ASSERT(ne03 == ne13); - const uint gqa = ne12/ne02; + // TODO: assert that dim2 and dim3 are contiguous + GGML_ASSERT(ne12 % ne02 == 0); + GGML_ASSERT(ne13 % ne03 == 0); + + const uint r2 = ne12/ne02; + const uint r3 = ne13/ne03; // find the break-even point where the matrix-matrix kernel becomes more efficient compared // to the matrix-vector kernel @@ -1289,9 +1275,10 @@ void ggml_metal_graph_compute( [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10]; [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11]; [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12]; - [encoder setBytes:&gqa length:sizeof(gqa) atIndex:13]; + [encoder setBytes:&r2 length:sizeof(r2) atIndex:13]; + [encoder setBytes:&r3 length:sizeof(r3) atIndex:14]; [encoder setThreadgroupMemoryLength:8192 atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { int nth0 = 32; int nth1 = 1; @@ -1327,90 +1314,60 @@ void ggml_metal_graph_compute( } break; case GGML_TYPE_Q4_0: { - GGML_ASSERT(ne02 == 1); - GGML_ASSERT(ne12 == 1); - nth0 = 8; nth1 = 8; [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_0_f32]; } break; case GGML_TYPE_Q4_1: { - GGML_ASSERT(ne02 == 1); - GGML_ASSERT(ne12 == 1); - nth0 = 8; nth1 = 8; [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_1_f32]; } break; case GGML_TYPE_Q5_0: { - GGML_ASSERT(ne02 == 1); - GGML_ASSERT(ne12 == 1); - nth0 = 8; nth1 = 8; [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_0_f32]; } break; case GGML_TYPE_Q5_1: { - GGML_ASSERT(ne02 == 1); - GGML_ASSERT(ne12 == 1); - nth0 = 8; nth1 = 8; [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_1_f32]; } break; case GGML_TYPE_Q8_0: { - GGML_ASSERT(ne02 == 1); - GGML_ASSERT(ne12 == 1); - nth0 = 8; nth1 = 8; [encoder setComputePipelineState:ctx->pipeline_mul_mv_q8_0_f32]; } break; case GGML_TYPE_Q2_K: { - GGML_ASSERT(ne02 == 1); - GGML_ASSERT(ne12 == 1); - nth0 = 2; nth1 = 32; [encoder setComputePipelineState:ctx->pipeline_mul_mv_q2_K_f32]; } break; case GGML_TYPE_Q3_K: { - GGML_ASSERT(ne02 == 1); - GGML_ASSERT(ne12 == 1); - nth0 = 2; nth1 = 32; [encoder setComputePipelineState:ctx->pipeline_mul_mv_q3_K_f32]; } break; case GGML_TYPE_Q4_K: { - GGML_ASSERT(ne02 == 1); - GGML_ASSERT(ne12 == 1); - nth0 = 4; //1; nth1 = 8; //32; [encoder setComputePipelineState:ctx->pipeline_mul_mv_q4_K_f32]; } break; case GGML_TYPE_Q5_K: { - GGML_ASSERT(ne02 == 1); - GGML_ASSERT(ne12 == 1); - nth0 = 2; nth1 = 32; [encoder setComputePipelineState:ctx->pipeline_mul_mv_q5_K_f32]; } break; case GGML_TYPE_Q6_K: { - GGML_ASSERT(ne02 == 1); - GGML_ASSERT(ne12 == 1); - nth0 = 2; nth1 = 32; [encoder setComputePipelineState:ctx->pipeline_mul_mv_q6_K_f32]; @@ -1439,31 +1396,32 @@ void ggml_metal_graph_compute( [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14]; [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15]; [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16]; - [encoder setBytes:&gqa length:sizeof(gqa) atIndex:17]; + [encoder setBytes:&r2 length:sizeof(r2) atIndex:17]; + [encoder setBytes:&r3 length:sizeof(r3) atIndex:18]; if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_Q4_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_Q3_K) { #ifdef GGML_QKK_64 - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; #else - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; #endif } else if (src0t == GGML_TYPE_Q5_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else if (src0t == GGML_TYPE_Q6_K) { - [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else { int64_t ny = (ne11 + nrows - 1)/nrows; - [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } } } break; @@ -1501,7 +1459,8 @@ void ggml_metal_graph_compute( //GGML_ASSERT(ne20 >= 64); GGML_ASSERT(src1t == GGML_TYPE_F32); - const uint gqa = ne12/ne22; + const uint r2 = ne12/ne22; + const uint r3 = ne13/ne23; // find the break-even point where the matrix-matrix kernel becomes more efficient compared // to the matrix-vector kernel @@ -1541,8 +1500,9 @@ void ggml_metal_graph_compute( [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10]; [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11]; [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12]; - [encoder setBytes:&gqa length:sizeof(gqa) atIndex:13]; - [encoder setBytes:&idx length:sizeof(idx) atIndex:14]; + [encoder setBytes:&r2 length:sizeof(r2) atIndex:13]; + [encoder setBytes:&r3 length:sizeof(r3) atIndex:14]; + [encoder setBytes:&idx length:sizeof(idx) atIndex:15]; // TODO: how to make this an array? read Metal docs for (int j = 0; j < n_as; ++j) { struct ggml_tensor * src_cur = dst->src[2 + j]; @@ -1550,11 +1510,11 @@ void ggml_metal_graph_compute( size_t offs_src_cur = 0; id id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur); - [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:15 + j]; + [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:16 + j]; } [encoder setThreadgroupMemoryLength:8192 atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } } break; case GGML_OP_GET_ROWS: diff --git a/src/ggml-metal.metal b/src/ggml-metal.metal index 4499b1bf..2def5f82 100644 --- a/src/ggml-metal.metal +++ b/src/ggml-metal.metal @@ -730,9 +730,20 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre // giard against the number of rows not being divisible by // N_DST, so this is another explicit assumption of the implementation. template -void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst, - int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne1, uint gqa, - uint3 tgpig, uint tiisg, uint sgitg) { +void mul_vec_q_n_f32( + device const void * src0, + device const float * src1, + device float * dst, + int64_t ne00, + int64_t ne01, + int64_t ne02, + int64_t ne10, + int64_t ne12, + int64_t ne0, + int64_t ne1, + uint r2, + uint r3, + uint3 tgpig, uint tiisg, uint sgitg) { const int nb = ne00/QK4_0; const int r0 = tgpig.x; @@ -741,7 +752,10 @@ void mul_vec_q_n_f32(device const void * src0, device const float * src1, device const int first_row = (r0 * nsg + sgitg) * nr; - const uint offset0 = first_row * nb + im/gqa*(nb*ne0); + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); device const block_q_type * x = (device const block_q_type *) src0 + offset0; device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; @@ -791,13 +805,14 @@ kernel void kernel_mul_mv_q4_0_f32( constant int64_t & ne02[[buffer(5)]], constant int64_t & ne10[[buffer(9)]], constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], + constant int64_t & ne0 [[buffer(15)]], + constant int64_t & ne1 [[buffer(16)]], + constant uint & r2 [[buffer(17)]], + constant uint & r3 [[buffer(18)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); + mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); } kernel void kernel_mul_mv_q4_1_f32( @@ -809,13 +824,14 @@ kernel void kernel_mul_mv_q4_1_f32( constant int64_t & ne02[[buffer(5)]], constant int64_t & ne10[[buffer(9)]], constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], + constant int64_t & ne0 [[buffer(15)]], + constant int64_t & ne1 [[buffer(16)]], + constant uint & r2 [[buffer(17)]], + constant uint & r3 [[buffer(18)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); + mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); } kernel void kernel_mul_mv_q5_0_f32( @@ -827,13 +843,14 @@ kernel void kernel_mul_mv_q5_0_f32( constant int64_t & ne02[[buffer(5)]], constant int64_t & ne10[[buffer(9)]], constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], + constant int64_t & ne0 [[buffer(15)]], + constant int64_t & ne1 [[buffer(16)]], + constant uint & r2 [[buffer(17)]], + constant uint & r3 [[buffer(18)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); + mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); } kernel void kernel_mul_mv_q5_1_f32( @@ -845,13 +862,14 @@ kernel void kernel_mul_mv_q5_1_f32( constant int64_t & ne02[[buffer(5)]], constant int64_t & ne10[[buffer(9)]], constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], + constant int64_t & ne0 [[buffer(15)]], + constant int64_t & ne1 [[buffer(16)]], + constant uint & r2 [[buffer(17)]], + constant uint & r3 [[buffer(18)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); + mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,r2,r3,tgpig,tiisg,sgitg); } @@ -866,9 +884,10 @@ kernel void kernel_mul_mv_q8_0_f32( constant int64_t & ne02[[buffer(5)]], constant int64_t & ne10[[buffer(9)]], constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], + constant int64_t & ne0 [[buffer(15)]], + constant int64_t & ne1 [[buffer(16)]], + constant uint & r2 [[buffer(17)]], + constant uint & r3 [[buffer(18)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -880,8 +899,14 @@ kernel void kernel_mul_mv_q8_0_f32( const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; + const int first_row = (r0 * nsg + sgitg) * nr; - const uint offset0 = first_row * nb + im/gqa*(nb*ne0); + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_q8_0 * x = (device const block_q8_0 *) src0 + offset0; device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; @@ -939,6 +964,8 @@ kernel void kernel_mul_mv_f32_f32( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, + constant uint & r2 [[buffer(17)]], + constant uint & r3 [[buffer(18)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { @@ -946,7 +973,12 @@ kernel void kernel_mul_mv_f32_f32( const int64_t rb = tgpig.y*N_F32_F32; const int64_t im = tgpig.z; - device const float * x = (device const float *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + + device const float * x = (device const float *) (src0 + offset0); if (ne00 < 128) { for (int row = 0; row < N_F32_F32; ++row) { @@ -1012,6 +1044,8 @@ kernel void kernel_mul_mv_f16_f16( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, + constant uint & r2 [[buffer(17)]], + constant uint & r3 [[buffer(18)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { @@ -1019,7 +1053,12 @@ kernel void kernel_mul_mv_f16_f16( const int64_t rb = tgpig.y*N_F16_F16; const int64_t im = tgpig.z; - device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + + device const half * x = (device const half *) (src0 + offset0); if (ne00 < 128) { for (int row = 0; row < N_F16_F16; ++row) { @@ -1083,6 +1122,8 @@ kernel void kernel_mul_mv_f16_f32_1row( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, + constant uint & r2 [[buffer(17)]], + constant uint & r3 [[buffer(18)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { @@ -1090,7 +1131,12 @@ kernel void kernel_mul_mv_f16_f32_1row( const int64_t r1 = tgpig.y; const int64_t im = tgpig.z; - device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + + device const half * x = (device const half *) (src0 + offset0); device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12); float sumf = 0; @@ -1137,6 +1183,8 @@ kernel void kernel_mul_mv_f16_f32( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, + constant uint & r2 [[buffer(17)]], + constant uint & r3 [[buffer(18)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { @@ -1144,7 +1192,12 @@ kernel void kernel_mul_mv_f16_f32( const int64_t rb = tgpig.y*N_F16_F32; const int64_t im = tgpig.z; - device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + + device const half * x = (device const half *) (src0 + offset0); if (ne00 < 128) { for (int row = 0; row < N_F16_F32; ++row) { @@ -1209,6 +1262,8 @@ kernel void kernel_mul_mv_f16_f32_l4( constant uint64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, + constant uint & r2 [[buffer(17)]], + constant uint & r3 [[buffer(18)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { @@ -1216,7 +1271,12 @@ kernel void kernel_mul_mv_f16_f32_l4( const int64_t r0 = tgpig.x; const int64_t im = tgpig.z; - device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb02*ne02; + + device const half4 * x4 = (device const half4 *) (src0 + offset0); for (int r1 = 0; r1 < nrows; ++r1) { device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12); @@ -1276,7 +1336,7 @@ kernel void kernel_alibi_f32( } else { m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1); } - + device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1; device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01; for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { @@ -1821,23 +1881,30 @@ kernel void kernel_mul_mv_q2_K_f32( constant int64_t & ne02[[buffer(5)]], constant int64_t & ne10[[buffer(9)]], constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], + constant int64_t & ne0 [[buffer(15)]], + constant int64_t & ne1 [[buffer(16)]], + constant uint & r2 [[buffer(17)]], + constant uint & r3 [[buffer(18)]], uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { const int nb = ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; - const int r2 = tgpig.z; + const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; const int ib_row = first_row * nb; - const uint offset0 = r2/gqa*(nb*ne0); + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -1846,11 +1913,11 @@ kernel void kernel_mul_mv_q2_K_f32( #if QK_K == 256 const int ix = tiisg/8; // 0...3 const int it = tiisg%8; // 0...7 - const int im = it/4; // 0 or 1 + const int iq = it/4; // 0 or 1 const int ir = it%4; // 0...3 const int is = (8*ir)/16;// 0 or 1 - device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir; + device const float * y4 = y + ix * QK_K + 128 * iq + 8 * ir; for (int ib = ix; ib < nb; ib += 4) { @@ -1862,8 +1929,8 @@ kernel void kernel_mul_mv_q2_K_f32( yl[i+24] = y4[i+96]; sumy[3] += yl[i+24]; } - device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*im + is; - device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir; + device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*iq + is; + device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; device const half * dh = &x[ib].d; for (int row = 0; row < N_DST; row++) { @@ -1950,7 +2017,7 @@ kernel void kernel_mul_mv_q2_K_f32( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum; + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; } } } @@ -1965,9 +2032,10 @@ kernel void kernel_mul_mv_q3_K_f32( constant int64_t & ne02[[buffer(5)]], constant int64_t & ne10[[buffer(9)]], constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], + constant int64_t & ne0 [[buffer(15)]], + constant int64_t & ne1 [[buffer(16)]], + constant uint & r2 [[buffer(17)]], + constant uint & r3 [[buffer(18)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -1976,12 +2044,17 @@ kernel void kernel_mul_mv_q3_K_f32( const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; - const int64_t r2 = tgpig.z; + const int64_t im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; - const uint offset0 = r2/gqa*(nb*ne0); + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; + device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; float yl[32]; @@ -2103,7 +2176,7 @@ kernel void kernel_mul_mv_q3_K_f32( } if (tiisg == 0) { for (int row = 0; row < 2; ++row) { - dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row]; + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row]; } } } @@ -2117,26 +2190,33 @@ kernel void kernel_mul_mv_q3_K_f32( constant int64_t & ne02[[buffer(5)]], constant int64_t & ne10[[buffer(9)]], constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], + constant int64_t & ne0 [[buffer(15)]], + constant int64_t & ne1 [[buffer(16)]], + constant uint & r2 [[buffer(17)]], + constant uint & r3 [[buffer(18)]], uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { const int nb = ne00/QK_K; const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; - const int64_t r2 = tgpig.z; + const int64_t im = tgpig.z; const int row = 2 * r0 + sgitg; - const uint offset0 = r2/gqa*(nb*ne0); + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; + device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + const int ix = tiisg/4; const int il = 4 * (tiisg%4);// 0, 4, 8, 12 - const int im = il/8; // 0, 0, 1, 1 + const int iq = il/8; // 0, 0, 1, 1 const int in = il%8; // 0, 4, 0, 4 float2 sum = {0.f, 0.f}; @@ -2156,7 +2236,7 @@ kernel void kernel_mul_mv_q3_K_f32( const float d4 = d_all * ((int32_t)(s[0] & 0xF000) - 32768) * 1.f/262144.f; for (int l = 0; l < 4; l += 2) { - const uint16_t hm = h[l/2] >> im; + const uint16_t hm = h[l/2] >> iq; sum[0] += y[l+ 0] * d1 * ((int32_t)(q[l/2] & 0x0003) - ((hm & 0x0001) ? 0 : 4)) + y[l+16] * d2 * ((int32_t)(q[l/2] & 0x000c) - ((hm & 0x0004) ? 0 : 16)) + y[l+32] * d3 * ((int32_t)(q[l/2] & 0x0030) - ((hm & 0x0010) ? 0 : 64)) @@ -2172,7 +2252,7 @@ kernel void kernel_mul_mv_q3_K_f32( const float tot = simd_sum(sumf); if (tiisg == 0) { - dst[r1*ne0 + r2*ne0*ne1 + row] = tot; + dst[r1*ne0 + im*ne0*ne1 + row] = tot; } } @@ -2190,10 +2270,11 @@ kernel void kernel_mul_mv_q4_K_f32( constant int64_t & ne12 [[buffer(11)]], constant int64_t & ne0 [[buffer(15)]], constant int64_t & ne1 [[buffer(16)]], - constant uint & gqa [[buffer(17)]], + constant uint & r2 [[buffer(17)]], + constant uint & r3 [[buffer(18)]], uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { const uint16_t kmask1 = 0x3f3f; const uint16_t kmask2 = 0x0f0f; @@ -2201,26 +2282,32 @@ kernel void kernel_mul_mv_q4_K_f32( const int ix = tiisg/8; // 0...3 const int it = tiisg%8; // 0...7 - const int im = it/4; // 0 or 1 + const int iq = it/4; // 0 or 1 const int ir = it%4; // 0...3 const int nb = ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; - const int r2 = tgpig.z; + const int im = tgpig.z; //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; const int first_row = r0 * N_DST; const int ib_row = first_row * nb; - const uint offset0 = r2/gqa*(nb*ne0); + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + float yl[16]; float yh[16]; float sumf[N_DST]={0.f}, all_sum; const int step = sizeof(block_q4_K) * nb / 2; - device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir; + device const float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; uint16_t sc16[4]; thread const uint8_t * sc8 = (thread const uint8_t *)sc16; @@ -2235,8 +2322,8 @@ kernel void kernel_mul_mv_q4_K_f32( yh[i+8] = y4[i+160]; sumy[3] += yh[i+8]; } - device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im; - device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir; + device const uint16_t * sc = (device const uint16_t *)x[ib].scales + iq; + device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * iq + 4 * ir; device const half * dh = &x[ib].d; for (int row = 0; row < N_DST; row++) { @@ -2280,7 +2367,7 @@ kernel void kernel_mul_mv_q4_K_f32( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = all_sum; + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; } } } @@ -2294,9 +2381,10 @@ kernel void kernel_mul_mv_q4_K_f32( constant int64_t & ne02[[buffer(5)]], constant int64_t & ne10[[buffer(9)]], constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], + constant int64_t & ne0 [[buffer(15)]], + constant int64_t & ne1 [[buffer(16)]], + constant uint & r2 [[buffer(17)]], + constant uint & r3 [[buffer(18)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -2307,12 +2395,18 @@ kernel void kernel_mul_mv_q4_K_f32( const int nb = ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; - const int r2 = tgpig.z; + const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; const int ib_row = first_row * nb; - const uint offset0 = r2/gqa*(nb*ne0); + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0; - device const float * y = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; + device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1; + float yl[8]; float yh[8]; float sumf[N_DST]={0.f}, all_sum; @@ -2368,7 +2462,7 @@ kernel void kernel_mul_mv_q4_K_f32( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0+ r2*ne0*ne1 + first_row + row] = all_sum; + dst[r1*ne0+ im*ne0*ne1 + first_row + row] = all_sum; } } } @@ -2383,9 +2477,10 @@ kernel void kernel_mul_mv_q5_K_f32( constant int64_t & ne02[[buffer(5)]], constant int64_t & ne10[[buffer(9)]], constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], + constant int64_t & ne0 [[buffer(15)]], + constant int64_t & ne1 [[buffer(16)]], + constant uint & r2 [[buffer(17)]], + constant uint & r3 [[buffer(18)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -2394,12 +2489,17 @@ kernel void kernel_mul_mv_q5_K_f32( const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; - const int r2 = tgpig.z; + const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; - const uint offset0 = r2/gqa*(nb*ne0); + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; + device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; float sumf[2]={0.f}; @@ -2415,15 +2515,15 @@ kernel void kernel_mul_mv_q5_K_f32( const int tid = tiisg/4; const int ix = tiisg%4; - const int im = tid/4; + const int iq = tid/4; const int ir = tid%4; const int n = 8; const int l0 = n*ir; - const int q_offset = 32*im + l0; - const int y_offset = 64*im + l0; + const int q_offset = 32*iq + l0; + const int y_offset = 64*iq + l0; - const uint8_t hm1 = 1u << (2*im); + const uint8_t hm1 = 1u << (2*iq); const uint8_t hm2 = hm1 << 1; const uint8_t hm3 = hm1 << 4; const uint8_t hm4 = hm2 << 4; @@ -2438,7 +2538,7 @@ kernel void kernel_mul_mv_q5_K_f32( device const uint8_t * q1 = x[i].qs + q_offset; device const uint8_t * qh = x[i].qh + l0; device const half * dh = &x[i].d; - device const uint16_t * a = (device const uint16_t *)x[i].scales + im; + device const uint16_t * a = (device const uint16_t *)x[i].scales + iq; device const float * y2 = y1 + 128; float4 sumy = {0.f, 0.f, 0.f, 0.f}; @@ -2494,7 +2594,7 @@ kernel void kernel_mul_mv_q5_K_f32( const int il = 4 * (tiisg/8); // 0, 4, 8, 12 const int ix = tiisg%8; - const int im = il/8; // 0, 0, 1, 1 + const int iq = il/8; // 0, 0, 1, 1 const int in = il%8; // 0, 4, 0, 4 device const float * y = yy + ix*QK_K + il; @@ -2519,7 +2619,7 @@ kernel void kernel_mul_mv_q5_K_f32( float2 acc = {0.f, 0.f}; for (int l = 0; l < 4; ++l) { - const uint8_t hl = h[l] >> im; + const uint8_t hl = h[l] >> iq; acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16)) + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16)); acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256)) @@ -2541,7 +2641,7 @@ kernel void kernel_mul_mv_q5_K_f32( for (int row = 0; row < 2; ++row) { const float tot = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot; + dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; } } @@ -2556,9 +2656,10 @@ kernel void kernel_mul_mv_q6_K_f32( constant int64_t & ne02[[buffer(5)]], constant int64_t & ne10[[buffer(9)]], constant int64_t & ne12[[buffer(11)]], - constant int64_t & ne0[[buffer(15)]], - constant int64_t & ne1[[buffer(16)]], - constant uint & gqa[[buffer(17)]], + constant int64_t & ne0 [[buffer(15)]], + constant int64_t & ne1 [[buffer(16)]], + constant uint & r2 [[buffer(17)]], + constant uint & r3 [[buffer(18)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -2572,12 +2673,17 @@ kernel void kernel_mul_mv_q6_K_f32( const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; - const int r2 = tgpig.z; + const int im = tgpig.z; const int row = 2 * r0 + sgitg; - const uint offset0 = r2/gqa*(nb*ne0); + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0; - device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; + device const float * yy = (device const float *) src1 + r1*ne10 + im*ne00*ne1; float sumf = 0; @@ -2643,7 +2749,7 @@ kernel void kernel_mul_mv_q6_K_f32( const float tot = simd_sum(sumf); if (tiisg == 0) { - dst[r1*ne0 + r2*ne0*ne1 + row] = tot; + dst[r1*ne0 + im*ne0*ne1 + row] = tot; } } @@ -2954,23 +3060,24 @@ kernel void kernel_get_rows( // each block_q contains 16*nl weights template void kernel_mul_mm_impl(device const uchar * src0, - device const uchar * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant int64_t & nb01, - constant int64_t & nb02, - constant int64_t & ne12, - constant int64_t & nb10, - constant int64_t & nb11, - constant int64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & gqa, - threadgroup uchar * shared_memory [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const uchar * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant int64_t & nb01, + constant int64_t & nb02, + constant int64_t & ne12, + constant int64_t & nb10, + constant int64_t & nb11, + constant int64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & r2, + constant uint & r3, + threadgroup uchar * shared_memory [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { threadgroup half * sa = (threadgroup half *)(shared_memory); threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); @@ -2996,7 +3103,10 @@ void kernel_mul_mm_impl(device const uchar * src0, short il = (tiitg % THREAD_PER_ROW); - uint offset0 = im/gqa*nb02; + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02); ushort offset1 = il/nl; device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; @@ -3094,7 +3204,8 @@ kernel void kernel_mul_mm(device const uchar * src0, constant int64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant uint & gqa, + constant uint & r2, + constant uint & r3, threadgroup uchar * shared_memory [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], @@ -3113,7 +3224,8 @@ kernel void kernel_mul_mm(device const uchar * src0, nb12, ne0, ne1, - gqa, + r2, + r3, shared_memory, tgpig, tiitg, @@ -3135,7 +3247,8 @@ kernel void kernel_mul_mm_id( constant int64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant uint & gqa, + constant uint & r2, + constant uint & r3, constant int & idx, device const uchar * src00, device const uchar * src01, @@ -3165,7 +3278,8 @@ kernel void kernel_mul_mm_id( nb12, ne0, ne1, - gqa, + r2, + r3, shared_memory, tgpig, tiitg, @@ -3214,7 +3328,8 @@ typedef void (mat_mm_t)( constant int64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant uint & gqa, + constant uint & r2, + constant uint & r3, threadgroup uchar *, uint3, uint, uint); @@ -3245,7 +3360,8 @@ typedef void (mat_mm_id_t)( constant int64_t & nb12, constant int64_t & ne0, constant int64_t & ne1, - constant uint & gqa, + constant uint & r2, + constant uint & r3, constant int & idx, device const uchar * src00, device const uchar * src01, diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 22ed469e..a5110c66 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include @@ -20,12 +21,35 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m std::vector data(size); std::random_device rd; + +#if 0 std::default_random_engine generator(rd()); std::uniform_real_distribution distribution(min, max); for (size_t i = 0; i < size; i++) { data[i] = distribution(generator); } +#endif + auto init_thread = [&](size_t start, size_t end) { + std::default_random_engine generator(rd()); + std::uniform_real_distribution distribution(min, max); + + for (size_t i = start; i < end; i++) { + data[i] = distribution(generator); + } + }; + + size_t n_threads = std::thread::hardware_concurrency(); + std::vector threads; + threads.reserve(n_threads); + for (size_t i = 0; i < n_threads; i++) { + size_t start = i*size/n_threads; + size_t end = (i+1)*size/n_threads; + threads.emplace_back(init_thread, start, end); + } + for (auto & t : threads) { + t.join(); + } if (tensor->type == GGML_TYPE_F32) { ggml_backend_tensor_set(tensor, data.data(), 0, size * sizeof(float)); @@ -202,6 +226,10 @@ static bool isinf_or_max(float f) { return std::isinf(f) || f == FLT_MAX || f == -FLT_MAX; } +static bool ggml_is_view_op(enum ggml_op op) { + return op == GGML_OP_VIEW || op == GGML_OP_RESHAPE || op == GGML_OP_PERMUTE || op == GGML_OP_TRANSPOSE; +} + struct test_case { virtual ~test_case() {} @@ -221,12 +249,23 @@ struct test_case { } } + virtual size_t op_size(ggml_tensor * t) { + size_t size = ggml_nbytes(t); + // add source tensors + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (t->src[i] != NULL) { + size += ggml_nbytes(t->src[i]); + } + } + return size; + } + bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_name) { ggml_init_params params = { /* .mem_size = */ ggml_tensor_overhead()*128 + ggml_graph_overhead(), /* .mem_base = */ NULL, /* .no_alloc = */ true, - }; + }; ggml_context * ctx = ggml_init(params); ggml_tensor * out = build_graph(ctx); @@ -237,10 +276,13 @@ struct test_case { return true; } + printf(" %s(%s): ", ggml_op_desc(out), vars().c_str()); + fflush(stdout); + // check if backends support op for (ggml_backend_t backend : {backend1, backend2}) { if (!ggml_backend_supports_op(backend, out)) { - printf(" %s: not supported\n", ggml_op_desc(out)); + printf("not supported\n"); ggml_free(ctx); return true; } @@ -275,7 +317,7 @@ struct test_case { for (size_t i = 0; i < f1.size(); i++) { // check for nans if (std::isnan(f1[i]) || std::isnan(f2[i])) { - printf(" Error: %s: NaN at index %zu\n", ggml_op_desc(t1), i); + printf("NaN at index %zu ", i); ud->ok = false; return true; } @@ -283,12 +325,12 @@ struct test_case { if (isinf_or_max(f1[i]) || isinf_or_max(f2[i])) { if (isinf_or_max(f1[i]) && isinf_or_max(f2[i])) { if (std::signbit(f1[i]) != std::signbit(f2[i])) { - printf(" Error: %s: inf sign mismatch: %f %f\n", ggml_op_desc(t1), f1[i], f2[i]); + printf("inf sign mismatch: %f %f ", f1[i], f2[i]); ud->ok = false; return true; } } else { - printf(" Error: %s: inf mismatch: %f %f\n", ggml_op_desc(t1), f1[i], f2[i]); + printf("inf mismatch: %f %f ", f1[i], f2[i]); ud->ok = false; return true; } @@ -297,15 +339,14 @@ struct test_case { double err = nmse(f1.data(), f2.data(), f1.size()); if (err > ud->max_err) { - printf(" Error: %s: NMSE = %f\n", ggml_op_desc(t1), err); + printf("NMSE = %f ", err); ud->ok = false; } return true; - }; + }; ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud); - printf(" %s(%s): ", ggml_op_desc(out), vars().c_str()); if (ud.ok) { printf("\033[1;32mOK\033[0m\n"); } else { @@ -318,6 +359,103 @@ struct test_case { return ud.ok; } + + bool eval_perf(ggml_backend_t backend, const char * op_name) { + static const size_t graph_nodes = 8192; + + ggml_init_params params = { + /* .mem_size = */ ggml_tensor_overhead()*128 + ggml_graph_overhead_custom(graph_nodes, false), + /* .mem_base = */ NULL, + /* .no_alloc = */ true, + }; + ggml_context * ctx = ggml_init(params); + + ggml_tensor * out = build_graph(ctx); + + if (op_name != nullptr && strcmp(ggml_op_desc(out), op_name) != 0) { + //printf(" %s: skipping\n", ggml_op_desc(out)); + ggml_free(ctx); + return true; + } + + int len = printf(" %s(%s): ", ggml_op_desc(out), vars().c_str()); + fflush(stdout); + + // check if backends support op + if (!ggml_backend_supports_op(backend, out)) { + printf("not supported\n"); + ggml_free(ctx); + return true; + } + + // align while also leaving some margin for variations in parameters + int align = 20; + int last = (len + align - 1) / align * align; + if (last - len < 5) { + last += align; + } + last = std::max(last, 60); + printf("%*s", last - len, ""); + + // allocate + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx, backend); + + // randomize tensors + initialize_tensors(ctx); + + // build graph + ggml_cgraph * gf = ggml_new_graph_custom(ctx, graph_nodes, false); + ggml_build_forward_expand(gf, out); + + // warmup run + ggml_backend_graph_compute(backend, gf); + + // duplicate the op + size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU + int n_runs = std::min((size_t)gf->size - gf->n_nodes, target_size / op_size(out)) + 1; + for (int i = 1; i < n_runs; i++) { + gf->nodes[gf->n_nodes++] = out; + } + + // calculate memory + size_t mem = n_runs * op_size(out); + auto tensor_op_size = [](ggml_tensor * t) { + size_t size = ggml_nbytes(t); + // add source tensors + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (t->src[i] != NULL) { + size += ggml_nbytes(t->src[i]); + } + } + return size; + }; + for (int i = 0; i < gf->n_nodes; i++) { + if (ggml_is_view_op(gf->nodes[i]->op) || gf->nodes[i] == out) + continue; + mem += tensor_op_size(gf->nodes[i]); + } + + // run + ggml_backend_synchronize(backend); + + int64_t start_time = ggml_time_us(); + ggml_backend_graph_compute(backend, gf); + ggml_backend_synchronize(backend); + int64_t end_time = ggml_time_us(); + double time_us = end_time - start_time; + + printf(" %5d runs - %8.2f us/run - %8zu kB/run - \033[1;34m%7.2f GB/s\033[0m\n", + n_runs, + time_us / n_runs, + op_size(out) / 1024, + mem / (time_us/1e6) / 1024.0 / 1024.0 / 1024.0); + + ggml_backend_buffer_free(buf); + + ggml_free(ctx); + + return true; + } }; // GGML_OP_UNARY @@ -389,6 +527,10 @@ struct test_repeat : public test_case { return VARS_TO_STR3(type, ne, nr); } + size_t op_size(ggml_tensor * t) override { + return ggml_nbytes(t) * 2; + } + test_repeat(ggml_type type = GGML_TYPE_F32, std::array ne = {10, 10, 10, 10}, std::array nr = {2, 2, 2, 2}) @@ -432,6 +574,10 @@ struct test_cpy : public test_case { return VARS_TO_STR3(type_src, type_dst, ne); } + size_t op_size(ggml_tensor * t) override { + return ggml_nbytes(t) + ggml_nbytes(t->src[0]); + } + test_cpy(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32, std::array ne = {10, 10, 10, 1}) : type_src(type_src), type_dst(type_dst), ne(ne) {} @@ -470,16 +616,20 @@ struct test_cont : public test_case { // GGML_OP_MUL // GGML_OP_DIV struct test_bin_bcast : public test_case { - using op_t = std::function; + using op_t = ggml_tensor * (*) (ggml_context *, ggml_tensor *, ggml_tensor *); op_t op; const ggml_type type; const std::array ne; - const std::array nr; + const std::array nr; std::string vars() override { return VARS_TO_STR3(type, ne, nr); } + size_t op_size(ggml_tensor * t) override { + return ggml_nbytes(t) * 3; + } + test_bin_bcast(op_t op, ggml_type type = GGML_TYPE_F32, std::array ne = {10, 10, 1, 1}, std::array nr = {1, 2, 1, 1}) @@ -491,6 +641,17 @@ struct test_bin_bcast : public test_case { ggml_tensor * out = op(ctx, a, b); return out; } + + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (op == ggml_div) { + // avoid division by zero + init_tensor_uniform(t, 1.0f, 2.0f); + } else { + init_tensor_uniform(t); + } + } + } }; // GGML_OP_SCALE @@ -576,6 +737,15 @@ struct test_mul_mat : public test_case { return 5e-4; } + size_t op_size(ggml_tensor * t) override { + size_t a = ggml_nbytes(t->src[0]) * n * nr[0] * nr[1]; + size_t b = ggml_nbytes(t->src[1]) * m; + size_t c = ggml_nbytes(t); + return a + b + c; + + GGML_UNUSED(t); + } + test_mul_mat(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32, int64_t m = 32, int64_t n = 32, int64_t k = 32, std::array bs = {10, 10}, @@ -584,13 +754,80 @@ struct test_mul_mat : public test_case { ggml_tensor * build_graph(ggml_context * ctx) override { // C^T = A * B^T: (k, m) * (k, n) => (m, n) - ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0]*nr[0], bs[1]*nr[1]); + ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0] , bs[1]); ggml_tensor * b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]); ggml_tensor * out = ggml_mul_mat(ctx, a, b); return out; } }; +// GGML_OP_MUL_MAT_ID +struct test_mul_mat_id : public test_case { + const ggml_type type_a; + const ggml_type type_b; + const int n_mats; + const int id; + const int64_t m; + const int64_t n; + const int64_t k; + const std::array bs; // dims 3 and 4 + const std::array nr; // repeat in dims 3 and 4 + + std::string vars() override { + return VARS_TO_STR9(type_a, type_b, n_mats, id, m, n, k, bs, nr); + } + + double max_nmse_err() override { + return 5e-4; + } + + size_t op_size(ggml_tensor * t) override { + size_t a = ggml_nbytes(t->src[2]) * n * nr[0] * nr[1]; + size_t b = ggml_nbytes(t->src[1]) * m; + size_t c = ggml_nbytes(t); + return a + b + c; + + GGML_UNUSED(t); + } + + test_mul_mat_id(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32, + int n_mats = 2, int id = 0, + int64_t m = 32, int64_t n = 32, int64_t k = 32, + std::array bs = {10, 10}, + std::array nr = {2, 2}) + : type_a(type_a), type_b(type_b), n_mats(n_mats), id(id), + m(m), n(n), k(k), bs(bs), nr(nr) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + // C^T = A * B^T: (k, m) * (k, n) => (m, n) + std::vector mats; + for (int i = 0; i < n_mats; i++) { + ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0], bs[1]); + mats.push_back(a); + } + ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_mats); + ggml_tensor * b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]); + ggml_tensor * out = ggml_mul_mat_id(ctx, mats.data(), ids, id, b); + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (t->type == GGML_TYPE_I32) { + // ids + std::vector data(n_mats); + for (int i = 0; i < n_mats; i++) { + data[i] = i; + } + std::shuffle(data.begin(), data.end(), std::default_random_engine(std::random_device()())); + ggml_backend_tensor_set(t, data.data(), 0, n_mats * sizeof(int)); + } else { + init_tensor_uniform(t); + } + } + } +}; + // GGML_OP_SQR struct test_sqr : public test_case { const ggml_type type; @@ -852,64 +1089,6 @@ struct test_argsort : public test_case { } }; -// GGML_OP_MUL_MAT_ID -struct test_mul_mat_id : public test_case { - const ggml_type type_a; - const ggml_type type_b; - const int n_mats; - const int id; - const int64_t m; - const int64_t n; - const int64_t k; - const std::array bs; // dims 3 and 4 - const std::array nr; // repeat in dims 3 and 4 - - std::string vars() override { - return VARS_TO_STR9(type_a, type_b, n_mats, id, m, n, k, bs, nr); - } - - double max_nmse_err() override { - return 5e-4; - } - - test_mul_mat_id(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32, - int n_mats = 2, int id = 0, - int64_t m = 32, int64_t n = 32, int64_t k = 32, - std::array bs = {10, 10}, - std::array nr = {2, 2}) - : type_a(type_a), type_b(type_b), n_mats(n_mats), id(id), - m(m), n(n), k(k), bs(bs), nr(nr) {} - - ggml_tensor * build_graph(ggml_context * ctx) override { - // C^T = A * B^T: (k, m) * (k, n) => (m, n) - std::vector mats; - for (int i = 0; i < n_mats; i++) { - ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0]*nr[0], bs[1]*nr[1]); - mats.push_back(a); - } - ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_mats); - ggml_tensor * b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]); - ggml_tensor * out = ggml_mul_mat_id(ctx, mats.data(), ids, id, b); - return out; - } - - void initialize_tensors(ggml_context * ctx) override { - for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { - if (t->type == GGML_TYPE_I32) { - // ids - std::vector data(n_mats); - for (int i = 0; i < n_mats; i++) { - data[i] = i; - } - std::shuffle(data.begin(), data.end(), std::default_random_engine(std::random_device()())); - ggml_backend_tensor_set(t, data.data(), 0, n_mats * sizeof(int)); - } else { - init_tensor_uniform(t); - } - } - } -}; - // GGML_OP_SUM_ROWS struct test_sum_rows : public test_case { const ggml_type type; @@ -936,8 +1115,6 @@ enum test_mode { }; static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) { - ggml_backend_t backend_cpu = ggml_backend_cpu_init(); - std::vector> test_cases; // unary ops @@ -950,7 +1127,12 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_get_rows(type, 16, 5, 3)); } - test_cases.emplace_back(new test_repeat()); + test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 1})); + test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {2, 1, 1, 1})); + test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 2, 1, 1})); + test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 2, 1})); + test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 10, 10, 10}, {1, 1, 1, 2})); + test_cases.emplace_back(new test_dup()); test_cases.emplace_back(new test_cpy()); test_cases.emplace_back(new test_cont()); @@ -961,6 +1143,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } }; + add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 8, 1}, {1, 1, 1, 1}); + add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 320, 320}, {1, 1, 1, 1}); add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 1, 1}, {1, 1, 1, 1}); add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 1}, {1, 1, 1, 1}); add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 1, 1}); @@ -972,6 +1156,23 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 2, 2, 2}); add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {2, 2, 2, 2}); + // stable diffusion + add_test_bin_bcast(GGML_TYPE_F32, {1280, 1, 1, 1}, {1, 1, 1, 1}); + add_test_bin_bcast(GGML_TYPE_F32, {1280, 1, 1, 1}, {1, 16, 16, 1}); + add_test_bin_bcast(GGML_TYPE_F32, {1280, 16, 16, 1}, {1, 1, 1, 1}); + add_test_bin_bcast(GGML_TYPE_F32, {1280, 1, 1, 1}, {1, 256, 1, 1}); + add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1280, 1}, {16, 16, 1, 1}); + add_test_bin_bcast(GGML_TYPE_F32, {16, 16, 1280, 1}, {1, 1, 1, 1}); + add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1920, 1}, {16, 16, 1, 1}); + add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 2560, 1}, {16, 16, 1, 1}); + add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1280, 1}, {32, 32, 1, 1}); + add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 1920, 1}, {32, 32, 1, 1}); + add_test_bin_bcast(GGML_TYPE_F32, {1, 1, 640, 1}, {32, 32, 1, 1}); + add_test_bin_bcast(GGML_TYPE_F32, {5120, 1, 1, 1}, {1, 256, 1, 1}); + add_test_bin_bcast(GGML_TYPE_F32, {640, 1, 1, 1}, {1, 1, 1, 1}); + add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {1, 1, 1, 1}); + add_test_bin_bcast(GGML_TYPE_F32, {3, 3, 2560, 1280}, {2, 1, 1, 1}); + test_cases.emplace_back(new test_scale()); for (float eps : {1e-6f, 1e-5f, 1e-3f, 1e-1f}) { @@ -979,7 +1180,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps)); } - ggml_type all_types[] = { + const ggml_type all_types[] = { GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, @@ -1010,6 +1211,16 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } } + for (ggml_type type_a : all_types) { + for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) { + for (int n_mats : {1, 2, 4}) { + for (int id = 0; id < n_mats; id++) { + test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 16, 256, {1, 1}, {1, 1})); + } + } + } + } + test_cases.emplace_back(new test_sqr()); test_cases.emplace_back(new test_clamp()); @@ -1028,7 +1239,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_rope(type, { 64, 71, 10, 1}, 64, 2, 512)); // neox (falcon 7B) test_cases.emplace_back(new test_rope(type, { 64, 8, 10, 1}, 64, 2, 512)); // neox (falcon 40B) test_cases.emplace_back(new test_rope(type, { 64, 128, 10, 1}, 64, 2, 512)); // neox (falcon 40B) - //test_cases.emplace_back(new test_rope(type, {80, 32, 10, 1}, 20, 2, 512)); // neox rope (stablelm) (TODO: enable after llama.cpp sync) + //test_cases.emplace_back(new test_rope(type, {80, 32, 10, 1}, 20, 2, 512)); // neox (stablelm) (TODO: enable after llama.cpp sync) } test_cases.emplace_back(new test_alibi()); @@ -1039,39 +1250,37 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order)); } - for (ggml_type type_a : all_types) { - for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) { - for (int n_mats : {1, 2, 4}) { - for (int id = 0; id < n_mats; id++) { - test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 16, 256, {1, 1}, {1, 1})); - } - } - } - } - test_cases.emplace_back(new test_sum_rows()); // run tests - size_t n_ok = 0; - for (auto & test : test_cases) { - if (test->eval(backend, backend_cpu, op_name)) { - n_ok++; - } - } + if (mode == MODE_TEST) { + ggml_backend_t backend_cpu = ggml_backend_cpu_init(); - printf(" %zu/%zu tests passed\n", n_ok, test_cases.size()); + size_t n_ok = 0; + for (auto & test : test_cases) { + if (test->eval(backend, backend_cpu, op_name)) { + n_ok++; + } + } + printf(" %zu/%zu tests passed\n", n_ok, test_cases.size()); - ggml_backend_free(backend_cpu); + ggml_backend_free(backend_cpu); - return n_ok == test_cases.size(); + return n_ok == test_cases.size(); + } else if (mode == MODE_PERF) { + for (auto & test : test_cases) { + test->eval_perf(backend, op_name); + } + return true; + } else { + GGML_ASSERT(false); + } } static void usage(char ** argv) { - // command line: test-backend-ops [mode] [-o op] [-b backend] - // modes are correctness (compare with CPU) or performance printf("Usage: %s [mode] [-o op] [-b backend]\n", argv[0]); - printf(" valid modes are: test (compare with CPU backend for correctness) or perf (performance evaluation) [not implemented]\n"); - printf(" op names are as given ggml_op_desc()\n"); + printf(" valid modes are: test (compare with CPU backend for correctness) or perf (performance evaluation)\n"); + printf(" op names are as given by ggml_op_desc()\n"); } int main(int argc, char ** argv) {