From: slaren Date: Tue, 5 Dec 2023 12:56:07 +0000 (+0100) Subject: ggml : full broadcast in mul, add, div + ggml_mul_mat_id, ggml_argsort, ggml_top_k... X-Git-Tag: upstream/0.0.1642~1183 X-Git-Url: https://git.djapps.eu/?a=commitdiff_plain;h=703825ffab5b90d16ca295eb8a57d5c81cf0282b;p=pkg%2Fggml%2Fsources%2Fggml ggml : full broadcast in mul, add, div + ggml_mul_mat_id, ggml_argsort, ggml_top_k (#625) * ggml : support broadcasting in dim 0 in add and mul * add cuda add/mul broadcast impl add configurable eps to cuda norm * add metal impl ggml-ci * deduplicate code in cuda impl * try to optimize cuda impl * ggml : support broadcasting in ggml_div * test-backend-ops : allow filtering by op and backend * ggml-cuda : add ggml_div impl * ggml : add ggml_mul_mat_id, ggml_sort, ggml_top_k (CPU only) * fix ggml_div threads * fix ggml_div with accelerate * ggml_sort -> ggml_argsort * whatever * actually fix accelerate div * disable opencl ci * ci : disable ctest error check temporarily until we fix backend ops test * cmake : propagete GGML_USE_xxx compile flags with ggml target * whisper : utlize new ggml_add broadcast for dim 0 * cmake : adendum to ee666ae9 * ggml_backend_graph_copy : fix leak * ggml_cuda : add ggml_sum_rows impl * metal : add ggml_div * metal : add ggml_sum_rows * ggml_cuda : add ggml_argsort impl * move kernel * metal : add ggml_argsort * mul_mat_id : fix missing init task * cuda/metal: fix argsort synchronization * metal : add ggml_mul_mat_id * ggml-cuda : add mul_mat_id for f16 + tensor cores * test-backend-ops : add tests for quants mat mul * ggml : fix q5_0 and q5_1 hist stats * test-backend-ops : use smaller matrices to avoid automatic offloading, add mat-vec tests * metal : fix alibi to match the CPU behavior * metal : check dimensions in supports_op * test-backend-ops : reduce error threshold for mat muls * ggml-cuda : simplify dequantize funs, add supports_op by type for mul_mat_id * ggml-cuda : support quantized types in mul_mat_id with cublas * ggml-cuda : add fallback over CPU for mul_mat_id * test-backend-ops : increase mul mat error threshold * cleanup ggml-ci * test-backend-ops : fix usage * cleanup * ci : re-enable tests * metal : fix compile warnings --------- Co-authored-by: Georgi Gerganov --- diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index e5543d92..ba719f65 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -8,6 +8,7 @@ on: jobs: test-ubuntu-opencl: + if: false runs-on: ubuntu-latest env: GGML_NLOOP: 3 diff --git a/examples/starcoder/starcoder-mmap.cpp b/examples/starcoder/starcoder-mmap.cpp index 1ab039c3..8d2c72dd 100644 --- a/examples/starcoder/starcoder-mmap.cpp +++ b/examples/starcoder/starcoder-mmap.cpp @@ -75,7 +75,7 @@ struct llama_buffer { void resize(size_t len) { #ifdef GGML_USE_METAL free(addr); - int result = posix_memalign((void **) &addr, getpagesize(), len); + int result = posix_memalign((void **) &addr, sysconf(_SC_PAGESIZE), len); if (result == 0) { memset(addr, 0, len); } diff --git a/examples/whisper/whisper.cpp b/examples/whisper/whisper.cpp index f0a0a5a6..37671215 100644 --- a/examples/whisper/whisper.cpp +++ b/examples/whisper/whisper.cpp @@ -1341,10 +1341,10 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con model.e_pe = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_state, n_audio_ctx); model.e_conv_1_w = ggml_new_tensor_3d(ctx, vtype, 3, n_mels, n_audio_state); - model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 2*n_audio_ctx, n_audio_state); + model.e_conv_1_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state); model.e_conv_2_w = ggml_new_tensor_3d(ctx, vtype, 3, n_audio_state, n_audio_state); - model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, n_audio_ctx, n_audio_state); + model.e_conv_2_b = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, 1, n_audio_state); model.e_ln_w = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); model.e_ln_b = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n_audio_state); @@ -1574,29 +1574,25 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con auto tensor = model.tensors[name.data()]; - const bool is_conv_bias = (name == "encoder.conv1.bias" || name == "encoder.conv2.bias"); - - if (!is_conv_bias) { - if (ggml_nelements(tensor) != nelements) { - WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); - WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n", - __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]); - return false; - } + if (ggml_nelements(tensor) != nelements) { + WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file\n", __func__, name.data()); + WHISPER_LOG_ERROR("%s: shape: [%d, %d, %d], expected: [%d, %d, %d]\n", + __func__, ne[0], ne[1], ne[2], (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2]); + return false; + } - if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) { - WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n", - __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]); - return false; - } + if (tensor->ne[0] != ne[0] || tensor->ne[1] != ne[1] || tensor->ne[2] != ne[2]) { + WHISPER_LOG_ERROR("%s: tensor '%s' has wrong shape in model file: got [%d, %d, %d], expected [%d, %d, %d]\n", + __func__, name.data(), (int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], ne[0], ne[1], ne[2]); + return false; + } - const size_t bpe = ggml_type_size(ggml_type(ttype)); + const size_t bpe = ggml_type_size(ggml_type(ttype)); - if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) { - WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", - __func__, name.data(), ggml_nbytes(tensor), nelements*bpe); - return false; - } + if ((nelements*bpe)/ggml_blck_size(tensor->type) != ggml_nbytes(tensor)) { + WHISPER_LOG_ERROR("%s: tensor '%s' has wrong size in model file: got %zu, expected %zu\n", + __func__, name.data(), ggml_nbytes(tensor), nelements*bpe); + return false; } ggml_backend_t backend = wctx.backend; @@ -1607,7 +1603,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con #ifdef GGML_USE_METAL || ggml_backend_is_metal(backend) #endif - ) && !is_conv_bias) { + )) { // for the CPU and Metal backend, we can read directly into the tensor loader->read(loader->context, tensor->data, ggml_nbytes(tensor)); BYTESWAP_TENSOR(tensor); @@ -1618,7 +1614,7 @@ static bool whisper_model_load(struct whisper_model_loader * loader, whisper_con // we repeat the 2 bias tensors along dim 0: // [1, 512] -> [3000, 512] (conv1.bias) // [1, 512] -> [1500, 512] (conv2.bias) - if (is_conv_bias) { + if (false) { loader->read(loader->context, read_buf.data(), read_buf.size() / tensor->ne[0]); float * data_f32 = (float *) read_buf.data(); @@ -1733,21 +1729,11 @@ static struct ggml_cgraph * whisper_build_graph_conv( { cur = ggml_conv_1d_ph(ctx0, model.e_conv_1_w, mel, 1, 1); cur = ggml_add(ctx0, cur, model.e_conv_1_b); - //cur = ggml_add(ctx0, - // ggml_repeat(ctx0, - // model.e_conv_1_b, - // cur), - // cur); cur = ggml_gelu(ctx0, cur); cur = ggml_conv_1d_ph(ctx0, model.e_conv_2_w, cur, 2, 1); cur = ggml_add(ctx0, cur, model.e_conv_2_b); - //cur = ggml_add(ctx0, - // ggml_repeat(ctx0, - // model.e_conv_2_b, - // cur), - // cur); cur = ggml_gelu(ctx0, cur); } diff --git a/include/ggml/ggml.h b/include/ggml/ggml.h index b53abaa1..353d52e1 100644 --- a/include/ggml/ggml.h +++ b/include/ggml/ggml.h @@ -284,6 +284,20 @@ const type prefix##3 = (pointer)->array[3]; \ GGML_UNUSED(prefix##3); +#define GGML_TENSOR_UNARY_OP_LOCALS \ + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \ + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \ + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \ + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + +#define GGML_TENSOR_BINARY_OP_LOCALS \ + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \ + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \ + GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \ + GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \ + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \ + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + #ifdef __cplusplus extern "C" { #endif @@ -382,6 +396,7 @@ extern "C" { GGML_OP_GROUP_NORM, GGML_OP_MUL_MAT, + GGML_OP_MUL_MAT_ID, GGML_OP_OUT_PROD, GGML_OP_SCALE, @@ -408,8 +423,8 @@ extern "C" { GGML_OP_CONV_TRANSPOSE_2D, GGML_OP_POOL_1D, GGML_OP_POOL_2D, - GGML_OP_UPSCALE, // nearest interpolate + GGML_OP_ARGSORT, GGML_OP_FLASH_ATTN, GGML_OP_FLASH_FF, @@ -1033,6 +1048,15 @@ extern "C" { struct ggml_tensor * a, struct ggml_tensor * b); + // indirect matrix multiplication + // ggml_mul_mat_id(ctx, as, ids, id, b) ~= ggml_mul_mat(as[ids[id]], b) + GGML_API struct ggml_tensor * ggml_mul_mat_id( + struct ggml_context * ctx, + struct ggml_tensor * as[], + struct ggml_tensor * ids, + int id, + struct ggml_tensor * b); + // A: m columns, n rows, // B: p columns, n rows, // result is m columns, p rows @@ -1518,6 +1542,23 @@ extern "C" { struct ggml_tensor * a, int scale_factor); + // sort rows + enum ggml_sort_order { + GGML_SORT_ASC, + GGML_SORT_DESC, + }; + + GGML_API struct ggml_tensor * ggml_argsort( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_sort_order order); + + // top k elements per row + GGML_API struct ggml_tensor * ggml_top_k( + struct ggml_context * ctx, + struct ggml_tensor * a, + int k); + GGML_API struct ggml_tensor * ggml_flash_attn( struct ggml_context * ctx, struct ggml_tensor * q, @@ -1579,7 +1620,6 @@ extern "C" { int kh); // used in sam - GGML_API struct ggml_tensor * ggml_add_rel_pos( struct ggml_context * ctx, struct ggml_tensor * a, diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 15a70041..94ee0ce1 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -171,7 +171,7 @@ if (GGML_OPENBLAS) set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} ${OPENBLAS_LIB}) set(GGML_EXTRA_INCS ${GGML_EXTRA_INCS} ${OPENBLAS_INC}) - set(GGML_EXTRA_FLAGS ${GGML_EXTRA_FLAGS} -DGGML_USE_OPENBLAS) + set(GGML_EXTRA_FLAGS ${GGML_EXTRA_FLAGS} -DGGML_USE_OPENBLAS) else() message(WARNING "OpenBLAS not found") endif() @@ -213,7 +213,17 @@ if (GGML_CUBLAS) set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h) - add_compile_definitions(GGML_USE_CUBLAS) + set(GGML_EXTRA_FLAGS ${GGML_EXTRA_FLAGS} -DGGML_USE_CUBLAS) + + if (GGML_CUDA_FORCE_DMMV) + add_compile_definitions(GGML_CUDA_FORCE_DMMV) + endif() + if (GGML_CUDA_FORCE_MMQ) + add_compile_definitions(GGML_CUDA_FORCE_MMQ) + endif() + + # required for dynamic parallelism + set(CMAKE_CUDA_SEPARABLE_COMPILATION ON) if (GGML_STATIC) set(GGML_EXTRA_LIBS ${GGML_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) @@ -245,7 +255,9 @@ if (GGML_HIPBLAS) if (${hipblas_FOUND} AND ${hip_FOUND}) message(STATUS "HIP and hipBLAS found") - add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS) + + set(GGML_EXTRA_FLAGS ${GGML_EXTRA_FLAGS} -DGGML_USE_CUBLAS) + add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h) if (BUILD_SHARED_LIBS) set_target_properties(ggml-rocm PROPERTIES POSITION_INDEPENDENT_CODE ON) @@ -280,7 +292,8 @@ if (GGML_METAL) set(GGML_METAL_SOURCES ggml-metal.m ggml-metal.h) - add_compile_definitions(GGML_USE_METAL) + set(GGML_EXTRA_FLAGS ${GGML_EXTRA_FLAGS} -DGGML_USE_METAL) + #add_compile_definitions(GGML_METAL_NDEBUG) # get full path to the file diff --git a/src/ggml-backend.c b/src/ggml-backend.c index 5f3005b2..9f31b065 100644 --- a/src/ggml-backend.c +++ b/src/ggml-backend.c @@ -1064,8 +1064,6 @@ ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, int n_bac struct ggml_backend_sched * sched = malloc(sizeof(struct ggml_backend_sched)); memset(sched, 0, sizeof(struct ggml_backend_sched)); - fprintf(stderr, "ggml_backend_sched size: %lu KB\n", sizeof(struct ggml_backend_sched)/1024); - sched->n_backends = n_backends; for (int i = 0; i < n_backends; i++) { sched->backends[i] = backends[i]; @@ -1271,6 +1269,7 @@ struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, s free(hash_set.keys); free(node_copies); + free(node_init); return (struct ggml_backend_graph_copy) { /* .buffer = */ buffer, diff --git a/src/ggml-cuda.cu b/src/ggml-cuda.cu index 9a8e40eb..dbe92d97 100644 --- a/src/ggml-cuda.cu +++ b/src/ggml-cuda.cu @@ -69,6 +69,7 @@ #define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize #define cudaSetDevice hipSetDevice #define cudaStreamCreateWithFlags hipStreamCreateWithFlags +#define cudaStreamFireAndForget hipStreamFireAndForget #define cudaStreamNonBlocking hipStreamNonBlocking #define cudaStreamSynchronize hipStreamSynchronize #define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags) @@ -433,8 +434,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_ #define WARP_SIZE 32 #define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses -#define CUDA_ADD_BLOCK_SIZE 256 -#define CUDA_MUL_BLOCK_SIZE 256 +#define CUDA_ADDMUL_BLOCK_SIZE 256 #define CUDA_GELU_BLOCK_SIZE 256 #define CUDA_SILU_BLOCK_SIZE 256 #define CUDA_RELU_BLOCK_SIZE 256 @@ -501,40 +501,43 @@ static size_t g_scratch_offset = 0; static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr}; -static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; - - if (i >= kx) { - return; - } - dst[i] = x[i] + y[i%ky]; +static __device__ __forceinline__ float op_add(const float a, const float b) { + return a + b; } -static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; +static __device__ __forceinline__ float op_mul(const float a, const float b) { + return a * b; +} - if (i >= k) { - return; - } - dst[i] = __hadd(x[i], __float2half(y[i])); +static __device__ __forceinline__ float op_div(const float a, const float b) { + return a / b; } -static __global__ void add_f16_f32_f32(const half * x, const float * y, float * dst, const int k) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; +template +static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst, + int ne0,/* int ne1, int ne2, */int ne3, + int ne10, int ne11, int ne12, int ne13, + /*int s0, */ int s1, int s2, int s3, + /*int s10,*/ int s11, int s12, int s13) { + const int i0 = blockDim.x*blockIdx.x + threadIdx.x; + const int i1 = blockIdx.y; + const int i2 = blockIdx.z / ne3; + const int i3 = blockIdx.z % ne3; - if (i >= k) { + if (i0 >= ne0) { return; } - dst[i] = __half2float(x[i]) + y[i]; -} -static __global__ void mul_f32(const float * x, const float * y, float * dst, const int kx, const int ky) { - const int i = blockDim.x*blockIdx.x + threadIdx.x; + const int i10 = i0 % ne10; + const int i11 = i1 % ne11; + const int i12 = i2 % ne12; + const int i13 = i3 % ne13; - if (i >= kx) { - return; - } - dst[i] = x[i] * y[i%ky]; + const size_t i_dst = i3*s3 + i2*s2 + i1*s1 + i0; + const size_t i_src0 = i_dst; + const size_t i_src1 = i13*s13 + i12*s12 + i11*s11 + i10; + + dst[i_dst] = (dst_t)bin_op((float)src0[i_src0], (float)src1[i_src1]); } static __global__ void gelu_f32(const float * x, float * dst, const int k) { @@ -577,6 +580,14 @@ static __global__ void sqr_f32(const float * x, float * dst, const int k) { dst[i] = x[i] * x[i]; } +static __device__ __forceinline__ float warp_reduce_sum(float x) { +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x += __shfl_xor_sync(0xffffffff, x, mask, 32); + } + return x; +} + static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { @@ -587,12 +598,10 @@ static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) { } template -static __global__ void norm_f32(const float * x, float * dst, const int ncols) { +static __global__ void norm_f32(const float * x, float * dst, const int ncols, const float eps) { const int row = blockIdx.x*blockDim.y + threadIdx.y; const int tid = threadIdx.x; - const float eps = 1e-5f; - float2 mean_var = make_float2(0.f, 0.f); for (int col = tid; col < ncols; col += block_size) { @@ -624,14 +633,6 @@ static __global__ void norm_f32(const float * x, float * dst, const int ncols) { } } -static __device__ __forceinline__ float warp_reduce_sum(float x) { -#pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - x += __shfl_xor_sync(0xffffffff, x, mask, 32); - } - return x; -} - template static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) { const int row = blockIdx.x*blockDim.y + threadIdx.y; @@ -4676,6 +4677,65 @@ static __global__ void alibi_f32(const float * x, float * dst, const int ncols, dst[i] = col * m_k + x[i]; } +static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) { + const int row = blockIdx.y; + const int col = threadIdx.x; + + float sum = 0.0f; + for (int i = col; i < ncols; i += blockDim.x) { + sum += x[row * ncols + i]; + } + + sum = warp_reduce_sum(sum); + + if (col == 0) { + dst[row] = sum; + } +} + +template +static inline __device__ void swap(T & a, T & b) { + T tmp = a; + a = b; + b = tmp; +} + +template +static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols) { + // bitonic sort + int col = threadIdx.x; + int row = blockIdx.y; + + if (col >= ncols) return; + + const float * x_row = x + row * ncols; + int * dst_row = dst + row * ncols; + + // initialize indices + if (col < ncols) { + dst_row[col] = col; + } + __syncthreads(); + + for (int k = 2; k <= ncols; k *= 2) { + for (int j = k / 2; j > 0; j /= 2) { + int ixj = col ^ j; + if (ixj > col) { + if ((col & k) == 0) { + if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) { + swap(dst_row[col], dst_row[ixj]); + } + } else { + if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) { + swap(dst_row[col], dst_row[ixj]); + } + } + } + __syncthreads(); + } + } +} + static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) { const int col = blockDim.y*blockIdx.y + threadIdx.y; const int row = blockDim.x*blockIdx.x + threadIdx.x; @@ -4780,25 +4840,35 @@ static void get_rows_cuda(const void * x, const int32_t * y, float * dst, const k_get_rows<<>>(x, y, dst, ncols); } -static void add_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) { - const int num_blocks = (kx + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE; - add_f32<<>>(x, y, dst, kx, ky); -} +template +struct bin_bcast_cuda { + template + void operator()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, + const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd, + cudaStream_t stream) { -static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE; - add_f16_f32_f16<<>>(x, y, dst, k); -} + GGML_TENSOR_BINARY_OP_LOCALS -static void add_f16_f32_f32_cuda(const half * x, const float * y, float * dst, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE; - add_f16_f32_f32<<>>(x, y, dst, k); -} + //size_t s0 = nb0 / sizeof(src1_t); + size_t s1 = nb1 / sizeof(src1_t); + size_t s2 = nb2 / sizeof(src1_t); + size_t s3 = nb3 / sizeof(src1_t); -static void mul_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) { - const int num_blocks = (kx + CUDA_MUL_BLOCK_SIZE - 1) / CUDA_MUL_BLOCK_SIZE; - mul_f32<<>>(x, y, dst, kx, ky); -} + //size_t s10 = nb10 / sizeof(src1_t); + size_t s11 = nb11 / sizeof(src1_t); + size_t s12 = nb12 / sizeof(src1_t); + size_t s13 = nb13 / sizeof(src1_t); + + const int num_blocks_x = (ne0 + CUDA_ADDMUL_BLOCK_SIZE - 1) / CUDA_ADDMUL_BLOCK_SIZE; + dim3 num_blocks(num_blocks_x, ne1, ne2*ne3); + + k_bin_bcast<<>>(src0_dd, src1_dd, dst_dd, + ne0,/* ne1, ne2, */ne3, + ne10, ne11, ne12, ne13, + /* s0, */s1, s2, s3, + /* s10,*/ s11, s12, s13); + } +}; static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE; @@ -4820,14 +4890,14 @@ static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t sqr_f32<<>>(x, dst, k); } -static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { +static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) { GGML_ASSERT(ncols % WARP_SIZE == 0); if (ncols < 1024) { const dim3 block_dims(WARP_SIZE, 1, 1); - norm_f32<<>>(x, dst, ncols); + norm_f32<<>>(x, dst, ncols, eps); } else { const dim3 block_dims(1024, 1, 1); - norm_f32<1024><<>>(x, dst, ncols); + norm_f32<1024><<>>(x, dst, ncols, eps); } } @@ -4849,38 +4919,14 @@ static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, con quantize_q8_1<<>>(x, vy, kx, kx_padded); } -template -static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; - dequantize_block<<>>(vx, y, k); -} - -template -static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; - dequantize_block<<>>(vx, y, k); -} - -template -static void dequantize_row_q5_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; - dequantize_block<<>>(vx, y, k); -} - -template -static void dequantize_row_q5_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; - dequantize_block<<>>(vx, y, k); -} - -template -static void dequantize_row_q8_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +template +static __host__ __device__ void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; - dequantize_block<<>>(vx, y, k); + dequantize_block<<>>(vx, y, k); } template -static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static __host__ __device__ void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; #if QK_K == 256 dequantize_block_q2_K<<>>(vx, y); @@ -4890,7 +4936,7 @@ static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cu } template -static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static __host__ __device__ void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; #if QK_K == 256 dequantize_block_q3_K<<>>(vx, y); @@ -4900,13 +4946,13 @@ static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cu } template -static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static __host__ __device__ void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; dequantize_block_q4_K<<>>(vx, y); } template -static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static __host__ __device__ void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; #if QK_K == 256 dequantize_block_q5_K<<>>(vx, y); @@ -4916,7 +4962,7 @@ static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cu } template -static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { +static __host__ __device__ void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) { const int nb = k / QK_K; #if QK_K == 256 dequantize_block_q6_K<<>>(vx, y); @@ -4925,6 +4971,64 @@ static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cu #endif } +static to_fp16_cuda_t __host__ __device__ ggml_get_to_fp16_cuda(ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + return dequantize_block_cuda; + case GGML_TYPE_Q4_1: + return dequantize_block_cuda; + case GGML_TYPE_Q5_0: + return dequantize_block_cuda; + case GGML_TYPE_Q5_1: + return dequantize_block_cuda; + case GGML_TYPE_Q8_0: + return dequantize_block_cuda; + case GGML_TYPE_Q2_K: + return dequantize_row_q2_K_cuda; + case GGML_TYPE_Q3_K: + return dequantize_row_q3_K_cuda; + case GGML_TYPE_Q4_K: + return dequantize_row_q4_K_cuda; + case GGML_TYPE_Q5_K: + return dequantize_row_q5_K_cuda; + case GGML_TYPE_Q6_K: + return dequantize_row_q6_K_cuda; + case GGML_TYPE_F32: + return dequantize_block_cuda<1, 1, convert_f32>; + default: + return nullptr; + } +} + +static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { + switch (type) { + case GGML_TYPE_Q4_0: + return dequantize_block_cuda; + case GGML_TYPE_Q4_1: + return dequantize_block_cuda; + case GGML_TYPE_Q5_0: + return dequantize_block_cuda; + case GGML_TYPE_Q5_1: + return dequantize_block_cuda; + case GGML_TYPE_Q8_0: + return dequantize_block_cuda; + case GGML_TYPE_Q2_K: + return dequantize_row_q2_K_cuda; + case GGML_TYPE_Q3_K: + return dequantize_row_q3_K_cuda; + case GGML_TYPE_Q4_K: + return dequantize_row_q4_K_cuda; + case GGML_TYPE_Q5_K: + return dequantize_row_q5_K_cuda; + case GGML_TYPE_Q6_K: + return dequantize_row_q6_K_cuda; + case GGML_TYPE_F16: + return dequantize_block_cuda<1, 1, convert_f16>; + default: + return nullptr; + } +} + static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; @@ -5013,6 +5117,15 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f dequantize_mul_mat_vec_q6_k<<>>(vx, y, dst, ncols, nrows); } +static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); + const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; + const dim3 block_nums(block_num_y, 1, 1); + const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); + dequantize_mul_mat_vec<1, 1, convert_f16> + <<>>(vx, y, dst, ncols, nrows); +} + static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { GGML_ASSERT(ncols % QK4_0 == 0); const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; @@ -5103,83 +5216,6 @@ static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float * <<>>(vx, vy, dst, ncols, nrows); } -static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; - dequantize_block<1, 1, convert_f16><<>>(vx, y, k); -} - -static void convert_fp32_to_fp16_cuda(const void * vx, half * y, const int k, cudaStream_t stream) { - const int num_blocks = (k + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; - dequantize_block<1, 1, convert_f32><<>>(vx, y, k); -} - -static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec<1, 1, convert_f16> - <<>>(vx, y, dst, ncols, nrows); -} - -static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) { - switch (type) { - case GGML_TYPE_Q4_0: - return dequantize_row_q4_0_cuda; - case GGML_TYPE_Q4_1: - return dequantize_row_q4_1_cuda; - case GGML_TYPE_Q5_0: - return dequantize_row_q5_0_cuda; - case GGML_TYPE_Q5_1: - return dequantize_row_q5_1_cuda; - case GGML_TYPE_Q8_0: - return dequantize_row_q8_0_cuda; - case GGML_TYPE_Q2_K: - return dequantize_row_q2_K_cuda; - case GGML_TYPE_Q3_K: - return dequantize_row_q3_K_cuda; - case GGML_TYPE_Q4_K: - return dequantize_row_q4_K_cuda; - case GGML_TYPE_Q5_K: - return dequantize_row_q5_K_cuda; - case GGML_TYPE_Q6_K: - return dequantize_row_q6_K_cuda; - case GGML_TYPE_F32: - return convert_fp32_to_fp16_cuda; - default: - return nullptr; - } -} - -static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) { - switch (type) { - case GGML_TYPE_Q4_0: - return dequantize_row_q4_0_cuda; - case GGML_TYPE_Q4_1: - return dequantize_row_q4_1_cuda; - case GGML_TYPE_Q5_0: - return dequantize_row_q5_0_cuda; - case GGML_TYPE_Q5_1: - return dequantize_row_q5_1_cuda; - case GGML_TYPE_Q8_0: - return dequantize_row_q8_0_cuda; - case GGML_TYPE_Q2_K: - return dequantize_row_q2_K_cuda; - case GGML_TYPE_Q3_K: - return dequantize_row_q3_K_cuda; - case GGML_TYPE_Q4_K: - return dequantize_row_q4_K_cuda; - case GGML_TYPE_Q5_K: - return dequantize_row_q5_K_cuda; - case GGML_TYPE_Q6_K: - return dequantize_row_q6_K_cuda; - case GGML_TYPE_F16: - return convert_fp16_to_fp32_cuda; - default: - return nullptr; - } -} - static void ggml_mul_mat_q4_0_q8_1_cuda( const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) { @@ -5752,6 +5788,27 @@ static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const alibi_f32<<>>(x, dst, ncols, k_rows, n_heads_log2_floor, m0, m1); } +static void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + const dim3 block_dims(WARP_SIZE, 1, 1); + const dim3 block_nums(1, nrows, 1); + k_sum_rows_f32<<>>(x, dst, ncols); +} + +static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) { + // bitonic sort requires ncols to be power of 2 + GGML_ASSERT((ncols & (ncols - 1)) == 0); + + const dim3 block_dims(ncols, 1, 1); + const dim3 block_nums(1, nrows, 1); + if (order == GGML_SORT_ASC) { + k_argsort_f32_i32<<>>(x, dst, ncols); + } else if (order == GGML_SORT_DESC) { + k_argsort_f32_i32<<>>(x, dst, ncols); + } else { + GGML_ASSERT(false); + } +} + static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) { const dim3 block_dims(1, CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1); const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE; @@ -6140,44 +6197,46 @@ static void ggml_cuda_op_get_rows( } } -inline void ggml_cuda_op_add( +template +inline void ggml_cuda_op_bin_bcast( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { GGML_ASSERT(src1->type == GGML_TYPE_F32); - const int64_t ne10 = src1->ne[0]; - const int64_t ne11 = src1->ne[1]; - if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - add_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(src0), ne10*ne11, main_stream); + op()(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { - add_f16_f32_f16_cuda((const half *) src0_dd, src1_dd, (half *) dst_dd, ggml_nelements(src0), main_stream); + op()(src0, src1, dst, (const half *) src0_dd, src1_dd, (half *) dst_dd, main_stream); } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { - add_f16_f32_f32_cuda((const half *) src0_dd, src1_dd, dst_dd, ggml_nelements(src0), main_stream); + op()(src0, src1, dst, (const half *) src0_dd, src1_dd, dst_dd, main_stream); } else { - fprintf(stderr, "src0->type: %d dst->type: %d\n", src0->type, dst->type); + fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, + ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type)); GGML_ASSERT(false); } - (void) src1; - (void) dst; } -inline void ggml_cuda_op_mul( +inline void ggml_cuda_op_add( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - GGML_ASSERT( dst->type == GGML_TYPE_F32); + ggml_cuda_op_bin_bcast>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); +} - const int64_t ne10 = src1->ne[0]; - const int64_t ne11 = src1->ne[1]; +inline void ggml_cuda_op_mul( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { - mul_f32_cuda(src0_dd, src1_dd, dst_dd, ggml_nelements(src0), ne10*ne11, main_stream); + ggml_cuda_op_bin_bcast>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); +} - (void) dst; +inline void ggml_cuda_op_div( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + ggml_cuda_op_bin_bcast>(src0, src1, dst, src0_dd, src1_dd, dst_dd, main_stream); } inline void ggml_cuda_op_gelu( @@ -6246,7 +6305,10 @@ inline void ggml_cuda_op_norm( const int64_t ne00 = src0->ne[0]; const int64_t nrows = ggml_nrows(src0); - norm_f32_cuda(src0_dd, dst_dd, ne00, nrows, main_stream); + float eps; + memcpy(&eps, dst->op_params, sizeof(float)); + + norm_f32_cuda(src0_dd, dst_dd, ne00, nrows, eps, main_stream); (void) src1; (void) dst; @@ -6785,6 +6847,42 @@ inline void ggml_cuda_op_im2col( (void) src0_dd; } +inline void ggml_cuda_op_sum_rows( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + const int64_t ncols = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + sum_rows_f32_cuda(src0_dd, dst_dd, ncols, nrows, main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + +inline void ggml_cuda_op_argsort( + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, + const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_I32); + + const int64_t ncols = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; + + argsort_f32_i32_cuda(src0_dd, (int *)dst_dd, ncols, nrows, order, main_stream); + + (void) src1; + (void) dst; + (void) src1_dd; +} + inline void ggml_cuda_op_diag_mask_inf( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const float * src0_dd, const float * src1_dd, float * dst_dd, const cudaStream_t & main_stream) { @@ -7298,6 +7396,10 @@ static void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, gg ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_mul); } +static void ggml_cuda_div(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_div); +} + static void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_gelu); } @@ -7401,7 +7503,7 @@ static void ggml_cuda_mul_mat_vec_nc(const ggml_tensor * src0, const ggml_tensor ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream); } -__global__ void k_compute_batched_ptrs( +static __global__ void k_compute_batched_ptrs( const half * src0_as_f16, const half * src1_as_f16, half * dst_f16, const void ** ptrs_src, void ** ptrs_dst, int ne12, int ne13, @@ -7457,9 +7559,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const CUDA_CHECK(ggml_cuda_set_device(g_main_device)); cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; - int id; - CUDA_CHECK(cudaGetDevice(&id)); - CUBLAS_CHECK(cublasSetStream(g_cublas_handles[id], main_stream)); + CUBLAS_CHECK(cublasSetStream(g_cublas_handles[g_main_device], main_stream)); ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; void * src0_ddq = src0_extra->data_device[g_main_device]; @@ -7516,7 +7616,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const // there is no broadcast and src0, src1 are contiguous across dims 2, 3 // use cublasGemmStridedBatchedEx CUBLAS_CHECK( - cublasGemmStridedBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, + cublasGemmStridedBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, &alpha_f16, (const char *) src0_as_f16, CUDA_R_16F, nb01/sizeof(half), src0->nb[2]/sizeof(half), // strideA (const char *) src1_as_f16, CUDA_R_16F, nb11/sizeof(float), src1->nb[2]/sizeof(float), // strideB @@ -7550,7 +7650,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const CUDA_CHECK(cudaGetLastError()); CUBLAS_CHECK( - cublasGemmBatchedEx(g_cublas_handles[id], CUBLAS_OP_T, CUBLAS_OP_N, + cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N, ne01, ne11, ne10, &alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, nb01/sizeof(half), (const void **) (ptrs_src + 1*ne23), CUDA_R_16F, nb11/sizeof(float), @@ -7648,6 +7748,219 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1 } } +#if 0 +template +static __global__ void k_compute_batched_ptrs_id( + const void ** ptrs_src, void ** ptrs_dst, + int ne12, int ne13, + int ne23, + int nb02, int nb03, + int nb12, int nb13, + int nb2, int nb3, + int r2, int r3, + ggml_type src0_type, half * src0_as_f16, int64_t src0_ne, + const half * src1_f16, half * dst_f16, + const int32_t * ids, const int id, + Srcs... src0s) { + + int i = ids[id]; + + half * src0_f16; + const void * srcs_ar[] = { (const half *) src0s... }; + if (src0_type == GGML_TYPE_F16) { + src0_f16 = (half *) srcs_ar[i]; + } else { + src0_f16 = src0_as_f16; + if (threadIdx.x == 0 && threadIdx.y == 0) { + const to_fp16_cuda_t to_fp16 = ggml_get_to_fp16_cuda(src0_type); + to_fp16(srcs_ar[i], src0_f16, src0_ne, cudaStreamFireAndForget); + } + } + + int i13 = blockIdx.x * blockDim.x + threadIdx.x; + int i12 = blockIdx.y * blockDim.y + threadIdx.y; + + if (i13 >= ne13 || i12 >= ne12) { + return; + } + + int i03 = i13 / r3; + int i02 = i12 / r2; + + ptrs_src[0*ne23 + i12 + i13*ne12] = (const char *) src0_f16 + i02*nb02 + i03*nb03; + ptrs_src[1*ne23 + i12 + i13*ne12] = (const char *) src1_f16 + i12*nb12/2 + i13*nb13/2; + ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f16 + i12* nb2/2 + i13* nb3/2; +} + +static void ggml_cuda_mul_mat_id_cublas(ggml_tensor * dst) { + const struct ggml_tensor * ids = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + const struct ggml_tensor * src00 = dst->src[2]; + + const int id = dst->op_params[0]; + + GGML_ASSERT(!ggml_is_transposed(src00)); + GGML_ASSERT(!ggml_is_transposed(src1)); + + GGML_ASSERT(src00->backend != GGML_BACKEND_GPU_SPLIT); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + const int64_t ne00 = src00->ne[0]; GGML_UNUSED(ne00); + const int64_t ne01 = src00->ne[1]; + const int64_t ne02 = src00->ne[2]; + const int64_t ne03 = src00->ne[3]; + + //const int64_t nb01 = src00->nb[1]; + const int64_t nb02 = src00->nb[2]; GGML_UNUSED(nb02); + const int64_t nb03 = src00->nb[3]; GGML_UNUSED(nb03); + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + const int64_t ne12 = src1->ne[2]; + const int64_t ne13 = src1->ne[3]; + + //const int64_t nb11 = src1->nb[1]; + const int64_t nb12 = src1->nb[2]; GGML_UNUSED(nb12); + const int64_t nb13 = src1->nb[3]; GGML_UNUSED(nb13); + + const int64_t ne1 = ggml_nelements(src1); + const int64_t ne = ggml_nelements(dst); + + CUDA_CHECK(ggml_cuda_set_device(g_main_device)); + cudaStream_t main_stream = g_cudaStreams[g_main_device][0]; + + CUBLAS_CHECK(cublasSetStream(g_cublas_handles[g_main_device], main_stream)); + + //ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) src0->extra; + //void * src0_ddq = src0_extra->data_device[g_main_device]; + //half * src0_as_f16 = (half *) src0_ddq; + + ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) src1->extra; + float * src1_ddf = (float *) src1_extra->data_device[g_main_device]; + + ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) dst->extra; + float * dst_ddf = (float *) dst_extra->data_device[g_main_device]; + + // convert src1 to fp16 + const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); + GGML_ASSERT(to_fp16_cuda != nullptr); + + size_t src1_as = 0; + half * src1_as_f16 = (half *) ggml_cuda_pool_malloc(ne1 * sizeof(half), &src1_as); + to_fp16_cuda(src1_ddf, src1_as_f16, ne1, main_stream); + + size_t dst_as = 0; + half * dst_f16 = (half *) ggml_cuda_pool_malloc(ne * sizeof(half), &dst_as); + + GGML_ASSERT(ne12 % ne02 == 0); + GGML_ASSERT(ne13 % ne03 == 0); + + // broadcast factors + const int64_t r2 = ne12/ne02; + const int64_t r3 = ne13/ne03; + + const half alpha_f16 = 1.0f; + const half beta_f16 = 0.0f; + + // use cublasGemmBatchedEx + const int ne23 = ne12*ne13; + + const void ** ptrs_src = nullptr; + void ** ptrs_dst = nullptr; + + size_t ptrs_src_s = 0; + size_t ptrs_dst_s = 0; + + ptrs_src = (const void **) ggml_cuda_pool_malloc(2*ne23*sizeof(void *), &ptrs_src_s); + ptrs_dst = ( void **) ggml_cuda_pool_malloc(1*ne23*sizeof(void *), &ptrs_dst_s); + + int64_t src0_ne = ggml_nelements(src00); + half * src0_as_f16 = nullptr; + size_t src0_as = 0; + if (src00->type != GGML_TYPE_F16) { + src0_as_f16 = (half *) ggml_cuda_pool_malloc(src0_ne * sizeof(half), &src0_as); + } + + static_assert(GGML_MAX_SRC == 6, "GGML_MAX_SRC == 6"); + dim3 block_dims(ne13, ne12); + k_compute_batched_ptrs_id<<<1, block_dims, 0, main_stream>>>( + ptrs_src, ptrs_dst, + ne12, ne13, + ne23, + ne00*ne01*sizeof(half), ne00*ne01*ne02*sizeof(half), + nb12, nb13, + dst->nb[2], dst->nb[3], + r2, r3, + src00->type, src0_as_f16, src0_ne, + src1_as_f16, dst_f16, + (const int *)((ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device], id, + dst->src[2] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[2]->extra)->data_device[g_main_device] : nullptr, + dst->src[3] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[3]->extra)->data_device[g_main_device] : nullptr, + dst->src[4] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[4]->extra)->data_device[g_main_device] : nullptr, + dst->src[5] ? (const half *)((ggml_tensor_extra_gpu *)dst->src[5]->extra)->data_device[g_main_device] : nullptr + ); + CUDA_CHECK(cudaGetLastError()); + + CUBLAS_CHECK( + cublasGemmBatchedEx(g_cublas_handles[g_main_device], CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + &alpha_f16, (const void **) (ptrs_src + 0*ne23), CUDA_R_16F, ne00, + (const void **) (ptrs_src + 1*ne23), CUDA_R_16F, ne10, + &beta_f16, ( void **) (ptrs_dst + 0*ne23), CUDA_R_16F, ne01, + ne23, + CUBLAS_COMPUTE_16F, + CUBLAS_GEMM_DEFAULT_TENSOR_OP)); + + if (src0_as != 0) { + ggml_cuda_pool_free(src0_as_f16, src0_as); + } + if (ptrs_src_s != 0) { + ggml_cuda_pool_free(ptrs_src, ptrs_src_s); + } + if (ptrs_dst_s != 0) { + ggml_cuda_pool_free(ptrs_dst, ptrs_dst_s); + } + + const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16); + to_fp32_cuda(dst_f16, dst_ddf, ne, main_stream); + + ggml_cuda_pool_free(src1_as_f16, src1_as); + ggml_cuda_pool_free(dst_f16, dst_as); +} +#endif + +static void ggml_cuda_mul_mat_id(const ggml_tensor * _src0, const ggml_tensor * _src1, ggml_tensor * dst) { +#if 0 +//#ifdef CUDA_USE_TENSOR_CORES +// const bool use_tensor_cores = true; +//#else +// const bool use_tensor_cores = false; +//#endif + + ggml_cuda_mul_mat_id_cublas(dst); + + // TODO: mmq/mmv support +#else + const struct ggml_tensor * ids = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + const int id = dst->op_params[0]; + + int32_t * ids_dev = (int32_t *)((ggml_tensor_extra_gpu *)ids->extra)->data_device[g_main_device]; + + int32_t a_id; + CUDA_CHECK(cudaMemcpyAsync(&a_id, ids_dev + id, sizeof(int32_t), cudaMemcpyDeviceToHost, g_cudaStreams[g_main_device][0])); + CUDA_CHECK(cudaStreamSynchronize(g_cudaStreams[g_main_device][0])); + + GGML_ASSERT(a_id >= 0 && a_id < ids->ne[0]); + const struct ggml_tensor * src0 = dst->src[a_id + 2]; + + ggml_cuda_mul_mat(src0, src1, dst); +#endif + + (void) _src0; + (void) _src1; +} + static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale); } @@ -7735,6 +8048,16 @@ static void ggml_cuda_im2col(const ggml_tensor * src0, const ggml_tensor * src1, ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_im2col); } +static void ggml_cuda_sum_rows(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(src0)); + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_sum_rows); +} + +static void ggml_cuda_argsort(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(ggml_is_contiguous(src0)); + ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_argsort); +} + static void ggml_cuda_nop(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { (void) src0; (void) src1; @@ -8054,6 +8377,9 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ case GGML_OP_MUL: func = ggml_cuda_mul; break; + case GGML_OP_DIV: + func = ggml_cuda_div; + break; case GGML_OP_UNARY: switch (ggml_get_unary_op(tensor)) { case GGML_UNARY_OP_GELU: @@ -8080,6 +8406,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ } func = ggml_cuda_mul_mat; break; + case GGML_OP_MUL_MAT_ID: + if (!any_on_device && !ggml_cuda_can_mul_mat(tensor->src[2], tensor->src[1], tensor)) { + return false; + } + func = ggml_cuda_mul_mat_id; + break; case GGML_OP_SCALE: func = ggml_cuda_scale; break; @@ -8119,6 +8451,12 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_ case GGML_OP_IM2COL: func = ggml_cuda_im2col; break; + case GGML_OP_SUM_ROWS: + func = ggml_cuda_sum_rows; + break; + case GGML_OP_ARGSORT: + func = ggml_cuda_argsort; + break; default: return false; } @@ -8343,6 +8681,7 @@ static ggml_backend_buffer_t ggml_backend_cuda_host_buffer_type_alloc_buffer(ggm // FIXME: this is a hack to avoid having to implement a new buffer type ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size); + buffer->buft = buft; buffer->iface.free_buffer = ggml_backend_cuda_host_buffer_free_buffer; return buffer; @@ -8515,6 +8854,7 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten return false; } break; + case GGML_OP_MUL_MAT_ID: case GGML_OP_NONE: case GGML_OP_RESHAPE: case GGML_OP_VIEW: @@ -8526,6 +8866,7 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten case GGML_OP_DUP: case GGML_OP_ADD: case GGML_OP_MUL: + case GGML_OP_DIV: case GGML_OP_RMS_NORM: case GGML_OP_MUL_MAT: case GGML_OP_SCALE: @@ -8538,6 +8879,8 @@ static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, const ggml_ten case GGML_OP_ROPE: case GGML_OP_ALIBI: case GGML_OP_IM2COL: + case GGML_OP_SUM_ROWS: + case GGML_OP_ARGSORT: return true; default: return false; @@ -8595,6 +8938,7 @@ static ggml_backend_t ggml_backend_reg_cuda_init(const char * params, void * use static int ggml_backend_cuda_reg_devices() { int device_count = ggml_cuda_get_device_count(); + //int device_count = 1; // DEBUG: some tools require delaying CUDA initialization for (int i = 0; i < device_count; i++) { char name[128]; snprintf(name, sizeof(name), "%s%d", GGML_CUDA_NAME, i); diff --git a/src/ggml-metal.m b/src/ggml-metal.m index 37b291a9..f2267356 100644 --- a/src/ggml-metal.m +++ b/src/ggml-metal.m @@ -62,6 +62,8 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(add_row); // TODO: avoid this extra kernel, instead extend the "add" kernel to support broadcast GGML_METAL_DECL_KERNEL(mul); GGML_METAL_DECL_KERNEL(mul_row); // TODO: avoid this extra kernel, instead extend the "mul" kernel to support broadcast + GGML_METAL_DECL_KERNEL(div); + GGML_METAL_DECL_KERNEL(div_row); GGML_METAL_DECL_KERNEL(scale); GGML_METAL_DECL_KERNEL(scale_4); GGML_METAL_DECL_KERNEL(silu); @@ -112,15 +114,30 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(mul_mm_q4_K_f32); GGML_METAL_DECL_KERNEL(mul_mm_q5_K_f32); GGML_METAL_DECL_KERNEL(mul_mm_q6_K_f32); + GGML_METAL_DECL_KERNEL(mul_mm_id_f32_f32); + GGML_METAL_DECL_KERNEL(mul_mm_id_f16_f32); + GGML_METAL_DECL_KERNEL(mul_mm_id_q4_0_f32); + GGML_METAL_DECL_KERNEL(mul_mm_id_q4_1_f32); + GGML_METAL_DECL_KERNEL(mul_mm_id_q5_0_f32); + GGML_METAL_DECL_KERNEL(mul_mm_id_q5_1_f32); + GGML_METAL_DECL_KERNEL(mul_mm_id_q8_0_f32); + GGML_METAL_DECL_KERNEL(mul_mm_id_q2_K_f32); + GGML_METAL_DECL_KERNEL(mul_mm_id_q3_K_f32); + GGML_METAL_DECL_KERNEL(mul_mm_id_q4_K_f32); + GGML_METAL_DECL_KERNEL(mul_mm_id_q5_K_f32); + GGML_METAL_DECL_KERNEL(mul_mm_id_q6_K_f32); GGML_METAL_DECL_KERNEL(rope_f32); GGML_METAL_DECL_KERNEL(rope_f16); GGML_METAL_DECL_KERNEL(alibi_f32); GGML_METAL_DECL_KERNEL(im2col_f16); + GGML_METAL_DECL_KERNEL(argsort_f32_i32_asc); + GGML_METAL_DECL_KERNEL(argsort_f32_i32_desc); GGML_METAL_DECL_KERNEL(cpy_f32_f16); GGML_METAL_DECL_KERNEL(cpy_f32_f32); GGML_METAL_DECL_KERNEL(cpy_f16_f16); GGML_METAL_DECL_KERNEL(concat); GGML_METAL_DECL_KERNEL(sqr); + GGML_METAL_DECL_KERNEL(sum_rows); #undef GGML_METAL_DECL_KERNEL }; @@ -289,6 +306,8 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(add_row); GGML_METAL_ADD_KERNEL(mul); GGML_METAL_ADD_KERNEL(mul_row); + GGML_METAL_ADD_KERNEL(div); + GGML_METAL_ADD_KERNEL(div_row); GGML_METAL_ADD_KERNEL(scale); GGML_METAL_ADD_KERNEL(scale_4); GGML_METAL_ADD_KERNEL(silu); @@ -340,16 +359,31 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(mul_mm_q4_K_f32); GGML_METAL_ADD_KERNEL(mul_mm_q5_K_f32); GGML_METAL_ADD_KERNEL(mul_mm_q6_K_f32); + GGML_METAL_ADD_KERNEL(mul_mm_id_f32_f32); + GGML_METAL_ADD_KERNEL(mul_mm_id_f16_f32); + GGML_METAL_ADD_KERNEL(mul_mm_id_q4_0_f32); + GGML_METAL_ADD_KERNEL(mul_mm_id_q4_1_f32); + GGML_METAL_ADD_KERNEL(mul_mm_id_q5_0_f32); + GGML_METAL_ADD_KERNEL(mul_mm_id_q5_1_f32); + GGML_METAL_ADD_KERNEL(mul_mm_id_q8_0_f32); + GGML_METAL_ADD_KERNEL(mul_mm_id_q2_K_f32); + GGML_METAL_ADD_KERNEL(mul_mm_id_q3_K_f32); + GGML_METAL_ADD_KERNEL(mul_mm_id_q4_K_f32); + GGML_METAL_ADD_KERNEL(mul_mm_id_q5_K_f32); + GGML_METAL_ADD_KERNEL(mul_mm_id_q6_K_f32); } GGML_METAL_ADD_KERNEL(rope_f32); GGML_METAL_ADD_KERNEL(rope_f16); GGML_METAL_ADD_KERNEL(alibi_f32); GGML_METAL_ADD_KERNEL(im2col_f16); + GGML_METAL_ADD_KERNEL(argsort_f32_i32_asc); + GGML_METAL_ADD_KERNEL(argsort_f32_i32_desc); GGML_METAL_ADD_KERNEL(cpy_f32_f16); GGML_METAL_ADD_KERNEL(cpy_f32_f32); GGML_METAL_ADD_KERNEL(cpy_f16_f16); GGML_METAL_ADD_KERNEL(concat); GGML_METAL_ADD_KERNEL(sqr); + GGML_METAL_ADD_KERNEL(sum_rows); #undef GGML_METAL_ADD_KERNEL } @@ -367,6 +401,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(add_row); GGML_METAL_DEL_KERNEL(mul); GGML_METAL_DEL_KERNEL(mul_row); + GGML_METAL_DEL_KERNEL(div); + GGML_METAL_DEL_KERNEL(div_row); GGML_METAL_DEL_KERNEL(scale); GGML_METAL_DEL_KERNEL(scale_4); GGML_METAL_DEL_KERNEL(silu); @@ -418,16 +454,31 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(mul_mm_q4_K_f32); GGML_METAL_DEL_KERNEL(mul_mm_q5_K_f32); GGML_METAL_DEL_KERNEL(mul_mm_q6_K_f32); + GGML_METAL_DEL_KERNEL(mul_mm_id_f32_f32); + GGML_METAL_DEL_KERNEL(mul_mm_id_f16_f32); + GGML_METAL_DEL_KERNEL(mul_mm_id_q4_0_f32); + GGML_METAL_DEL_KERNEL(mul_mm_id_q4_1_f32); + GGML_METAL_DEL_KERNEL(mul_mm_id_q5_0_f32); + GGML_METAL_DEL_KERNEL(mul_mm_id_q5_1_f32); + GGML_METAL_DEL_KERNEL(mul_mm_id_q8_0_f32); + GGML_METAL_DEL_KERNEL(mul_mm_id_q2_K_f32); + GGML_METAL_DEL_KERNEL(mul_mm_id_q3_K_f32); + GGML_METAL_DEL_KERNEL(mul_mm_id_q4_K_f32); + GGML_METAL_DEL_KERNEL(mul_mm_id_q5_K_f32); + GGML_METAL_DEL_KERNEL(mul_mm_id_q6_K_f32); } GGML_METAL_DEL_KERNEL(rope_f32); GGML_METAL_DEL_KERNEL(rope_f16); GGML_METAL_DEL_KERNEL(alibi_f32); GGML_METAL_DEL_KERNEL(im2col_f16); + GGML_METAL_DEL_KERNEL(argsort_f32_i32_asc); + GGML_METAL_DEL_KERNEL(argsort_f32_i32_desc); GGML_METAL_DEL_KERNEL(cpy_f32_f16); GGML_METAL_DEL_KERNEL(cpy_f32_f32); GGML_METAL_DEL_KERNEL(cpy_f16_f16); GGML_METAL_DEL_KERNEL(concat); GGML_METAL_DEL_KERNEL(sqr); + GGML_METAL_DEL_KERNEL(sum_rows); #undef GGML_METAL_DEL_KERNEL @@ -884,6 +935,8 @@ void ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(ne1, ne2, ne3) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; case GGML_OP_ADD: + case GGML_OP_MUL: + case GGML_OP_DIV: { GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(ggml_is_contiguous(src1)); @@ -897,11 +950,21 @@ void ggml_metal_graph_compute( GGML_ASSERT(ne11 == 1); nb = ne00 / 4; - [encoder setComputePipelineState:ctx->pipeline_add_row]; + switch (dst->op) { + case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add_row]; break; + case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul_row]; break; + case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div_row]; break; + default: GGML_ASSERT(false); + } bcast_row = true; } else { - [encoder setComputePipelineState:ctx->pipeline_add]; + switch (dst->op) { + case GGML_OP_ADD: [encoder setComputePipelineState:ctx->pipeline_add]; break; + case GGML_OP_MUL: [encoder setComputePipelineState:ctx->pipeline_mul]; break; + case GGML_OP_DIV: [encoder setComputePipelineState:ctx->pipeline_div]; break; + default: GGML_ASSERT(false); + } } [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; @@ -942,31 +1005,6 @@ void ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } } break; - case GGML_OP_MUL: - { - GGML_ASSERT(ggml_is_contiguous(src0)); - GGML_ASSERT(ggml_is_contiguous(src1)); - - // utilize float4 - GGML_ASSERT(ne00 % 4 == 0); - const int64_t nb = ne00/4; - - if (ggml_nelements(src1) == ne10) { - // src1 is a row - GGML_ASSERT(ne11 == 1); - [encoder setComputePipelineState:ctx->pipeline_mul_row]; - } else { - [encoder setComputePipelineState:ctx->pipeline_mul]; - } - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&nb length:sizeof(nb) atIndex:3]; - - const int64_t n = ggml_nelements(dst)/4; - - [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; - } break; case GGML_OP_SCALE: { GGML_ASSERT(ggml_is_contiguous(src0)); @@ -1039,6 +1077,40 @@ void ggml_metal_graph_compute( const int64_t n = ggml_nelements(dst); [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case GGML_OP_SUM_ROWS: + { + GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); + + [encoder setComputePipelineState:ctx->pipeline_sum_rows]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; + [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; + [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; + [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; + [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; + [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; + [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; + [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; + [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; + [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; + [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:17]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:19]; + [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:20]; + [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:21]; + [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:22]; + [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:23]; + [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:24]; + [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:25]; + + [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } break; case GGML_OP_SOFT_MAX: { int nth = 32; // SIMD width @@ -1331,6 +1403,96 @@ void ggml_metal_graph_compute( } } } break; + case GGML_OP_MUL_MAT_ID: + { + //GGML_ASSERT(ne00 == ne10); + //GGML_ASSERT(ne03 == ne13); + + GGML_ASSERT(src0t == GGML_TYPE_I32); + + const int n_as = ne00; + + // TODO: make this more general + GGML_ASSERT(n_as <= 8); + + struct ggml_tensor * src2 = gf->nodes[i]->src[2]; + + const int64_t ne20 = src2 ? src2->ne[0] : 0; + const int64_t ne21 = src2 ? src2->ne[1] : 0; + const int64_t ne22 = src2 ? src2->ne[2] : 0; + const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23); + + const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20); + const uint64_t nb21 = src2 ? src2->nb[1] : 0; + const uint64_t nb22 = src2 ? src2->nb[2] : 0; + const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23); + + const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t); + + GGML_ASSERT(!ggml_is_transposed(src2)); + GGML_ASSERT(!ggml_is_transposed(src1)); + + GGML_ASSERT(ne20 % 32 == 0); + // !!!!!!!!! TODO: this assert is probably required but not sure! + //GGML_ASSERT(ne20 >= 64); + GGML_ASSERT(src1t == GGML_TYPE_F32); + + const uint gqa = ne12/ne22; + + // find the break-even point where the matrix-matrix kernel becomes more efficient compared + // to the matrix-vector kernel + int ne11_mm_min = 0; + + const int idx = ((int32_t *) dst->op_params)[0]; + + // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs + // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel + if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && + ne11 > ne11_mm_min) { + switch (src2->type) { + case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break; + case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break; + case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_0_f32]; break; + case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_1_f32]; break; + case GGML_TYPE_Q5_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_0_f32]; break; + case GGML_TYPE_Q5_1: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_1_f32]; break; + case GGML_TYPE_Q8_0: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q8_0_f32]; break; + case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q2_K_f32]; break; + case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q3_K_f32]; break; + case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q4_K_f32]; break; + case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q5_K_f32]; break; + case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_q6_K_f32]; break; + default: GGML_ASSERT(false && "MUL_MAT_ID not implemented"); + } + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:3]; + [encoder setBytes:&ne22 length:sizeof(ne22) atIndex:4]; + [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:5]; + [encoder setBytes:&nb22 length:sizeof(nb22) atIndex:6]; + [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7]; + [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:8]; + [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:9]; + [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:10]; + [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:11]; + [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:12]; + [encoder setBytes:&gqa length:sizeof(gqa) atIndex:13]; + [encoder setBytes:&idx length:sizeof(idx) atIndex:14]; + // TODO: how to make this an array? read Metal docs + for (int j = 0; j < n_as; ++j) { + struct ggml_tensor * src_cur = dst->src[2 + j]; + + size_t offs_src_cur = 0; + id id_src_cur = ggml_metal_get_buffer(ctx, src_cur, &offs_src_cur); + + [encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:15 + j]; + } + + [encoder setThreadgroupMemoryLength:8192 atIndex:0]; + [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne21 + 63)/64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + } + } break; case GGML_OP_GET_ROWS: { switch (src0->type) { @@ -1549,6 +1711,27 @@ void ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)]; } break; + case GGML_OP_ARGSORT: + { + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_I32); + + const int nrows = ggml_nrows(src0); + + enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; + + switch (order) { + case GGML_SORT_ASC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_asc]; break; + case GGML_SORT_DESC: [encoder setComputePipelineState:ctx->pipeline_argsort_f32_i32_desc]; break; + default: GGML_ASSERT(false); + }; + + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + + [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)]; + } break; case GGML_OP_DUP: case GGML_OP_CPY: case GGML_OP_CONT: @@ -1809,21 +1992,48 @@ static bool ggml_backend_metal_supports_op(ggml_backend_t backend, const struct case GGML_OP_CONCAT: case GGML_OP_ADD: case GGML_OP_MUL: + case GGML_OP_DIV: case GGML_OP_SCALE: case GGML_OP_SQR: + case GGML_OP_SUM_ROWS: case GGML_OP_SOFT_MAX: - case GGML_OP_DIAG_MASK_INF: - case GGML_OP_MUL_MAT: - case GGML_OP_GET_ROWS: case GGML_OP_RMS_NORM: case GGML_OP_NORM: case GGML_OP_ALIBI: case GGML_OP_ROPE: case GGML_OP_IM2COL: + case GGML_OP_ARGSORT: case GGML_OP_DUP: case GGML_OP_CPY: case GGML_OP_CONT: return true; + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_GET_ROWS: + { + // TODO: also check during graph_compute + return op->ne[0] % 4 == 0; + } break; + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: + { + // TODO: also check during graph_compute + struct ggml_tensor * a; + struct ggml_tensor * b; UNUSED(b); + if (op->op == GGML_OP_MUL_MAT) { + a = op->src[0]; + b = op->src[1]; + } else { + a = op->src[2]; + b = op->src[1]; + } + if (a->ne[3] != 1) { + return false; + } + if (ggml_is_quantized(a->type) && a->ne[2] != 1) { + return false; + } + return true; + } break; default: return false; } diff --git a/src/ggml-metal.metal b/src/ggml-metal.metal index 5d1357cd..4499b1bf 100644 --- a/src/ggml-metal.metal +++ b/src/ggml-metal.metal @@ -3,6 +3,7 @@ using namespace metal; #define MAX(x, y) ((x) > (y) ? (x) : (y)) +#define SWAP(x, y) { auto tmp = (x); (x) = (y); (y) = tmp; } #define QK4_0 32 #define QR4_0 2 @@ -39,8 +40,13 @@ typedef struct { int8_t qs[QK8_0]; // quants } block_q8_0; -// general-purpose kernel for addition of two tensors -// pros: works for non-contiguous tensors, supports broadcast across dims 1, 2 and 3 +enum ggml_sort_order { + GGML_SORT_ASC, + GGML_SORT_DESC, +}; + +// general-purpose kernel for addition, multiplication and division of two tensors +// pros: works for non-contiguous tensors, supports broadcast across all dims // cons: not very efficient kernel void kernel_add( device const char * src0, @@ -81,16 +87,111 @@ kernel void kernel_add( const int64_t i12 = i02 % ne12; const int64_t i11 = i01 % ne11; - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + tpitg.x*nb00; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11 + tpitg.x*nb10; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + tpitg.x*nb0; + device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + const int i10 = i0 % ne10; + *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10)); + } +} + +kernel void kernel_mul( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant int64_t & nb00, + constant int64_t & nb01, + constant int64_t & nb02, + constant int64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant int64_t & nb10, + constant int64_t & nb11, + constant int64_t & nb12, + constant int64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant int64_t & nb0, + constant int64_t & nb1, + constant int64_t & nb2, + constant int64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig.z; + const int64_t i02 = tgpig.y; + const int64_t i01 = tgpig.x; + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - ((device float *)dst_ptr)[0] = ((device float *)src0_ptr)[0] + ((device float *)src1_ptr)[0]; + const int i10 = i0 % ne10; + *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10)); + } +} - src0_ptr += ntg.x*nb00; - src1_ptr += ntg.x*nb10; - dst_ptr += ntg.x*nb0; +kernel void kernel_div( + device const char * src0, + device const char * src1, + device char * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant int64_t & nb00, + constant int64_t & nb01, + constant int64_t & nb02, + constant int64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant int64_t & nb10, + constant int64_t & nb11, + constant int64_t & nb12, + constant int64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant int64_t & nb0, + constant int64_t & nb1, + constant int64_t & nb2, + constant int64_t & nb3, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t i03 = tgpig.z; + const int64_t i02 = tgpig.y; + const int64_t i01 = tgpig.x; + + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; + device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; + device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + + for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { + const int i10 = i0 % ne10; + *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10)); } } @@ -105,23 +206,22 @@ kernel void kernel_add_row( dst[tpig] = src0[tpig] + src1[tpig % nb]; } -kernel void kernel_mul( +kernel void kernel_mul_row( device const float4 * src0, device const float4 * src1, device float4 * dst, + constant int64_t & nb [[buffer(27)]], uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * src1[tpig]; + dst[tpig] = src0[tpig] * src1[tpig % nb]; } -// assumption: src1 is a row -// broadcast src1 into src0 -kernel void kernel_mul_row( +kernel void kernel_div_row( device const float4 * src0, device const float4 * src1, device float4 * dst, - constant int64_t & nb, + constant int64_t & nb [[buffer(27)]], uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * src1[tpig % nb]; + dst[tpig] = src0[tpig] / src1[tpig % nb]; } kernel void kernel_scale( @@ -162,6 +262,54 @@ kernel void kernel_sqr( dst[tpig] = src0[tpig] * src0[tpig]; } +kernel void kernel_sum_rows( + device const float * src0, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant int64_t & nb00, + constant int64_t & nb01, + constant int64_t & nb02, + constant int64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant int64_t & nb10, + constant int64_t & nb11, + constant int64_t & nb12, + constant int64_t & nb13, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant int64_t & nb0, + constant int64_t & nb1, + constant int64_t & nb2, + constant int64_t & nb3, + uint3 tpig[[thread_position_in_grid]]) { + int64_t i3 = tpig.z; + int64_t i2 = tpig.y; + int64_t i1 = tpig.x; + + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + device const float * src_row = (device const float *) ((device const char *) src0 + i1*nb01 + i2*nb02 + i3*nb03); + device float * dst_row = (device float *) ((device char *) dst + i1*nb1 + i2*nb2 + i3*nb3); + + float row_sum = 0; + + for (int64_t i0 = 0; i0 < ne00; i0++) { + row_sum += src_row[i0]; + } + + dst_row[0] = row_sum; +} + constant float GELU_COEF_A = 0.044715f; constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; @@ -1120,17 +1268,21 @@ kernel void kernel_alibi_f32( const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + const int64_t k = i3*ne3 + i2; - device float * dst_data = (device float *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); float m_k; - if (i2 < n_heads_log2_floor) { - m_k = pow(m0, i2 + 1); + if (k < n_heads_log2_floor) { + m_k = pow(m0, k + 1); } else { - m_k = pow(m1, 2 * (i2 - n_heads_log2_floor) + 1); + m_k = pow(m1, 2 * (k - n_heads_log2_floor) + 1); } + + device char * dst_row = (device char *) dst + i3*nb3 + i2*nb2 + i1*nb1; + device const char * src_row = (device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01; for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const float * src = (device float *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); - dst_data[i00] = src[0] + m_k * (i00 - ne00 + 1); + const float src_v = *(device float *)(src_row + i00*nb00); + device float * dst_v = (device float *)(dst_row + i00*nb0); + *dst_v = i00 * m_k + src_v; } } @@ -1335,6 +1487,58 @@ kernel void kernel_im2col_f16( } } +// bitonic sort implementation following the CUDA kernels as reference +typedef void (argsort_t)( + device const float * x, + device int32_t * dst, + constant int64_t & ncols, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]); + +template +kernel void kernel_argsort_f32_i32( + device const float * x, + device int32_t * dst, + constant int64_t & ncols, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]]) { + // bitonic sort + int col = tpitg[0]; + int row = tgpig[1]; + + if (col >= ncols) return; + + device const float * x_row = x + row * ncols; + device int32_t * dst_row = dst + row * ncols; + + // initialize indices + if (col < ncols) { + dst_row[col] = col; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + for (int k = 2; k <= ncols; k *= 2) { + for (int j = k / 2; j > 0; j /= 2) { + int ixj = col ^ j; + if (ixj > col) { + if ((col & k) == 0) { + if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) { + SWAP(dst_row[col], dst_row[ixj]); + } + } else { + if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) { + SWAP(dst_row[col], dst_row[ixj]); + } + } + } + threadgroup_barrier(mem_flags::mem_threadgroup); + } + } +} + +template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32; +template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32; + kernel void kernel_cpy_f16_f16( device const half * src0, device half * dst, @@ -2749,7 +2953,7 @@ kernel void kernel_get_rows( // each block_q contains 16*nl weights template -kernel void kernel_mul_mm(device const uchar * src0, +void kernel_mul_mm_impl(device const uchar * src0, device const uchar * src1, device float * dst, constant int64_t & ne00, @@ -2876,14 +3080,112 @@ kernel void kernel_mul_mm(device const uchar * src0, } } +template +kernel void kernel_mul_mm(device const uchar * src0, + device const uchar * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant int64_t & nb01, + constant int64_t & nb02, + constant int64_t & ne12, + constant int64_t & nb10, + constant int64_t & nb11, + constant int64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & gqa, + threadgroup uchar * shared_memory [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + kernel_mul_mm_impl( + src0, + src1, + dst, + ne00, + ne02, + nb01, + nb02, + ne12, + nb10, + nb11, + nb12, + ne0, + ne1, + gqa, + shared_memory, + tgpig, + tiitg, + sgitg); +} + +template +kernel void kernel_mul_mm_id( + device const int32_t * ids, + device const uchar * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant int64_t & nb01, + constant int64_t & nb02, + constant int64_t & ne12, + constant int64_t & nb10, + constant int64_t & nb11, + constant int64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & gqa, + constant int & idx, + device const uchar * src00, + device const uchar * src01, + device const uchar * src02, + device const uchar * src03, + device const uchar * src04, + device const uchar * src05, + device const uchar * src06, + device const uchar * src07, + threadgroup uchar * shared_memory [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07}; + + kernel_mul_mm_impl( + src0[ids[idx]], + src1, + dst, + ne00, + ne02, + nb01, + nb02, + ne12, + nb10, + nb11, + nb12, + ne0, + ne1, + gqa, + shared_memory, + tgpig, + tiitg, + sgitg); +} + #if QK_K == 256 #define QK_NL 16 #else #define QK_NL 4 #endif -typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \ - constant uint64_t &, constant uint64_t &, uint, uint, uint); +typedef void (get_rows_t)( + device const void * src0, + device const int * src1, + device float * dst, + constant int64_t & ne00, + constant uint64_t & nb01, + constant uint64_t & nb1, + uint, uint, uint); template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows; @@ -2913,7 +3215,8 @@ typedef void (mat_mm_t)( constant int64_t & ne0, constant int64_t & ne1, constant uint & gqa, - threadgroup uchar *, uint3, uint, uint); + threadgroup uchar *, + uint3, uint, uint); template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm; @@ -2927,3 +3230,43 @@ template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm; + +typedef void (mat_mm_id_t)( + device const int32_t * ids, + device const uchar * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant int64_t & nb01, + constant int64_t & nb02, + constant int64_t & ne12, + constant int64_t & nb10, + constant int64_t & nb11, + constant int64_t & nb12, + constant int64_t & ne0, + constant int64_t & ne1, + constant uint & gqa, + constant int & idx, + device const uchar * src00, + device const uchar * src01, + device const uchar * src02, + device const uchar * src03, + device const uchar * src04, + device const uchar * src05, + device const uchar * src06, + device const uchar * src07, + threadgroup uchar *, + uint3, uint, uint); + +template [[host_name("kernel_mul_mm_id_f32_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_f16_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_1_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q8_0_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q2_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q3_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q4_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q5_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; +template [[host_name("kernel_mul_mm_id_q6_K_f32")]] kernel mat_mm_id_t kernel_mul_mm_id; diff --git a/src/ggml.c b/src/ggml.c index 1b192b76..26c86c42 100644 --- a/src/ggml.c +++ b/src/ggml.c @@ -233,24 +233,6 @@ inline static void * ggml_aligned_malloc(size_t size) { #define UNUSED GGML_UNUSED #define SWAP(x, y, T) do { T SWAP = x; x = y; y = SWAP; } while (0) -// -// tensor access macros -// - -#define GGML_TENSOR_UNARY_OP_LOCALS \ - GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \ - GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \ - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \ - GGML_TENSOR_LOCALS(size_t, nb, dst, nb) - -#define GGML_TENSOR_BINARY_OP_LOCALS \ - GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) \ - GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) \ - GGML_TENSOR_LOCALS(int64_t, ne1, src1, ne) \ - GGML_TENSOR_LOCALS(size_t, nb1, src1, nb) \ - GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) \ - GGML_TENSOR_LOCALS(size_t, nb, dst, nb) - #if defined(GGML_USE_ACCELERATE) #include #if defined(GGML_USE_CLBLAST) // allow usage of CLBlast alongside Accelerate functions @@ -1613,6 +1595,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "GROUP_NORM", "MUL_MAT", + "MUL_MAT_ID", "OUT_PROD", "SCALE", @@ -1640,6 +1623,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "POOL_1D", "POOL_2D", "UPSCALE", + "ARGSORT", "FLASH_ATTN", "FLASH_FF", @@ -1666,7 +1650,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68"); +static_assert(GGML_OP_COUNT == 70, "GGML_OP_COUNT != 70"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -1695,6 +1679,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "group_norm(x)", "X*Y", + "X[i]*Y", "X*Y", "x*v", @@ -1722,6 +1707,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "pool_1d(x)", "pool_2d(x)", "upscale(x)", + "argsort(x)", "flash_attn(x)", "flash_ff(x)", @@ -1748,7 +1734,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68"); +static_assert(GGML_OP_COUNT == 70, "GGML_OP_COUNT != 70"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -1789,6 +1775,7 @@ static void ggml_setup_op_has_task_pass(void) { p[GGML_OP_ACC ] = true; p[GGML_OP_MUL_MAT ] = true; + p[GGML_OP_MUL_MAT_ID ] = true; p[GGML_OP_OUT_PROD ] = true; p[GGML_OP_SET ] = true; p[GGML_OP_GET_ROWS_BACK ] = true; @@ -3186,9 +3173,7 @@ static struct ggml_tensor * ggml_add_impl( struct ggml_tensor * a, struct ggml_tensor * b, bool inplace) { - // TODO: support less-strict constraint - // GGML_ASSERT(ggml_can_repeat(b, a)); - GGML_ASSERT(ggml_can_repeat_rows(b, a)); + GGML_ASSERT(ggml_can_repeat(b, a)); bool is_node = false; @@ -3403,9 +3388,7 @@ static struct ggml_tensor * ggml_mul_impl( struct ggml_tensor * a, struct ggml_tensor * b, bool inplace) { - // TODO: support less-strict constraint - // GGML_ASSERT(ggml_can_repeat(b, a)); - GGML_ASSERT(ggml_can_repeat_rows(b, a)); + GGML_ASSERT(ggml_can_repeat(b, a)); bool is_node = false; @@ -3450,7 +3433,7 @@ static struct ggml_tensor * ggml_div_impl( struct ggml_tensor * a, struct ggml_tensor * b, bool inplace) { - GGML_ASSERT(ggml_are_same_shape(a, b)); + GGML_ASSERT(ggml_can_repeat(b, a)); bool is_node = false; @@ -4088,6 +4071,49 @@ struct ggml_tensor * ggml_mul_mat( return result; } +// ggml_mul_mat_id + +struct ggml_tensor * ggml_mul_mat_id( + struct ggml_context * ctx, + struct ggml_tensor * as[], + struct ggml_tensor * ids, + int id, + struct ggml_tensor * b) { + + int64_t n_as = ids->ne[0]; + + GGML_ASSERT(ids->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_is_vector(ids)); + GGML_ASSERT(n_as > 0 && n_as <= GGML_MAX_SRC - 2); + GGML_ASSERT(id >= 0 && id < n_as); + + bool is_node = false; + + if (as[0]->grad || b->grad) { + is_node = true; + } + + const int64_t ne[4] = { as[0]->ne[1], b->ne[1], b->ne[2], b->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(as[0]->n_dims, b->n_dims), ne); + + ggml_set_op_params_i32(result, 0, id); + + result->op = GGML_OP_MUL_MAT_ID; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = ids; + result->src[1] = b; + + for (int64_t i = 0; i < n_as; i++) { + struct ggml_tensor * a = as[i]; + GGML_ASSERT(ggml_are_same_shape(as[0], a)); + GGML_ASSERT(ggml_can_mul_mat(a, b)); + GGML_ASSERT(!ggml_is_transposed(a)); + result->src[i + 2] = a; + } + + return result; +} + // ggml_out_prod struct ggml_tensor * ggml_out_prod( @@ -5478,6 +5504,43 @@ struct ggml_tensor * ggml_upscale( return ggml_upscale_impl(ctx, a, scale_factor); } +// ggml_argsort + +struct ggml_tensor * ggml_argsort( + struct ggml_context * ctx, + struct ggml_tensor * a, + enum ggml_sort_order order) { + bool is_node = false; + + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_I32, a->n_dims, a->ne); + + ggml_set_op_params_i32(result, 0, (int32_t) order); + + result->op = GGML_OP_ARGSORT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = a; + + return result; +} + +// ggml_top_k + +struct ggml_tensor * ggml_top_k( + struct ggml_context * ctx, + struct ggml_tensor * a, + int k) { + GGML_ASSERT(a->ne[0] >= k); + + struct ggml_tensor * result = ggml_argsort(ctx, a, GGML_SORT_DESC); + + result = ggml_view_4d(ctx, result, + k, result->ne[1], result->ne[2], result->ne[3], + result->nb[1], result->nb[2], result->nb[3], + 0); + + return result; +} + // ggml_flash_attn struct ggml_tensor * ggml_flash_attn( @@ -6837,7 +6900,7 @@ static void ggml_compute_forward_add_f32( const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { - GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; @@ -6870,16 +6933,19 @@ static void ggml_compute_forward_add_f32( const int64_t i13 = i03 % ne13; const int64_t i12 = i02 % ne12; const int64_t i11 = i01 % ne11; + const int64_t nr0 = ne00 / ne10; float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + for (int64_t r = 0; r < nr0; ++r) { #ifdef GGML_USE_ACCELERATE - vDSP_vadd(src0_ptr, 1, src1_ptr, 1, dst_ptr, 1, ne00); + vDSP_vadd(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10); #else - ggml_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr); + ggml_vec_add_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); #endif + } } } else { // src1 is not contiguous @@ -6896,8 +6962,9 @@ static void ggml_compute_forward_add_f32( float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); - for (int i0 = 0; i0 < ne0; i0++) { - float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10); + for (int64_t i0 = 0; i0 < ne0; ++i0) { + const int64_t i10 = i0 % ne10; + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10); dst_ptr[i0] = src0_ptr[i0] + *src1_ptr; } @@ -7617,7 +7684,7 @@ static void ggml_compute_forward_mul_f32( const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { - GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; @@ -7640,7 +7707,6 @@ static void ggml_compute_forward_mul_f32( GGML_ASSERT( nb0 == sizeof(float)); GGML_ASSERT(nb00 == sizeof(float)); - GGML_ASSERT(ne00 == ne10); if (nb10 == sizeof(float)) { for (int64_t ir = ith; ir < nr; ir += nth) { @@ -7652,20 +7718,21 @@ static void ggml_compute_forward_mul_f32( const int64_t i13 = i03 % ne13; const int64_t i12 = i02 % ne12; const int64_t i11 = i01 % ne11; + const int64_t nr0 = ne00 / ne10; float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + for (int64_t r = 0 ; r < nr0; ++r) { #ifdef GGML_USE_ACCELERATE - UNUSED(ggml_vec_mul_f32); + UNUSED(ggml_vec_mul_f32); - vDSP_vmul( src0_ptr, 1, src1_ptr, 1, dst_ptr, 1, ne00); + vDSP_vmul(src0_ptr + r*ne10, 1, src1_ptr, 1, dst_ptr + r*ne10, 1, ne10); #else - ggml_vec_mul_f32(ne00, dst_ptr, src0_ptr, src1_ptr); + ggml_vec_mul_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); #endif - // } - // } + } } } else { // src1 is not contiguous @@ -7683,8 +7750,9 @@ static void ggml_compute_forward_mul_f32( float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); - for (int64_t i0 = 0; i0 < ne00; i0++) { - float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10); + for (int64_t i0 = 0; i0 < ne00; ++i0) { + const int64_t i10 = i0 % ne10; + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10); dst_ptr[i0] = src0_ptr[i0] * (*src1_ptr); } @@ -7718,14 +7786,16 @@ static void ggml_compute_forward_div_f32( const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst) { - assert(params->ith == 0); - assert(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_can_repeat(src1, src0) && ggml_are_same_shape(src0, dst)); if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { return; } - const int nr = ggml_nrows(src0); + const int ith = params->ith; + const int nth = params->nth; + + const int64_t nr = ggml_nrows(src0); GGML_TENSOR_BINARY_OP_LOCALS @@ -7733,41 +7803,50 @@ static void ggml_compute_forward_div_f32( GGML_ASSERT(nb00 == sizeof(float)); if (nb10 == sizeof(float)) { - for (int ir = 0; ir < nr; ++ir) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + for (int64_t ir = ith; ir < nr; ir += nth) { + // src0 and dst are same shape => same indices + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + const int64_t nr0 = ne00 / ne10; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); + + for (int64_t r = 0; r < nr0; ++r) { #ifdef GGML_USE_ACCELERATE - UNUSED(ggml_vec_div_f32); + UNUSED(ggml_vec_div_f32); - vDSP_vdiv( - (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1, - (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1, - (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1, - ne0); + vDSP_vdiv(src1_ptr, 1, src0_ptr + r*ne10, 1, dst_ptr + r*ne10, 1, ne10); #else - ggml_vec_div_f32(ne0, - (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), - (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), - (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11)); + ggml_vec_div_f32(ne10, dst_ptr + r*ne10, src0_ptr + r*ne10, src1_ptr); #endif - // } - // } + } } } else { // src1 is not contiguous - for (int ir = 0; ir < nr; ++ir) { - // src0, src1 and dst are same shape => same indices - const int i3 = ir/(ne2*ne1); - const int i2 = (ir - i3*ne2*ne1)/ne1; - const int i1 = (ir - i3*ne2*ne1 - i2*ne1); + for (int64_t ir = ith; ir < nr; ir += nth) { + // src0 and dst are same shape => same indices + // src1 is broadcastable across src0 and dst in i1, i2, i3 + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); - float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ); - float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01); - for (int i0 = 0; i0 < ne0; i0++) { - float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10); + const int64_t i13 = i03 % ne13; + const int64_t i12 = i02 % ne12; + const int64_t i11 = i01 % ne11; + + float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + + for (int64_t i0 = 0; i0 < ne00; ++i0) { + const int64_t i10 = i0 % ne10; + float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i10*nb10); dst_ptr[i0] = src0_ptr[i0] / (*src1_ptr); } @@ -8213,7 +8292,7 @@ static void ggml_compute_forward_repeat_f16( return; } - GGML_TENSOR_UNARY_OP_LOCALS; + GGML_TENSOR_UNARY_OP_LOCALS // guaranteed to be an integer due to the check in ggml_can_repeat const int nr0 = (int)(ne0/ne00); @@ -9526,6 +9605,8 @@ static void ggml_compute_forward_mul_mat( char * wdata = params->wdata; const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type); + assert(params->wsize >= ne11*ne12*ne13*row_size); + for (int64_t i13 = 0; i13 < ne13; ++i13) { for (int64_t i12 = 0; i12 < ne12; ++i12) { for (int64_t i11 = 0; i11 < ne11; ++i11) { @@ -9627,6 +9708,26 @@ static void ggml_compute_forward_mul_mat( } } +// ggml_compute_forward_mul_mat_id + +static void ggml_compute_forward_mul_mat_id( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * ids = dst->src[0]; + const struct ggml_tensor * src1 = dst->src[1]; + + const int id = ggml_get_op_params_i32(dst, 0); + + const int a_id = ((int32_t *)ids->data)[id]; + + GGML_ASSERT(a_id >= 0 && a_id < ids->ne[0]); + + const struct ggml_tensor * src0 = dst->src[a_id + 2]; + + ggml_compute_forward_mul_mat(params, src0, src1, dst); +} + // ggml_compute_forward_out_prod static void ggml_compute_forward_out_prod_f32( @@ -11960,6 +12061,67 @@ static void ggml_compute_forward_upscale( } } +// ggml_compute_forward_argsort + +static void ggml_compute_forward_argsort_f32( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + + if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { + return; + } + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(nb0 == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t nr = ggml_nrows(src0); + + enum ggml_sort_order order = (enum ggml_sort_order) ggml_get_op_params_i32(dst, 0); + + for (int64_t i = ith; i < nr; i += nth) { + int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1); + const float * src_data = (float *)((char *) src0->data + i*nb01); + + for (int64_t j = 0; j < ne0; j++) { + dst_data[j] = j; + } + + // C doesn't have a functional sort, so we do a bubble sort instead + for (int64_t j = 0; j < ne0; j++) { + for (int64_t k = j + 1; k < ne0; k++) { + if ((order == GGML_SORT_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) || + (order == GGML_SORT_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) { + int32_t tmp = dst_data[j]; + dst_data[j] = dst_data[k]; + dst_data[k] = tmp; + } + } + } + } +} + +static void ggml_compute_forward_argsort( + const struct ggml_compute_params * params, + const struct ggml_tensor * src0, + struct ggml_tensor * dst) { + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_argsort_f32(params, src0, dst); + } break; + default: + { + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_flash_attn static void ggml_compute_forward_flash_attn_f32( @@ -13783,6 +13945,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_mul_mat(params, tensor->src[0], tensor->src[1], tensor); } break; + case GGML_OP_MUL_MAT_ID: + { + ggml_compute_forward_mul_mat_id(params, tensor); + } break; case GGML_OP_OUT_PROD: { ggml_compute_forward_out_prod(params, tensor->src[0], tensor->src[1], tensor); @@ -13887,6 +14053,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm { ggml_compute_forward_upscale(params, tensor->src[0], tensor); } break; + case GGML_OP_ARGSORT: + { + ggml_compute_forward_argsort(params, tensor->src[0], tensor); + } break; case GGML_OP_FLASH_ATTN: { const int32_t t = ggml_get_op_params_i32(tensor, 0); @@ -14537,6 +14707,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor zero_table); } } break; + case GGML_OP_MUL_MAT_ID: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_OUT_PROD: { GGML_ASSERT(false); // TODO: not implemented @@ -14875,6 +15049,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor { GGML_ASSERT(false); // TODO: not implemented } break; + case GGML_OP_ARGSORT: + { + GGML_ASSERT(false); // TODO: not implemented + } break; case GGML_OP_FLASH_ATTN: { struct ggml_tensor * flash_grad = NULL; @@ -15471,7 +15649,6 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { n_tasks = n_threads; } break; case GGML_OP_SUB: - case GGML_OP_DIV: case GGML_OP_SQR: case GGML_OP_SQRT: case GGML_OP_LOG: @@ -15510,6 +15687,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { break; case GGML_OP_SILU_BACK: case GGML_OP_MUL: + case GGML_OP_DIV: case GGML_OP_NORM: case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM_BACK: @@ -15547,6 +15725,11 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { } #endif } break; + case GGML_OP_MUL_MAT_ID: + { + // FIXME: blas + n_tasks = n_threads; + } break; case GGML_OP_OUT_PROD: { n_tasks = n_threads; @@ -15603,6 +15786,10 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { { n_tasks = n_threads; } break; + case GGML_OP_ARGSORT: + { + n_tasks = n_threads; + } break; case GGML_OP_FLASH_ATTN: { n_tasks = n_threads; @@ -15866,6 +16053,23 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { cur = ggml_type_size(vec_dot_type)*ggml_nelements(node->src[1])/ggml_blck_size(vec_dot_type); } } break; + case GGML_OP_MUL_MAT_ID: + { + const struct ggml_tensor * a = node->src[2]; + const struct ggml_tensor * b = node->src[1]; + const enum ggml_type vec_dot_type = type_traits[a->type].vec_dot_type; +#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) + if (ggml_compute_forward_mul_mat_use_blas(a, b, node)) { + if (a->type != GGML_TYPE_F32) { + // here we need memory just for single 2D matrix from src0 + cur = ggml_type_size(GGML_TYPE_F32)*(a->ne[0]*a->ne[1]); + } + } else +#endif + if (b->type != vec_dot_type) { + cur = ggml_type_size(vec_dot_type)*ggml_nelements(b)/ggml_blck_size(vec_dot_type); + } + } break; case GGML_OP_OUT_PROD: { n_tasks = n_threads; @@ -17749,8 +17953,8 @@ size_t ggml_quantize_q5_0(const float * src, void * dst, int n, int k, int64_t * memcpy(&qh, &y[i].qh, sizeof(qh)); for (int j = 0; j < QK5_0; j += 2) { - const uint8_t vh0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; - const uint8_t vh1 = ((qh & (1u << (j + 16))) >> (j + 12)); + const uint8_t vh0 = ((qh & (1u << (j/2 + 0 ))) >> (j/2 + 0 )) << 4; + const uint8_t vh1 = ((qh & (1u << (j/2 + 16))) >> (j/2 + 12)); // cast to 16 bins const uint8_t vi0 = ((y[i].qs[j/2] & 0x0F) | vh0) / 2; @@ -17779,8 +17983,8 @@ size_t ggml_quantize_q5_1(const float * src, void * dst, int n, int k, int64_t * memcpy(&qh, &y[i].qh, sizeof(qh)); for (int j = 0; j < QK5_1; j += 2) { - const uint8_t vh0 = ((qh & (1u << (j + 0 ))) >> (j + 0 )) << 4; - const uint8_t vh1 = ((qh & (1u << (j + 16))) >> (j + 12)); + const uint8_t vh0 = ((qh & (1u << (j/2 + 0 ))) >> (j/2 + 0 )) << 4; + const uint8_t vh1 = ((qh & (1u << (j/2 + 16))) >> (j/2 + 12)); // cast to 16 bins const uint8_t vi0 = ((y[i].qs[j/2] & 0x0F) | vh0) / 2; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index b42b0fa9..d30523a2 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2,9 +2,10 @@ #include #include #include +#include #include -#include #include +#include #include #include #include @@ -28,10 +29,12 @@ static void init_tensor_uniform(ggml_tensor * tensor, float min = -1.0f, float m if (tensor->type == GGML_TYPE_F32) { ggml_backend_tensor_set(tensor, data.data(), 0, size * sizeof(float)); - } else if (tensor->type == GGML_TYPE_F16) { - std::vector data16(size); - ggml_fp32_to_fp16_row(data.data(), data16.data(), size); - ggml_backend_tensor_set(tensor, data16.data(), 0, size * sizeof(ggml_fp16_t)); + } else if (ggml_is_quantized(tensor->type) || tensor->type == GGML_TYPE_F16) { + GGML_ASSERT(size % ggml_blck_size(tensor->type) == 0); + std::vector dataq(ggml_type_size(tensor->type)*size/ggml_blck_size(tensor->type)); + int64_t hist[16]; + ggml_quantize_chunk(tensor->type, data.data(), dataq.data(), 0, size, hist); + ggml_backend_tensor_set(tensor, dataq.data(), 0, dataq.size()); } else { GGML_ASSERT(false); } @@ -55,6 +58,8 @@ static std::vector tensor_to_float(const ggml_tensor * t) { v = (float) ggml_fp16_to_fp32(*(ggml_fp16_t*)&buf[i]); } else if (t->type == GGML_TYPE_F32) { v = *(float *) &buf[i]; + } else if (t->type == GGML_TYPE_I32) { + v = *(int32_t *) &buf[i]; } else { GGML_ASSERT(false); } @@ -206,13 +211,17 @@ struct test_case { virtual ggml_tensor * build_graph(ggml_context * ctx) = 0; + virtual double max_nmse_err() { + return 1e-6; + } + virtual void initialize_tensors(ggml_context * ctx) { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { init_tensor_uniform(t); } } - bool eval(ggml_backend_t backend1, ggml_backend_t backend2) { + bool eval(ggml_backend_t backend1, ggml_backend_t backend2, const char * op_name) { ggml_init_params params = { /* .mem_size = */ ggml_tensor_overhead()*128 + ggml_graph_overhead(), /* .mem_base = */ NULL, @@ -222,6 +231,12 @@ struct test_case { ggml_tensor * out = build_graph(ctx); + if (op_name != nullptr && strcmp(ggml_op_desc(out), op_name) != 0) { + //printf(" %s: skipping\n", ggml_op_desc(out)); + ggml_free(ctx); + return true; + } + // check if backends support op for (ggml_backend_t backend : {backend1, backend2}) { if (!ggml_backend_supports_op(backend, out)) { @@ -242,18 +257,26 @@ struct test_case { initialize_tensors(ctx); // compare - bool ok = true; + struct callback_userdata { + bool ok; + double max_err; + }; + + callback_userdata ud { + true, + max_nmse_err(), + }; auto callback = [](int index, ggml_tensor * t1, ggml_tensor * t2, void * user_data) -> bool { std::vector f1 = tensor_to_float(t1); std::vector f2 = tensor_to_float(t2); - bool * ok = (bool *) user_data; + callback_userdata * ud = (callback_userdata *) user_data; for (size_t i = 0; i < f1.size(); i++) { // check for nans if (std::isnan(f1[i]) || std::isnan(f2[i])) { - printf(" Error: %s: NaN\n", ggml_op_desc(t1)); - *ok = false; + printf(" Error: %s: NaN at index %zu\n", ggml_op_desc(t1), i); + ud->ok = false; return true; } // check for infs: both must be inf of the same sign, or both must be finite @@ -261,29 +284,29 @@ struct test_case { if (isinf_or_max(f1[i]) && isinf_or_max(f2[i])) { if (std::signbit(f1[i]) != std::signbit(f2[i])) { printf(" Error: %s: inf sign mismatch: %f %f\n", ggml_op_desc(t1), f1[i], f2[i]); - *ok = false; + ud->ok = false; return true; } } else { printf(" Error: %s: inf mismatch: %f %f\n", ggml_op_desc(t1), f1[i], f2[i]); - *ok = false; + ud->ok = false; return true; } } } double err = nmse(f1.data(), f2.data(), f1.size()); - if (err > 1e-6) { + if (err > ud->max_err) { printf(" Error: %s: NMSE = %f\n", ggml_op_desc(t1), err); - *ok = false; + ud->ok = false; } return true; }; - ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ok); + ggml_backend_compare_graph_backend(backend1, backend2, gf, callback, &ud); printf(" %s(%s): ", ggml_op_desc(out), vars().c_str()); - if (ok) { + if (ud.ok) { printf("\033[1;32mOK\033[0m\n"); } else { printf("\033[1;31mFAIL\033[0m\n"); @@ -293,7 +316,7 @@ struct test_case { ggml_free(ctx); - return ok; + return ud.ok; } }; @@ -444,30 +467,11 @@ struct test_cont : public test_case { }; // GGML_OP_ADD -struct test_add : public test_case { - const ggml_type type; - const std::array ne; - const std::array nr; - - std::string vars() override { - return VARS_TO_STR3(type, ne, nr); - } - - test_add(ggml_type type = GGML_TYPE_F32, - std::array ne = {10, 10, 1, 1}, - std::array nr = {1, 2, 1, 1}) - : type(type), ne(ne), nr(nr) {} - - ggml_tensor * build_graph(ggml_context * ctx) override { - ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]); - ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data()); - ggml_tensor * out = ggml_add(ctx, a, b); - return out; - } -}; - // GGML_OP_MUL -struct test_mul : public test_case { +// GGML_OP_DIV +struct test_bin_bcast : public test_case { + using op_t = std::function; + op_t op; const ggml_type type; const std::array ne; const std::array nr; @@ -476,15 +480,15 @@ struct test_mul : public test_case { return VARS_TO_STR3(type, ne, nr); } - test_mul(ggml_type type = GGML_TYPE_F32, + test_bin_bcast(op_t op, ggml_type type = GGML_TYPE_F32, std::array ne = {10, 10, 1, 1}, std::array nr = {1, 2, 1, 1}) - : type(type), ne(ne), nr(nr) {} + : op(op), type(type), ne(ne), nr(nr) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0]*nr[0], ne[1]*nr[1], ne[2]*nr[2], ne[3]*nr[3]); ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data()); - ggml_tensor * out = ggml_mul(ctx, a, b); + ggml_tensor * out = op(ctx, a, b); return out; } }; @@ -568,6 +572,10 @@ struct test_mul_mat : public test_case { return VARS_TO_STR7(type_a, type_b, m, n, k, bs, nr); } + double max_nmse_err() override { + return 5e-4; + } + test_mul_mat(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32, int64_t m = 32, int64_t n = 32, int64_t k = 32, std::array bs = {10, 10}, @@ -794,7 +802,128 @@ struct test_concat : public test_case { } }; -static bool test_backend(ggml_backend_t backend) { +// GGML_OP_ARGSORT +struct test_argsort : public test_case { + const ggml_type type; + const std::array ne; + ggml_sort_order order; + + std::string vars() override { + return VARS_TO_STR3(type, ne, order); + } + + test_argsort(ggml_type type = GGML_TYPE_F32, + std::array ne = {16, 10, 10, 10}, + ggml_sort_order order = GGML_SORT_ASC) + : type(type), ne(ne), order(order) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_tensor * out = ggml_argsort(ctx, a, order); + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (t->type == GGML_TYPE_I32) { + std::vector data(ggml_nelements(t)); + for (int i = 0; i < ggml_nelements(t); i++) { + data[i] = rand(); + } + std::shuffle(data.begin(), data.end(), std::default_random_engine(std::random_device()())); + ggml_backend_tensor_set(t, data.data(), 0, ne[0]*ne[1]*ne[2]*ne[3] * sizeof(int)); + } else { + init_tensor_uniform(t); + } + } + } +}; + + +// GGML_OP_MUL_MAT_ID +struct test_mul_mat_id : public test_case { + const ggml_type type_a; + const ggml_type type_b; + const int n_mats; + const int id; + const int64_t m; + const int64_t n; + const int64_t k; + const std::array bs; // dims 3 and 4 + const std::array nr; // repeat in dims 3 and 4 + + std::string vars() override { + return VARS_TO_STR9(type_a, type_b, n_mats, id, m, n, k, bs, nr); + } + + double max_nmse_err() override { + return 5e-4; + } + + test_mul_mat_id(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32, + int n_mats = 2, int id = 0, + int64_t m = 32, int64_t n = 32, int64_t k = 32, + std::array bs = {10, 10}, + std::array nr = {2, 2}) + : type_a(type_a), type_b(type_b), n_mats(n_mats), id(id), + m(m), n(n), k(k), bs(bs), nr(nr) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + // C^T = A * B^T: (k, m) * (k, n) => (m, n) + std::vector mats; + for (int i = 0; i < n_mats; i++) { + ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, k, m, bs[0]*nr[0], bs[1]*nr[1]); + mats.push_back(a); + } + ggml_tensor * ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n_mats); + ggml_tensor * b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0]*nr[0], bs[1]*nr[1]); + ggml_tensor * out = ggml_mul_mat_id(ctx, mats.data(), ids, id, b); + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + if (t->type == GGML_TYPE_I32) { + // ids + std::vector data(n_mats); + for (int i = 0; i < n_mats; i++) { + data[i] = i; + } + std::shuffle(data.begin(), data.end(), std::default_random_engine(std::random_device()())); + ggml_backend_tensor_set(t, data.data(), 0, n_mats * sizeof(int)); + } else { + init_tensor_uniform(t); + } + } + } +}; + +// GGML_OP_SUM_ROWS +struct test_sum_rows : public test_case { + const ggml_type type; + const std::array ne; + + std::string vars() override { + return VARS_TO_STR2(type, ne); + } + + test_sum_rows(ggml_type type = GGML_TYPE_F32, + std::array ne = {10, 10, 10, 10}) + : type(type), ne(ne) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_tensor * out = ggml_sum_rows(ctx, a); + return out; + } +}; + +enum test_mode { + MODE_TEST, + MODE_PERF, +}; + +static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op_name) { ggml_backend_t backend_cpu = ggml_backend_cpu_init(); std::vector> test_cases; @@ -814,27 +943,22 @@ static bool test_backend(ggml_backend_t backend) { test_cases.emplace_back(new test_cpy()); test_cases.emplace_back(new test_cont()); - test_cases.emplace_back(new test_add(GGML_TYPE_F32, {16, 10, 1, 1}, {1, 1, 1, 1})); - test_cases.emplace_back(new test_add(GGML_TYPE_F32, {16, 10, 10, 1}, {1, 1, 1, 1})); - test_cases.emplace_back(new test_add(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 1, 1})); - //test_cases.emplace_back(new test_add(GGML_TYPE_F32, {16, 10, 10, 10}, {2, 1, 1, 1})); // broadcasting dim 0 is not supported - test_cases.emplace_back(new test_add(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 2, 1, 1})); - test_cases.emplace_back(new test_add(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 2, 1})); - test_cases.emplace_back(new test_add(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 1, 2})); - test_cases.emplace_back(new test_add(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 2, 2})); - test_cases.emplace_back(new test_add(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 2, 2, 2})); - //test_cases.emplace_back(new test_add(GGML_TYPE_F32, {16, 10, 10, 10}, {2, 2, 2, 2})); - - test_cases.emplace_back(new test_mul(GGML_TYPE_F32, {16, 10, 1, 1}, {1, 1, 1, 1})); - test_cases.emplace_back(new test_mul(GGML_TYPE_F32, {16, 10, 10, 1}, {1, 1, 1, 1})); - test_cases.emplace_back(new test_mul(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 1, 1})); - //test_cases.emplace_back(new test_mul(GGML_TYPE_F32, {16, 10, 10, 10}, {2, 1, 1, 1})); // broadcasting dim 0 is not supported - test_cases.emplace_back(new test_mul(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 2, 1, 1})); - test_cases.emplace_back(new test_mul(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 2, 1})); - test_cases.emplace_back(new test_mul(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 1, 2})); - test_cases.emplace_back(new test_mul(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 2, 2})); - test_cases.emplace_back(new test_mul(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 2, 2, 2})); - //test_cases.emplace_back(new test_mul(GGML_TYPE_F32, {16, 10, 10, 10}, {2, 2, 2, 2})); + auto add_test_bin_bcast = [&](ggml_type type, std::array ne, std::array nr) { + for (auto op : {ggml_add, ggml_mul, ggml_div}) { + test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr)); + } + }; + + add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 1, 1}, {1, 1, 1, 1}); + add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 1}, {1, 1, 1, 1}); + add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 1, 1}); + add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {2, 1, 1, 1}); + add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 2, 1, 1}); + add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 2, 1}); + add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 1, 2}); + add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 1, 2, 2}); + add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {1, 2, 2, 2}); + add_test_bin_bcast(GGML_TYPE_F32, {16, 10, 10, 10}, {2, 2, 2, 2}); test_cases.emplace_back(new test_scale()); @@ -843,16 +967,34 @@ static bool test_backend(ggml_backend_t backend) { test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 10, 10, 10}, eps)); } - for (ggml_type t0 : {GGML_TYPE_F32, GGML_TYPE_F16}) { - for (ggml_type t1 : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) { + ggml_type all_types[] = { + GGML_TYPE_F32, GGML_TYPE_F16, + GGML_TYPE_Q4_0, GGML_TYPE_Q4_1, + GGML_TYPE_Q5_0, GGML_TYPE_Q5_1, + GGML_TYPE_Q8_0, + GGML_TYPE_Q2_K, GGML_TYPE_Q3_K, + GGML_TYPE_Q4_K, GGML_TYPE_Q5_K, + GGML_TYPE_Q6_K + }; + + for (ggml_type type_a : all_types) { + for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) { // FIXME: CPU crashes on f16xf16 - test_cases.emplace_back(new test_mul_mat(t0, t1, 32, 32, 32, { 1, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(t0, t1, 32, 32, 32, {10, 1}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(t0, t1, 32, 32, 32, {10, 1}, {2, 1})); - test_cases.emplace_back(new test_mul_mat(t0, t1, 32, 32, 32, {10, 10}, {1, 1})); - test_cases.emplace_back(new test_mul_mat(t0, t1, 32, 32, 32, {10, 10}, {2, 1})); - test_cases.emplace_back(new test_mul_mat(t0, t1, 32, 32, 32, {10, 10}, {1, 2})); - test_cases.emplace_back(new test_mul_mat(t0, t1, 32, 32, 32, {10, 10}, {2, 2})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 1}, {2, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {1, 2})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 1, 256, {10, 10}, {2, 2})); + + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, { 1, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 1}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 1}, {2, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 1})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {1, 2})); + test_cases.emplace_back(new test_mul_mat(type_a, type_b, 16, 16, 256, {10, 10}, {2, 2})); } } @@ -881,9 +1023,26 @@ static bool test_backend(ggml_backend_t backend) { test_cases.emplace_back(new test_im2col()); test_cases.emplace_back(new test_concat()); + for (ggml_sort_order order : {GGML_SORT_ASC, GGML_SORT_DESC}) { + test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order)); + } + + for (ggml_type type_a : all_types) { + for (ggml_type type_b : {GGML_TYPE_F32 /*, GGML_TYPE_F16 */}) { + for (int n_mats : {1, 2, 4}) { + for (int id = 0; id < n_mats; id++) { + test_cases.emplace_back(new test_mul_mat_id(type_a, type_b, n_mats, id, 16, 16, 256, {1, 1}, {1, 1})); + } + } + } + } + + test_cases.emplace_back(new test_sum_rows()); + + // run tests size_t n_ok = 0; for (auto & test : test_cases) { - if (test->eval(backend, backend_cpu)) { + if (test->eval(backend, backend_cpu, op_name)) { n_ok++; } } @@ -895,7 +1054,44 @@ static bool test_backend(ggml_backend_t backend) { return n_ok == test_cases.size(); } -int main() { +static void usage(char ** argv) { + // command line: test-backend-ops [mode] [-o op] [-b backend] + // modes are correctness (compare with CPU) or performance + printf("Usage: %s [mode] [-o op] [-b backend]\n", argv[0]); + printf(" valid modes are: test (compare with CPU backend for correctness) or perf (performance evaluation) [not implemented]\n"); + printf(" op names are as given ggml_op_desc()\n"); +} + +int main(int argc, char ** argv) { + test_mode mode = MODE_TEST; + const char * op_name = NULL; + const char * backend = NULL; + + for (int i = 1; i < argc; i++) { + if (strcmp(argv[i], "test") == 0) { + mode = MODE_TEST; + } else if (strcmp(argv[i], "perf") == 0) { + mode = MODE_PERF; + } else if (strcmp(argv[i], "-o") == 0) { + if (i + 1 < argc) { + op_name = argv[++i]; + } else { + usage(argv); + return 1; + } + } else if (strcmp(argv[i], "-b") == 0) { + if (i + 1 < argc) { + backend = argv[++i]; + } else { + usage(argv); + return 1; + } + } else { + usage(argv); + return 1; + } + } + // enumerate backends printf("Testing %zu backends\n\n", ggml_backend_reg_get_count()); @@ -904,11 +1100,17 @@ int main() { for (size_t i = 0; i < ggml_backend_reg_get_count(); i++) { printf("Backend %zu/%zu (%s)\n", i + 1, ggml_backend_reg_get_count(), ggml_backend_reg_get_name(i)); + if (backend != NULL && strcmp(backend, ggml_backend_reg_get_name(i)) != 0) { + printf(" Skipping\n"); + n_ok++; + continue; + } + ggml_backend_t backend = ggml_backend_reg_init_backend(i, NULL); GGML_ASSERT(backend != NULL); printf(" Backend name: %s\n", ggml_backend_name(backend)); - bool ok = test_backend(backend); + bool ok = test_backend(backend, mode, op_name); printf(" Backend %s: ", ggml_backend_name(backend)); if (ok) { diff --git a/tests/test-conv1d.cpp b/tests/test-conv1d.cpp index 0e5c75f7..5b741568 100644 --- a/tests/test-conv1d.cpp +++ b/tests/test-conv1d.cpp @@ -21,6 +21,13 @@ #include #include +static void ggml_log_callback_default(ggml_log_level level, const char * text, void * user_data) { + (void) level; + (void) user_data; + fputs(text, stderr); + fflush(stderr); +} + struct test_model { struct ggml_tensor * a; struct ggml_tensor * b; diff --git a/tests/test-conv2d.cpp b/tests/test-conv2d.cpp index 64eee784..f50a53af 100644 --- a/tests/test-conv2d.cpp +++ b/tests/test-conv2d.cpp @@ -21,6 +21,13 @@ #include #include +static void ggml_log_callback_default(ggml_log_level level, const char * text, void * user_data) { + (void) level; + (void) user_data; + fputs(text, stderr); + fflush(stderr); +} + struct test_model { struct ggml_tensor * a; struct ggml_tensor * b; diff --git a/tests/test-mul-mat.cpp b/tests/test-mul-mat.cpp index 1811492c..2bee7339 100644 --- a/tests/test-mul-mat.cpp +++ b/tests/test-mul-mat.cpp @@ -21,6 +21,13 @@ #include #include +static void ggml_log_callback_default(ggml_log_level level, const char * text, void * user_data) { + (void) level; + (void) user_data; + fputs(text, stderr); + fflush(stderr); +} + struct test_model { struct ggml_tensor * a; struct ggml_tensor * b;